package auth import ( "context" "encoding/json" "errors" "fmt" "strings" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres" ) type postgresStore struct { db postgresplatform.DBTX } type PostgresTransactor struct { pool *pgxpool.Pool } func NewPostgresStore(pool *pgxpool.Pool) Store { return &postgresStore{db: pool} } func NewPostgresTransactor(pool *pgxpool.Pool) *PostgresTransactor { return &PostgresTransactor{pool: pool} } func (t *PostgresTransactor) WithinTransaction(ctx context.Context, fn func(store Store) error) error { return postgresplatform.WithTransaction(ctx, t.pool, func(tx pgx.Tx) error { return fn(&postgresStore{db: tx}) }) } func (s *postgresStore) Users() UserRepository { return &postgresUserRepository{db: s.db} } func (s *postgresStore) Devices() DeviceRepository { return &postgresDeviceRepository{db: s.db} } func (s *postgresStore) AuthSessions() AuthSessionRepository { return &postgresAuthSessionRepository{db: s.db} } func (s *postgresStore) Installation() InstallationRepository { return &postgresInstallationRepository{db: s.db} } type postgresUserRepository struct { db postgresplatform.DBTX } type postgresDeviceRepository struct { db postgresplatform.DBTX } type postgresAuthSessionRepository struct { db postgresplatform.DBTX } type postgresInstallationRepository struct { db postgresplatform.DBTX } func (r *postgresUserRepository) GetByEmail(ctx context.Context, email string) (*User, error) { const query = ` SELECT id::text, email, password_hash, mfa_enabled, created_at, updated_at FROM users WHERE email = $1 ` return scanOptionalUser(r.db.QueryRow(ctx, query, email)) } func (r *postgresUserRepository) GetByID(ctx context.Context, userID string) (*User, error) { const query = ` SELECT id::text, email, password_hash, mfa_enabled, created_at, updated_at FROM users WHERE id = $1::uuid ` return scanOptionalUser(r.db.QueryRow(ctx, query, userID)) } func (r *postgresDeviceRepository) Upsert(ctx context.Context, params UpsertDeviceParams) (*Device, error) { const query = ` INSERT INTO devices ( user_id, device_fingerprint, device_label, trust_status, trusted_at, last_seen_at, created_at, updated_at ) VALUES ( $1::uuid, $2, $3, CASE WHEN $4 THEN 'trusted' ELSE 'pending' END, CASE WHEN $4 THEN $5::timestamptz ELSE NULL::timestamptz END, $5::timestamptz, $5::timestamptz, $5::timestamptz ) ON CONFLICT (user_id, device_fingerprint) DO UPDATE SET device_label = EXCLUDED.device_label, last_seen_at = EXCLUDED.last_seen_at, updated_at = EXCLUDED.updated_at, trust_status = CASE WHEN devices.trust_status = 'revoked' THEN devices.trust_status WHEN devices.trust_status = 'trusted' THEN devices.trust_status WHEN EXCLUDED.trust_status = 'trusted' THEN 'trusted' ELSE devices.trust_status END, trusted_at = CASE WHEN devices.trust_status = 'trusted' THEN devices.trusted_at WHEN EXCLUDED.trust_status = 'trusted' THEN EXCLUDED.trusted_at ELSE devices.trusted_at END RETURNING id::text, user_id::text, device_fingerprint, COALESCE(device_label, ''), trust_status, trusted_at, last_seen_at, revoked_at, revoked_reason, created_at, updated_at ` return scanDevice(r.db.QueryRow(ctx, query, params.UserID, params.Fingerprint, params.Label, params.TrustRequested, params.SeenAt, )) } func (r *postgresDeviceRepository) GetByIDForUser(ctx context.Context, userID, deviceID string) (*Device, error) { const query = ` SELECT id::text, user_id::text, device_fingerprint, COALESCE(device_label, ''), trust_status, trusted_at, last_seen_at, revoked_at, revoked_reason, created_at, updated_at FROM devices WHERE id = $1::uuid AND user_id = $2::uuid ` device, err := scanDevice(r.db.QueryRow(ctx, query, deviceID, userID)) if errors.Is(err, pgx.ErrNoRows) { return nil, nil } return device, err } func (r *postgresDeviceRepository) ListTrustedByUser(ctx context.Context, userID string) ([]Device, error) { const query = ` SELECT id::text, user_id::text, device_fingerprint, COALESCE(device_label, ''), trust_status, trusted_at, last_seen_at, revoked_at, revoked_reason, created_at, updated_at FROM devices WHERE user_id = $1::uuid AND trust_status = 'trusted' AND revoked_at IS NULL ORDER BY created_at DESC ` rows, err := r.db.Query(ctx, query, userID) if err != nil { return nil, fmt.Errorf("query trusted devices: %w", err) } defer rows.Close() var devices []Device for rows.Next() { device, err := scanDevice(rows) if err != nil { return nil, err } devices = append(devices, *device) } return devices, rows.Err() } func (r *postgresDeviceRepository) Revoke(ctx context.Context, params RevokeDeviceParams) error { const query = ` UPDATE devices SET trust_status = 'revoked', revoked_at = $3, revoked_reason = $4, updated_at = $3 WHERE id = $1::uuid AND user_id = $2::uuid ` if _, err := r.db.Exec(ctx, query, params.DeviceID, params.UserID, params.RevokedAt, params.Reason); err != nil { return fmt.Errorf("revoke device: %w", err) } return nil } func (r *postgresAuthSessionRepository) Create(ctx context.Context, session AuthSession) error { const query = ` INSERT INTO auth_sessions ( id, user_id, device_id, refresh_token_hash, refresh_expires_at, last_seen_at, created_at, updated_at ) VALUES ($1::uuid, $2::uuid, $3::uuid, $4, $5, $6, $7, $8) ` if _, err := r.db.Exec(ctx, query, session.ID, session.UserID, session.DeviceID, session.RefreshTokenHash, session.RefreshExpiresAt, session.LastSeenAt, session.CreatedAt, session.UpdatedAt, ); err != nil { return fmt.Errorf("create auth session: %w", err) } return nil } func (r *postgresAuthSessionRepository) GetByID(ctx context.Context, authSessionID string) (*AuthSession, error) { return r.getByID(ctx, authSessionID, "") } func (r *postgresAuthSessionRepository) GetByIDForUpdate(ctx context.Context, authSessionID string) (*AuthSession, error) { return r.getByID(ctx, authSessionID, " FOR UPDATE") } func (r *postgresAuthSessionRepository) getByID(ctx context.Context, authSessionID string, suffix string) (*AuthSession, error) { query := ` SELECT id::text, user_id::text, device_id::text, refresh_token_hash, refresh_expires_at, last_seen_at, last_rotated_at, revoked_at, revoked_reason, created_at, updated_at FROM auth_sessions WHERE id = $1::uuid` + suffix session, err := scanAuthSession(r.db.QueryRow(ctx, query, authSessionID)) if errors.Is(err, pgx.ErrNoRows) { return nil, nil } return session, err } func (r *postgresAuthSessionRepository) Rotate(ctx context.Context, params RotateAuthSessionParams) error { const query = ` UPDATE auth_sessions SET refresh_token_hash = $2, refresh_expires_at = $3, last_seen_at = $4, last_rotated_at = $5, updated_at = $5 WHERE id = $1::uuid AND revoked_at IS NULL ` if _, err := r.db.Exec(ctx, query, params.AuthSessionID, params.RefreshTokenHash, params.RefreshExpiresAt, params.LastSeenAt, params.LastRotatedAt, ); err != nil { return fmt.Errorf("rotate auth session: %w", err) } return nil } func (r *postgresAuthSessionRepository) Touch(ctx context.Context, authSessionID string, seenAt time.Time) error { const query = ` UPDATE auth_sessions SET last_seen_at = $2, updated_at = $2 WHERE id = $1::uuid AND revoked_at IS NULL ` if _, err := r.db.Exec(ctx, query, authSessionID, seenAt); err != nil { return fmt.Errorf("touch auth session: %w", err) } return nil } func (r *postgresAuthSessionRepository) Revoke(ctx context.Context, params RevokeAuthSessionParams) error { const query = ` UPDATE auth_sessions SET revoked_at = $3, revoked_reason = $4, updated_at = $3 WHERE id = $1::uuid AND user_id = $2::uuid AND revoked_at IS NULL ` if _, err := r.db.Exec(ctx, query, params.AuthSessionID, params.UserID, params.RevokedAt, params.Reason); err != nil { return fmt.Errorf("revoke auth session: %w", err) } return nil } func (r *postgresAuthSessionRepository) RevokeByDevice(ctx context.Context, userID, deviceID, reason string, revokedAt time.Time) error { const query = ` UPDATE auth_sessions SET revoked_at = $3, revoked_reason = $4, updated_at = $3 WHERE user_id = $1::uuid AND device_id = $2::uuid AND revoked_at IS NULL ` if _, err := r.db.Exec(ctx, query, userID, deviceID, revokedAt, reason); err != nil { return fmt.Errorf("revoke auth sessions by device: %w", err) } return nil } func (r *postgresInstallationRepository) GetStatus(ctx context.Context) (*InstallationAuthorityState, error) { const query = ` SELECT install_id, authority_state, product_root_key_fingerprint, bootstrapped_owner_email, bootstrapped_at FROM installation_authority WHERE id = 1 ` status := &InstallationAuthorityState{} if err := r.db.QueryRow(ctx, query).Scan( &status.InstallID, &status.AuthorityState, &status.ProductRootFingerprint, &status.BootstrappedOwnerEmail, &status.BootstrappedAt, ); err != nil { if errors.Is(err, pgx.ErrNoRows) { return &InstallationAuthorityState{ Bootstrapped: false, AuthorityState: "unbootstrapped", }, nil } return nil, fmt.Errorf("get installation status: %w", err) } status.Bootstrapped = true return status, nil } func (r *postgresInstallationRepository) BootstrapOwner(ctx context.Context, params BootstrapOwnerParams) (*User, error) { var existingInstallID string if err := r.db.QueryRow(ctx, ` SELECT install_id FROM installation_authority WHERE id = 1 FOR UPDATE `).Scan(&existingInstallID); err != nil && !errors.Is(err, pgx.ErrNoRows) { return nil, fmt.Errorf("lock installation authority: %w", err) } else if err == nil { return nil, ErrInstallationAlreadyBootstrapped } email := strings.ToLower(strings.TrimSpace(params.Email)) now := params.Now.UTC() user, err := scanOptionalUser(r.db.QueryRow(ctx, ` INSERT INTO users (email, password_hash, mfa_enabled, platform_role, created_at, updated_at) VALUES ($1, $2, FALSE, $3, $4, $4) ON CONFLICT (email) DO UPDATE SET password_hash = EXCLUDED.password_hash, platform_role = EXCLUDED.platform_role, updated_at = EXCLUDED.updated_at RETURNING id::text, email, password_hash, mfa_enabled, created_at, updated_at `, email, params.PasswordHash, params.Role, now)) if err != nil { return nil, fmt.Errorf("upsert bootstrap owner: %w", err) } if user == nil { return nil, fmt.Errorf("upsert bootstrap owner returned no user") } payload := json.RawMessage(`{}`) if len(params.ActivationPayload) > 0 { payload = params.ActivationPayload } if _, err := r.db.Exec(ctx, ` INSERT INTO installation_authority ( id, install_id, authority_state, product_root_key_fingerprint, activation_payload, activation_signature, bootstrapped_owner_email, bootstrapped_at, created_at, updated_at ) VALUES ( 1, $1, 'active', $2, $3::jsonb, $4, $5, $6, $6, $6 ) `, params.InstallID, params.ProductRootKeyFingerprint, []byte(payload), params.ActivationSignature, email, now); err != nil { return nil, fmt.Errorf("insert installation authority: %w", err) } if _, err := r.db.Exec(ctx, ` UPDATE platform_role_grants SET revoked_at = $4 WHERE user_id = $1::uuid AND role = $2 AND install_id = $3 AND revoked_at IS NULL `, user.ID, params.Role, params.InstallID, now); err != nil { return nil, fmt.Errorf("revoke superseded platform role grants: %w", err) } if _, err := r.db.Exec(ctx, ` INSERT INTO platform_role_grants ( user_id, role, install_id, grant_payload, grant_signature, grant_source, granted_at, expires_at, metadata ) VALUES ( $1::uuid, $2, $3, $4::jsonb, $5, $6, $7, $8, '{"bootstrap_owner":true}'::jsonb ) `, user.ID, params.Role, params.InstallID, []byte(payload), params.ActivationSignature, params.GrantSource, now, params.ExpiresAt); err != nil { return nil, fmt.Errorf("insert platform role grant: %w", err) } if _, err := r.db.Exec(ctx, ` INSERT INTO organization_memberships ( organization_id, user_id, role_id, status, invited_by_user_id, created_at, updated_at ) SELECT id, $1::uuid, 'org_owner', 'active', $1::uuid, $2, $2 FROM organizations WHERE slug = 'default' ON CONFLICT (organization_id, user_id) DO UPDATE SET role_id = 'org_owner', status = 'active', invited_by_user_id = EXCLUDED.invited_by_user_id, updated_at = EXCLUDED.updated_at `, user.ID, now); err != nil { return nil, fmt.Errorf("upsert default organization owner membership: %w", err) } return user, nil } type scanner interface { Scan(dest ...any) error } func scanOptionalUser(row scanner) (*User, error) { user := &User{} if err := row.Scan( &user.ID, &user.Email, &user.PasswordHash, &user.MFAEnabled, &user.CreatedAt, &user.UpdatedAt, ); err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, nil } return nil, fmt.Errorf("scan user: %w", err) } return user, nil } func scanDevice(row scanner) (*Device, error) { device := &Device{} var trustedAt, lastSeenAt, revokedAt *time.Time var revokedReason *string if err := row.Scan( &device.ID, &device.UserID, &device.Fingerprint, &device.Label, &device.TrustStatus, &trustedAt, &lastSeenAt, &revokedAt, &revokedReason, &device.CreatedAt, &device.UpdatedAt, ); err != nil { return nil, fmt.Errorf("scan device: %w", err) } device.TrustedAt = trustedAt device.LastSeenAt = lastSeenAt device.RevokedAt = revokedAt device.RevokedReason = revokedReason return device, nil } func scanAuthSession(row scanner) (*AuthSession, error) { session := &AuthSession{} var lastSeenAt, lastRotatedAt, revokedAt *time.Time var revokedReason *string if err := row.Scan( &session.ID, &session.UserID, &session.DeviceID, &session.RefreshTokenHash, &session.RefreshExpiresAt, &lastSeenAt, &lastRotatedAt, &revokedAt, &revokedReason, &session.CreatedAt, &session.UpdatedAt, ); err != nil { return nil, fmt.Errorf("scan auth session: %w", err) } session.LastSeenAt = lastSeenAt session.LastRotatedAt = lastRotatedAt session.RevokedAt = revokedAt session.RevokedReason = revokedReason return session, nil }