package sessionbroker import ( "context" "encoding/json" "errors" "fmt" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/example/remote-access-platform/backend/internal/platform/authority" postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres" ) type postgresStore struct { db postgresplatform.DBTX authority *authority.Verifier } type PostgresTransactor struct { pool *pgxpool.Pool authority *authority.Verifier } func NewPostgresStore(pool *pgxpool.Pool, verifiers ...*authority.Verifier) Store { var authorityVerifier *authority.Verifier if len(verifiers) > 0 { authorityVerifier = verifiers[0] } return &postgresStore{db: pool, authority: authorityVerifier} } func NewPostgresTransactor(pool *pgxpool.Pool, verifiers ...*authority.Verifier) *PostgresTransactor { var authorityVerifier *authority.Verifier if len(verifiers) > 0 { authorityVerifier = verifiers[0] } return &PostgresTransactor{pool: pool, authority: authorityVerifier} } func (t *PostgresTransactor) WithinTransaction(ctx context.Context, fn func(store Store) error) error { return postgresplatform.WithTransaction(ctx, t.pool, func(tx pgx.Tx) error { return fn(&postgresStore{db: tx, authority: t.authority}) }) } func (s *postgresStore) RemoteSessions() RemoteSessionRepository { return &postgresRemoteSessionRepository{db: s.db} } func (s *postgresStore) SessionAttachments() SessionAttachmentRepository { return &postgresSessionAttachmentRepository{db: s.db} } func (s *postgresStore) ResourcePolicies() ResourcePolicyRepository { return &postgresResourcePolicyRepository{db: s.db} } func (s *postgresStore) ResourceRuntime() ResourceRuntimeRepository { return &postgresResourceRuntimeRepository{db: s.db} } func (s *postgresStore) AuditEvents() AuditEventRepository { return &postgresAuditEventRepository{db: s.db} } func (s *postgresStore) Access() AccessRepository { return &postgresAccessRepository{db: s.db, authority: s.authority} } type postgresRemoteSessionRepository struct { db postgresplatform.DBTX } type postgresSessionAttachmentRepository struct { db postgresplatform.DBTX } type postgresResourcePolicyRepository struct { db postgresplatform.DBTX } type postgresResourceRuntimeRepository struct { db postgresplatform.DBTX } type postgresAuditEventRepository struct { db postgresplatform.DBTX } type postgresAccessRepository struct { db postgresplatform.DBTX authority *authority.Verifier } func (r *postgresRemoteSessionRepository) Create(ctx context.Context, session RemoteSession) error { const query = ` INSERT INTO remote_sessions ( id, organization_id, resource_id, protocol, state, worker_id, controller_user_id, detach_deadline_at, last_heartbeat_at, takeover_version, metadata, created_at, updated_at ) VALUES ( $1::uuid, $2::uuid, $3::uuid, $4, $5, NULLIF($6, ''), $7::uuid, $8, $9, $10, $11::jsonb, $12, $13 ) ` if _, err := r.db.Exec(ctx, query, session.ID, session.OrganizationID, session.ResourceID, session.Protocol, session.State, session.WorkerID, session.ControllerUserID, session.DetachDeadlineAt, session.LastHeartbeatAt, session.TakeoverVersion, jsonPayload(session.Metadata), session.CreatedAt, session.UpdatedAt, ); err != nil { return fmt.Errorf("create remote session: %w", err) } return nil } func (r *postgresRemoteSessionRepository) GetByID(ctx context.Context, sessionID string) (*RemoteSession, error) { return r.getByID(ctx, sessionID, "") } func (r *postgresRemoteSessionRepository) GetByIDForUpdate(ctx context.Context, sessionID string) (*RemoteSession, error) { return r.getByID(ctx, sessionID, " FOR UPDATE") } func (r *postgresRemoteSessionRepository) getByID(ctx context.Context, sessionID string, suffix string) (*RemoteSession, error) { query := ` SELECT id::text, organization_id::text, resource_id::text, protocol, state, COALESCE(worker_id, ''), controller_user_id::text, detach_deadline_at, last_heartbeat_at, takeover_version, metadata, created_at, updated_at FROM remote_sessions WHERE id = $1::uuid` + suffix remoteSession, err := scanRemoteSession(r.db.QueryRow(ctx, query, sessionID)) if errors.Is(err, pgx.ErrNoRows) { return nil, nil } return remoteSession, err } func (r *postgresRemoteSessionRepository) ListByController(ctx context.Context, userID string) ([]RemoteSession, error) { const query = ` SELECT id::text, organization_id::text, resource_id::text, protocol, state, COALESCE(worker_id, ''), controller_user_id::text, detach_deadline_at, last_heartbeat_at, takeover_version, metadata, created_at, updated_at FROM remote_sessions WHERE controller_user_id = $1::uuid ORDER BY updated_at DESC ` rows, err := r.db.Query(ctx, query, userID) if err != nil { return nil, fmt.Errorf("list remote sessions: %w", err) } defer rows.Close() var sessions []RemoteSession for rows.Next() { item, err := scanRemoteSession(rows) if err != nil { return nil, err } sessions = append(sessions, *item) } return sessions, rows.Err() } func (r *postgresRemoteSessionRepository) CountLiveByResource(ctx context.Context, resourceID string) (int, error) { const query = ` SELECT COUNT(*) FROM remote_sessions WHERE resource_id = $1::uuid AND state IN ('starting', 'active', 'detached', 'reconnecting') ` var count int if err := r.db.QueryRow(ctx, query, resourceID).Scan(&count); err != nil { return 0, fmt.Errorf("count live remote sessions: %w", err) } return count, nil } func (r *postgresRemoteSessionRepository) ListDetachedExpired(ctx context.Context, before time.Time, limit int) ([]RemoteSession, error) { const query = ` SELECT id::text, organization_id::text, resource_id::text, protocol, state, COALESCE(worker_id, ''), controller_user_id::text, detach_deadline_at, last_heartbeat_at, takeover_version, metadata, created_at, updated_at FROM remote_sessions WHERE state = 'detached' AND detach_deadline_at IS NOT NULL AND detach_deadline_at <= $1 ORDER BY detach_deadline_at ASC LIMIT $2 ` rows, err := r.db.Query(ctx, query, before, limit) if err != nil { return nil, fmt.Errorf("list detached expired sessions: %w", err) } defer rows.Close() var sessions []RemoteSession for rows.Next() { item, err := scanRemoteSession(rows) if err != nil { return nil, err } sessions = append(sessions, *item) } return sessions, rows.Err() } func (r *postgresRemoteSessionRepository) UpdateState(ctx context.Context, params UpdateRemoteSessionStateParams) error { const query = ` UPDATE remote_sessions SET state = $2, worker_id = NULLIF($3, ''), detach_deadline_at = $4, last_heartbeat_at = $5, takeover_version = $6, updated_at = $7 WHERE id = $1::uuid ` if _, err := r.db.Exec(ctx, query, params.RemoteSessionID, params.State, params.WorkerID, params.DetachDeadlineAt, params.LastHeartbeatAt, params.TakeoverVersion, params.UpdatedAt, ); err != nil { return fmt.Errorf("update remote session state: %w", err) } return nil } func (r *postgresSessionAttachmentRepository) Create(ctx context.Context, attachment SessionAttachment) error { const query = ` INSERT INTO session_attachments ( id, remote_session_id, user_id, device_id, role, state, superseded_by, takeover_of, attached_at, detached_at, last_input_at, metadata, created_at, updated_at ) VALUES ( $1::uuid, $2::uuid, $3::uuid, $4::uuid, $5, $6, NULLIF($7, '')::uuid, NULLIF($8, '')::uuid, $9, $10, $11, $12::jsonb, $13, $14 ) ` if _, err := r.db.Exec(ctx, query, attachment.ID, attachment.RemoteSessionID, attachment.UserID, attachment.DeviceID, attachment.Role, attachment.State, stringValue(attachment.SupersededBy), stringValue(attachment.TakeoverOf), attachment.AttachedAt, attachment.DetachedAt, attachment.LastInputAt, jsonPayload(attachment.Metadata), attachment.CreatedAt, attachment.UpdatedAt, ); err != nil { return fmt.Errorf("create session attachment: %w", err) } return nil } func (r *postgresSessionAttachmentRepository) GetByID(ctx context.Context, attachmentID string) (*SessionAttachment, error) { return r.getByID(ctx, attachmentID, "") } func (r *postgresSessionAttachmentRepository) GetByIDForUpdate(ctx context.Context, attachmentID string) (*SessionAttachment, error) { return r.getByID(ctx, attachmentID, " FOR UPDATE") } func (r *postgresSessionAttachmentRepository) getByID(ctx context.Context, attachmentID string, suffix string) (*SessionAttachment, error) { query := ` SELECT id::text, remote_session_id::text, user_id::text, device_id::text, role, state, superseded_by::text, takeover_of::text, attached_at, detached_at, last_input_at, metadata, created_at, updated_at FROM session_attachments WHERE id = $1::uuid` + suffix attachment, err := scanSessionAttachment(r.db.QueryRow(ctx, query, attachmentID)) if errors.Is(err, pgx.ErrNoRows) { return nil, nil } return attachment, err } func (r *postgresSessionAttachmentRepository) ListByRemoteSession(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error) { return r.listByRemoteSession(ctx, remoteSessionID, "") } func (r *postgresSessionAttachmentRepository) ListActiveByRemoteSessionForUpdate(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error) { return r.listByRemoteSession(ctx, remoteSessionID, " AND state IN ('attaching', 'active', 'reconnecting') FOR UPDATE") } func (r *postgresSessionAttachmentRepository) listByRemoteSession(ctx context.Context, remoteSessionID string, suffix string) ([]SessionAttachment, error) { query := ` SELECT id::text, remote_session_id::text, user_id::text, device_id::text, role, state, superseded_by::text, takeover_of::text, attached_at, detached_at, last_input_at, metadata, created_at, updated_at FROM session_attachments WHERE remote_session_id = $1::uuid` + suffix rows, err := r.db.Query(ctx, query, remoteSessionID) if err != nil { return nil, fmt.Errorf("list session attachments: %w", err) } defer rows.Close() var attachments []SessionAttachment for rows.Next() { item, err := scanSessionAttachment(rows) if err != nil { return nil, err } attachments = append(attachments, *item) } return attachments, rows.Err() } func (r *postgresSessionAttachmentRepository) UpdateState(ctx context.Context, params UpdateSessionAttachmentStateParams) error { const query = ` UPDATE session_attachments SET state = $2, detached_at = $3, last_input_at = $4, updated_at = $5 WHERE id = $1::uuid ` if _, err := r.db.Exec(ctx, query, params.AttachmentID, params.State, params.DetachedAt, params.LastInputAt, params.UpdatedAt, ); err != nil { return fmt.Errorf("update session attachment state: %w", err) } return nil } func (r *postgresSessionAttachmentRepository) Supersede(ctx context.Context, params SupersedeAttachmentParams) error { const query = ` UPDATE session_attachments SET state = 'superseded', superseded_by = $2::uuid, detached_at = $3, updated_at = $4 WHERE id = $1::uuid ` if _, err := r.db.Exec(ctx, query, params.PreviousAttachmentID, params.NextAttachmentID, params.DetachedAt, params.UpdatedAt, ); err != nil { return fmt.Errorf("supersede attachment: %w", err) } return nil } func (r *postgresResourcePolicyRepository) GetByResourceID(ctx context.Context, resourceID string) (*ResourcePolicy, error) { const query = ` SELECT resource_id::text, max_concurrent_sessions, takeover_policy, require_trusted_device, detach_grace_period_seconds, clipboard_enabled, clipboard_mode, file_transfer_enabled, COALESCE(file_transfer_mode, 'disabled'), created_at, updated_at FROM resource_policies WHERE resource_id = $1::uuid ` policy, err := scanResourcePolicy(r.db.QueryRow(ctx, query, resourceID)) if errors.Is(err, pgx.ErrNoRows) { return nil, nil } return policy, err } func (r *postgresResourcePolicyRepository) Upsert(ctx context.Context, policy ResourcePolicy) error { const query = ` INSERT INTO resource_policies ( resource_id, max_concurrent_sessions, takeover_policy, require_trusted_device, detach_grace_period_seconds, clipboard_enabled, clipboard_mode, file_transfer_enabled, file_transfer_mode, created_at, updated_at ) VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) ON CONFLICT (resource_id) DO UPDATE SET max_concurrent_sessions = EXCLUDED.max_concurrent_sessions, takeover_policy = EXCLUDED.takeover_policy, require_trusted_device = EXCLUDED.require_trusted_device, detach_grace_period_seconds = EXCLUDED.detach_grace_period_seconds, clipboard_enabled = EXCLUDED.clipboard_enabled, clipboard_mode = EXCLUDED.clipboard_mode, file_transfer_enabled = EXCLUDED.file_transfer_enabled, file_transfer_mode = EXCLUDED.file_transfer_mode, updated_at = EXCLUDED.updated_at ` clipboardMode := normalizeClipboardMode(policy.ClipboardMode) fileTransferMode := normalizeFileTransferMode(policy.FileTransferMode) if _, err := r.db.Exec(ctx, query, policy.ResourceID, policy.MaxConcurrentSessions, policy.TakeoverPolicy, policy.RequireTrustedDevice, int(policy.DetachGracePeriod.Seconds()), clipboardMode != ResourceClipboardModeDisabled, clipboardMode, fileTransferAllowsClientToServer(fileTransferMode), fileTransferMode, policy.CreatedAt, policy.UpdatedAt, ); err != nil { return fmt.Errorf("upsert resource policy: %w", err) } return nil } func (r *postgresResourceRuntimeRepository) GetByID(ctx context.Context, resourceID string) (*ResourceRuntimeSpec, error) { const query = ` SELECT id::text, organization_id::text, name, address, protocol, secret_ref, certificate_verification_mode, metadata FROM resources WHERE id = $1::uuid ` item := &ResourceRuntimeSpec{} var secretRef *string var metadata []byte if err := r.db.QueryRow(ctx, query, resourceID).Scan( &item.ID, &item.OrganizationID, &item.Name, &item.Address, &item.Protocol, &secretRef, &item.CertificateVerificationMode, &metadata, ); err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, nil } return nil, fmt.Errorf("get resource runtime spec: %w", err) } item.SecretRef = secretRef item.Metadata = metadata return item, nil } func (r *postgresAuditEventRepository) Create(ctx context.Context, event AuditEvent) error { const query = ` INSERT INTO audit_events ( id, actor_user_id, actor_device_id, event_type, target_type, target_id, remote_session_id, payload, created_at ) VALUES ( $1::uuid, NULLIF($2, '')::uuid, NULLIF($3, '')::uuid, $4, $5, $6, NULLIF($7, '')::uuid, $8::jsonb, $9 ) ` if _, err := r.db.Exec(ctx, query, event.ID, stringValue(event.ActorUserID), stringValue(event.ActorDeviceID), event.EventType, event.TargetType, event.TargetID, stringValue(event.RemoteSessionID), jsonPayload(event.Payload), event.CreatedAt, ); err != nil { return fmt.Errorf("create audit event: %w", err) } return nil } func (r *postgresAccessRepository) IsTrustedDevice(ctx context.Context, userID, deviceID string) (bool, error) { const query = ` SELECT EXISTS( SELECT 1 FROM devices WHERE id = $1::uuid AND user_id = $2::uuid AND trust_status = 'trusted' AND revoked_at IS NULL ) ` var trusted bool if err := r.db.QueryRow(ctx, query, deviceID, userID).Scan(&trusted); err != nil { return false, fmt.Errorf("check trusted device: %w", err) } return trusted, nil } func (r *postgresAccessRepository) GetPlatformRole(ctx context.Context, userID string) (string, error) { return authority.EffectivePlatformRole(ctx, r.db, r.authority, userID) } func (r *postgresAccessRepository) GetOrganizationRole(ctx context.Context, organizationID, userID string) (string, bool, error) { const query = ` SELECT role_id FROM organization_memberships WHERE organization_id = $1::uuid AND user_id = $2::uuid AND status = 'active' ` var role string if err := r.db.QueryRow(ctx, query, organizationID, userID).Scan(&role); err != nil { if errors.Is(err, pgx.ErrNoRows) { return "", false, nil } return "", false, fmt.Errorf("get organization role: %w", err) } return role, true, nil } type scanner interface { Scan(dest ...any) error } func scanRemoteSession(row scanner) (*RemoteSession, error) { item := &RemoteSession{} var detachDeadlineAt, lastHeartbeatAt *time.Time var metadata []byte if err := row.Scan( &item.ID, &item.OrganizationID, &item.ResourceID, &item.Protocol, &item.State, &item.WorkerID, &item.ControllerUserID, &detachDeadlineAt, &lastHeartbeatAt, &item.TakeoverVersion, &metadata, &item.CreatedAt, &item.UpdatedAt, ); err != nil { return nil, fmt.Errorf("scan remote session: %w", err) } item.DetachDeadlineAt = detachDeadlineAt item.LastHeartbeatAt = lastHeartbeatAt item.Metadata = metadata item.RenderQualityProfile = renderQualityProfileFromSessionMetadata(metadata) return item, nil } func scanSessionAttachment(row scanner) (*SessionAttachment, error) { item := &SessionAttachment{} var supersededBy, takeoverOf *string var attachedAt, detachedAt, lastInputAt *time.Time var metadata []byte if err := row.Scan( &item.ID, &item.RemoteSessionID, &item.UserID, &item.DeviceID, &item.Role, &item.State, &supersededBy, &takeoverOf, &attachedAt, &detachedAt, &lastInputAt, &metadata, &item.CreatedAt, &item.UpdatedAt, ); err != nil { return nil, fmt.Errorf("scan session attachment: %w", err) } item.SupersededBy = supersededBy item.TakeoverOf = takeoverOf item.AttachedAt = attachedAt item.DetachedAt = detachedAt item.LastInputAt = lastInputAt item.Metadata = metadata return item, nil } func scanResourcePolicy(row scanner) (*ResourcePolicy, error) { item := &ResourcePolicy{} var detachGraceSeconds int if err := row.Scan( &item.ResourceID, &item.MaxConcurrentSessions, &item.TakeoverPolicy, &item.RequireTrustedDevice, &detachGraceSeconds, &item.ClipboardEnabled, &item.ClipboardMode, &item.FileTransferEnabled, &item.FileTransferMode, &item.CreatedAt, &item.UpdatedAt, ); err != nil { return nil, fmt.Errorf("scan resource policy: %w", err) } item.DetachGracePeriod = time.Duration(detachGraceSeconds) * time.Second item.ClipboardMode = normalizeClipboardMode(item.ClipboardMode) item.ClipboardEnabled = item.ClipboardMode != ResourceClipboardModeDisabled item.FileTransferMode = normalizeFileTransferMode(item.FileTransferMode) item.FileTransferEnabled = fileTransferAllowsClientToServer(item.FileTransferMode) return item, nil } func jsonPayload(payload []byte) []byte { if len(payload) == 0 { return []byte(`{}`) } if json.Valid(payload) { return payload } return []byte(`{}`) } func stringValue(value *string) string { if value == nil { return "" } return *value }