567 lines
15 KiB
Go
567 lines
15 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
|
|
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
|
|
)
|
|
|
|
type postgresStore struct {
|
|
db postgresplatform.DBTX
|
|
}
|
|
|
|
type PostgresTransactor struct {
|
|
pool *pgxpool.Pool
|
|
}
|
|
|
|
func NewPostgresStore(pool *pgxpool.Pool) Store {
|
|
return &postgresStore{db: pool}
|
|
}
|
|
|
|
func NewPostgresTransactor(pool *pgxpool.Pool) *PostgresTransactor {
|
|
return &PostgresTransactor{pool: pool}
|
|
}
|
|
|
|
func (t *PostgresTransactor) WithinTransaction(ctx context.Context, fn func(store Store) error) error {
|
|
return postgresplatform.WithTransaction(ctx, t.pool, func(tx pgx.Tx) error {
|
|
return fn(&postgresStore{db: tx})
|
|
})
|
|
}
|
|
|
|
func (s *postgresStore) Users() UserRepository {
|
|
return &postgresUserRepository{db: s.db}
|
|
}
|
|
|
|
func (s *postgresStore) Devices() DeviceRepository {
|
|
return &postgresDeviceRepository{db: s.db}
|
|
}
|
|
|
|
func (s *postgresStore) AuthSessions() AuthSessionRepository {
|
|
return &postgresAuthSessionRepository{db: s.db}
|
|
}
|
|
|
|
func (s *postgresStore) Installation() InstallationRepository {
|
|
return &postgresInstallationRepository{db: s.db}
|
|
}
|
|
|
|
type postgresUserRepository struct {
|
|
db postgresplatform.DBTX
|
|
}
|
|
|
|
type postgresDeviceRepository struct {
|
|
db postgresplatform.DBTX
|
|
}
|
|
|
|
type postgresAuthSessionRepository struct {
|
|
db postgresplatform.DBTX
|
|
}
|
|
|
|
type postgresInstallationRepository struct {
|
|
db postgresplatform.DBTX
|
|
}
|
|
|
|
func (r *postgresUserRepository) GetByEmail(ctx context.Context, email string) (*User, error) {
|
|
const query = `
|
|
SELECT id::text, email, password_hash, mfa_enabled, platform_role, created_at, updated_at
|
|
FROM users
|
|
WHERE email = $1
|
|
`
|
|
return scanOptionalUser(r.db.QueryRow(ctx, query, email))
|
|
}
|
|
|
|
func (r *postgresUserRepository) GetByID(ctx context.Context, userID string) (*User, error) {
|
|
const query = `
|
|
SELECT id::text, email, password_hash, mfa_enabled, platform_role, created_at, updated_at
|
|
FROM users
|
|
WHERE id = $1::uuid
|
|
`
|
|
return scanOptionalUser(r.db.QueryRow(ctx, query, userID))
|
|
}
|
|
|
|
func (r *postgresUserRepository) List(ctx context.Context) ([]User, error) {
|
|
const query = `
|
|
SELECT id::text, email, password_hash, mfa_enabled, platform_role, created_at, updated_at
|
|
FROM users
|
|
ORDER BY created_at DESC
|
|
`
|
|
rows, err := r.db.Query(ctx, query)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query users: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
var users []User
|
|
for rows.Next() {
|
|
user, err := scanOptionalUser(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if user != nil {
|
|
users = append(users, *user)
|
|
}
|
|
}
|
|
return users, rows.Err()
|
|
}
|
|
|
|
func (r *postgresUserRepository) Create(ctx context.Context, user User) (*User, error) {
|
|
const query = `
|
|
INSERT INTO users (email, password_hash, mfa_enabled, platform_role, created_at, updated_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6)
|
|
RETURNING id::text, email, password_hash, mfa_enabled, platform_role, created_at, updated_at
|
|
`
|
|
return scanOptionalUser(r.db.QueryRow(ctx, query,
|
|
user.Email,
|
|
user.PasswordHash,
|
|
user.MFAEnabled,
|
|
user.PlatformRole,
|
|
user.CreatedAt,
|
|
user.UpdatedAt,
|
|
))
|
|
}
|
|
|
|
func (r *postgresDeviceRepository) Upsert(ctx context.Context, params UpsertDeviceParams) (*Device, error) {
|
|
const query = `
|
|
INSERT INTO devices (
|
|
user_id,
|
|
device_fingerprint,
|
|
device_label,
|
|
trust_status,
|
|
trusted_at,
|
|
last_seen_at,
|
|
created_at,
|
|
updated_at
|
|
) VALUES (
|
|
$1::uuid,
|
|
$2,
|
|
$3,
|
|
CASE WHEN $4 THEN 'trusted' ELSE 'pending' END,
|
|
CASE WHEN $4 THEN $5::timestamptz ELSE NULL::timestamptz END,
|
|
$5::timestamptz,
|
|
$5::timestamptz,
|
|
$5::timestamptz
|
|
)
|
|
ON CONFLICT (user_id, device_fingerprint) DO UPDATE SET
|
|
device_label = EXCLUDED.device_label,
|
|
last_seen_at = EXCLUDED.last_seen_at,
|
|
updated_at = EXCLUDED.updated_at,
|
|
trust_status = CASE
|
|
WHEN devices.trust_status = 'revoked' THEN devices.trust_status
|
|
WHEN devices.trust_status = 'trusted' THEN devices.trust_status
|
|
WHEN EXCLUDED.trust_status = 'trusted' THEN 'trusted'
|
|
ELSE devices.trust_status
|
|
END,
|
|
trusted_at = CASE
|
|
WHEN devices.trust_status = 'trusted' THEN devices.trusted_at
|
|
WHEN EXCLUDED.trust_status = 'trusted' THEN EXCLUDED.trusted_at
|
|
ELSE devices.trusted_at
|
|
END
|
|
RETURNING
|
|
id::text, user_id::text, device_fingerprint, COALESCE(device_label, ''),
|
|
trust_status, trusted_at, last_seen_at, revoked_at, revoked_reason, created_at, updated_at
|
|
`
|
|
return scanDevice(r.db.QueryRow(ctx, query,
|
|
params.UserID,
|
|
params.Fingerprint,
|
|
params.Label,
|
|
params.TrustRequested,
|
|
params.SeenAt,
|
|
))
|
|
}
|
|
|
|
func (r *postgresDeviceRepository) GetByIDForUser(ctx context.Context, userID, deviceID string) (*Device, error) {
|
|
const query = `
|
|
SELECT id::text, user_id::text, device_fingerprint, COALESCE(device_label, ''),
|
|
trust_status, trusted_at, last_seen_at, revoked_at, revoked_reason, created_at, updated_at
|
|
FROM devices
|
|
WHERE id = $1::uuid AND user_id = $2::uuid
|
|
`
|
|
device, err := scanDevice(r.db.QueryRow(ctx, query, deviceID, userID))
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
return device, err
|
|
}
|
|
|
|
func (r *postgresDeviceRepository) ListTrustedByUser(ctx context.Context, userID string) ([]Device, error) {
|
|
const query = `
|
|
SELECT id::text, user_id::text, device_fingerprint, COALESCE(device_label, ''),
|
|
trust_status, trusted_at, last_seen_at, revoked_at, revoked_reason, created_at, updated_at
|
|
FROM devices
|
|
WHERE user_id = $1::uuid AND trust_status = 'trusted' AND revoked_at IS NULL
|
|
ORDER BY created_at DESC
|
|
`
|
|
rows, err := r.db.Query(ctx, query, userID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query trusted devices: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var devices []Device
|
|
for rows.Next() {
|
|
device, err := scanDevice(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
devices = append(devices, *device)
|
|
}
|
|
return devices, rows.Err()
|
|
}
|
|
|
|
func (r *postgresDeviceRepository) Revoke(ctx context.Context, params RevokeDeviceParams) error {
|
|
const query = `
|
|
UPDATE devices
|
|
SET trust_status = 'revoked',
|
|
revoked_at = $3,
|
|
revoked_reason = $4,
|
|
updated_at = $3
|
|
WHERE id = $1::uuid AND user_id = $2::uuid
|
|
`
|
|
if _, err := r.db.Exec(ctx, query, params.DeviceID, params.UserID, params.RevokedAt, params.Reason); err != nil {
|
|
return fmt.Errorf("revoke device: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *postgresAuthSessionRepository) Create(ctx context.Context, session AuthSession) error {
|
|
const query = `
|
|
INSERT INTO auth_sessions (
|
|
id,
|
|
user_id,
|
|
device_id,
|
|
refresh_token_hash,
|
|
refresh_expires_at,
|
|
last_seen_at,
|
|
created_at,
|
|
updated_at
|
|
) VALUES ($1::uuid, $2::uuid, $3::uuid, $4, $5, $6, $7, $8)
|
|
`
|
|
if _, err := r.db.Exec(ctx, query,
|
|
session.ID,
|
|
session.UserID,
|
|
session.DeviceID,
|
|
session.RefreshTokenHash,
|
|
session.RefreshExpiresAt,
|
|
session.LastSeenAt,
|
|
session.CreatedAt,
|
|
session.UpdatedAt,
|
|
); err != nil {
|
|
return fmt.Errorf("create auth session: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *postgresAuthSessionRepository) GetByID(ctx context.Context, authSessionID string) (*AuthSession, error) {
|
|
return r.getByID(ctx, authSessionID, "")
|
|
}
|
|
|
|
func (r *postgresAuthSessionRepository) GetByIDForUpdate(ctx context.Context, authSessionID string) (*AuthSession, error) {
|
|
return r.getByID(ctx, authSessionID, " FOR UPDATE")
|
|
}
|
|
|
|
func (r *postgresAuthSessionRepository) getByID(ctx context.Context, authSessionID string, suffix string) (*AuthSession, error) {
|
|
query := `
|
|
SELECT id::text, user_id::text, device_id::text, refresh_token_hash, refresh_expires_at,
|
|
last_seen_at, last_rotated_at, revoked_at, revoked_reason, created_at, updated_at
|
|
FROM auth_sessions
|
|
WHERE id = $1::uuid` + suffix
|
|
session, err := scanAuthSession(r.db.QueryRow(ctx, query, authSessionID))
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
return session, err
|
|
}
|
|
|
|
func (r *postgresAuthSessionRepository) Rotate(ctx context.Context, params RotateAuthSessionParams) error {
|
|
const query = `
|
|
UPDATE auth_sessions
|
|
SET refresh_token_hash = $2,
|
|
refresh_expires_at = $3,
|
|
last_seen_at = $4,
|
|
last_rotated_at = $5,
|
|
updated_at = $5
|
|
WHERE id = $1::uuid AND revoked_at IS NULL
|
|
`
|
|
if _, err := r.db.Exec(ctx, query,
|
|
params.AuthSessionID,
|
|
params.RefreshTokenHash,
|
|
params.RefreshExpiresAt,
|
|
params.LastSeenAt,
|
|
params.LastRotatedAt,
|
|
); err != nil {
|
|
return fmt.Errorf("rotate auth session: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *postgresAuthSessionRepository) Touch(ctx context.Context, authSessionID string, seenAt time.Time) error {
|
|
const query = `
|
|
UPDATE auth_sessions
|
|
SET last_seen_at = $2, updated_at = $2
|
|
WHERE id = $1::uuid AND revoked_at IS NULL
|
|
`
|
|
if _, err := r.db.Exec(ctx, query, authSessionID, seenAt); err != nil {
|
|
return fmt.Errorf("touch auth session: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *postgresAuthSessionRepository) Revoke(ctx context.Context, params RevokeAuthSessionParams) error {
|
|
const query = `
|
|
UPDATE auth_sessions
|
|
SET revoked_at = $3,
|
|
revoked_reason = $4,
|
|
updated_at = $3
|
|
WHERE id = $1::uuid AND user_id = $2::uuid AND revoked_at IS NULL
|
|
`
|
|
if _, err := r.db.Exec(ctx, query, params.AuthSessionID, params.UserID, params.RevokedAt, params.Reason); err != nil {
|
|
return fmt.Errorf("revoke auth session: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *postgresAuthSessionRepository) RevokeByDevice(ctx context.Context, userID, deviceID, reason string, revokedAt time.Time) error {
|
|
const query = `
|
|
UPDATE auth_sessions
|
|
SET revoked_at = $3,
|
|
revoked_reason = $4,
|
|
updated_at = $3
|
|
WHERE user_id = $1::uuid AND device_id = $2::uuid AND revoked_at IS NULL
|
|
`
|
|
if _, err := r.db.Exec(ctx, query, userID, deviceID, revokedAt, reason); err != nil {
|
|
return fmt.Errorf("revoke auth sessions by device: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *postgresInstallationRepository) GetStatus(ctx context.Context) (*InstallationAuthorityState, error) {
|
|
const query = `
|
|
SELECT install_id, authority_state, product_root_key_fingerprint, bootstrapped_owner_email, bootstrapped_at
|
|
FROM installation_authority
|
|
WHERE id = 1
|
|
`
|
|
status := &InstallationAuthorityState{}
|
|
if err := r.db.QueryRow(ctx, query).Scan(
|
|
&status.InstallID,
|
|
&status.AuthorityState,
|
|
&status.ProductRootFingerprint,
|
|
&status.BootstrappedOwnerEmail,
|
|
&status.BootstrappedAt,
|
|
); err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return &InstallationAuthorityState{
|
|
Bootstrapped: false,
|
|
AuthorityState: "unbootstrapped",
|
|
}, nil
|
|
}
|
|
return nil, fmt.Errorf("get installation status: %w", err)
|
|
}
|
|
status.Bootstrapped = true
|
|
return status, nil
|
|
}
|
|
|
|
func (r *postgresInstallationRepository) BootstrapOwner(ctx context.Context, params BootstrapOwnerParams) (*User, error) {
|
|
var existingInstallID string
|
|
if err := r.db.QueryRow(ctx, `
|
|
SELECT install_id
|
|
FROM installation_authority
|
|
WHERE id = 1
|
|
FOR UPDATE
|
|
`).Scan(&existingInstallID); err != nil && !errors.Is(err, pgx.ErrNoRows) {
|
|
return nil, fmt.Errorf("lock installation authority: %w", err)
|
|
} else if err == nil {
|
|
return nil, ErrInstallationAlreadyBootstrapped
|
|
}
|
|
|
|
email := strings.ToLower(strings.TrimSpace(params.Email))
|
|
now := params.Now.UTC()
|
|
user, err := scanOptionalUser(r.db.QueryRow(ctx, `
|
|
INSERT INTO users (email, password_hash, mfa_enabled, platform_role, created_at, updated_at)
|
|
VALUES ($1, $2, FALSE, $3, $4, $4)
|
|
ON CONFLICT (email) DO UPDATE SET
|
|
password_hash = EXCLUDED.password_hash,
|
|
platform_role = EXCLUDED.platform_role,
|
|
updated_at = EXCLUDED.updated_at
|
|
RETURNING id::text, email, password_hash, mfa_enabled, platform_role, created_at, updated_at
|
|
`, email, params.PasswordHash, params.Role, now))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("upsert bootstrap owner: %w", err)
|
|
}
|
|
if user == nil {
|
|
return nil, fmt.Errorf("upsert bootstrap owner returned no user")
|
|
}
|
|
|
|
payload := json.RawMessage(`{}`)
|
|
if len(params.ActivationPayload) > 0 {
|
|
payload = params.ActivationPayload
|
|
}
|
|
if _, err := r.db.Exec(ctx, `
|
|
INSERT INTO installation_authority (
|
|
id,
|
|
install_id,
|
|
authority_state,
|
|
product_root_key_fingerprint,
|
|
activation_payload,
|
|
activation_signature,
|
|
bootstrapped_owner_email,
|
|
bootstrapped_at,
|
|
created_at,
|
|
updated_at
|
|
) VALUES (
|
|
1,
|
|
$1,
|
|
'active',
|
|
$2,
|
|
$3::jsonb,
|
|
$4,
|
|
$5,
|
|
$6,
|
|
$6,
|
|
$6
|
|
)
|
|
`, params.InstallID, params.ProductRootKeyFingerprint, []byte(payload), params.ActivationSignature, email, now); err != nil {
|
|
return nil, fmt.Errorf("insert installation authority: %w", err)
|
|
}
|
|
|
|
if _, err := r.db.Exec(ctx, `
|
|
UPDATE platform_role_grants
|
|
SET revoked_at = $4
|
|
WHERE user_id = $1::uuid
|
|
AND role = $2
|
|
AND install_id = $3
|
|
AND revoked_at IS NULL
|
|
`, user.ID, params.Role, params.InstallID, now); err != nil {
|
|
return nil, fmt.Errorf("revoke superseded platform role grants: %w", err)
|
|
}
|
|
if _, err := r.db.Exec(ctx, `
|
|
INSERT INTO platform_role_grants (
|
|
user_id,
|
|
role,
|
|
install_id,
|
|
grant_payload,
|
|
grant_signature,
|
|
grant_source,
|
|
granted_at,
|
|
expires_at,
|
|
metadata
|
|
) VALUES (
|
|
$1::uuid,
|
|
$2,
|
|
$3,
|
|
$4::jsonb,
|
|
$5,
|
|
$6,
|
|
$7,
|
|
$8,
|
|
'{"bootstrap_owner":true}'::jsonb
|
|
)
|
|
`, user.ID, params.Role, params.InstallID, []byte(payload), params.ActivationSignature, params.GrantSource, now, params.ExpiresAt); err != nil {
|
|
return nil, fmt.Errorf("insert platform role grant: %w", err)
|
|
}
|
|
|
|
if _, err := r.db.Exec(ctx, `
|
|
INSERT INTO organization_memberships (
|
|
organization_id,
|
|
user_id,
|
|
role_id,
|
|
status,
|
|
invited_by_user_id,
|
|
created_at,
|
|
updated_at
|
|
)
|
|
SELECT id, $1::uuid, 'org_owner', 'active', $1::uuid, $2, $2
|
|
FROM organizations
|
|
WHERE slug = 'default'
|
|
ON CONFLICT (organization_id, user_id) DO UPDATE SET
|
|
role_id = 'org_owner',
|
|
status = 'active',
|
|
invited_by_user_id = EXCLUDED.invited_by_user_id,
|
|
updated_at = EXCLUDED.updated_at
|
|
`, user.ID, now); err != nil {
|
|
return nil, fmt.Errorf("upsert default organization owner membership: %w", err)
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
type scanner interface {
|
|
Scan(dest ...any) error
|
|
}
|
|
|
|
func scanOptionalUser(row scanner) (*User, error) {
|
|
user := &User{}
|
|
if err := row.Scan(
|
|
&user.ID,
|
|
&user.Email,
|
|
&user.PasswordHash,
|
|
&user.MFAEnabled,
|
|
&user.PlatformRole,
|
|
&user.CreatedAt,
|
|
&user.UpdatedAt,
|
|
); err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("scan user: %w", err)
|
|
}
|
|
return user, nil
|
|
}
|
|
|
|
func scanDevice(row scanner) (*Device, error) {
|
|
device := &Device{}
|
|
var trustedAt, lastSeenAt, revokedAt *time.Time
|
|
var revokedReason *string
|
|
if err := row.Scan(
|
|
&device.ID,
|
|
&device.UserID,
|
|
&device.Fingerprint,
|
|
&device.Label,
|
|
&device.TrustStatus,
|
|
&trustedAt,
|
|
&lastSeenAt,
|
|
&revokedAt,
|
|
&revokedReason,
|
|
&device.CreatedAt,
|
|
&device.UpdatedAt,
|
|
); err != nil {
|
|
return nil, fmt.Errorf("scan device: %w", err)
|
|
}
|
|
device.TrustedAt = trustedAt
|
|
device.LastSeenAt = lastSeenAt
|
|
device.RevokedAt = revokedAt
|
|
device.RevokedReason = revokedReason
|
|
return device, nil
|
|
}
|
|
|
|
func scanAuthSession(row scanner) (*AuthSession, error) {
|
|
session := &AuthSession{}
|
|
var lastSeenAt, lastRotatedAt, revokedAt *time.Time
|
|
var revokedReason *string
|
|
if err := row.Scan(
|
|
&session.ID,
|
|
&session.UserID,
|
|
&session.DeviceID,
|
|
&session.RefreshTokenHash,
|
|
&session.RefreshExpiresAt,
|
|
&lastSeenAt,
|
|
&lastRotatedAt,
|
|
&revokedAt,
|
|
&revokedReason,
|
|
&session.CreatedAt,
|
|
&session.UpdatedAt,
|
|
); err != nil {
|
|
return nil, fmt.Errorf("scan auth session: %w", err)
|
|
}
|
|
session.LastSeenAt = lastSeenAt
|
|
session.LastRotatedAt = lastRotatedAt
|
|
session.RevokedAt = revokedAt
|
|
session.RevokedReason = revokedReason
|
|
return session, nil
|
|
}
|