Files
rdp-proxy/backend/internal/modules/auth/postgres_store.go
T
2026-05-12 21:02:29 +03:00

567 lines
15 KiB
Go

package auth
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
)
type postgresStore struct {
db postgresplatform.DBTX
}
type PostgresTransactor struct {
pool *pgxpool.Pool
}
func NewPostgresStore(pool *pgxpool.Pool) Store {
return &postgresStore{db: pool}
}
func NewPostgresTransactor(pool *pgxpool.Pool) *PostgresTransactor {
return &PostgresTransactor{pool: pool}
}
func (t *PostgresTransactor) WithinTransaction(ctx context.Context, fn func(store Store) error) error {
return postgresplatform.WithTransaction(ctx, t.pool, func(tx pgx.Tx) error {
return fn(&postgresStore{db: tx})
})
}
func (s *postgresStore) Users() UserRepository {
return &postgresUserRepository{db: s.db}
}
func (s *postgresStore) Devices() DeviceRepository {
return &postgresDeviceRepository{db: s.db}
}
func (s *postgresStore) AuthSessions() AuthSessionRepository {
return &postgresAuthSessionRepository{db: s.db}
}
func (s *postgresStore) Installation() InstallationRepository {
return &postgresInstallationRepository{db: s.db}
}
type postgresUserRepository struct {
db postgresplatform.DBTX
}
type postgresDeviceRepository struct {
db postgresplatform.DBTX
}
type postgresAuthSessionRepository struct {
db postgresplatform.DBTX
}
type postgresInstallationRepository struct {
db postgresplatform.DBTX
}
func (r *postgresUserRepository) GetByEmail(ctx context.Context, email string) (*User, error) {
const query = `
SELECT id::text, email, password_hash, mfa_enabled, platform_role, created_at, updated_at
FROM users
WHERE email = $1
`
return scanOptionalUser(r.db.QueryRow(ctx, query, email))
}
func (r *postgresUserRepository) GetByID(ctx context.Context, userID string) (*User, error) {
const query = `
SELECT id::text, email, password_hash, mfa_enabled, platform_role, created_at, updated_at
FROM users
WHERE id = $1::uuid
`
return scanOptionalUser(r.db.QueryRow(ctx, query, userID))
}
func (r *postgresUserRepository) List(ctx context.Context) ([]User, error) {
const query = `
SELECT id::text, email, password_hash, mfa_enabled, platform_role, created_at, updated_at
FROM users
ORDER BY created_at DESC
`
rows, err := r.db.Query(ctx, query)
if err != nil {
return nil, fmt.Errorf("query users: %w", err)
}
defer rows.Close()
var users []User
for rows.Next() {
user, err := scanOptionalUser(rows)
if err != nil {
return nil, err
}
if user != nil {
users = append(users, *user)
}
}
return users, rows.Err()
}
func (r *postgresUserRepository) Create(ctx context.Context, user User) (*User, error) {
const query = `
INSERT INTO users (email, password_hash, mfa_enabled, platform_role, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id::text, email, password_hash, mfa_enabled, platform_role, created_at, updated_at
`
return scanOptionalUser(r.db.QueryRow(ctx, query,
user.Email,
user.PasswordHash,
user.MFAEnabled,
user.PlatformRole,
user.CreatedAt,
user.UpdatedAt,
))
}
func (r *postgresDeviceRepository) Upsert(ctx context.Context, params UpsertDeviceParams) (*Device, error) {
const query = `
INSERT INTO devices (
user_id,
device_fingerprint,
device_label,
trust_status,
trusted_at,
last_seen_at,
created_at,
updated_at
) VALUES (
$1::uuid,
$2,
$3,
CASE WHEN $4 THEN 'trusted' ELSE 'pending' END,
CASE WHEN $4 THEN $5::timestamptz ELSE NULL::timestamptz END,
$5::timestamptz,
$5::timestamptz,
$5::timestamptz
)
ON CONFLICT (user_id, device_fingerprint) DO UPDATE SET
device_label = EXCLUDED.device_label,
last_seen_at = EXCLUDED.last_seen_at,
updated_at = EXCLUDED.updated_at,
trust_status = CASE
WHEN devices.trust_status = 'revoked' THEN devices.trust_status
WHEN devices.trust_status = 'trusted' THEN devices.trust_status
WHEN EXCLUDED.trust_status = 'trusted' THEN 'trusted'
ELSE devices.trust_status
END,
trusted_at = CASE
WHEN devices.trust_status = 'trusted' THEN devices.trusted_at
WHEN EXCLUDED.trust_status = 'trusted' THEN EXCLUDED.trusted_at
ELSE devices.trusted_at
END
RETURNING
id::text, user_id::text, device_fingerprint, COALESCE(device_label, ''),
trust_status, trusted_at, last_seen_at, revoked_at, revoked_reason, created_at, updated_at
`
return scanDevice(r.db.QueryRow(ctx, query,
params.UserID,
params.Fingerprint,
params.Label,
params.TrustRequested,
params.SeenAt,
))
}
func (r *postgresDeviceRepository) GetByIDForUser(ctx context.Context, userID, deviceID string) (*Device, error) {
const query = `
SELECT id::text, user_id::text, device_fingerprint, COALESCE(device_label, ''),
trust_status, trusted_at, last_seen_at, revoked_at, revoked_reason, created_at, updated_at
FROM devices
WHERE id = $1::uuid AND user_id = $2::uuid
`
device, err := scanDevice(r.db.QueryRow(ctx, query, deviceID, userID))
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return device, err
}
func (r *postgresDeviceRepository) ListTrustedByUser(ctx context.Context, userID string) ([]Device, error) {
const query = `
SELECT id::text, user_id::text, device_fingerprint, COALESCE(device_label, ''),
trust_status, trusted_at, last_seen_at, revoked_at, revoked_reason, created_at, updated_at
FROM devices
WHERE user_id = $1::uuid AND trust_status = 'trusted' AND revoked_at IS NULL
ORDER BY created_at DESC
`
rows, err := r.db.Query(ctx, query, userID)
if err != nil {
return nil, fmt.Errorf("query trusted devices: %w", err)
}
defer rows.Close()
var devices []Device
for rows.Next() {
device, err := scanDevice(rows)
if err != nil {
return nil, err
}
devices = append(devices, *device)
}
return devices, rows.Err()
}
func (r *postgresDeviceRepository) Revoke(ctx context.Context, params RevokeDeviceParams) error {
const query = `
UPDATE devices
SET trust_status = 'revoked',
revoked_at = $3,
revoked_reason = $4,
updated_at = $3
WHERE id = $1::uuid AND user_id = $2::uuid
`
if _, err := r.db.Exec(ctx, query, params.DeviceID, params.UserID, params.RevokedAt, params.Reason); err != nil {
return fmt.Errorf("revoke device: %w", err)
}
return nil
}
func (r *postgresAuthSessionRepository) Create(ctx context.Context, session AuthSession) error {
const query = `
INSERT INTO auth_sessions (
id,
user_id,
device_id,
refresh_token_hash,
refresh_expires_at,
last_seen_at,
created_at,
updated_at
) VALUES ($1::uuid, $2::uuid, $3::uuid, $4, $5, $6, $7, $8)
`
if _, err := r.db.Exec(ctx, query,
session.ID,
session.UserID,
session.DeviceID,
session.RefreshTokenHash,
session.RefreshExpiresAt,
session.LastSeenAt,
session.CreatedAt,
session.UpdatedAt,
); err != nil {
return fmt.Errorf("create auth session: %w", err)
}
return nil
}
func (r *postgresAuthSessionRepository) GetByID(ctx context.Context, authSessionID string) (*AuthSession, error) {
return r.getByID(ctx, authSessionID, "")
}
func (r *postgresAuthSessionRepository) GetByIDForUpdate(ctx context.Context, authSessionID string) (*AuthSession, error) {
return r.getByID(ctx, authSessionID, " FOR UPDATE")
}
func (r *postgresAuthSessionRepository) getByID(ctx context.Context, authSessionID string, suffix string) (*AuthSession, error) {
query := `
SELECT id::text, user_id::text, device_id::text, refresh_token_hash, refresh_expires_at,
last_seen_at, last_rotated_at, revoked_at, revoked_reason, created_at, updated_at
FROM auth_sessions
WHERE id = $1::uuid` + suffix
session, err := scanAuthSession(r.db.QueryRow(ctx, query, authSessionID))
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return session, err
}
func (r *postgresAuthSessionRepository) Rotate(ctx context.Context, params RotateAuthSessionParams) error {
const query = `
UPDATE auth_sessions
SET refresh_token_hash = $2,
refresh_expires_at = $3,
last_seen_at = $4,
last_rotated_at = $5,
updated_at = $5
WHERE id = $1::uuid AND revoked_at IS NULL
`
if _, err := r.db.Exec(ctx, query,
params.AuthSessionID,
params.RefreshTokenHash,
params.RefreshExpiresAt,
params.LastSeenAt,
params.LastRotatedAt,
); err != nil {
return fmt.Errorf("rotate auth session: %w", err)
}
return nil
}
func (r *postgresAuthSessionRepository) Touch(ctx context.Context, authSessionID string, seenAt time.Time) error {
const query = `
UPDATE auth_sessions
SET last_seen_at = $2, updated_at = $2
WHERE id = $1::uuid AND revoked_at IS NULL
`
if _, err := r.db.Exec(ctx, query, authSessionID, seenAt); err != nil {
return fmt.Errorf("touch auth session: %w", err)
}
return nil
}
func (r *postgresAuthSessionRepository) Revoke(ctx context.Context, params RevokeAuthSessionParams) error {
const query = `
UPDATE auth_sessions
SET revoked_at = $3,
revoked_reason = $4,
updated_at = $3
WHERE id = $1::uuid AND user_id = $2::uuid AND revoked_at IS NULL
`
if _, err := r.db.Exec(ctx, query, params.AuthSessionID, params.UserID, params.RevokedAt, params.Reason); err != nil {
return fmt.Errorf("revoke auth session: %w", err)
}
return nil
}
func (r *postgresAuthSessionRepository) RevokeByDevice(ctx context.Context, userID, deviceID, reason string, revokedAt time.Time) error {
const query = `
UPDATE auth_sessions
SET revoked_at = $3,
revoked_reason = $4,
updated_at = $3
WHERE user_id = $1::uuid AND device_id = $2::uuid AND revoked_at IS NULL
`
if _, err := r.db.Exec(ctx, query, userID, deviceID, revokedAt, reason); err != nil {
return fmt.Errorf("revoke auth sessions by device: %w", err)
}
return nil
}
func (r *postgresInstallationRepository) GetStatus(ctx context.Context) (*InstallationAuthorityState, error) {
const query = `
SELECT install_id, authority_state, product_root_key_fingerprint, bootstrapped_owner_email, bootstrapped_at
FROM installation_authority
WHERE id = 1
`
status := &InstallationAuthorityState{}
if err := r.db.QueryRow(ctx, query).Scan(
&status.InstallID,
&status.AuthorityState,
&status.ProductRootFingerprint,
&status.BootstrappedOwnerEmail,
&status.BootstrappedAt,
); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return &InstallationAuthorityState{
Bootstrapped: false,
AuthorityState: "unbootstrapped",
}, nil
}
return nil, fmt.Errorf("get installation status: %w", err)
}
status.Bootstrapped = true
return status, nil
}
func (r *postgresInstallationRepository) BootstrapOwner(ctx context.Context, params BootstrapOwnerParams) (*User, error) {
var existingInstallID string
if err := r.db.QueryRow(ctx, `
SELECT install_id
FROM installation_authority
WHERE id = 1
FOR UPDATE
`).Scan(&existingInstallID); err != nil && !errors.Is(err, pgx.ErrNoRows) {
return nil, fmt.Errorf("lock installation authority: %w", err)
} else if err == nil {
return nil, ErrInstallationAlreadyBootstrapped
}
email := strings.ToLower(strings.TrimSpace(params.Email))
now := params.Now.UTC()
user, err := scanOptionalUser(r.db.QueryRow(ctx, `
INSERT INTO users (email, password_hash, mfa_enabled, platform_role, created_at, updated_at)
VALUES ($1, $2, FALSE, $3, $4, $4)
ON CONFLICT (email) DO UPDATE SET
password_hash = EXCLUDED.password_hash,
platform_role = EXCLUDED.platform_role,
updated_at = EXCLUDED.updated_at
RETURNING id::text, email, password_hash, mfa_enabled, platform_role, created_at, updated_at
`, email, params.PasswordHash, params.Role, now))
if err != nil {
return nil, fmt.Errorf("upsert bootstrap owner: %w", err)
}
if user == nil {
return nil, fmt.Errorf("upsert bootstrap owner returned no user")
}
payload := json.RawMessage(`{}`)
if len(params.ActivationPayload) > 0 {
payload = params.ActivationPayload
}
if _, err := r.db.Exec(ctx, `
INSERT INTO installation_authority (
id,
install_id,
authority_state,
product_root_key_fingerprint,
activation_payload,
activation_signature,
bootstrapped_owner_email,
bootstrapped_at,
created_at,
updated_at
) VALUES (
1,
$1,
'active',
$2,
$3::jsonb,
$4,
$5,
$6,
$6,
$6
)
`, params.InstallID, params.ProductRootKeyFingerprint, []byte(payload), params.ActivationSignature, email, now); err != nil {
return nil, fmt.Errorf("insert installation authority: %w", err)
}
if _, err := r.db.Exec(ctx, `
UPDATE platform_role_grants
SET revoked_at = $4
WHERE user_id = $1::uuid
AND role = $2
AND install_id = $3
AND revoked_at IS NULL
`, user.ID, params.Role, params.InstallID, now); err != nil {
return nil, fmt.Errorf("revoke superseded platform role grants: %w", err)
}
if _, err := r.db.Exec(ctx, `
INSERT INTO platform_role_grants (
user_id,
role,
install_id,
grant_payload,
grant_signature,
grant_source,
granted_at,
expires_at,
metadata
) VALUES (
$1::uuid,
$2,
$3,
$4::jsonb,
$5,
$6,
$7,
$8,
'{"bootstrap_owner":true}'::jsonb
)
`, user.ID, params.Role, params.InstallID, []byte(payload), params.ActivationSignature, params.GrantSource, now, params.ExpiresAt); err != nil {
return nil, fmt.Errorf("insert platform role grant: %w", err)
}
if _, err := r.db.Exec(ctx, `
INSERT INTO organization_memberships (
organization_id,
user_id,
role_id,
status,
invited_by_user_id,
created_at,
updated_at
)
SELECT id, $1::uuid, 'org_owner', 'active', $1::uuid, $2, $2
FROM organizations
WHERE slug = 'default'
ON CONFLICT (organization_id, user_id) DO UPDATE SET
role_id = 'org_owner',
status = 'active',
invited_by_user_id = EXCLUDED.invited_by_user_id,
updated_at = EXCLUDED.updated_at
`, user.ID, now); err != nil {
return nil, fmt.Errorf("upsert default organization owner membership: %w", err)
}
return user, nil
}
type scanner interface {
Scan(dest ...any) error
}
func scanOptionalUser(row scanner) (*User, error) {
user := &User{}
if err := row.Scan(
&user.ID,
&user.Email,
&user.PasswordHash,
&user.MFAEnabled,
&user.PlatformRole,
&user.CreatedAt,
&user.UpdatedAt,
); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("scan user: %w", err)
}
return user, nil
}
func scanDevice(row scanner) (*Device, error) {
device := &Device{}
var trustedAt, lastSeenAt, revokedAt *time.Time
var revokedReason *string
if err := row.Scan(
&device.ID,
&device.UserID,
&device.Fingerprint,
&device.Label,
&device.TrustStatus,
&trustedAt,
&lastSeenAt,
&revokedAt,
&revokedReason,
&device.CreatedAt,
&device.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan device: %w", err)
}
device.TrustedAt = trustedAt
device.LastSeenAt = lastSeenAt
device.RevokedAt = revokedAt
device.RevokedReason = revokedReason
return device, nil
}
func scanAuthSession(row scanner) (*AuthSession, error) {
session := &AuthSession{}
var lastSeenAt, lastRotatedAt, revokedAt *time.Time
var revokedReason *string
if err := row.Scan(
&session.ID,
&session.UserID,
&session.DeviceID,
&session.RefreshTokenHash,
&session.RefreshExpiresAt,
&lastSeenAt,
&lastRotatedAt,
&revokedAt,
&revokedReason,
&session.CreatedAt,
&session.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan auth session: %w", err)
}
session.LastSeenAt = lastSeenAt
session.LastRotatedAt = lastRotatedAt
session.RevokedAt = revokedAt
session.RevokedReason = revokedReason
return session, nil
}