Initial project snapshot

This commit is contained in:
2026-04-28 22:29:50 +03:00
commit 8ba0561f4f
365 changed files with 91832 additions and 0 deletions
+19
View File
@@ -0,0 +1,19 @@
package auth
import "errors"
var (
ErrInvalidCredentials = errors.New("invalid credentials")
ErrInvalidRefreshToken = errors.New("invalid refresh token")
ErrAuthSessionRevoked = errors.New("auth session revoked")
ErrDeviceRevoked = errors.New("device revoked")
ErrDeviceNotTrusted = errors.New("device not trusted")
ErrAuthSessionNotFound = errors.New("auth session not found")
ErrTrustedDeviceMissing = errors.New("trusted device not found")
ErrInstallationAlreadyBootstrapped = errors.New("installation is already bootstrapped")
ErrInstallationActivationRequired = errors.New("signed installation activation is required")
ErrInvalidInstallationActivation = errors.New("invalid installation activation")
ErrInsecureBootstrapDisabled = errors.New("insecure installation bootstrap is disabled")
ErrInvalidBootstrapOwner = errors.New("invalid bootstrap owner")
)
+114
View File
@@ -0,0 +1,114 @@
package auth
import (
"encoding/json"
"time"
)
type DeviceTrustStatus string
const (
DeviceTrustStatusPending DeviceTrustStatus = "pending"
DeviceTrustStatusTrusted DeviceTrustStatus = "trusted"
DeviceTrustStatusRevoked DeviceTrustStatus = "revoked"
)
type User struct {
ID string
Email string
PasswordHash string
MFAEnabled bool
CreatedAt time.Time
UpdatedAt time.Time
}
type Device struct {
ID string
UserID string
Fingerprint string
Label string
TrustStatus DeviceTrustStatus
TrustedAt *time.Time
LastSeenAt *time.Time
RevokedAt *time.Time
RevokedReason *string
CreatedAt time.Time
UpdatedAt time.Time
}
type AuthSession struct {
ID string
UserID string
DeviceID string
RefreshTokenHash string
RefreshExpiresAt time.Time
LastSeenAt *time.Time
LastRotatedAt *time.Time
RevokedAt *time.Time
RevokedReason *string
CreatedAt time.Time
UpdatedAt time.Time
}
type LoginCommand struct {
Email string `json:"email"`
Password string `json:"password"`
DeviceFingerprint string `json:"device_fingerprint"`
DeviceLabel string `json:"device_label"`
TrustDevice bool `json:"trust_device"`
}
type RefreshCommand struct {
RefreshToken string `json:"refresh_token"`
}
type BootstrapOwnerCommand struct {
Email string `json:"email"`
Password string `json:"password"`
ActivationPayload json.RawMessage `json:"activation_payload"`
ActivationSignature string `json:"activation_signature"`
}
type RevokeAuthSessionCommand struct {
UserID string `json:"user_id"`
AuthSessionID string `json:"auth_session_id"`
Reason string `json:"reason"`
}
type RevokeDeviceCommand struct {
UserID string `json:"user_id"`
DeviceID string `json:"device_id"`
Reason string `json:"reason"`
}
type TokenPair struct {
AccessToken string `json:"access_token"`
AccessTokenExpiresAt time.Time `json:"access_token_expires_at"`
RefreshToken string `json:"refresh_token"`
RefreshTokenExpiresAt time.Time `json:"refresh_token_expires_at"`
}
type AuthResult struct {
User User `json:"user"`
Device Device `json:"device"`
AuthSession AuthSession `json:"auth_session"`
Tokens TokenPair `json:"tokens"`
}
type InstallationStatus struct {
Bootstrapped bool `json:"bootstrapped"`
AuthorityState string `json:"authority_state"`
InstallID string `json:"install_id,omitempty"`
BootstrappedOwnerEmail string `json:"bootstrapped_owner_email,omitempty"`
BootstrappedAt *time.Time `json:"bootstrapped_at,omitempty"`
AuthorityMode string `json:"authority_mode"`
StrictAuthority bool `json:"strict_authority"`
RootFingerprint string `json:"root_fingerprint,omitempty"`
InsecureBootstrapAllowed bool `json:"insecure_bootstrap_allowed"`
}
type BootstrapOwnerResult struct {
Installation InstallationStatus `json:"installation"`
User User `json:"user"`
PlatformRole string `json:"platform_role"`
}
+173
View File
@@ -0,0 +1,173 @@
package auth
import (
"encoding/json"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
"github.com/example/remote-access-platform/backend/internal/platform/module"
)
type Module struct {
service *Service
}
func NewModule(deps module.Dependencies, service *Service) *Module {
return &Module{service: service}
}
func (m *Module) Name() string {
return "auth"
}
func (m *Module) RegisterRoutes(router chi.Router) {
router.Route("/installation", func(r chi.Router) {
r.Get("/status", m.handleInstallationStatus)
r.Post("/bootstrap-owner", m.handleBootstrapOwner)
})
router.Route("/auth", func(r chi.Router) {
r.Post("/login", m.handleLogin)
r.Post("/refresh", m.handleRefresh)
r.Post("/sessions/revoke", m.handleRevokeAuthSession)
r.Get("/devices", m.handleTrustedDevices)
r.Post("/devices/{deviceID}/revoke", m.handleRevokeTrustedDevice)
})
}
func (m *Module) handleInstallationStatus(w http.ResponseWriter, r *http.Request) {
status, err := m.service.InstallationStatus(r.Context())
if err != nil {
statusCode, message := m.service.MapError(err)
httpx.WriteError(w, statusCode, message)
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"installation": status})
}
func (m *Module) handleBootstrapOwner(w http.ResponseWriter, r *http.Request) {
var cmd BootstrapOwnerCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid installation bootstrap payload")
return
}
result, err := m.service.BootstrapOwner(r.Context(), cmd)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusCreated, result)
}
func (m *Module) handleLogin(w http.ResponseWriter, r *http.Request) {
var cmd LoginCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid login payload")
return
}
result, err := m.service.Login(r.Context(), cmd)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusOK, result)
}
func (m *Module) handleRefresh(w http.ResponseWriter, r *http.Request) {
var cmd RefreshCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid refresh payload")
return
}
result, err := m.service.Refresh(r.Context(), cmd)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusOK, result)
}
func (m *Module) handleRevokeAuthSession(w http.ResponseWriter, r *http.Request) {
var cmd RevokeAuthSessionCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid auth session revoke payload")
return
}
if err := m.service.RevokeAuthSession(r.Context(), cmd); err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{
"status": "revoked",
"message": httpx.NewMessage(
"auth.session.revoked",
"status.auth.session.revoked",
"Auth session revoked.",
nil,
"",
),
})
}
func (m *Module) handleTrustedDevices(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id")
if userID == "" {
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
return
}
devices, err := m.service.ListTrustedDevices(r.Context(), userID)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{
"devices": devices,
})
}
func (m *Module) handleRevokeTrustedDevice(w http.ResponseWriter, r *http.Request) {
var payload struct {
UserID string `json:"user_id"`
Reason string `json:"reason"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid device revoke payload")
return
}
err := m.service.RevokeTrustedDevice(r.Context(), RevokeDeviceCommand{
UserID: payload.UserID,
DeviceID: chi.URLParam(r, "deviceID"),
Reason: payload.Reason,
})
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{
"status": "revoked",
"message": httpx.NewMessage(
"auth.device.revoked",
"status.auth.device.revoked",
"Trusted device revoked.",
nil,
"",
),
})
}
@@ -0,0 +1,525 @@
package auth
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
)
type postgresStore struct {
db postgresplatform.DBTX
}
type PostgresTransactor struct {
pool *pgxpool.Pool
}
func NewPostgresStore(pool *pgxpool.Pool) Store {
return &postgresStore{db: pool}
}
func NewPostgresTransactor(pool *pgxpool.Pool) *PostgresTransactor {
return &PostgresTransactor{pool: pool}
}
func (t *PostgresTransactor) WithinTransaction(ctx context.Context, fn func(store Store) error) error {
return postgresplatform.WithTransaction(ctx, t.pool, func(tx pgx.Tx) error {
return fn(&postgresStore{db: tx})
})
}
func (s *postgresStore) Users() UserRepository {
return &postgresUserRepository{db: s.db}
}
func (s *postgresStore) Devices() DeviceRepository {
return &postgresDeviceRepository{db: s.db}
}
func (s *postgresStore) AuthSessions() AuthSessionRepository {
return &postgresAuthSessionRepository{db: s.db}
}
func (s *postgresStore) Installation() InstallationRepository {
return &postgresInstallationRepository{db: s.db}
}
type postgresUserRepository struct {
db postgresplatform.DBTX
}
type postgresDeviceRepository struct {
db postgresplatform.DBTX
}
type postgresAuthSessionRepository struct {
db postgresplatform.DBTX
}
type postgresInstallationRepository struct {
db postgresplatform.DBTX
}
func (r *postgresUserRepository) GetByEmail(ctx context.Context, email string) (*User, error) {
const query = `
SELECT id::text, email, password_hash, mfa_enabled, created_at, updated_at
FROM users
WHERE email = $1
`
return scanOptionalUser(r.db.QueryRow(ctx, query, email))
}
func (r *postgresUserRepository) GetByID(ctx context.Context, userID string) (*User, error) {
const query = `
SELECT id::text, email, password_hash, mfa_enabled, created_at, updated_at
FROM users
WHERE id = $1::uuid
`
return scanOptionalUser(r.db.QueryRow(ctx, query, userID))
}
func (r *postgresDeviceRepository) Upsert(ctx context.Context, params UpsertDeviceParams) (*Device, error) {
const query = `
INSERT INTO devices (
user_id,
device_fingerprint,
device_label,
trust_status,
trusted_at,
last_seen_at,
created_at,
updated_at
) VALUES (
$1::uuid,
$2,
$3,
CASE WHEN $4 THEN 'trusted' ELSE 'pending' END,
CASE WHEN $4 THEN $5::timestamptz ELSE NULL::timestamptz END,
$5::timestamptz,
$5::timestamptz,
$5::timestamptz
)
ON CONFLICT (user_id, device_fingerprint) DO UPDATE SET
device_label = EXCLUDED.device_label,
last_seen_at = EXCLUDED.last_seen_at,
updated_at = EXCLUDED.updated_at,
trust_status = CASE
WHEN devices.trust_status = 'revoked' THEN devices.trust_status
WHEN devices.trust_status = 'trusted' THEN devices.trust_status
WHEN EXCLUDED.trust_status = 'trusted' THEN 'trusted'
ELSE devices.trust_status
END,
trusted_at = CASE
WHEN devices.trust_status = 'trusted' THEN devices.trusted_at
WHEN EXCLUDED.trust_status = 'trusted' THEN EXCLUDED.trusted_at
ELSE devices.trusted_at
END
RETURNING
id::text, user_id::text, device_fingerprint, COALESCE(device_label, ''),
trust_status, trusted_at, last_seen_at, revoked_at, revoked_reason, created_at, updated_at
`
return scanDevice(r.db.QueryRow(ctx, query,
params.UserID,
params.Fingerprint,
params.Label,
params.TrustRequested,
params.SeenAt,
))
}
func (r *postgresDeviceRepository) GetByIDForUser(ctx context.Context, userID, deviceID string) (*Device, error) {
const query = `
SELECT id::text, user_id::text, device_fingerprint, COALESCE(device_label, ''),
trust_status, trusted_at, last_seen_at, revoked_at, revoked_reason, created_at, updated_at
FROM devices
WHERE id = $1::uuid AND user_id = $2::uuid
`
device, err := scanDevice(r.db.QueryRow(ctx, query, deviceID, userID))
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return device, err
}
func (r *postgresDeviceRepository) ListTrustedByUser(ctx context.Context, userID string) ([]Device, error) {
const query = `
SELECT id::text, user_id::text, device_fingerprint, COALESCE(device_label, ''),
trust_status, trusted_at, last_seen_at, revoked_at, revoked_reason, created_at, updated_at
FROM devices
WHERE user_id = $1::uuid AND trust_status = 'trusted' AND revoked_at IS NULL
ORDER BY created_at DESC
`
rows, err := r.db.Query(ctx, query, userID)
if err != nil {
return nil, fmt.Errorf("query trusted devices: %w", err)
}
defer rows.Close()
var devices []Device
for rows.Next() {
device, err := scanDevice(rows)
if err != nil {
return nil, err
}
devices = append(devices, *device)
}
return devices, rows.Err()
}
func (r *postgresDeviceRepository) Revoke(ctx context.Context, params RevokeDeviceParams) error {
const query = `
UPDATE devices
SET trust_status = 'revoked',
revoked_at = $3,
revoked_reason = $4,
updated_at = $3
WHERE id = $1::uuid AND user_id = $2::uuid
`
if _, err := r.db.Exec(ctx, query, params.DeviceID, params.UserID, params.RevokedAt, params.Reason); err != nil {
return fmt.Errorf("revoke device: %w", err)
}
return nil
}
func (r *postgresAuthSessionRepository) Create(ctx context.Context, session AuthSession) error {
const query = `
INSERT INTO auth_sessions (
id,
user_id,
device_id,
refresh_token_hash,
refresh_expires_at,
last_seen_at,
created_at,
updated_at
) VALUES ($1::uuid, $2::uuid, $3::uuid, $4, $5, $6, $7, $8)
`
if _, err := r.db.Exec(ctx, query,
session.ID,
session.UserID,
session.DeviceID,
session.RefreshTokenHash,
session.RefreshExpiresAt,
session.LastSeenAt,
session.CreatedAt,
session.UpdatedAt,
); err != nil {
return fmt.Errorf("create auth session: %w", err)
}
return nil
}
func (r *postgresAuthSessionRepository) GetByID(ctx context.Context, authSessionID string) (*AuthSession, error) {
return r.getByID(ctx, authSessionID, "")
}
func (r *postgresAuthSessionRepository) GetByIDForUpdate(ctx context.Context, authSessionID string) (*AuthSession, error) {
return r.getByID(ctx, authSessionID, " FOR UPDATE")
}
func (r *postgresAuthSessionRepository) getByID(ctx context.Context, authSessionID string, suffix string) (*AuthSession, error) {
query := `
SELECT id::text, user_id::text, device_id::text, refresh_token_hash, refresh_expires_at,
last_seen_at, last_rotated_at, revoked_at, revoked_reason, created_at, updated_at
FROM auth_sessions
WHERE id = $1::uuid` + suffix
session, err := scanAuthSession(r.db.QueryRow(ctx, query, authSessionID))
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return session, err
}
func (r *postgresAuthSessionRepository) Rotate(ctx context.Context, params RotateAuthSessionParams) error {
const query = `
UPDATE auth_sessions
SET refresh_token_hash = $2,
refresh_expires_at = $3,
last_seen_at = $4,
last_rotated_at = $5,
updated_at = $5
WHERE id = $1::uuid AND revoked_at IS NULL
`
if _, err := r.db.Exec(ctx, query,
params.AuthSessionID,
params.RefreshTokenHash,
params.RefreshExpiresAt,
params.LastSeenAt,
params.LastRotatedAt,
); err != nil {
return fmt.Errorf("rotate auth session: %w", err)
}
return nil
}
func (r *postgresAuthSessionRepository) Touch(ctx context.Context, authSessionID string, seenAt time.Time) error {
const query = `
UPDATE auth_sessions
SET last_seen_at = $2, updated_at = $2
WHERE id = $1::uuid AND revoked_at IS NULL
`
if _, err := r.db.Exec(ctx, query, authSessionID, seenAt); err != nil {
return fmt.Errorf("touch auth session: %w", err)
}
return nil
}
func (r *postgresAuthSessionRepository) Revoke(ctx context.Context, params RevokeAuthSessionParams) error {
const query = `
UPDATE auth_sessions
SET revoked_at = $3,
revoked_reason = $4,
updated_at = $3
WHERE id = $1::uuid AND user_id = $2::uuid AND revoked_at IS NULL
`
if _, err := r.db.Exec(ctx, query, params.AuthSessionID, params.UserID, params.RevokedAt, params.Reason); err != nil {
return fmt.Errorf("revoke auth session: %w", err)
}
return nil
}
func (r *postgresAuthSessionRepository) RevokeByDevice(ctx context.Context, userID, deviceID, reason string, revokedAt time.Time) error {
const query = `
UPDATE auth_sessions
SET revoked_at = $3,
revoked_reason = $4,
updated_at = $3
WHERE user_id = $1::uuid AND device_id = $2::uuid AND revoked_at IS NULL
`
if _, err := r.db.Exec(ctx, query, userID, deviceID, revokedAt, reason); err != nil {
return fmt.Errorf("revoke auth sessions by device: %w", err)
}
return nil
}
func (r *postgresInstallationRepository) GetStatus(ctx context.Context) (*InstallationAuthorityState, error) {
const query = `
SELECT install_id, authority_state, product_root_key_fingerprint, bootstrapped_owner_email, bootstrapped_at
FROM installation_authority
WHERE id = 1
`
status := &InstallationAuthorityState{}
if err := r.db.QueryRow(ctx, query).Scan(
&status.InstallID,
&status.AuthorityState,
&status.ProductRootFingerprint,
&status.BootstrappedOwnerEmail,
&status.BootstrappedAt,
); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return &InstallationAuthorityState{
Bootstrapped: false,
AuthorityState: "unbootstrapped",
}, nil
}
return nil, fmt.Errorf("get installation status: %w", err)
}
status.Bootstrapped = true
return status, nil
}
func (r *postgresInstallationRepository) BootstrapOwner(ctx context.Context, params BootstrapOwnerParams) (*User, error) {
var existingInstallID string
if err := r.db.QueryRow(ctx, `
SELECT install_id
FROM installation_authority
WHERE id = 1
FOR UPDATE
`).Scan(&existingInstallID); err != nil && !errors.Is(err, pgx.ErrNoRows) {
return nil, fmt.Errorf("lock installation authority: %w", err)
} else if err == nil {
return nil, ErrInstallationAlreadyBootstrapped
}
email := strings.ToLower(strings.TrimSpace(params.Email))
now := params.Now.UTC()
user, err := scanOptionalUser(r.db.QueryRow(ctx, `
INSERT INTO users (email, password_hash, mfa_enabled, platform_role, created_at, updated_at)
VALUES ($1, $2, FALSE, $3, $4, $4)
ON CONFLICT (email) DO UPDATE SET
password_hash = EXCLUDED.password_hash,
platform_role = EXCLUDED.platform_role,
updated_at = EXCLUDED.updated_at
RETURNING id::text, email, password_hash, mfa_enabled, created_at, updated_at
`, email, params.PasswordHash, params.Role, now))
if err != nil {
return nil, fmt.Errorf("upsert bootstrap owner: %w", err)
}
if user == nil {
return nil, fmt.Errorf("upsert bootstrap owner returned no user")
}
payload := json.RawMessage(`{}`)
if len(params.ActivationPayload) > 0 {
payload = params.ActivationPayload
}
if _, err := r.db.Exec(ctx, `
INSERT INTO installation_authority (
id,
install_id,
authority_state,
product_root_key_fingerprint,
activation_payload,
activation_signature,
bootstrapped_owner_email,
bootstrapped_at,
created_at,
updated_at
) VALUES (
1,
$1,
'active',
$2,
$3::jsonb,
$4,
$5,
$6,
$6,
$6
)
`, params.InstallID, params.ProductRootKeyFingerprint, []byte(payload), params.ActivationSignature, email, now); err != nil {
return nil, fmt.Errorf("insert installation authority: %w", err)
}
if _, err := r.db.Exec(ctx, `
UPDATE platform_role_grants
SET revoked_at = $4
WHERE user_id = $1::uuid
AND role = $2
AND install_id = $3
AND revoked_at IS NULL
`, user.ID, params.Role, params.InstallID, now); err != nil {
return nil, fmt.Errorf("revoke superseded platform role grants: %w", err)
}
if _, err := r.db.Exec(ctx, `
INSERT INTO platform_role_grants (
user_id,
role,
install_id,
grant_payload,
grant_signature,
grant_source,
granted_at,
expires_at,
metadata
) VALUES (
$1::uuid,
$2,
$3,
$4::jsonb,
$5,
$6,
$7,
$8,
'{"bootstrap_owner":true}'::jsonb
)
`, user.ID, params.Role, params.InstallID, []byte(payload), params.ActivationSignature, params.GrantSource, now, params.ExpiresAt); err != nil {
return nil, fmt.Errorf("insert platform role grant: %w", err)
}
if _, err := r.db.Exec(ctx, `
INSERT INTO organization_memberships (
organization_id,
user_id,
role_id,
status,
invited_by_user_id,
created_at,
updated_at
)
SELECT id, $1::uuid, 'org_owner', 'active', $1::uuid, $2, $2
FROM organizations
WHERE slug = 'default'
ON CONFLICT (organization_id, user_id) DO UPDATE SET
role_id = 'org_owner',
status = 'active',
invited_by_user_id = EXCLUDED.invited_by_user_id,
updated_at = EXCLUDED.updated_at
`, user.ID, now); err != nil {
return nil, fmt.Errorf("upsert default organization owner membership: %w", err)
}
return user, nil
}
type scanner interface {
Scan(dest ...any) error
}
func scanOptionalUser(row scanner) (*User, error) {
user := &User{}
if err := row.Scan(
&user.ID,
&user.Email,
&user.PasswordHash,
&user.MFAEnabled,
&user.CreatedAt,
&user.UpdatedAt,
); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("scan user: %w", err)
}
return user, nil
}
func scanDevice(row scanner) (*Device, error) {
device := &Device{}
var trustedAt, lastSeenAt, revokedAt *time.Time
var revokedReason *string
if err := row.Scan(
&device.ID,
&device.UserID,
&device.Fingerprint,
&device.Label,
&device.TrustStatus,
&trustedAt,
&lastSeenAt,
&revokedAt,
&revokedReason,
&device.CreatedAt,
&device.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan device: %w", err)
}
device.TrustedAt = trustedAt
device.LastSeenAt = lastSeenAt
device.RevokedAt = revokedAt
device.RevokedReason = revokedReason
return device, nil
}
func scanAuthSession(row scanner) (*AuthSession, error) {
session := &AuthSession{}
var lastSeenAt, lastRotatedAt, revokedAt *time.Time
var revokedReason *string
if err := row.Scan(
&session.ID,
&session.UserID,
&session.DeviceID,
&session.RefreshTokenHash,
&session.RefreshExpiresAt,
&lastSeenAt,
&lastRotatedAt,
&revokedAt,
&revokedReason,
&session.CreatedAt,
&session.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan auth session: %w", err)
}
session.LastSeenAt = lastSeenAt
session.LastRotatedAt = lastRotatedAt
session.RevokedAt = revokedAt
session.RevokedReason = revokedReason
return session, nil
}
@@ -0,0 +1,97 @@
package auth
import (
"context"
"encoding/json"
"time"
)
type UserRepository interface {
GetByEmail(ctx context.Context, email string) (*User, error)
GetByID(ctx context.Context, userID string) (*User, error)
}
type DeviceRepository interface {
Upsert(ctx context.Context, params UpsertDeviceParams) (*Device, error)
GetByIDForUser(ctx context.Context, userID, deviceID string) (*Device, error)
ListTrustedByUser(ctx context.Context, userID string) ([]Device, error)
Revoke(ctx context.Context, params RevokeDeviceParams) error
}
type AuthSessionRepository interface {
Create(ctx context.Context, session AuthSession) error
GetByID(ctx context.Context, authSessionID string) (*AuthSession, error)
GetByIDForUpdate(ctx context.Context, authSessionID string) (*AuthSession, error)
Rotate(ctx context.Context, params RotateAuthSessionParams) error
Touch(ctx context.Context, authSessionID string, seenAt time.Time) error
Revoke(ctx context.Context, params RevokeAuthSessionParams) error
RevokeByDevice(ctx context.Context, userID, deviceID, reason string, revokedAt time.Time) error
}
type InstallationRepository interface {
GetStatus(ctx context.Context) (*InstallationAuthorityState, error)
BootstrapOwner(ctx context.Context, params BootstrapOwnerParams) (*User, error)
}
type Store interface {
Users() UserRepository
Devices() DeviceRepository
AuthSessions() AuthSessionRepository
Installation() InstallationRepository
}
type Transactor interface {
WithinTransaction(ctx context.Context, fn func(store Store) error) error
}
type UpsertDeviceParams struct {
UserID string
Fingerprint string
Label string
TrustRequested bool
SeenAt time.Time
}
type RotateAuthSessionParams struct {
AuthSessionID string
RefreshTokenHash string
RefreshExpiresAt time.Time
LastSeenAt time.Time
LastRotatedAt time.Time
}
type RevokeAuthSessionParams struct {
AuthSessionID string
UserID string
Reason string
RevokedAt time.Time
}
type RevokeDeviceParams struct {
UserID string
DeviceID string
Reason string
RevokedAt time.Time
}
type InstallationAuthorityState struct {
Bootstrapped bool
AuthorityState string
InstallID string
ProductRootFingerprint string
BootstrappedOwnerEmail string
BootstrappedAt *time.Time
}
type BootstrapOwnerParams struct {
Email string
PasswordHash string
Role string
InstallID string
ProductRootKeyFingerprint string
ActivationPayload json.RawMessage
ActivationSignature string
GrantSource string
ExpiresAt *time.Time
Now time.Time
}
+440
View File
@@ -0,0 +1,440 @@
package auth
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"github.com/example/remote-access-platform/backend/internal/platform/authority"
"github.com/example/remote-access-platform/backend/internal/platform/module"
)
type Service struct {
cfg module.Config
store Store
transactor Transactor
tokenManager *TokenManager
authority *authority.Verifier
now func() time.Time
}
func NewService(deps module.Dependencies, store Store, transactor Transactor, verifiers ...*authority.Verifier) *Service {
var authorityVerifier *authority.Verifier
if len(verifiers) > 0 {
authorityVerifier = verifiers[0]
} else if verifier, err := authority.NewVerifier(deps.Config.Installation); err == nil {
authorityVerifier = verifier
}
return &Service{
cfg: deps.Config,
store: store,
transactor: transactor,
tokenManager: NewTokenManager(TokenConfig{
Issuer: deps.Config.Auth.Issuer,
AccessTokenSecret: deps.Config.Auth.AccessTokenSecret,
RefreshHashSecret: deps.Config.Auth.RefreshHashSecret,
AccessTokenTTL: deps.Config.Auth.AccessTokenTTL,
RefreshTokenTTL: deps.Config.Auth.RefreshTokenTTL,
}),
authority: authorityVerifier,
now: time.Now,
}
}
func (s *Service) Login(ctx context.Context, cmd LoginCommand) (*AuthResult, error) {
user, err := s.store.Users().GetByEmail(ctx, cmd.Email)
if err != nil {
return nil, err
}
if user == nil {
return nil, ErrInvalidCredentials
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(cmd.Password)); err != nil {
return nil, ErrInvalidCredentials
}
var result AuthResult
now := s.now().UTC()
if err := s.transactor.WithinTransaction(ctx, func(store Store) error {
device, err := store.Devices().Upsert(ctx, UpsertDeviceParams{
UserID: user.ID,
Fingerprint: cmd.DeviceFingerprint,
Label: cmd.DeviceLabel,
TrustRequested: cmd.TrustDevice,
SeenAt: now,
})
if err != nil {
return err
}
if device.TrustStatus == DeviceTrustStatusRevoked {
return ErrDeviceRevoked
}
authSessionID := uuid.NewString()
refreshToken, refreshHash, refreshExpiresAt, err := s.tokenManager.IssueRefreshToken(authSessionID, now)
if err != nil {
return err
}
accessToken, accessExpiresAt, err := s.tokenManager.IssueAccessToken(user.ID, authSessionID, device.ID, now)
if err != nil {
return err
}
session := AuthSession{
ID: authSessionID,
UserID: user.ID,
DeviceID: device.ID,
RefreshTokenHash: refreshHash,
RefreshExpiresAt: refreshExpiresAt,
CreatedAt: now,
UpdatedAt: now,
LastSeenAt: &now,
}
if err := store.AuthSessions().Create(ctx, session); err != nil {
return err
}
result = AuthResult{
User: *user,
Device: *device,
AuthSession: session,
Tokens: TokenPair{
AccessToken: accessToken,
AccessTokenExpiresAt: accessExpiresAt,
RefreshToken: refreshToken,
RefreshTokenExpiresAt: refreshExpiresAt,
},
}
return nil
}); err != nil {
return nil, err
}
return &result, nil
}
func (s *Service) Refresh(ctx context.Context, cmd RefreshCommand) (*AuthResult, error) {
authSessionID, err := s.tokenManager.ParseRefreshToken(cmd.RefreshToken)
if err != nil {
return nil, err
}
var result AuthResult
now := s.now().UTC()
if err := s.transactor.WithinTransaction(ctx, func(store Store) error {
session, err := store.AuthSessions().GetByIDForUpdate(ctx, authSessionID)
if err != nil {
return err
}
if session == nil {
return ErrInvalidRefreshToken
}
if session.RevokedAt != nil {
return ErrAuthSessionRevoked
}
if now.After(session.RefreshExpiresAt) {
if revokeErr := store.AuthSessions().Revoke(ctx, RevokeAuthSessionParams{
AuthSessionID: session.ID,
UserID: session.UserID,
Reason: "refresh_token_expired",
RevokedAt: now,
}); revokeErr != nil {
return revokeErr
}
return ErrInvalidRefreshToken
}
expectedHash := s.tokenManager.HashRefreshToken(cmd.RefreshToken)
if expectedHash != session.RefreshTokenHash {
if revokeErr := store.AuthSessions().Revoke(ctx, RevokeAuthSessionParams{
AuthSessionID: session.ID,
UserID: session.UserID,
Reason: "refresh_rotation_reuse_detected",
RevokedAt: now,
}); revokeErr != nil {
return revokeErr
}
return ErrInvalidRefreshToken
}
user, err := store.Users().GetByID(ctx, session.UserID)
if err != nil {
return err
}
if user == nil {
return ErrInvalidCredentials
}
device, err := store.Devices().GetByIDForUser(ctx, session.UserID, session.DeviceID)
if err != nil {
return err
}
if device == nil {
return ErrTrustedDeviceMissing
}
if device.TrustStatus == DeviceTrustStatusRevoked {
return ErrDeviceRevoked
}
refreshToken, refreshHash, refreshExpiresAt, err := s.tokenManager.IssueRefreshToken(session.ID, now)
if err != nil {
return err
}
accessToken, accessExpiresAt, err := s.tokenManager.IssueAccessToken(user.ID, session.ID, device.ID, now)
if err != nil {
return err
}
if err := store.AuthSessions().Rotate(ctx, RotateAuthSessionParams{
AuthSessionID: session.ID,
RefreshTokenHash: refreshHash,
RefreshExpiresAt: refreshExpiresAt,
LastSeenAt: now,
LastRotatedAt: now,
}); err != nil {
return err
}
result = AuthResult{
User: *user,
Device: *device,
AuthSession: AuthSession{
ID: session.ID,
UserID: session.UserID,
DeviceID: session.DeviceID,
RefreshTokenHash: refreshHash,
RefreshExpiresAt: refreshExpiresAt,
LastSeenAt: &now,
LastRotatedAt: &now,
},
Tokens: TokenPair{
AccessToken: accessToken,
AccessTokenExpiresAt: accessExpiresAt,
RefreshToken: refreshToken,
RefreshTokenExpiresAt: refreshExpiresAt,
},
}
return nil
}); err != nil {
return nil, err
}
return &result, nil
}
func (s *Service) InstallationStatus(ctx context.Context) (*InstallationStatus, error) {
record, err := s.store.Installation().GetStatus(ctx)
if err != nil {
return nil, err
}
return s.installationStatusFromRecord(record), nil
}
func (s *Service) BootstrapOwner(ctx context.Context, cmd BootstrapOwnerCommand) (*BootstrapOwnerResult, error) {
email := strings.ToLower(strings.TrimSpace(cmd.Email))
password := strings.TrimSpace(cmd.Password)
if email == "" || !strings.Contains(email, "@") || len(password) < 12 {
return nil, ErrInvalidBootstrapOwner
}
now := s.now().UTC()
role := authority.PlatformRoleAdmin
installID := ""
grantSource := "installation_activation"
rootFingerprint := ""
activationPayload := cmd.ActivationPayload
activationSignature := strings.TrimSpace(cmd.ActivationSignature)
var expiresAt *time.Time
if s.strictAuthority() {
if len(activationPayload) == 0 || activationSignature == "" {
return nil, ErrInstallationActivationRequired
}
activation, err := s.authority.VerifyActivation(activationPayload, activationSignature)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrInvalidInstallationActivation, err)
}
if !strings.EqualFold(activation.OwnerEmail, email) {
return nil, ErrInvalidInstallationActivation
}
role = activation.PlatformRole
installID = activation.InstallID
expiresAt = activation.ExpiresAt
rootFingerprint = s.authority.RootFingerprint()
} else {
if s.authority == nil || !s.authority.AllowInsecureBootstrap() {
return nil, ErrInsecureBootstrapDisabled
}
installID = uuid.NewString()
grantSource = "dev_insecure"
rootFingerprint = "dev-insecure"
devPayload, err := json.Marshal(authority.ActivationPayload{
SchemaVersion: authority.ActivationSchemaVersion,
InstallID: installID,
OwnerEmail: email,
PlatformRole: role,
IssuedAt: now,
Environment: s.cfg.App.Env,
})
if err != nil {
return nil, err
}
activationPayload = json.RawMessage(devPayload)
activationSignature = "dev-insecure"
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("hash bootstrap owner password: %w", err)
}
var user *User
if err := s.transactor.WithinTransaction(ctx, func(store Store) error {
created, err := store.Installation().BootstrapOwner(ctx, BootstrapOwnerParams{
Email: email,
PasswordHash: string(passwordHash),
Role: role,
InstallID: installID,
ProductRootKeyFingerprint: rootFingerprint,
ActivationPayload: activationPayload,
ActivationSignature: activationSignature,
GrantSource: grantSource,
ExpiresAt: expiresAt,
Now: now,
})
if err != nil {
return err
}
user = created
return nil
}); err != nil {
return nil, err
}
status, err := s.InstallationStatus(ctx)
if err != nil {
return nil, err
}
return &BootstrapOwnerResult{
Installation: *status,
User: *user,
PlatformRole: role,
}, nil
}
func (s *Service) RevokeAuthSession(ctx context.Context, cmd RevokeAuthSessionCommand) error {
return s.transactor.WithinTransaction(ctx, func(store Store) error {
session, err := store.AuthSessions().GetByIDForUpdate(ctx, cmd.AuthSessionID)
if err != nil {
return err
}
if session == nil || session.UserID != cmd.UserID {
return ErrAuthSessionNotFound
}
return store.AuthSessions().Revoke(ctx, RevokeAuthSessionParams{
AuthSessionID: cmd.AuthSessionID,
UserID: cmd.UserID,
Reason: cmd.Reason,
RevokedAt: s.now().UTC(),
})
})
}
func (s *Service) RevokeTrustedDevice(ctx context.Context, cmd RevokeDeviceCommand) error {
return s.transactor.WithinTransaction(ctx, func(store Store) error {
device, err := store.Devices().GetByIDForUser(ctx, cmd.UserID, cmd.DeviceID)
if err != nil {
return err
}
if device == nil {
return ErrTrustedDeviceMissing
}
if device.TrustStatus != DeviceTrustStatusTrusted {
return ErrDeviceNotTrusted
}
now := s.now().UTC()
if err := store.Devices().Revoke(ctx, RevokeDeviceParams{
UserID: cmd.UserID,
DeviceID: cmd.DeviceID,
Reason: cmd.Reason,
RevokedAt: now,
}); err != nil {
return err
}
return store.AuthSessions().RevokeByDevice(ctx, cmd.UserID, cmd.DeviceID, "device_revoked:"+cmd.Reason, now)
})
}
func (s *Service) ListTrustedDevices(ctx context.Context, userID string) ([]Device, error) {
return s.store.Devices().ListTrustedByUser(ctx, userID)
}
func (s *Service) MapError(err error) (int, string) {
switch {
case err == nil:
return 0, ""
case errors.Is(err, ErrInvalidCredentials):
return 401, "invalid credentials"
case errors.Is(err, ErrInvalidRefreshToken):
return 401, "invalid refresh token"
case errors.Is(err, ErrAuthSessionRevoked):
return 401, "auth session revoked"
case errors.Is(err, ErrDeviceRevoked):
return 403, "device revoked"
case errors.Is(err, ErrDeviceNotTrusted):
return 409, "device is not trusted"
case errors.Is(err, ErrAuthSessionNotFound), errors.Is(err, ErrTrustedDeviceMissing):
return 404, err.Error()
case errors.Is(err, ErrInstallationActivationRequired), errors.Is(err, ErrInvalidInstallationActivation), errors.Is(err, ErrInvalidBootstrapOwner):
return 400, err.Error()
case errors.Is(err, ErrInsecureBootstrapDisabled):
return 403, err.Error()
case errors.Is(err, ErrInstallationAlreadyBootstrapped):
return 409, err.Error()
default:
return 500, fmt.Sprintf("internal error: %v", err)
}
}
func (s *Service) installationStatusFromRecord(record *InstallationAuthorityState) *InstallationStatus {
if record == nil {
record = &InstallationAuthorityState{AuthorityState: "unbootstrapped"}
}
mode := authority.ModeLegacy
strict := false
rootFingerprint := ""
insecureAllowed := false
if s.authority != nil {
mode = s.authority.Mode()
strict = s.authority.Strict()
rootFingerprint = s.authority.RootFingerprint()
insecureAllowed = s.authority.AllowInsecureBootstrap()
}
if record.ProductRootFingerprint != "" {
rootFingerprint = record.ProductRootFingerprint
}
return &InstallationStatus{
Bootstrapped: record.Bootstrapped,
AuthorityState: record.AuthorityState,
InstallID: record.InstallID,
BootstrappedOwnerEmail: record.BootstrappedOwnerEmail,
BootstrappedAt: record.BootstrappedAt,
AuthorityMode: mode,
StrictAuthority: strict,
RootFingerprint: rootFingerprint,
InsecureBootstrapAllowed: insecureAllowed,
}
}
func (s *Service) strictAuthority() bool {
return s.authority != nil && s.authority.Strict()
}
+95
View File
@@ -0,0 +1,95 @@
package auth
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
)
type TokenManager struct {
issuer string
accessSecret []byte
refreshHashSecret []byte
accessTTL time.Duration
refreshTTL time.Duration
}
type AccessClaims struct {
AuthSessionID string `json:"sid"`
DeviceID string `json:"did"`
jwt.RegisteredClaims
}
func NewTokenManager(cfg TokenConfig) *TokenManager {
return &TokenManager{
issuer: cfg.Issuer,
accessSecret: []byte(cfg.AccessTokenSecret),
refreshHashSecret: []byte(cfg.RefreshHashSecret),
accessTTL: cfg.AccessTokenTTL,
refreshTTL: cfg.RefreshTokenTTL,
}
}
type TokenConfig struct {
Issuer string
AccessTokenSecret string
RefreshHashSecret string
AccessTokenTTL time.Duration
RefreshTokenTTL time.Duration
}
func (m *TokenManager) IssueAccessToken(userID, authSessionID, deviceID string, now time.Time) (string, time.Time, error) {
expiresAt := now.Add(m.accessTTL)
claims := AccessClaims{
AuthSessionID: authSessionID,
DeviceID: deviceID,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: m.issuer,
Subject: userID,
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signed, err := token.SignedString(m.accessSecret)
if err != nil {
return "", time.Time{}, fmt.Errorf("sign access token: %w", err)
}
return signed, expiresAt, nil
}
func (m *TokenManager) IssueRefreshToken(authSessionID string, now time.Time) (raw string, hash string, expiresAt time.Time, err error) {
secret := make([]byte, 32)
if _, err = rand.Read(secret); err != nil {
return "", "", time.Time{}, fmt.Errorf("read random refresh secret: %w", err)
}
encodedSecret := base64.RawURLEncoding.EncodeToString(secret)
raw = authSessionID + "." + encodedSecret
hash = m.HashRefreshToken(raw)
expiresAt = now.Add(m.refreshTTL)
return raw, hash, expiresAt, nil
}
func (m *TokenManager) HashRefreshToken(token string) string {
mac := hmac.New(sha256.New, m.refreshHashSecret)
_, _ = mac.Write([]byte(token))
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
}
func (m *TokenManager) ParseRefreshToken(token string) (string, error) {
sessionID, _, ok := strings.Cut(token, ".")
if !ok || sessionID == "" {
return "", ErrInvalidRefreshToken
}
return sessionID, nil
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,34 @@
package cluster
import (
"encoding/json"
"testing"
)
func TestMeshLatestObservationKeySeparatesRouteHealthByRoute(t *testing.T) {
key := meshLatestObservationKey(json.RawMessage(`{
"observation_type":"synthetic_route_health",
"route_id":"route-1"
}`))
if key != "synthetic_route_health:route-1" {
t.Fatalf("key = %q", key)
}
}
func TestMeshLatestObservationKeySeparatesConnectionManagerMode(t *testing.T) {
key := meshLatestObservationKey(json.RawMessage(`{
"observation_type":"peer_connection_manager",
"transport_mode":"relay_control",
"relay_node_id":"node-r"
}`))
if key != "peer_connection_manager:relay_control:node-r" {
t.Fatalf("key = %q", key)
}
}
func TestMeshLatestObservationKeyDefaults(t *testing.T) {
key := meshLatestObservationKey(json.RawMessage(`{}`))
if key != "default" {
t.Fatalf("key = %q", key)
}
}
@@ -0,0 +1,91 @@
package cluster
import (
"context"
"encoding/json"
"time"
)
type Repository interface {
GetPlatformRole(ctx context.Context, userID string) (string, error)
ListClusters(ctx context.Context) ([]Cluster, error)
GetCluster(ctx context.Context, clusterID string) (Cluster, error)
CreateCluster(ctx context.Context, input CreateClusterInput) (Cluster, error)
UpdateCluster(ctx context.Context, input UpdateClusterInput) (Cluster, error)
GetClusterAuthority(ctx context.Context, clusterID string) (ClusterAuthorityKey, error)
EnsureClusterAuthority(ctx context.Context, clusterID string, actorUserID *string) (ClusterAuthorityKey, error)
ListClusterNodes(ctx context.Context, clusterID string) ([]ClusterNode, error)
ListNodeGroups(ctx context.Context, clusterID string) ([]ClusterNodeGroup, error)
CreateNodeGroup(ctx context.Context, input CreateNodeGroupInput) (ClusterNodeGroup, error)
AssignNodeToGroup(ctx context.Context, input AssignNodeGroupInput) (ClusterNode, error)
CreateJoinToken(ctx context.Context, input CreateJoinTokenInput, tokenHash string) (NodeJoinToken, error)
SetJoinTokenAuthority(ctx context.Context, clusterID, tokenID string, payload json.RawMessage, signature ClusterSignature) (NodeJoinToken, error)
GetValidJoinTokenByHash(ctx context.Context, clusterID, tokenHash string) (NodeJoinToken, error)
RevokeJoinToken(ctx context.Context, input RevokeJoinTokenInput) (NodeJoinToken, error)
ExpireJoinTokens(ctx context.Context, clusterID string) error
CreateJoinRequest(ctx context.Context, input CreateJoinRequestInput, joinTokenID string) (NodeJoinRequest, error)
GetJoinRequestForBootstrap(ctx context.Context, input GetJoinRequestBootstrapInput) (NodeJoinRequest, error)
ListJoinRequests(ctx context.Context, clusterID string) ([]NodeJoinRequest, error)
ApproveJoinRequest(ctx context.Context, input ApproveJoinRequestInput) (ApprovedJoinRequest, error)
SetJoinRequestApprovalAuthority(ctx context.Context, clusterID, joinRequestID string, payload json.RawMessage, signature ClusterSignature) (NodeJoinRequest, error)
RejectJoinRequest(ctx context.Context, input RejectJoinRequestInput) (NodeJoinRequest, error)
AssignNodeRole(ctx context.Context, input AssignNodeRoleInput) (NodeRoleAssignment, error)
ListNodeRoleAssignments(ctx context.Context, clusterID, nodeID string) ([]NodeRoleAssignment, error)
AttachExistingNodeToCluster(ctx context.Context, input AttachExistingNodeInput) (ClusterNode, error)
RecordHeartbeat(ctx context.Context, input RecordHeartbeatInput) (NodeHeartbeat, error)
ListNodeHeartbeats(ctx context.Context, clusterID, nodeID string, limit int) ([]NodeHeartbeat, error)
RevokeNodeIdentity(ctx context.Context, input RevokeNodeIdentityInput) error
DisableClusterMembership(ctx context.Context, input DisableMembershipInput) error
UpsertFabricTestingFlag(ctx context.Context, input UpsertFabricTestingFlagInput) (FabricTestingFlag, error)
ListFabricTestingFlags(ctx context.Context) ([]FabricTestingFlag, error)
GetEffectiveNodeTestingFlags(ctx context.Context, clusterID, nodeID string) (EffectiveNodeTestingFlags, error)
RecordNodeTelemetry(ctx context.Context, input RecordNodeTelemetryInput) (NodeTelemetryObservation, error)
ListNodeTelemetry(ctx context.Context, clusterID, nodeID string, limit int) ([]NodeTelemetryObservation, error)
SetDesiredWorkload(ctx context.Context, input SetDesiredWorkloadInput) (NodeWorkloadDesiredState, error)
ListDesiredWorkloads(ctx context.Context, clusterID, nodeID string) ([]NodeWorkloadDesiredState, error)
ReportWorkloadStatus(ctx context.Context, input ReportWorkloadStatusInput) (NodeWorkloadStatus, error)
ListLatestWorkloadStatuses(ctx context.Context, clusterID, nodeID string) ([]NodeWorkloadStatus, error)
ReportMeshLink(ctx context.Context, input ReportMeshLinkInput) (MeshLinkObservation, error)
ListMeshLinks(ctx context.Context, clusterID string) ([]MeshLinkObservation, error)
CreateRouteIntent(ctx context.Context, input CreateRouteIntentInput) (MeshRouteIntent, error)
ListRouteIntents(ctx context.Context, clusterID string) ([]MeshRouteIntent, error)
ListQoSPolicies(ctx context.Context, clusterID string) ([]MeshQoSPolicy, error)
ListFabricEntryPoints(ctx context.Context, clusterID string) ([]FabricEntryPoint, error)
CreateFabricEntryPoint(ctx context.Context, input CreateFabricEntryPointInput) (FabricEntryPoint, error)
SetFabricEntryPointNode(ctx context.Context, input SetFabricEntryPointNodeInput) (FabricEntryPointNode, error)
ListFabricEntryPointNodes(ctx context.Context, clusterID, entryPointID string) ([]FabricEntryPointNode, error)
ListFabricEgressPools(ctx context.Context, clusterID string) ([]FabricEgressPool, error)
CreateFabricEgressPool(ctx context.Context, input CreateFabricEgressPoolInput) (FabricEgressPool, error)
SetFabricEgressPoolNode(ctx context.Context, input SetFabricEgressPoolNodeInput) (FabricEgressPoolNode, error)
ListFabricEgressPoolNodes(ctx context.Context, clusterID, egressPoolID string) ([]FabricEgressPoolNode, error)
GetClusterAuthorityState(ctx context.Context, clusterID string) (ClusterAuthorityState, error)
UpdateClusterAuthorityState(ctx context.Context, input UpdateClusterAuthorityInput) (ClusterAuthorityState, error)
ListClusterAdminSummaries(ctx context.Context) ([]ClusterAdminSummary, error)
CreateVPNConnection(ctx context.Context, input CreateVPNConnectionInput) (VPNConnection, error)
ListVPNConnections(ctx context.Context, clusterID string) ([]VPNConnection, error)
GetVPNConnection(ctx context.Context, clusterID, vpnConnectionID string) (VPNConnection, error)
UpdateVPNConnectionDesiredState(ctx context.Context, input UpdateVPNConnectionDesiredStateInput) (VPNConnection, error)
UpsertVPNConnectionRoutePolicy(ctx context.Context, input UpsertVPNConnectionRoutePolicyInput) (VPNConnectionRoutePolicy, error)
ListVPNConnectionRoutePolicies(ctx context.Context, clusterID, vpnConnectionID string) ([]VPNConnectionRoutePolicy, error)
SetVPNConnectionAllowedNodes(ctx context.Context, input SetVPNConnectionAllowedNodesInput) ([]VPNConnectionAllowedNode, error)
ListVPNConnectionAllowedNodes(ctx context.Context, clusterID, vpnConnectionID string) ([]VPNConnectionAllowedNode, error)
AcquireVPNConnectionLease(ctx context.Context, input AcquireVPNConnectionLeaseInput, expiresAt time.Time, fencingToken string) (VPNConnectionLease, error)
RenewVPNConnectionLease(ctx context.Context, input RenewVPNConnectionLeaseInput, expiresAt time.Time) (VPNConnectionLease, error)
ReleaseVPNConnectionLease(ctx context.Context, input ReleaseVPNConnectionLeaseInput) (VPNConnectionLease, error)
FenceVPNConnectionLease(ctx context.Context, input FenceVPNConnectionLeaseInput) (VPNConnectionLease, error)
GetActiveVPNConnectionLease(ctx context.Context, clusterID, vpnConnectionID string) (VPNConnectionLease, error)
CheckVPNLeaseOwnerEligibility(ctx context.Context, clusterID, vpnConnectionID, ownerNodeID string) (VPNLeaseOwnerEligibility, error)
ExpireStaleVPNConnectionLeases(ctx context.Context, clusterID string, now time.Time) ([]VPNConnectionLease, error)
ListNodeVPNAssignments(ctx context.Context, clusterID, nodeID string) ([]NodeVPNAssignment, error)
ReportNodeVPNAssignmentStatus(ctx context.Context, input ReportNodeVPNAssignmentStatusInput) (NodeVPNAssignmentStatus, error)
RecordAudit(ctx context.Context, event ClusterAuditEvent) error
ListAuditEvents(ctx context.Context, clusterID string, limit int) ([]ClusterAuditEvent, error)
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+43
View File
@@ -0,0 +1,43 @@
package cluster
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
"strings"
"time"
)
const joinTokenHashPrefix = "sha256:"
func generateJoinToken() (string, error) {
var random [32]byte
if _, err := rand.Read(random[:]); err != nil {
return "", err
}
return "rap_join_" + base64.RawURLEncoding.EncodeToString(random[:]), nil
}
func hashJoinToken(token string) (string, error) {
trimmed := strings.TrimSpace(token)
if trimmed == "" {
return "", errors.New("join token is required")
}
sum := sha256.Sum256([]byte(trimmed))
return joinTokenHashPrefix + hex.EncodeToString(sum[:]), nil
}
func isPlatformAdminRole(role string) bool {
return role == PlatformRoleAdmin || role == PlatformRoleRecoveryAdmin
}
func isAllowedNodeRole(role string) bool {
_, ok := allowedNodeRoles[role]
return ok
}
func defaultJoinTokenExpiry(now time.Time) time.Time {
return now.Add(30 * time.Minute)
}
@@ -0,0 +1,344 @@
package identitysource
import (
"context"
"encoding/json"
"errors"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
"github.com/example/remote-access-platform/backend/internal/platform/module"
)
type Module struct {
db *pgxpool.Pool
}
type IdentitySource struct {
ID string `json:"id"`
OrganizationID string `json:"organization_id"`
Kind string `json:"kind"`
Name string `json:"name"`
Status string `json:"status"`
Config json.RawMessage `json:"config"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type IdentityMapping struct {
ID string `json:"id"`
IdentitySourceID string `json:"identity_source_id"`
MappingType string `json:"mapping_type"`
ExternalSelector json.RawMessage `json:"external_selector"`
InternalTarget json.RawMessage `json:"internal_target"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type upsertIdentitySourceRequest struct {
ActorUserID string `json:"actor_user_id"`
OrganizationID string `json:"organization_id"`
Kind string `json:"kind"`
Name string `json:"name"`
Status string `json:"status"`
Config json.RawMessage `json:"config"`
IdentityMappings []struct {
MappingType string `json:"mapping_type"`
ExternalSelector json.RawMessage `json:"external_selector"`
InternalTarget json.RawMessage `json:"internal_target"`
} `json:"identity_mappings"`
}
func NewModule(deps module.Dependencies) *Module {
return &Module{db: deps.Infra.DB}
}
func (m *Module) Name() string {
return "identitysource"
}
func (m *Module) RegisterRoutes(router chi.Router) {
router.Route("/identity-sources", func(r chi.Router) {
r.Get("/", m.listIdentitySources)
r.Post("/", m.createIdentitySource)
r.Get("/{identitySourceID}", m.getIdentitySource)
r.Put("/{identitySourceID}", m.updateIdentitySource)
})
}
func (m *Module) listIdentitySources(w http.ResponseWriter, r *http.Request) {
orgID := r.URL.Query().Get("organization_id")
if orgID == "" {
httpx.WriteError(w, http.StatusBadRequest, "organization_id is required")
return
}
rows, err := m.db.Query(r.Context(), `
SELECT id, organization_id, kind, name, status, config, created_at, updated_at
FROM identity_sources
WHERE organization_id = $1
ORDER BY created_at DESC
`, orgID)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer rows.Close()
var items []IdentitySource
for rows.Next() {
item, err := scanIdentitySource(rows)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
items = append(items, item)
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"identity_sources": items})
}
func (m *Module) getIdentitySource(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "identitySourceID")
item, err := m.getByID(r.Context(), id)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
httpx.WriteError(w, http.StatusNotFound, "identity source not found")
return
}
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
mappings, err := m.listMappings(r.Context(), id)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{
"identity_source": item,
"identity_mappings": mappings,
})
}
func (m *Module) createIdentitySource(w http.ResponseWriter, r *http.Request) {
req, err := decodeRequest(r)
if err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
now := time.Now().UTC()
item := IdentitySource{
ID: uuid.NewString(),
OrganizationID: req.OrganizationID,
Kind: req.Kind,
Name: req.Name,
Status: req.Status,
Config: req.Config,
CreatedAt: now,
UpdatedAt: now,
}
tx, err := m.db.Begin(r.Context())
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer tx.Rollback(r.Context())
if _, err := tx.Exec(r.Context(), `
INSERT INTO identity_sources (id, organization_id, kind, name, status, config, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6::jsonb, $7, $8)
`, item.ID, item.OrganizationID, item.Kind, item.Name, item.Status, []byte(item.Config), item.CreatedAt, item.UpdatedAt); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
mappings, err := upsertMappings(r.Context(), tx, item.ID, req.IdentityMappings)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := tx.Commit(r.Context()); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusCreated, map[string]any{
"identity_source": item,
"identity_mappings": mappings,
})
}
func (m *Module) updateIdentitySource(w http.ResponseWriter, r *http.Request) {
req, err := decodeRequest(r)
if err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
id := chi.URLParam(r, "identitySourceID")
tx, err := m.db.Begin(r.Context())
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer tx.Rollback(r.Context())
tag, err := tx.Exec(r.Context(), `
UPDATE identity_sources
SET organization_id = $2, kind = $3, name = $4, status = $5, config = $6::jsonb, updated_at = $7
WHERE id = $1
`, id, req.OrganizationID, req.Kind, req.Name, req.Status, []byte(req.Config), time.Now().UTC())
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if tag.RowsAffected() == 0 {
httpx.WriteError(w, http.StatusNotFound, "identity source not found")
return
}
if _, err := tx.Exec(r.Context(), `DELETE FROM identity_mappings WHERE identity_source_id = $1`, id); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
mappings, err := upsertMappings(r.Context(), tx, id, req.IdentityMappings)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := tx.Commit(r.Context()); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
item, err := m.getByID(r.Context(), id)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{
"identity_source": item,
"identity_mappings": mappings,
})
}
func decodeRequest(r *http.Request) (*upsertIdentitySourceRequest, error) {
var req upsertIdentitySourceRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.New("invalid identity source payload")
}
if req.ActorUserID == "" || req.OrganizationID == "" || req.Kind == "" || req.Name == "" {
return nil, errors.New("actor_user_id, organization_id, kind, and name are required")
}
if req.Status == "" {
req.Status = "active"
}
if len(req.Config) == 0 {
req.Config = json.RawMessage(`{}`)
}
if !json.Valid(req.Config) {
return nil, errors.New("config must be valid json")
}
for _, mapping := range req.IdentityMappings {
if len(mapping.ExternalSelector) == 0 {
mapping.ExternalSelector = json.RawMessage(`{}`)
}
if len(mapping.InternalTarget) == 0 {
mapping.InternalTarget = json.RawMessage(`{}`)
}
}
return &req, nil
}
func (m *Module) getByID(ctx context.Context, id string) (IdentitySource, error) {
row := m.db.QueryRow(ctx, `
SELECT id, organization_id, kind, name, status, config, created_at, updated_at
FROM identity_sources
WHERE id = $1
`, id)
return scanIdentitySource(row)
}
func (m *Module) listMappings(ctx context.Context, sourceID string) ([]IdentityMapping, error) {
rows, err := m.db.Query(ctx, `
SELECT id, identity_source_id, mapping_type, external_selector, internal_target, created_at, updated_at
FROM identity_mappings
WHERE identity_source_id = $1
ORDER BY created_at ASC
`, sourceID)
if err != nil {
return nil, err
}
defer rows.Close()
var mappings []IdentityMapping
for rows.Next() {
item, err := scanIdentityMapping(rows)
if err != nil {
return nil, err
}
mappings = append(mappings, item)
}
return mappings, rows.Err()
}
func upsertMappings(ctx context.Context, tx pgx.Tx, sourceID string, requested []struct {
MappingType string `json:"mapping_type"`
ExternalSelector json.RawMessage `json:"external_selector"`
InternalTarget json.RawMessage `json:"internal_target"`
}) ([]IdentityMapping, error) {
now := time.Now().UTC()
items := make([]IdentityMapping, 0, len(requested))
for _, mapping := range requested {
external := mapping.ExternalSelector
if len(external) == 0 {
external = json.RawMessage(`{}`)
}
internal := mapping.InternalTarget
if len(internal) == 0 {
internal = json.RawMessage(`{}`)
}
item := IdentityMapping{
ID: uuid.NewString(),
IdentitySourceID: sourceID,
MappingType: mapping.MappingType,
ExternalSelector: external,
InternalTarget: internal,
CreatedAt: now,
UpdatedAt: now,
}
if _, err := tx.Exec(ctx, `
INSERT INTO identity_mappings (
id, identity_source_id, mapping_type, external_selector, internal_target, created_at, updated_at
) VALUES ($1, $2, $3, $4::jsonb, $5::jsonb, $6, $7)
`, item.ID, item.IdentitySourceID, item.MappingType, []byte(item.ExternalSelector), []byte(item.InternalTarget), item.CreatedAt, item.UpdatedAt); err != nil {
return nil, err
}
items = append(items, item)
}
return items, nil
}
type rowScanner interface {
Scan(dest ...any) error
}
func scanIdentitySource(row rowScanner) (IdentitySource, error) {
var item IdentitySource
if err := row.Scan(&item.ID, &item.OrganizationID, &item.Kind, &item.Name, &item.Status, &item.Config, &item.CreatedAt, &item.UpdatedAt); err != nil {
return IdentitySource{}, err
}
if len(item.Config) == 0 {
item.Config = json.RawMessage(`{}`)
}
return item, nil
}
func scanIdentityMapping(row rowScanner) (IdentityMapping, error) {
var item IdentityMapping
if err := row.Scan(&item.ID, &item.IdentitySourceID, &item.MappingType, &item.ExternalSelector, &item.InternalTarget, &item.CreatedAt, &item.UpdatedAt); err != nil {
return IdentityMapping{}, err
}
if len(item.ExternalSelector) == 0 {
item.ExternalSelector = json.RawMessage(`{}`)
}
if len(item.InternalTarget) == 0 {
item.InternalTarget = json.RawMessage(`{}`)
}
return item, nil
}
+458
View File
@@ -0,0 +1,458 @@
package node
import (
"context"
"encoding/json"
"errors"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
"github.com/example/remote-access-platform/backend/internal/platform/module"
)
type Module struct {
db *pgxpool.Pool
}
type Node struct {
ID string `json:"id"`
OwnerOrganizationID *string `json:"owner_organization_id,omitempty"`
NodeKey string `json:"node_key"`
Name string `json:"name"`
OwnershipType string `json:"ownership_type"`
RegistrationStatus string `json:"registration_status"`
HealthStatus string `json:"health_status"`
VersionState string `json:"version_state"`
PartitionState string `json:"partition_state"`
DesiredVersion *string `json:"desired_version,omitempty"`
ReportedVersion *string `json:"reported_version,omitempty"`
LastSeenAt *time.Time `json:"last_seen_at,omitempty"`
Metadata json.RawMessage `json:"metadata"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type NodeCapability struct {
NodeID string `json:"node_id"`
Capability string `json:"capability"`
Value json.RawMessage `json:"value"`
UpdatedAt time.Time `json:"updated_at"`
}
type NodeService struct {
NodeID string `json:"node_id"`
ServiceType string `json:"service_type"`
Enabled bool `json:"enabled"`
DesiredState string `json:"desired_state"`
ReportedState string `json:"reported_state"`
LastReportedAt *time.Time `json:"last_reported_at,omitempty"`
Metadata json.RawMessage `json:"metadata"`
UpdatedAt time.Time `json:"updated_at"`
}
type NodeUpdatePolicy struct {
NodeID string `json:"node_id"`
Mode string `json:"mode"`
Channel string `json:"channel"`
MaintenanceWindow json.RawMessage `json:"maintenance_window"`
Canary bool `json:"canary"`
AutomaticRollout bool `json:"automatic_rollout"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type NodePartitionState struct {
NodeID string `json:"node_id"`
ClusterState string `json:"cluster_state"`
RecoveryMode string `json:"recovery_mode"`
Notes *string `json:"notes,omitempty"`
UpdatedAt time.Time `json:"updated_at"`
}
type upsertNodeRequest struct {
ActorUserID string `json:"actor_user_id"`
OwnerOrganizationID *string `json:"owner_organization_id"`
NodeKey string `json:"node_key"`
Name string `json:"name"`
OwnershipType string `json:"ownership_type"`
DesiredVersion *string `json:"desired_version"`
Metadata json.RawMessage `json:"metadata"`
Capabilities []struct {
Capability string `json:"capability"`
Value json.RawMessage `json:"value"`
} `json:"capabilities"`
Services []struct {
ServiceType string `json:"service_type"`
Enabled bool `json:"enabled"`
DesiredState string `json:"desired_state"`
Metadata json.RawMessage `json:"metadata"`
} `json:"services"`
UpdatePolicy struct {
Mode string `json:"mode"`
Channel string `json:"channel"`
MaintenanceWindow json.RawMessage `json:"maintenance_window"`
Canary bool `json:"canary"`
AutomaticRollout bool `json:"automatic_rollout"`
} `json:"update_policy"`
PartitionState struct {
ClusterState string `json:"cluster_state"`
RecoveryMode string `json:"recovery_mode"`
Notes *string `json:"notes"`
} `json:"partition_state"`
}
func NewModule(deps module.Dependencies) *Module {
return &Module{db: deps.Infra.DB}
}
func (m *Module) Name() string {
return "node"
}
func (m *Module) RegisterRoutes(router chi.Router) {
router.Route("/nodes", func(r chi.Router) {
r.Get("/", m.listNodes)
r.Post("/", m.createNode)
r.Get("/{nodeID}", m.getNode)
r.Put("/{nodeID}", m.updateNode)
})
}
func (m *Module) listNodes(w http.ResponseWriter, r *http.Request) {
rows, err := m.db.Query(r.Context(), `
SELECT id, owner_organization_id, node_key, name, ownership_type, registration_status, health_status,
version_state, partition_state, desired_version, reported_version, last_seen_at, metadata, created_at, updated_at
FROM nodes
ORDER BY created_at DESC
`)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer rows.Close()
var items []Node
for rows.Next() {
item, err := scanNode(rows)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
items = append(items, item)
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"nodes": items})
}
func (m *Module) getNode(w http.ResponseWriter, r *http.Request) {
nodeID := chi.URLParam(r, "nodeID")
item, err := m.getNodeByID(r.Context(), nodeID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
httpx.WriteError(w, http.StatusNotFound, "node not found")
return
}
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
caps, _ := m.listCapabilities(r.Context(), nodeID)
services, _ := m.listServices(r.Context(), nodeID)
updatePolicy, _ := m.getUpdatePolicy(r.Context(), nodeID)
partitionState, _ := m.getPartitionState(r.Context(), nodeID)
httpx.WriteJSON(w, http.StatusOK, map[string]any{
"node": item,
"capabilities": caps,
"services": services,
"update_policy": updatePolicy,
"partition_state": partitionState,
})
}
func (m *Module) createNode(w http.ResponseWriter, r *http.Request) {
req, err := decodeNodeRequest(r)
if err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
item := Node{
ID: uuid.NewString(),
OwnerOrganizationID: req.OwnerOrganizationID,
NodeKey: req.NodeKey,
Name: req.Name,
OwnershipType: req.OwnershipType,
RegistrationStatus: "pending",
HealthStatus: "unknown",
VersionState: "unknown",
PartitionState: "healthy",
DesiredVersion: req.DesiredVersion,
Metadata: req.Metadata,
CreatedAt: time.Now().UTC(),
UpdatedAt: time.Now().UTC(),
}
if err := m.persistNode(r.Context(), item, req, true); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusCreated, map[string]any{"node": item})
}
func (m *Module) updateNode(w http.ResponseWriter, r *http.Request) {
req, err := decodeNodeRequest(r)
if err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
nodeID := chi.URLParam(r, "nodeID")
item, err := m.getNodeByID(r.Context(), nodeID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
httpx.WriteError(w, http.StatusNotFound, "node not found")
return
}
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
item.OwnerOrganizationID = req.OwnerOrganizationID
item.NodeKey = req.NodeKey
item.Name = req.Name
item.OwnershipType = req.OwnershipType
item.DesiredVersion = req.DesiredVersion
item.Metadata = req.Metadata
item.UpdatedAt = time.Now().UTC()
if err := m.persistNode(r.Context(), item, req, false); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"node": item})
}
func decodeNodeRequest(r *http.Request) (*upsertNodeRequest, error) {
var req upsertNodeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.New("invalid node payload")
}
if req.ActorUserID == "" || req.NodeKey == "" || req.Name == "" || req.OwnershipType == "" {
return nil, errors.New("actor_user_id, node_key, name, and ownership_type are required")
}
if len(req.Metadata) == 0 {
req.Metadata = json.RawMessage(`{}`)
}
if !json.Valid(req.Metadata) {
return nil, errors.New("metadata must be valid json")
}
if req.UpdatePolicy.Mode == "" {
req.UpdatePolicy.Mode = "manual"
}
if req.UpdatePolicy.Channel == "" {
req.UpdatePolicy.Channel = "stable"
}
if len(req.UpdatePolicy.MaintenanceWindow) == 0 {
req.UpdatePolicy.MaintenanceWindow = json.RawMessage(`{}`)
}
if req.PartitionState.ClusterState == "" {
req.PartitionState.ClusterState = "healthy"
}
if req.PartitionState.RecoveryMode == "" {
req.PartitionState.RecoveryMode = "normal"
}
for i := range req.Capabilities {
if len(req.Capabilities[i].Value) == 0 {
req.Capabilities[i].Value = json.RawMessage(`{}`)
}
}
for i := range req.Services {
if req.Services[i].DesiredState == "" {
req.Services[i].DesiredState = "disabled"
}
if len(req.Services[i].Metadata) == 0 {
req.Services[i].Metadata = json.RawMessage(`{}`)
}
}
return &req, nil
}
func (m *Module) persistNode(ctx context.Context, item Node, req *upsertNodeRequest, create bool) error {
tx, err := m.db.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
if create {
_, err = tx.Exec(ctx, `
INSERT INTO nodes (
id, owner_organization_id, node_key, name, ownership_type, registration_status, health_status,
version_state, partition_state, desired_version, reported_version, last_seen_at, metadata, created_at, updated_at
) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13::jsonb,$14,$15)
`, item.ID, item.OwnerOrganizationID, item.NodeKey, item.Name, item.OwnershipType, item.RegistrationStatus, item.HealthStatus, item.VersionState, item.PartitionState, item.DesiredVersion, item.ReportedVersion, item.LastSeenAt, []byte(item.Metadata), item.CreatedAt, item.UpdatedAt)
} else {
_, err = tx.Exec(ctx, `
UPDATE nodes
SET owner_organization_id=$2, node_key=$3, name=$4, ownership_type=$5, desired_version=$6, metadata=$7::jsonb, updated_at=$8
WHERE id=$1
`, item.ID, item.OwnerOrganizationID, item.NodeKey, item.Name, item.OwnershipType, item.DesiredVersion, []byte(item.Metadata), item.UpdatedAt)
}
if err != nil {
return err
}
if _, err := tx.Exec(ctx, `DELETE FROM node_capabilities WHERE node_id = $1`, item.ID); err != nil {
return err
}
for _, capability := range req.Capabilities {
if _, err := tx.Exec(ctx, `
INSERT INTO node_capabilities (node_id, capability, value, updated_at)
VALUES ($1, $2, $3::jsonb, $4)
`, item.ID, capability.Capability, []byte(capability.Value), time.Now().UTC()); err != nil {
return err
}
}
if _, err := tx.Exec(ctx, `DELETE FROM node_services WHERE node_id = $1`, item.ID); err != nil {
return err
}
for _, service := range req.Services {
if _, err := tx.Exec(ctx, `
INSERT INTO node_services (
node_id, service_type, enabled, desired_state, reported_state, last_reported_at, metadata, updated_at
) VALUES ($1, $2, $3, $4, 'unknown', NULL, $5::jsonb, $6)
`, item.ID, service.ServiceType, service.Enabled, service.DesiredState, []byte(service.Metadata), time.Now().UTC()); err != nil {
return err
}
}
if _, err := tx.Exec(ctx, `
INSERT INTO node_update_policies (
node_id, mode, channel, maintenance_window, canary, automatic_rollout, created_at, updated_at
) VALUES ($1,$2,$3,$4::jsonb,$5,$6,$7,$8)
ON CONFLICT (node_id) DO UPDATE SET
mode = EXCLUDED.mode,
channel = EXCLUDED.channel,
maintenance_window = EXCLUDED.maintenance_window,
canary = EXCLUDED.canary,
automatic_rollout = EXCLUDED.automatic_rollout,
updated_at = EXCLUDED.updated_at
`, item.ID, req.UpdatePolicy.Mode, req.UpdatePolicy.Channel, []byte(req.UpdatePolicy.MaintenanceWindow), req.UpdatePolicy.Canary, req.UpdatePolicy.AutomaticRollout, time.Now().UTC(), time.Now().UTC()); err != nil {
return err
}
if _, err := tx.Exec(ctx, `
INSERT INTO node_partition_states (node_id, cluster_state, recovery_mode, notes, updated_at)
VALUES ($1,$2,$3,$4,$5)
ON CONFLICT (node_id) DO UPDATE SET
cluster_state = EXCLUDED.cluster_state,
recovery_mode = EXCLUDED.recovery_mode,
notes = EXCLUDED.notes,
updated_at = EXCLUDED.updated_at
`, item.ID, req.PartitionState.ClusterState, req.PartitionState.RecoveryMode, req.PartitionState.Notes, time.Now().UTC()); err != nil {
return err
}
return tx.Commit(ctx)
}
func (m *Module) getNodeByID(ctx context.Context, nodeID string) (Node, error) {
row := m.db.QueryRow(ctx, `
SELECT id, owner_organization_id, node_key, name, ownership_type, registration_status, health_status,
version_state, partition_state, desired_version, reported_version, last_seen_at, metadata, created_at, updated_at
FROM nodes
WHERE id = $1
`, nodeID)
return scanNode(row)
}
func (m *Module) listCapabilities(ctx context.Context, nodeID string) ([]NodeCapability, error) {
rows, err := m.db.Query(ctx, `SELECT node_id, capability, value, updated_at FROM node_capabilities WHERE node_id = $1 ORDER BY capability`, nodeID)
if err != nil {
return nil, err
}
defer rows.Close()
var out []NodeCapability
for rows.Next() {
var item NodeCapability
if err := rows.Scan(&item.NodeID, &item.Capability, &item.Value, &item.UpdatedAt); err != nil {
return nil, err
}
out = append(out, item)
}
return out, rows.Err()
}
func (m *Module) listServices(ctx context.Context, nodeID string) ([]NodeService, error) {
rows, err := m.db.Query(ctx, `
SELECT node_id, service_type, enabled, desired_state, reported_state, last_reported_at, metadata, updated_at
FROM node_services WHERE node_id = $1 ORDER BY service_type
`, nodeID)
if err != nil {
return nil, err
}
defer rows.Close()
var out []NodeService
for rows.Next() {
var item NodeService
if err := rows.Scan(&item.NodeID, &item.ServiceType, &item.Enabled, &item.DesiredState, &item.ReportedState, &item.LastReportedAt, &item.Metadata, &item.UpdatedAt); err != nil {
return nil, err
}
out = append(out, item)
}
return out, rows.Err()
}
func (m *Module) getUpdatePolicy(ctx context.Context, nodeID string) (*NodeUpdatePolicy, error) {
row := m.db.QueryRow(ctx, `
SELECT node_id, mode, channel, maintenance_window, canary, automatic_rollout, created_at, updated_at
FROM node_update_policies WHERE node_id = $1
`, nodeID)
var item NodeUpdatePolicy
if err := row.Scan(&item.NodeID, &item.Mode, &item.Channel, &item.MaintenanceWindow, &item.Canary, &item.AutomaticRollout, &item.CreatedAt, &item.UpdatedAt); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return nil, err
}
return &item, nil
}
func (m *Module) getPartitionState(ctx context.Context, nodeID string) (*NodePartitionState, error) {
row := m.db.QueryRow(ctx, `
SELECT node_id, cluster_state, recovery_mode, notes, updated_at
FROM node_partition_states WHERE node_id = $1
`, nodeID)
var item NodePartitionState
if err := row.Scan(&item.NodeID, &item.ClusterState, &item.RecoveryMode, &item.Notes, &item.UpdatedAt); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return nil, err
}
return &item, nil
}
type rowScanner interface {
Scan(dest ...any) error
}
func scanNode(row rowScanner) (Node, error) {
var item Node
if err := row.Scan(
&item.ID,
&item.OwnerOrganizationID,
&item.NodeKey,
&item.Name,
&item.OwnershipType,
&item.RegistrationStatus,
&item.HealthStatus,
&item.VersionState,
&item.PartitionState,
&item.DesiredVersion,
&item.ReportedVersion,
&item.LastSeenAt,
&item.Metadata,
&item.CreatedAt,
&item.UpdatedAt,
); err != nil {
return Node{}, err
}
if len(item.Metadata) == 0 {
item.Metadata = json.RawMessage(`{}`)
}
return item, nil
}
@@ -0,0 +1,356 @@
package nodeagent
import (
"encoding/json"
"errors"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
clustermodule "github.com/example/remote-access-platform/backend/internal/modules/cluster"
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
"github.com/example/remote-access-platform/backend/internal/platform/module"
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
)
type Module struct {
db *pgxpool.Pool
cluster *clustermodule.Service
}
func NewModule(deps module.Dependencies) *Module {
clusterStore := clustermodule.NewPostgresStore(deps.Infra.DB)
if deps.Config.Secret.EncryptionKeyBase64 != "" {
if encryptor, err := secrets.NewEncryptor(deps.Config.Secret.EncryptionKeyBase64, deps.Config.Secret.EncryptionKeyID); err == nil {
clusterStore.WithClusterKeyEncryptor(encryptor)
}
}
return &Module{
db: deps.Infra.DB,
cluster: clustermodule.NewService(clusterStore),
}
}
func (m *Module) Name() string {
return "nodeagent"
}
func (m *Module) RegisterRoutes(router chi.Router) {
router.Route("/node-agents", func(r chi.Router) {
r.Post("/enroll", m.enrollAgent)
r.Post("/enrollments/{requestID}/bootstrap", m.bootstrapEnrollment)
r.Post("/register", m.registerAgent)
r.Post("/{nodeID}/health", m.reportHealth)
r.Post("/{nodeID}/services/status", m.reportServiceStatus)
r.Post("/{nodeID}/update-manifest/request", m.requestUpdateManifest)
r.Post("/{nodeID}/update-result", m.acknowledgeUpdateResult)
r.Post("/{nodeID}/rollback-result", m.reportRollbackResult)
r.Get("/{nodeID}/clusters/{clusterID}/vpn-assignments/desired", m.listVPNAssignments)
r.Post("/{nodeID}/clusters/{clusterID}/vpn-assignments/{vpnConnectionID}/status", m.reportVPNAssignmentStatus)
})
}
func (m *Module) enrollAgent(w http.ResponseWriter, r *http.Request) {
var payload struct {
ClusterID string `json:"cluster_id"`
JoinToken string `json:"join_token"`
NodeName string `json:"node_name"`
NodeFingerprint string `json:"node_fingerprint"`
PublicKey string `json:"public_key"`
ReportedCapabilities json.RawMessage `json:"reported_capabilities"`
ReportedFacts json.RawMessage `json:"reported_facts"`
RequestedRoles json.RawMessage `json:"requested_roles"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid agent enrollment payload")
return
}
joinRequest, err := m.cluster.CreateJoinRequest(r.Context(), clustermodule.CreateJoinRequestInput{
ClusterID: payload.ClusterID,
JoinToken: payload.JoinToken,
NodeName: payload.NodeName,
NodeFingerprint: payload.NodeFingerprint,
PublicKey: payload.PublicKey,
ReportedCapabilities: payload.ReportedCapabilities,
ReportedFacts: payload.ReportedFacts,
RequestedRoles: payload.RequestedRoles,
})
if err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{
"status": "pending_approval",
"join_request": joinRequest,
})
}
func (m *Module) bootstrapEnrollment(w http.ResponseWriter, r *http.Request) {
var payload struct {
ClusterID string `json:"cluster_id"`
NodeFingerprint string `json:"node_fingerprint"`
PublicKey string `json:"public_key"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid enrollment bootstrap payload")
return
}
result, err := m.cluster.GetJoinRequestBootstrap(r.Context(), clustermodule.GetJoinRequestBootstrapInput{
ClusterID: payload.ClusterID,
JoinRequestID: chi.URLParam(r, "requestID"),
NodeFingerprint: payload.NodeFingerprint,
PublicKey: payload.PublicKey,
})
if err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, result)
}
func (m *Module) registerAgent(w http.ResponseWriter, r *http.Request) {
var payload struct {
NodeKey string `json:"node_key"`
Name string `json:"name"`
OwnershipType string `json:"ownership_type"`
OwnerOrganizationID *string `json:"owner_organization_id"`
ReportedVersion *string `json:"reported_version"`
Metadata json.RawMessage `json:"metadata"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid agent registration payload")
return
}
if payload.NodeKey == "" || payload.Name == "" || payload.OwnershipType == "" {
httpx.WriteError(w, http.StatusBadRequest, "node_key, name, and ownership_type are required")
return
}
if len(payload.Metadata) == 0 {
payload.Metadata = json.RawMessage(`{}`)
}
now := time.Now().UTC()
nodeID := uuid.NewString()
if err := m.db.QueryRow(r.Context(), `
INSERT INTO nodes (
id, owner_organization_id, node_key, name, ownership_type, registration_status, health_status,
version_state, partition_state, desired_version, reported_version, last_seen_at, metadata, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, 'active', 'unknown', 'unknown', 'healthy', NULL, $6, $7, $8::jsonb, $9, $10)
ON CONFLICT (node_key) DO UPDATE SET
name = EXCLUDED.name,
ownership_type = EXCLUDED.ownership_type,
owner_organization_id = EXCLUDED.owner_organization_id,
registration_status = 'active',
reported_version = EXCLUDED.reported_version,
last_seen_at = EXCLUDED.last_seen_at,
metadata = EXCLUDED.metadata,
updated_at = EXCLUDED.updated_at
RETURNING id
`, nodeID, payload.OwnerOrganizationID, payload.NodeKey, payload.Name, payload.OwnershipType, payload.ReportedVersion, now, []byte(payload.Metadata), now, now).Scan(&nodeID); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{
"node_id": nodeID,
"status": "registered",
"legacy": true,
"warning": "direct node-agent registration is retained for compatibility; production enrollment must use /node-agents/enroll",
})
}
func (m *Module) reportHealth(w http.ResponseWriter, r *http.Request) {
var payload struct {
HealthStatus string `json:"health_status"`
ReportedVersion *string `json:"reported_version"`
Metadata json.RawMessage `json:"metadata"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid node health payload")
return
}
if payload.HealthStatus == "" {
payload.HealthStatus = "unknown"
}
if len(payload.Metadata) == 0 {
payload.Metadata = json.RawMessage(`{}`)
}
if _, err := m.db.Exec(r.Context(), `
UPDATE nodes
SET health_status = $2, reported_version = COALESCE($3, reported_version), last_seen_at = $4, metadata = $5::jsonb, updated_at = $4
WHERE id = $1
`, chi.URLParam(r, "nodeID"), payload.HealthStatus, payload.ReportedVersion, time.Now().UTC(), []byte(payload.Metadata)); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{"status": "accepted"})
}
func (m *Module) reportServiceStatus(w http.ResponseWriter, r *http.Request) {
var payload struct {
Services []struct {
ServiceType string `json:"service_type"`
ReportedState string `json:"reported_state"`
Metadata json.RawMessage `json:"metadata"`
} `json:"services"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid node service status payload")
return
}
now := time.Now().UTC()
for _, service := range payload.Services {
if len(service.Metadata) == 0 {
service.Metadata = json.RawMessage(`{}`)
}
if _, err := m.db.Exec(r.Context(), `
INSERT INTO node_services (
node_id, service_type, enabled, desired_state, reported_state, last_reported_at, metadata, updated_at
) VALUES ($1, $2, FALSE, 'disabled', $3, $4, $5::jsonb, $4)
ON CONFLICT (node_id, service_type) DO UPDATE SET
reported_state = EXCLUDED.reported_state,
last_reported_at = EXCLUDED.last_reported_at,
metadata = EXCLUDED.metadata,
updated_at = EXCLUDED.updated_at
`, chi.URLParam(r, "nodeID"), service.ServiceType, service.ReportedState, now, []byte(service.Metadata)); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
}
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{"status": "accepted"})
}
func (m *Module) listVPNAssignments(w http.ResponseWriter, r *http.Request) {
items, err := m.cluster.ListNodeVPNAssignments(r.Context(), chi.URLParam(r, "clusterID"), chi.URLParam(r, "nodeID"))
if writeClusterServiceError(w, err) {
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{
"vpn_assignments": items,
"runtime_execution_enabled": false,
})
}
func (m *Module) reportVPNAssignmentStatus(w http.ResponseWriter, r *http.Request) {
var payload struct {
ObservedStatus string `json:"observed_status"`
StatusPayload json.RawMessage `json:"status_payload"`
ObservedAt *time.Time `json:"observed_at"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid vpn assignment status payload")
return
}
observedAt := time.Time{}
if payload.ObservedAt != nil {
observedAt = *payload.ObservedAt
}
item, err := m.cluster.ReportNodeVPNAssignmentStatus(r.Context(), clustermodule.ReportNodeVPNAssignmentStatusInput{
ClusterID: chi.URLParam(r, "clusterID"),
NodeID: chi.URLParam(r, "nodeID"),
VPNConnectionID: chi.URLParam(r, "vpnConnectionID"),
ObservedStatus: payload.ObservedStatus,
StatusPayload: payload.StatusPayload,
ObservedAt: observedAt,
})
if writeClusterServiceError(w, err) {
return
}
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{
"vpn_assignment_status": item,
"runtime_execution_enabled": false,
})
}
func (m *Module) requestUpdateManifest(w http.ResponseWriter, r *http.Request) {
nodeID := chi.URLParam(r, "nodeID")
var mode, channel string
var canary, automatic bool
var desiredVersion *string
if err := m.db.QueryRow(r.Context(), `
SELECT n.desired_version, p.mode, p.channel, p.canary, p.automatic_rollout
FROM nodes n
LEFT JOIN node_update_policies p ON p.node_id = n.id
WHERE n.id = $1
`, nodeID).Scan(&desiredVersion, &mode, &channel, &canary, &automatic); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{
"manifest": map[string]any{
"node_id": nodeID,
"desired_version": desiredVersion,
"mode": mode,
"channel": channel,
"canary": canary,
"automatic_rollout": automatic,
},
})
}
func (m *Module) acknowledgeUpdateResult(w http.ResponseWriter, r *http.Request) {
m.recordUpdateRun(w, r, "update")
}
func (m *Module) reportRollbackResult(w http.ResponseWriter, r *http.Request) {
m.recordUpdateRun(w, r, "rollback")
}
func (m *Module) recordUpdateRun(w http.ResponseWriter, r *http.Request, action string) {
var payload struct {
TargetVersion string `json:"target_version"`
Status string `json:"status"`
Payload json.RawMessage `json:"payload"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid update result payload")
return
}
if payload.Status == "" {
payload.Status = "acknowledged"
}
if len(payload.Payload) == 0 {
payload.Payload = json.RawMessage(`{}`)
}
now := time.Now().UTC()
runID := uuid.NewString()
if _, err := m.db.Exec(r.Context(), `
INSERT INTO node_agent_update_runs (
id, node_id, action, target_version, status, requested_at, acknowledged_at, completed_at, payload
) VALUES ($1, $2, $3, $4, $5, $6, $6, CASE WHEN $5 IN ('succeeded', 'failed') THEN $6 ELSE NULL END, $7::jsonb)
`, runID, chi.URLParam(r, "nodeID"), action, payload.TargetVersion, payload.Status, now, []byte(payload.Payload)); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if action == "update" && payload.Status == "succeeded" {
_, _ = m.db.Exec(r.Context(), `
UPDATE nodes
SET reported_version = $2, version_state = 'current', updated_at = $3
WHERE id = $1
`, chi.URLParam(r, "nodeID"), payload.TargetVersion, now)
}
if action == "rollback" && payload.Status == "succeeded" {
_, _ = m.db.Exec(r.Context(), `
UPDATE nodes
SET reported_version = $2, version_state = 'rollback', updated_at = $3
WHERE id = $1
`, chi.URLParam(r, "nodeID"), payload.TargetVersion, now)
}
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{"status": "accepted", "run_id": runID})
}
func writeClusterServiceError(w http.ResponseWriter, err error) bool {
if err == nil {
return false
}
switch {
case errors.Is(err, clustermodule.ErrVPNLeaseOwnerNotAllowed), errors.Is(err, clustermodule.ErrVPNLeaseOwnerRoleRequired):
httpx.WriteError(w, http.StatusForbidden, err.Error())
case errors.Is(err, clustermodule.ErrInvalidPayload):
httpx.WriteError(w, http.StatusBadRequest, err.Error())
default:
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
}
return true
}
@@ -0,0 +1,21 @@
package organization
import "testing"
func TestTenantSafeTopologyExposureDoesNotExposeCoreMesh(t *testing.T) {
value := tenantSafeTopologyExposure()
forbidden := []string{
"core_node_id",
"mesh_route",
"cluster_private_topology",
"certificate_serial",
}
for _, token := range forbidden {
if value == token {
t.Fatalf("topology exposure leaked forbidden token %q", token)
}
}
if value != "tenant_safe_no_core_mesh_topology" {
t.Fatalf("unexpected topology exposure marker: %q", value)
}
}
@@ -0,0 +1,518 @@
package organization
import (
"context"
"encoding/json"
"errors"
"net/http"
"strings"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/example/remote-access-platform/backend/internal/platform/authority"
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
"github.com/example/remote-access-platform/backend/internal/platform/module"
)
const (
RoleOrgOwner = "org_owner"
RoleOrgAdmin = "org_admin"
RoleOrgOperator = "org_operator"
RoleOrgMember = "org_member"
RoleOrgViewer = "org_viewer"
)
type Module struct {
db *pgxpool.Pool
authority *authority.Verifier
}
type Organization struct {
ID string `json:"id"`
Slug string `json:"slug"`
Name string `json:"name"`
Status string `json:"status"`
Metadata json.RawMessage `json:"metadata"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type Membership struct {
ID string `json:"id"`
OrganizationID string `json:"organization_id"`
UserID string `json:"user_id"`
RoleID string `json:"role_id"`
Status string `json:"status"`
InvitedByUser *string `json:"invited_by_user_id,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type AdminSummary struct {
OrganizationID string `json:"organization_id"`
ResourceCount int64 `json:"resource_count"`
ActiveSessionCount int64 `json:"active_session_count"`
ServiceEndpoints []ServiceSummary `json:"service_endpoints"`
ConnectorStatus map[string]any `json:"connector_status"`
RecentAudit []OrgAuditEvent `json:"recent_audit"`
TopologyExposure string `json:"topology_exposure"`
}
type ServiceSummary struct {
Protocol string `json:"protocol"`
Count int64 `json:"count"`
}
type OrgAuditEvent struct {
ID string `json:"id"`
EventType string `json:"event_type"`
TargetType string `json:"target_type"`
TargetID string `json:"target_id"`
Payload json.RawMessage `json:"payload"`
CreatedAt time.Time `json:"created_at"`
}
type createOrganizationRequest struct {
ActorUserID string `json:"actor_user_id"`
Slug string `json:"slug"`
Name string `json:"name"`
Metadata json.RawMessage `json:"metadata"`
}
type addMembershipRequest struct {
ActorUserID string `json:"actor_user_id"`
UserID string `json:"user_id"`
RoleID string `json:"role_id"`
}
func NewModule(deps module.Dependencies) *Module {
authorityVerifier, _ := authority.NewVerifier(deps.Config.Installation)
return &Module{db: deps.Infra.DB, authority: authorityVerifier}
}
func (m *Module) Name() string {
return "organization"
}
func (m *Module) RegisterRoutes(router chi.Router) {
router.Route("/organizations", func(r chi.Router) {
r.Get("/", m.listOrganizations)
r.Post("/", m.createOrganization)
r.Get("/{organizationID}", m.getOrganization)
r.Get("/{organizationID}/admin-summary", m.getAdminSummary)
r.Get("/{organizationID}/memberships", m.listMemberships)
r.Post("/{organizationID}/memberships", m.addMembership)
})
}
func (m *Module) listOrganizations(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id")
if userID == "" {
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
return
}
platformRole, err := m.getPlatformRole(r.Context(), userID)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
var rows pgx.Rows
if isPlatformAdmin(platformRole) {
rows, err = m.db.Query(r.Context(), `
SELECT id, slug, name, status, metadata, created_at, updated_at
FROM organizations
ORDER BY created_at DESC
`)
} else {
rows, err = m.db.Query(r.Context(), `
SELECT o.id, o.slug, o.name, o.status, o.metadata, o.created_at, o.updated_at
FROM organizations o
INNER JOIN organization_memberships om ON om.organization_id = o.id
WHERE om.user_id = $1 AND om.status = 'active'
ORDER BY o.created_at DESC
`, userID)
}
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer rows.Close()
var organizations []Organization
for rows.Next() {
org, err := scanOrganization(rows)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
organizations = append(organizations, org)
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"organizations": organizations})
}
func (m *Module) getOrganization(w http.ResponseWriter, r *http.Request) {
orgID := chi.URLParam(r, "organizationID")
userID := r.URL.Query().Get("user_id")
if userID == "" {
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
return
}
if err := m.ensureOrgAccess(r.Context(), orgID, userID, false); err != nil {
status := http.StatusInternalServerError
if errors.Is(err, pgx.ErrNoRows) || errors.Is(err, errForbidden) {
status = http.StatusForbidden
}
httpx.WriteError(w, status, err.Error())
return
}
org, err := m.getOrganizationByID(r.Context(), orgID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
httpx.WriteError(w, http.StatusNotFound, "organization not found")
return
}
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"organization": org})
}
func (m *Module) getAdminSummary(w http.ResponseWriter, r *http.Request) {
orgID := chi.URLParam(r, "organizationID")
actorUserID := r.URL.Query().Get("actor_user_id")
if actorUserID == "" {
httpx.WriteError(w, http.StatusBadRequest, "actor_user_id is required")
return
}
if err := m.ensureOrgAccess(r.Context(), orgID, actorUserID, true); err != nil {
httpx.WriteError(w, http.StatusForbidden, err.Error())
return
}
summary, err := m.loadAdminSummary(r.Context(), orgID)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"admin_summary": summary})
}
func (m *Module) loadAdminSummary(ctx context.Context, orgID string) (AdminSummary, error) {
var resourceCount int64
if err := m.db.QueryRow(ctx, `
SELECT COUNT(*)
FROM resources
WHERE organization_id = $1::uuid
`, orgID).Scan(&resourceCount); err != nil {
return AdminSummary{}, err
}
var activeSessionCount int64
if err := m.db.QueryRow(ctx, `
SELECT COUNT(*)
FROM remote_sessions
WHERE organization_id = $1::uuid
AND state = 'active'
`, orgID).Scan(&activeSessionCount); err != nil {
return AdminSummary{}, err
}
rows, err := m.db.Query(ctx, `
SELECT protocol, COUNT(*)
FROM resources
WHERE organization_id = $1::uuid
GROUP BY protocol
ORDER BY protocol
`, orgID)
if err != nil {
return AdminSummary{}, err
}
defer rows.Close()
var services []ServiceSummary
for rows.Next() {
var item ServiceSummary
if err := rows.Scan(&item.Protocol, &item.Count); err != nil {
return AdminSummary{}, err
}
services = append(services, item)
}
if err := rows.Err(); err != nil {
return AdminSummary{}, err
}
auditRows, err := m.db.Query(ctx, `
SELECT ae.id::text, ae.event_type, ae.target_type, ae.target_id, ae.payload, ae.created_at
FROM audit_events ae
LEFT JOIN remote_sessions rs ON rs.id = ae.remote_session_id
WHERE rs.organization_id = $1::uuid
ORDER BY ae.created_at DESC
LIMIT 20
`, orgID)
if err != nil {
return AdminSummary{}, err
}
defer auditRows.Close()
var audit []OrgAuditEvent
for auditRows.Next() {
var item OrgAuditEvent
if err := auditRows.Scan(&item.ID, &item.EventType, &item.TargetType, &item.TargetID, &item.Payload, &item.CreatedAt); err != nil {
return AdminSummary{}, err
}
audit = append(audit, item)
}
if err := auditRows.Err(); err != nil {
return AdminSummary{}, err
}
return AdminSummary{
OrganizationID: orgID,
ResourceCount: resourceCount,
ActiveSessionCount: activeSessionCount,
ServiceEndpoints: services,
ConnectorStatus: map[string]any{
"vpn": "not_implemented",
"connector": "not_implemented",
},
RecentAudit: audit,
TopologyExposure: tenantSafeTopologyExposure(),
}, nil
}
func tenantSafeTopologyExposure() string {
return "tenant_safe_no_core_mesh_topology"
}
func (m *Module) createOrganization(w http.ResponseWriter, r *http.Request) {
var req createOrganizationRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid organization payload")
return
}
if req.ActorUserID == "" || req.Name == "" || req.Slug == "" {
httpx.WriteError(w, http.StatusBadRequest, "actor_user_id, slug, and name are required")
return
}
role, err := m.getPlatformRole(r.Context(), req.ActorUserID)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if !isPlatformAdmin(role) {
httpx.WriteError(w, http.StatusForbidden, "platform admin role is required")
return
}
if len(req.Metadata) == 0 {
req.Metadata = json.RawMessage(`{}`)
}
if !json.Valid(req.Metadata) {
httpx.WriteError(w, http.StatusBadRequest, "metadata must be valid json")
return
}
now := time.Now().UTC()
org := Organization{
ID: uuid.NewString(),
Slug: normalizeSlug(req.Slug),
Name: req.Name,
Status: "active",
Metadata: req.Metadata,
CreatedAt: now,
UpdatedAt: now,
}
membership := Membership{
ID: uuid.NewString(),
OrganizationID: org.ID,
UserID: req.ActorUserID,
RoleID: RoleOrgOwner,
Status: "active",
CreatedAt: now,
UpdatedAt: now,
}
tx, err := m.db.Begin(r.Context())
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer tx.Rollback(r.Context())
if _, err := tx.Exec(r.Context(), `
INSERT INTO organizations (id, slug, name, status, metadata, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5::jsonb, $6, $7)
`, org.ID, org.Slug, org.Name, org.Status, []byte(org.Metadata), org.CreatedAt, org.UpdatedAt); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if _, err := tx.Exec(r.Context(), `
INSERT INTO organization_memberships (
id, organization_id, user_id, role_id, status, invited_by_user_id, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
`, membership.ID, membership.OrganizationID, membership.UserID, membership.RoleID, membership.Status, req.ActorUserID, membership.CreatedAt, membership.UpdatedAt); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := tx.Commit(r.Context()); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusCreated, map[string]any{
"organization": org,
"membership": membership,
})
}
func (m *Module) listMemberships(w http.ResponseWriter, r *http.Request) {
orgID := chi.URLParam(r, "organizationID")
userID := r.URL.Query().Get("user_id")
if userID == "" {
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
return
}
if err := m.ensureOrgAccess(r.Context(), orgID, userID, true); err != nil {
httpx.WriteError(w, http.StatusForbidden, err.Error())
return
}
rows, err := m.db.Query(r.Context(), `
SELECT id, organization_id, user_id, role_id, status, invited_by_user_id, created_at, updated_at
FROM organization_memberships
WHERE organization_id = $1
ORDER BY created_at DESC
`, orgID)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer rows.Close()
var memberships []Membership
for rows.Next() {
membership, err := scanMembership(rows)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
memberships = append(memberships, membership)
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"memberships": memberships})
}
func (m *Module) addMembership(w http.ResponseWriter, r *http.Request) {
orgID := chi.URLParam(r, "organizationID")
var req addMembershipRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid membership payload")
return
}
if req.ActorUserID == "" || req.UserID == "" || req.RoleID == "" {
httpx.WriteError(w, http.StatusBadRequest, "actor_user_id, user_id, and role_id are required")
return
}
if err := m.ensureOrgAccess(r.Context(), orgID, req.ActorUserID, true); err != nil {
httpx.WriteError(w, http.StatusForbidden, err.Error())
return
}
now := time.Now().UTC()
membership := Membership{
ID: uuid.NewString(),
OrganizationID: orgID,
UserID: req.UserID,
RoleID: req.RoleID,
Status: "active",
InvitedByUser: &req.ActorUserID,
CreatedAt: now,
UpdatedAt: now,
}
if _, err := m.db.Exec(r.Context(), `
INSERT INTO organization_memberships (
id, organization_id, user_id, role_id, status, invited_by_user_id, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (organization_id, user_id) DO UPDATE SET
role_id = EXCLUDED.role_id,
status = 'active',
invited_by_user_id = EXCLUDED.invited_by_user_id,
updated_at = EXCLUDED.updated_at
`, membership.ID, membership.OrganizationID, membership.UserID, membership.RoleID, membership.Status, membership.InvitedByUser, membership.CreatedAt, membership.UpdatedAt); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusCreated, map[string]any{"membership": membership})
}
var errForbidden = errors.New("forbidden")
func (m *Module) ensureOrgAccess(ctx context.Context, orgID, userID string, adminRequired bool) error {
role, err := m.getPlatformRole(ctx, userID)
if err != nil {
return err
}
if isPlatformAdmin(role) {
return nil
}
query := `
SELECT role_id
FROM organization_memberships
WHERE organization_id = $1 AND user_id = $2 AND status = 'active'
`
var roleID string
if err := m.db.QueryRow(ctx, query, orgID, userID).Scan(&roleID); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return errForbidden
}
return err
}
if adminRequired && roleID != RoleOrgOwner && roleID != RoleOrgAdmin {
return errForbidden
}
return nil
}
func (m *Module) getPlatformRole(ctx context.Context, userID string) (string, error) {
return authority.EffectivePlatformRole(ctx, m.db, m.authority, userID)
}
func isPlatformAdmin(role string) bool {
return role == "platform_admin" || role == "platform_recovery_admin"
}
func (m *Module) getOrganizationByID(ctx context.Context, orgID string) (Organization, error) {
row := m.db.QueryRow(ctx, `
SELECT id, slug, name, status, metadata, created_at, updated_at
FROM organizations
WHERE id = $1
`, orgID)
return scanOrganization(row)
}
type rowScanner interface {
Scan(dest ...any) error
}
func scanOrganization(row rowScanner) (Organization, error) {
var org Organization
if err := row.Scan(&org.ID, &org.Slug, &org.Name, &org.Status, &org.Metadata, &org.CreatedAt, &org.UpdatedAt); err != nil {
return Organization{}, err
}
if len(org.Metadata) == 0 {
org.Metadata = json.RawMessage(`{}`)
}
return org, nil
}
func scanMembership(row rowScanner) (Membership, error) {
var membership Membership
if err := row.Scan(
&membership.ID,
&membership.OrganizationID,
&membership.UserID,
&membership.RoleID,
&membership.Status,
&membership.InvitedByUser,
&membership.CreatedAt,
&membership.UpdatedAt,
); err != nil {
return Membership{}, err
}
return membership, nil
}
func normalizeSlug(in string) string {
return strings.ToLower(strings.TrimSpace(in))
}
+639
View File
@@ -0,0 +1,639 @@
package resource
import (
"context"
"encoding/json"
"errors"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/example/remote-access-platform/backend/internal/platform/authority"
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
"github.com/example/remote-access-platform/backend/internal/platform/module"
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
)
const (
CertificateVerificationModeStrict = "strict"
CertificateVerificationModeIgnore = "ignore"
RenderQualityProfileLowBandwidth = "low_bandwidth"
RenderQualityProfileBalanced = "balanced"
RenderQualityProfileHighQuality = "high_quality"
RenderQualityProfileTextPriority = "text_priority"
ClipboardModeDisabled = "disabled"
ClipboardModeClientToServer = "client_to_server"
ClipboardModeServerToClient = "server_to_client"
ClipboardModeBidirectional = "bidirectional"
FileTransferModeDisabled = "disabled"
FileTransferModeClientToServer = "client_to_server"
FileTransferModeServerToClient = "server_to_client"
FileTransferModeBidirectional = "bidirectional"
)
type Module struct {
db *pgxpool.Pool
appEnv string
secretStore *secrets.ResourceSecretStore
authority *authority.Verifier
}
type Resource struct {
ID string `json:"id"`
OrganizationID string `json:"organization_id"`
Name string `json:"name"`
Address string `json:"address"`
Protocol string `json:"protocol"`
SecretRef *string `json:"secret_ref,omitempty"`
CertificateVerificationMode string `json:"certificate_verification_mode"`
RenderQualityProfile string `json:"render_quality_profile"`
ClipboardMode string `json:"clipboard_mode"`
FileTransferMode string `json:"file_transfer_mode"`
Metadata json.RawMessage `json:"metadata"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type upsertResourceRequest struct {
ActorUserID string `json:"actor_user_id"`
OrganizationID string `json:"organization_id"`
Name string `json:"name"`
Address string `json:"address"`
Protocol string `json:"protocol"`
SecretRef *string `json:"secret_ref"`
CertificateVerificationMode string `json:"certificate_verification_mode"`
RenderQualityProfile string `json:"render_quality_profile"`
ClipboardMode string `json:"clipboard_mode"`
FileTransferMode string `json:"file_transfer_mode"`
Metadata json.RawMessage `json:"metadata"`
}
type upsertResourceSecretRequest struct {
ActorUserID string `json:"actor_user_id"`
Payload json.RawMessage `json:"payload"`
Metadata json.RawMessage `json:"metadata"`
}
func NewModule(deps module.Dependencies, secretStores ...*secrets.ResourceSecretStore) *Module {
var secretStore *secrets.ResourceSecretStore
if len(secretStores) > 0 {
secretStore = secretStores[0]
}
authorityVerifier, _ := authority.NewVerifier(deps.Config.Installation)
return &Module{db: deps.Infra.DB, appEnv: deps.Config.App.Env, secretStore: secretStore, authority: authorityVerifier}
}
func (m *Module) Name() string {
return "resource"
}
func (m *Module) RegisterRoutes(router chi.Router) {
router.Route("/resources", func(r chi.Router) {
r.Get("/", m.listResources)
r.Post("/", m.createResource)
r.Get("/{resourceID}", m.getResource)
r.Put("/{resourceID}", m.updateResource)
r.Put("/{resourceID}/secret", m.upsertResourceSecret)
})
}
func (m *Module) listResources(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id")
orgID := r.URL.Query().Get("organization_id")
if userID == "" {
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
return
}
platformRole, err := m.getPlatformRole(r.Context(), userID)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
query := `
SELECT r.id, r.organization_id, r.name, r.address, r.protocol, r.secret_ref,
r.certificate_verification_mode, r.metadata, r.created_at, r.updated_at,
COALESCE(rp.clipboard_mode, 'disabled') AS clipboard_mode,
COALESCE(rp.file_transfer_mode, 'disabled') AS file_transfer_mode
FROM resources r
LEFT JOIN resource_policies rp ON rp.resource_id = r.id
`
args := make([]any, 0, 2)
if platformRole == "platform_admin" || platformRole == "platform_recovery_admin" {
if orgID != "" {
query += ` WHERE r.organization_id = $1`
args = append(args, orgID)
}
query += ` ORDER BY r.created_at DESC`
} else {
query += `
INNER JOIN organization_memberships om ON om.organization_id = r.organization_id
WHERE om.user_id = $1 AND om.status = 'active'
`
args = append(args, userID)
if orgID != "" {
query += ` AND r.organization_id = $2`
args = append(args, orgID)
}
query += ` ORDER BY r.created_at DESC`
}
rows, err := m.db.Query(r.Context(), query, args...)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer rows.Close()
resources := make([]Resource, 0)
for rows.Next() {
resource, err := scanResource(rows)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
resources = append(resources, resource)
}
if err := rows.Err(); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"resources": resources})
}
func (m *Module) getResource(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id")
if userID == "" {
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
return
}
resource, err := m.getByID(r.Context(), chi.URLParam(r, "resourceID"))
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
httpx.WriteError(w, http.StatusNotFound, "resource not found")
return
}
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := m.ensureResourceAccess(r.Context(), resource.OrganizationID, userID, false); err != nil {
httpx.WriteError(w, http.StatusForbidden, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"resource": resource})
}
func (m *Module) createResource(w http.ResponseWriter, r *http.Request) {
req, err := decodeUpsertRequest(r)
if err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
if err := secrets.ValidateResourceSecretReadiness(req.Protocol, req.SecretRef, req.Metadata, m.appEnv); err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
now := time.Now().UTC()
resource := Resource{
ID: uuid.NewString(),
OrganizationID: req.OrganizationID,
Name: req.Name,
Address: req.Address,
Protocol: req.Protocol,
SecretRef: req.SecretRef,
CertificateVerificationMode: req.CertificateVerificationMode,
RenderQualityProfile: req.RenderQualityProfile,
ClipboardMode: req.ClipboardMode,
FileTransferMode: req.FileTransferMode,
Metadata: req.Metadata,
CreatedAt: now,
UpdatedAt: now,
}
if err := m.ensureResourceAccess(r.Context(), req.OrganizationID, req.ActorUserID, true); err != nil {
httpx.WriteError(w, http.StatusForbidden, err.Error())
return
}
tx, err := m.db.Begin(r.Context())
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer tx.Rollback(r.Context())
if _, err := tx.Exec(r.Context(), `
INSERT INTO resources (
id, organization_id, name, address, protocol, secret_ref, certificate_verification_mode, metadata, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9, $10)
`, resource.ID, resource.OrganizationID, resource.Name, resource.Address, resource.Protocol, resource.SecretRef, resource.CertificateVerificationMode, []byte(resource.Metadata), resource.CreatedAt, resource.UpdatedAt); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := upsertResourcePolicy(r.Context(), tx, resource.ID, resource.ClipboardMode, resource.FileTransferMode, now); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := tx.Commit(r.Context()); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusCreated, map[string]any{"resource": resource})
}
func (m *Module) updateResource(w http.ResponseWriter, r *http.Request) {
req, err := decodeUpsertRequest(r)
if err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
if err := secrets.ValidateResourceSecretReadiness(req.Protocol, req.SecretRef, req.Metadata, m.appEnv); err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
resourceID := chi.URLParam(r, "resourceID")
existing, err := m.getByID(r.Context(), resourceID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
httpx.WriteError(w, http.StatusNotFound, "resource not found")
return
}
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := m.ensureResourceAccess(r.Context(), existing.OrganizationID, req.ActorUserID, true); err != nil {
httpx.WriteError(w, http.StatusForbidden, err.Error())
return
}
if req.OrganizationID != existing.OrganizationID {
if err := m.ensureResourceAccess(r.Context(), req.OrganizationID, req.ActorUserID, true); err != nil {
httpx.WriteError(w, http.StatusForbidden, err.Error())
return
}
}
now := time.Now().UTC()
tx, err := m.db.Begin(r.Context())
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer tx.Rollback(r.Context())
tag, err := tx.Exec(r.Context(), `
UPDATE resources
SET
organization_id = $2,
name = $3,
address = $4,
protocol = $5,
secret_ref = $6,
certificate_verification_mode = $7,
metadata = $8::jsonb,
updated_at = $9
WHERE id = $1
`, resourceID, req.OrganizationID, req.Name, req.Address, req.Protocol, req.SecretRef, req.CertificateVerificationMode, []byte(req.Metadata), now)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if tag.RowsAffected() == 0 {
httpx.WriteError(w, http.StatusNotFound, "resource not found")
return
}
if err := upsertResourcePolicy(r.Context(), tx, resourceID, req.ClipboardMode, req.FileTransferMode, now); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := tx.Commit(r.Context()); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
resource, err := m.getByID(r.Context(), resourceID)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"resource": resource})
}
func (m *Module) upsertResourceSecret(w http.ResponseWriter, r *http.Request) {
if m.secretStore == nil {
httpx.WriteError(w, http.StatusServiceUnavailable, "resource secret encryption is not configured")
return
}
var req upsertResourceSecretRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid resource secret payload")
return
}
if req.ActorUserID == "" {
httpx.WriteError(w, http.StatusBadRequest, "actor_user_id is required")
return
}
resourceID := chi.URLParam(r, "resourceID")
resource, err := m.getByID(r.Context(), resourceID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
httpx.WriteError(w, http.StatusNotFound, "resource not found")
return
}
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := m.ensureResourceAccess(r.Context(), resource.OrganizationID, req.ActorUserID, true); err != nil {
httpx.WriteError(w, http.StatusForbidden, err.Error())
return
}
tx, err := m.db.Begin(r.Context())
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer tx.Rollback(r.Context())
secretStore := m.secretStore.WithDB(tx)
secretRef := secrets.DefaultResourceSecretRef(resource.OrganizationID, resource.ID)
descriptor, err := secretStore.Upsert(r.Context(), secrets.UpsertResourceSecretCommand{
OrganizationID: resource.OrganizationID,
ResourceID: resource.ID,
Protocol: resource.Protocol,
SecretRef: secretRef,
Payload: req.Payload,
Metadata: req.Metadata,
ActorUserID: req.ActorUserID,
})
if err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
if _, err := tx.Exec(r.Context(), `
UPDATE resources
SET secret_ref = $2, updated_at = $3
WHERE id = $1::uuid
`, resource.ID, descriptor.SecretRef, time.Now().UTC()); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := writeAuditEvent(r.Context(), tx, "resource_secret_rotated", req.ActorUserID, "resource_secret", descriptor.SecretRef, map[string]any{
"resource_id": resource.ID,
"organization_id": resource.OrganizationID,
"protocol": resource.Protocol,
"version": descriptor.Version,
"secret_ref": descriptor.SecretRef,
}); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := tx.Commit(r.Context()); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"secret": descriptor})
}
func decodeUpsertRequest(r *http.Request) (*upsertResourceRequest, error) {
var req upsertResourceRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.New("invalid resource payload")
}
if req.Name == "" {
return nil, errors.New("name is required")
}
if req.ActorUserID == "" {
return nil, errors.New("actor_user_id is required")
}
if req.OrganizationID == "" {
return nil, errors.New("organization_id is required")
}
if req.Address == "" {
return nil, errors.New("address is required")
}
if req.Protocol == "" {
req.Protocol = "rdp"
}
mode, err := normalizeCertificateVerificationMode(req.CertificateVerificationMode)
if err != nil {
return nil, err
}
req.CertificateVerificationMode = mode
renderQualityProfile, err := normalizeRenderQualityProfile(req.RenderQualityProfile)
if err != nil {
return nil, err
}
req.RenderQualityProfile = renderQualityProfile
clipboardMode, err := normalizeClipboardMode(req.ClipboardMode)
if err != nil {
return nil, err
}
req.ClipboardMode = clipboardMode
fileTransferMode, err := normalizeFileTransferMode(req.FileTransferMode)
if err != nil {
return nil, err
}
req.FileTransferMode = fileTransferMode
metadata, err := normalizeMetadata(req.Metadata, req.CertificateVerificationMode, req.RenderQualityProfile)
if err != nil {
return nil, err
}
req.Metadata = metadata
return &req, nil
}
func normalizeCertificateVerificationMode(mode string) (string, error) {
switch mode {
case "", CertificateVerificationModeStrict:
return CertificateVerificationModeStrict, nil
case CertificateVerificationModeIgnore:
return CertificateVerificationModeIgnore, nil
default:
return "", errors.New("certificate_verification_mode must be one of: strict, ignore")
}
}
func normalizeClipboardMode(mode string) (string, error) {
switch mode {
case "", ClipboardModeDisabled:
return ClipboardModeDisabled, nil
case ClipboardModeClientToServer, ClipboardModeServerToClient, ClipboardModeBidirectional:
return mode, nil
default:
return "", errors.New("clipboard_mode must be one of: disabled, client_to_server, server_to_client, bidirectional")
}
}
func normalizeFileTransferMode(mode string) (string, error) {
switch mode {
case "", FileTransferModeDisabled:
return FileTransferModeDisabled, nil
case FileTransferModeClientToServer, FileTransferModeServerToClient, FileTransferModeBidirectional:
return mode, nil
default:
return "", errors.New("file_transfer_mode must be one of: disabled, client_to_server, server_to_client, bidirectional")
}
}
func normalizeMetadata(raw json.RawMessage, certificateVerificationMode, renderQualityProfile string) (json.RawMessage, error) {
if len(raw) == 0 {
raw = json.RawMessage(`{}`)
}
if !json.Valid(raw) {
return nil, errors.New("metadata must be valid json")
}
var metadata map[string]any
if err := json.Unmarshal(raw, &metadata); err != nil {
return nil, errors.New("metadata must be a json object")
}
metadata["certificate_verification_mode"] = certificateVerificationMode
metadata["render_quality_profile"] = renderQualityProfile
encoded, err := json.Marshal(metadata)
if err != nil {
return nil, err
}
return json.RawMessage(encoded), nil
}
func (m *Module) getByID(ctx context.Context, resourceID string) (Resource, error) {
row := m.db.QueryRow(ctx, `
SELECT r.id, r.organization_id, r.name, r.address, r.protocol, r.secret_ref,
r.certificate_verification_mode, r.metadata, r.created_at, r.updated_at,
COALESCE(rp.clipboard_mode, 'disabled') AS clipboard_mode,
COALESCE(rp.file_transfer_mode, 'disabled') AS file_transfer_mode
FROM resources r
LEFT JOIN resource_policies rp ON rp.resource_id = r.id
WHERE r.id = $1
`, resourceID)
return scanResource(row)
}
func (m *Module) ensureResourceAccess(ctx context.Context, orgID, userID string, adminRequired bool) error {
role, err := m.getPlatformRole(ctx, userID)
if err != nil {
return err
}
if role == "platform_admin" || role == "platform_recovery_admin" {
return nil
}
var membershipRole string
if err := m.db.QueryRow(ctx, `
SELECT role_id
FROM organization_memberships
WHERE organization_id = $1 AND user_id = $2 AND status = 'active'
`, orgID, userID).Scan(&membershipRole); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return errors.New("forbidden")
}
return err
}
if adminRequired && membershipRole != "org_owner" && membershipRole != "org_admin" {
return errors.New("forbidden")
}
return nil
}
func (m *Module) getPlatformRole(ctx context.Context, userID string) (string, error) {
return authority.EffectivePlatformRole(ctx, m.db, m.authority, userID)
}
type rowScanner interface {
Scan(dest ...any) error
}
func scanResource(row rowScanner) (Resource, error) {
var resource Resource
if err := row.Scan(
&resource.ID,
&resource.OrganizationID,
&resource.Name,
&resource.Address,
&resource.Protocol,
&resource.SecretRef,
&resource.CertificateVerificationMode,
&resource.Metadata,
&resource.CreatedAt,
&resource.UpdatedAt,
&resource.ClipboardMode,
&resource.FileTransferMode,
); err != nil {
return Resource{}, err
}
if len(resource.Metadata) == 0 {
resource.Metadata = json.RawMessage(`{}`)
}
if resource.CertificateVerificationMode == "" {
resource.CertificateVerificationMode = CertificateVerificationModeStrict
}
if resource.RenderQualityProfile == "" {
resource.RenderQualityProfile = renderQualityProfileFromMetadata(resource.Metadata)
}
if resource.ClipboardMode == "" {
resource.ClipboardMode = ClipboardModeDisabled
}
if resource.FileTransferMode == "" {
resource.FileTransferMode = FileTransferModeDisabled
}
return resource, nil
}
func upsertResourcePolicy(ctx context.Context, tx pgx.Tx, resourceID, clipboardMode, fileTransferMode string, now time.Time) error {
clipboardEnabled := clipboardMode != ClipboardModeDisabled
fileTransferEnabled := fileTransferMode == FileTransferModeClientToServer || fileTransferMode == FileTransferModeBidirectional
_, err := tx.Exec(ctx, `
INSERT INTO resource_policies (
resource_id, clipboard_enabled, clipboard_mode, file_transfer_enabled, file_transfer_mode, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $6)
ON CONFLICT (resource_id) DO UPDATE SET
clipboard_enabled = EXCLUDED.clipboard_enabled,
clipboard_mode = EXCLUDED.clipboard_mode,
file_transfer_enabled = EXCLUDED.file_transfer_enabled,
file_transfer_mode = EXCLUDED.file_transfer_mode,
updated_at = EXCLUDED.updated_at
`, resourceID, clipboardEnabled, clipboardMode, fileTransferEnabled, fileTransferMode, now)
return err
}
func writeAuditEvent(ctx context.Context, tx pgx.Tx, eventType, actorUserID, targetType, targetID string, payload map[string]any) error {
encoded, err := json.Marshal(payload)
if err != nil {
return err
}
_, err = tx.Exec(ctx, `
INSERT INTO audit_events (
id, actor_user_id, event_type, target_type, target_id, payload, created_at
) VALUES (
$1::uuid, NULLIF($2, '')::uuid, $3, $4, $5, $6::jsonb, $7
)
`, uuid.NewString(), actorUserID, eventType, targetType, targetID, encoded, time.Now().UTC())
return err
}
func normalizeRenderQualityProfile(profile string) (string, error) {
switch profile {
case "", RenderQualityProfileBalanced:
return RenderQualityProfileBalanced, nil
case RenderQualityProfileLowBandwidth, RenderQualityProfileHighQuality, RenderQualityProfileTextPriority:
return profile, nil
default:
return "", errors.New("render_quality_profile must be one of: low_bandwidth, balanced, high_quality, text_priority")
}
}
func renderQualityProfileFromMetadata(raw json.RawMessage) string {
if len(raw) == 0 {
return RenderQualityProfileBalanced
}
var metadata map[string]any
if err := json.Unmarshal(raw, &metadata); err != nil {
return RenderQualityProfileBalanced
}
if profile, ok := metadata["render_quality_profile"].(string); ok {
switch profile {
case RenderQualityProfileLowBandwidth, RenderQualityProfileBalanced, RenderQualityProfileHighQuality, RenderQualityProfileTextPriority:
return profile
}
}
return RenderQualityProfileBalanced
}
@@ -0,0 +1,219 @@
package sessionbroker
import (
"crypto/rsa"
"fmt"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
)
const (
directWorkerTLSTrustModeSmokeInsecure = "smoke_insecure"
directWorkerTLSTrustModePublicCA = "public_ca"
directWorkerTLSTrustModePlatformCA = "platform_ca"
)
type DataPlaneTokenClaims struct {
SessionID string `json:"session_id"`
AttachmentID string `json:"attachment_id"`
UserID string `json:"user_id"`
OrganizationID string `json:"organization_id"`
ClusterID string `json:"cluster_id,omitempty"`
WorkerID string `json:"worker_id"`
ResourceID string `json:"resource_id"`
AllowedChannels []string `json:"allowed_channels"`
ExpiresAtValue time.Time `json:"expires_at"`
jwt.RegisteredClaims
}
func (s *Service) buildDataPlaneOffer(session RemoteSession, attachment SessionAttachment) (*sessioncontracts.DataPlaneOffer, error) {
if s.cfg.DataPlane.TokenTTL <= 0 || s.cfg.DataPlane.TokenPrivateKeyPEM == "" {
return nil, nil
}
now := s.now().UTC()
expiresAt := now.Add(s.cfg.DataPlane.TokenTTL)
allowedChannels := dataPlaneAllowedChannelsFromSession(session)
jti := uuid.NewString()
claims := DataPlaneTokenClaims{
SessionID: session.ID,
AttachmentID: attachment.ID,
UserID: attachment.UserID,
OrganizationID: session.OrganizationID,
WorkerID: session.WorkerID,
ResourceID: session.ResourceID,
AllowedChannels: allowedChannels,
ExpiresAtValue: expiresAt,
RegisteredClaims: jwt.RegisteredClaims{
ID: jti,
Issuer: s.cfg.Auth.Issuer,
Subject: attachment.UserID,
Audience: jwt.ClaimStrings{"rap-data-plane", "worker:" + session.WorkerID},
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(expiresAt),
},
}
token, err := signDataPlaneToken(claims, s.cfg.DataPlane.TokenPrivateKeyPEM)
if err != nil {
return nil, err
}
candidates := s.buildDataPlaneCandidates(session)
preferred := sessioncontracts.DataPlaneCandidateBackendGateway
if len(candidates) > 0 {
preferred = candidates[0].Type
}
return &sessioncontracts.DataPlaneOffer{
Preferred: preferred,
Token: token,
ExpiresAt: expiresAt,
Candidates: candidates,
}, nil
}
func (s *Service) buildDataPlaneCandidates(session RemoteSession) []sessioncontracts.DataPlaneCandidate {
var candidates []sessioncontracts.DataPlaneCandidate
if directURL := s.directWorkerWSSURL(session.WorkerID); directURL != "" && s.canAdvertiseDirectWorkerWSS() {
metadata := map[string]any(nil)
if s.cfg.DataPlane.DirectWorkerJSONRuntime {
metadata = map[string]any{
"runtime_transport": "json_v1",
"traffic_ready": true,
}
s.addDirectWorkerTLSTrustMetadata(metadata)
if s.cfg.DataPlane.DirectWorkerBinaryRender {
metadata["render_transport"] = "binary_v1"
metadata["binary_render"] = true
metadata["supported_color_modes"] = []string{"full_color", "grayscale"}
metadata["default_color_mode"] = "full_color"
}
}
candidates = append(candidates, sessioncontracts.DataPlaneCandidate{
Type: sessioncontracts.DataPlaneCandidateDirectWorkerWSS,
URL: directURL,
WorkerID: session.WorkerID,
Priority: 10,
Metadata: metadata,
})
}
if s.cfg.DataPlane.BackendGatewayURL != "" {
candidates = append(candidates, sessioncontracts.DataPlaneCandidate{
Type: sessioncontracts.DataPlaneCandidateBackendGateway,
URL: s.cfg.DataPlane.BackendGatewayURL,
Priority: 100,
})
}
return candidates
}
func (s *Service) canAdvertiseDirectWorkerWSS() bool {
trustMode := normalizeDirectWorkerTLSTrustMode(s.cfg.DataPlane.DirectWorkerTLSTrustMode)
return !secrets.IsProductionEnv(s.cfg.App.Env) || directWorkerTLSTrustModeIsProductionTrusted(trustMode)
}
func (s *Service) addDirectWorkerTLSTrustMetadata(metadata map[string]any) {
trustMode := normalizeDirectWorkerTLSTrustMode(s.cfg.DataPlane.DirectWorkerTLSTrustMode)
metadata["tls_trust_mode"] = trustMode
metadata["production_trusted"] = directWorkerTLSTrustModeIsProductionTrusted(trustMode)
metadata["smoke_only"] = trustMode == directWorkerTLSTrustModeSmokeInsecure
if s.cfg.DataPlane.DirectWorkerTLSCARef != "" {
metadata["tls_ca_ref"] = s.cfg.DataPlane.DirectWorkerTLSCARef
}
}
func normalizeDirectWorkerTLSTrustMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case directWorkerTLSTrustModePublicCA:
return directWorkerTLSTrustModePublicCA
case directWorkerTLSTrustModePlatformCA:
return directWorkerTLSTrustModePlatformCA
default:
return directWorkerTLSTrustModeSmokeInsecure
}
}
func directWorkerTLSTrustModeIsProductionTrusted(mode string) bool {
return mode == directWorkerTLSTrustModePublicCA || mode == directWorkerTLSTrustModePlatformCA
}
func (s *Service) directWorkerWSSURL(workerID string) string {
template := strings.TrimSpace(s.cfg.DataPlane.DirectWorkerWSSURLTemplate)
if template == "" || workerID == "" {
return ""
}
return strings.ReplaceAll(template, "{worker_id}", workerID)
}
func signDataPlaneToken(claims DataPlaneTokenClaims, privateKeyPEM string) (string, error) {
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(privateKeyPEM))
if err != nil {
return "", fmt.Errorf("parse data-plane private key: %w", err)
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
signed, err := token.SignedString(privateKey)
if err != nil {
return "", fmt.Errorf("sign data-plane token: %w", err)
}
return signed, nil
}
func parseDataPlaneToken(tokenValue string, publicKey *rsa.PublicKey) (*DataPlaneTokenClaims, error) {
claims := &DataPlaneTokenClaims{}
token, err := jwt.ParseWithClaims(tokenValue, claims, func(token *jwt.Token) (any, error) {
if token.Method != jwt.SigningMethodRS256 {
return nil, fmt.Errorf("unexpected data-plane signing method: %s", token.Header["alg"])
}
return publicKey, nil
})
if err != nil {
return nil, err
}
if !token.Valid {
return nil, fmt.Errorf("data-plane token invalid")
}
return claims, nil
}
func dataPlaneAllowedChannelsFromSession(session RemoteSession) []string {
channels := []string{
sessioncontracts.DataPlaneChannelControl,
sessioncontracts.DataPlaneChannelInput,
sessioncontracts.DataPlaneChannelRender,
sessioncontracts.DataPlaneChannelTelemetry,
}
metadata := decodeJSONMap(session.Metadata)
policy, _ := metadata["policy"].(map[string]any)
if policy != nil {
if mode, _ := policy["clipboard_mode"].(string); mode != "" && mode != string(ResourceClipboardModeDisabled) {
channels = append(channels, sessioncontracts.DataPlaneChannelClipboard)
}
if mode, _ := policy["file_transfer_mode"].(string); fileTransferAllowsClientToServer(ResourceFileTransferMode(mode)) {
channels = append(channels, sessioncontracts.DataPlaneChannelFileUpload)
}
if mode, _ := policy["file_transfer_mode"].(string); fileTransferAllowsServerToClient(ResourceFileTransferMode(mode)) {
channels = append(channels, sessioncontracts.DataPlaneChannelFileDownload)
}
}
return channels
}
func (s *Service) attachDataPlaneOffer(result *SessionControlResult) error {
if result == nil || result.Attachment == nil {
return nil
}
result.GatewayURL = s.cfg.DataPlane.BackendGatewayURL
offer, err := s.buildDataPlaneOffer(result.Session, *result.Attachment)
if err != nil {
return err
}
result.DataPlane = offer
return nil
}
@@ -0,0 +1,357 @@
package sessionbroker
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"slices"
"strings"
"testing"
"time"
"github.com/example/remote-access-platform/backend/internal/platform/config"
"github.com/example/remote-access-platform/backend/internal/platform/module"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
)
func TestDataPlaneTokenScopeValidation(t *testing.T) {
now := time.Now().UTC().Truncate(time.Second)
privateKeyPEM, publicKey := testRS256Key(t)
service := &Service{
cfg: module.Config{
Auth: config.AuthConfig{
Issuer: "rap-api-test",
},
DataPlane: config.DataPlaneConfig{
TokenTTL: time.Minute,
TokenPrivateKeyPEM: privateKeyPEM,
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
},
},
now: func() time.Time { return now },
}
session := RemoteSession{
ID: "session-1",
OrganizationID: "org-1",
ResourceID: "resource-1",
WorkerID: "worker-1",
Metadata: mustJSON(t, map[string]any{"policy": map[string]any{"clipboard_mode": "bidirectional", "file_transfer_mode": "client_to_server"}}),
}
attachment := SessionAttachment{
ID: "attachment-1",
UserID: "user-1",
}
offer, err := service.buildDataPlaneOffer(session, attachment)
if err != nil {
t.Fatalf("buildDataPlaneOffer returned error: %v", err)
}
if offer == nil {
t.Fatal("expected data-plane offer")
}
claims, err := parseDataPlaneToken(offer.Token, publicKey)
if err != nil {
t.Fatalf("parseDataPlaneToken returned error: %v", err)
}
assertEqual(t, claims.SessionID, session.ID, "session_id")
assertEqual(t, claims.AttachmentID, attachment.ID, "attachment_id")
assertEqual(t, claims.UserID, attachment.UserID, "user_id")
assertEqual(t, claims.OrganizationID, session.OrganizationID, "organization_id")
assertEqual(t, claims.WorkerID, session.WorkerID, "worker_id")
assertEqual(t, claims.ResourceID, session.ResourceID, "resource_id")
if claims.ID == "" {
t.Fatal("expected jti")
}
if claims.ExpiresAt == nil || !claims.ExpiresAt.Time.Equal(now.Add(time.Minute)) {
t.Fatalf("unexpected expires_at: %v", claims.ExpiresAt)
}
if !claims.ExpiresAtValue.Equal(now.Add(time.Minute)) {
t.Fatalf("unexpected expires_at claim value: %v", claims.ExpiresAtValue)
}
for _, channel := range []string{
sessioncontracts.DataPlaneChannelControl,
sessioncontracts.DataPlaneChannelInput,
sessioncontracts.DataPlaneChannelRender,
sessioncontracts.DataPlaneChannelTelemetry,
sessioncontracts.DataPlaneChannelClipboard,
sessioncontracts.DataPlaneChannelFileUpload,
} {
if !slices.Contains(claims.AllowedChannels, channel) {
t.Fatalf("expected allowed channel %q in %v", channel, claims.AllowedChannels)
}
}
}
func TestDataPlaneOfferResponseShapeCompatibility(t *testing.T) {
now := time.Now().UTC().Truncate(time.Second)
privateKeyPEM, _ := testRS256Key(t)
service := &Service{
cfg: module.Config{
Auth: config.AuthConfig{Issuer: "rap-api-test"},
DataPlane: config.DataPlaneConfig{
TokenTTL: time.Minute,
TokenPrivateKeyPEM: privateKeyPEM,
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
DirectWorkerJSONRuntime: true,
DirectWorkerTLSTrustMode: "smoke_insecure",
},
},
now: func() time.Time { return now },
}
result := &SessionControlResult{
Session: RemoteSession{
ID: "session-1",
OrganizationID: "org-1",
ResourceID: "resource-1",
WorkerID: "worker-1",
Metadata: mustJSON(t, map[string]any{"policy": map[string]any{"clipboard_mode": "disabled", "file_transfer_mode": "disabled"}}),
},
Attachment: &SessionAttachment{ID: "attachment-1", UserID: "user-1"},
AttachToken: &sessioncontracts.AttachTokenClaims{
Token: "existing-attach-token",
SessionID: "session-1",
AttachmentID: "attachment-1",
UserID: "user-1",
WorkerID: "worker-1",
ExpiresAt: now.Add(2 * time.Minute),
},
}
if err := service.attachDataPlaneOffer(result); err != nil {
t.Fatalf("attachDataPlaneOffer returned error: %v", err)
}
payload, err := json.Marshal(result)
if err != nil {
t.Fatalf("marshal response: %v", err)
}
var decoded map[string]any
if err := json.Unmarshal(payload, &decoded); err != nil {
t.Fatalf("decode response: %v", err)
}
if decoded["session"] == nil || decoded["attachment"] == nil || decoded["attach_token"] == nil {
t.Fatalf("response lost existing fields: %s", payload)
}
if decoded["data_plane"] == nil || decoded["gateway_url"] == nil {
t.Fatalf("response missing data-plane fields: %s", payload)
}
if result.DataPlane == nil {
t.Fatal("expected data-plane offer")
}
if result.DataPlane.Preferred != sessioncontracts.DataPlaneCandidateDirectWorkerWSS {
t.Fatalf("unexpected preferred candidate: %s", result.DataPlane.Preferred)
}
if len(result.DataPlane.Candidates) != 2 {
t.Fatalf("expected direct and fallback candidates, got %d", len(result.DataPlane.Candidates))
}
if result.DataPlane.Candidates[0].URL != "wss://worker-1.worker.example.test/rap/v1/data-plane" {
t.Fatalf("unexpected direct candidate URL: %s", result.DataPlane.Candidates[0].URL)
}
if result.DataPlane.Candidates[0].Metadata["runtime_transport"] != "json_v1" {
t.Fatalf("direct candidate is missing json_v1 runtime metadata: %#v", result.DataPlane.Candidates[0].Metadata)
}
if result.DataPlane.Candidates[0].Metadata["traffic_ready"] != true {
t.Fatalf("direct candidate is missing traffic_ready metadata: %#v", result.DataPlane.Candidates[0].Metadata)
}
if result.DataPlane.Candidates[0].Metadata["smoke_only"] != true {
t.Fatalf("direct candidate should be marked smoke-only by default: %#v", result.DataPlane.Candidates[0].Metadata)
}
if result.DataPlane.Candidates[0].Metadata["production_trusted"] != false {
t.Fatalf("smoke direct candidate must not be production-trusted: %#v", result.DataPlane.Candidates[0].Metadata)
}
if !strings.Contains(result.DataPlane.Candidates[1].URL, "/api/v1/gateway/ws") {
t.Fatalf("unexpected backend candidate URL: %s", result.DataPlane.Candidates[1].URL)
}
}
func TestDataPlaneDirectCandidateMetadataRequiresRuntimeFlag(t *testing.T) {
service := &Service{
cfg: module.Config{
DataPlane: config.DataPlaneConfig{
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
DirectWorkerTLSTrustMode: "smoke_insecure",
},
},
}
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
if len(candidates) != 2 {
t.Fatalf("expected direct and fallback candidates, got %d", len(candidates))
}
if candidates[0].Metadata != nil {
t.Fatalf("direct candidate must not advertise json_v1 before runtime flag is enabled: %#v", candidates[0].Metadata)
}
}
func TestDataPlaneDirectCandidateAdvertisesBinaryRenderOnlyWhenEnabled(t *testing.T) {
service := &Service{
cfg: module.Config{
DataPlane: config.DataPlaneConfig{
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
DirectWorkerJSONRuntime: true,
DirectWorkerBinaryRender: true,
DirectWorkerTLSTrustMode: "platform_ca",
DirectWorkerTLSCARef: "rap-platform-ca:v1",
},
},
}
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
if len(candidates) != 2 {
t.Fatalf("expected direct and fallback candidates, got %d", len(candidates))
}
if candidates[0].Metadata["render_transport"] != "binary_v1" {
t.Fatalf("direct candidate is missing binary render metadata: %#v", candidates[0].Metadata)
}
if candidates[0].Metadata["binary_render"] != true {
t.Fatalf("direct candidate is missing binary_render metadata: %#v", candidates[0].Metadata)
}
if candidates[0].Metadata["default_color_mode"] != "full_color" {
t.Fatalf("direct candidate is missing default_color_mode metadata: %#v", candidates[0].Metadata)
}
if candidates[0].Metadata["production_trusted"] != true || candidates[0].Metadata["tls_trust_mode"] != "platform_ca" {
t.Fatalf("direct candidate is missing production trust metadata: %#v", candidates[0].Metadata)
}
if candidates[0].Metadata["tls_ca_ref"] != "rap-platform-ca:v1" {
t.Fatalf("direct candidate is missing tls_ca_ref metadata: %#v", candidates[0].Metadata)
}
modes, ok := candidates[0].Metadata["supported_color_modes"].([]string)
if !ok || !slices.Contains(modes, "full_color") || !slices.Contains(modes, "grayscale") {
t.Fatalf("direct candidate is missing supported_color_modes metadata: %#v", candidates[0].Metadata)
}
}
func TestDataPlaneDirectCandidateOmittedInProductionWhenSmokeOnly(t *testing.T) {
service := &Service{
cfg: module.Config{
App: config.AppConfig{Env: "production"},
DataPlane: config.DataPlaneConfig{
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
DirectWorkerJSONRuntime: true,
DirectWorkerTLSTrustMode: "smoke_insecure",
},
},
}
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
if len(candidates) != 1 {
t.Fatalf("expected fallback-only candidates in production with smoke TLS, got %d", len(candidates))
}
if candidates[0].Type != sessioncontracts.DataPlaneCandidateBackendGateway {
t.Fatalf("production must not advertise smoke-only direct candidate: %#v", candidates)
}
}
func TestDataPlaneDirectCandidateAdvertisedInProductionWhenTrusted(t *testing.T) {
service := &Service{
cfg: module.Config{
App: config.AppConfig{Env: "production"},
DataPlane: config.DataPlaneConfig{
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
DirectWorkerJSONRuntime: true,
DirectWorkerTLSTrustMode: "public_ca",
},
},
}
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
if len(candidates) != 2 {
t.Fatalf("expected trusted direct and fallback candidates, got %d", len(candidates))
}
if candidates[0].Metadata["production_trusted"] != true || candidates[0].Metadata["tls_trust_mode"] != "public_ca" {
t.Fatalf("trusted production direct candidate metadata mismatch: %#v", candidates[0].Metadata)
}
}
func TestDataPlaneCandidatesFallbackOnlyWhenDirectTemplateMissing(t *testing.T) {
service := &Service{
cfg: module.Config{
DataPlane: config.DataPlaneConfig{
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
},
},
}
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
if len(candidates) != 1 {
t.Fatalf("expected fallback-only candidate list, got %d", len(candidates))
}
if candidates[0].Type != sessioncontracts.DataPlaneCandidateBackendGateway {
t.Fatalf("unexpected candidate type: %s", candidates[0].Type)
}
}
func TestDataPlaneAllowedChannelsRespectRuntimePolicy(t *testing.T) {
cases := []struct {
name string
policy map[string]any
expected []string
blocked []string
}{
{
name: "disabled policies expose only control input render telemetry",
policy: map[string]any{"clipboard_mode": "disabled", "file_transfer_mode": "disabled"},
expected: []string{sessioncontracts.DataPlaneChannelControl, sessioncontracts.DataPlaneChannelInput, sessioncontracts.DataPlaneChannelRender, sessioncontracts.DataPlaneChannelTelemetry},
blocked: []string{sessioncontracts.DataPlaneChannelClipboard, sessioncontracts.DataPlaneChannelFileUpload},
},
{
name: "clipboard policy adds clipboard channel",
policy: map[string]any{"clipboard_mode": "server_to_client", "file_transfer_mode": "disabled"},
expected: []string{sessioncontracts.DataPlaneChannelClipboard},
blocked: []string{sessioncontracts.DataPlaneChannelFileUpload},
},
{
name: "client upload policy adds file upload channel",
policy: map[string]any{"clipboard_mode": "disabled", "file_transfer_mode": "client_to_server"},
expected: []string{sessioncontracts.DataPlaneChannelFileUpload},
blocked: []string{sessioncontracts.DataPlaneChannelClipboard},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
session := RemoteSession{Metadata: mustJSON(t, map[string]any{"policy": tc.policy})}
channels := dataPlaneAllowedChannelsFromSession(session)
for _, channel := range tc.expected {
if !slices.Contains(channels, channel) {
t.Fatalf("expected channel %q in %v", channel, channels)
}
}
for _, channel := range tc.blocked {
if slices.Contains(channels, channel) {
t.Fatalf("did not expect channel %q in %v", channel, channels)
}
}
})
}
}
func mustJSON(t *testing.T, value any) []byte {
t.Helper()
payload, err := json.Marshal(value)
if err != nil {
t.Fatalf("marshal test metadata: %v", err)
}
return payload
}
func testRS256Key(t *testing.T) (string, *rsa.PublicKey) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("generate RSA key: %v", err)
}
encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})
return string(encoded), &privateKey.PublicKey
}
func assertEqual(t *testing.T, got, want, name string) {
t.Helper()
if got != want {
t.Fatalf("unexpected %s: got %q want %q", name, got, want)
}
}
@@ -0,0 +1,15 @@
package sessionbroker
import "errors"
var (
ErrSessionNotFound = errors.New("remote session not found")
ErrAttachmentNotFound = errors.New("session attachment not found")
ErrActiveControllerPresent = errors.New("active controller already present")
ErrTakeoverNotAllowed = errors.New("takeover not allowed")
ErrTrustedDeviceRequired = errors.New("trusted device required")
ErrAccessDenied = errors.New("access denied")
ErrSessionNotAttachable = errors.New("session is not attachable")
ErrSessionNotTerminable = errors.New("session is not terminable")
ErrAttachTokenInvalid = errors.New("attach token invalid or expired")
)
@@ -0,0 +1,65 @@
package sessionbroker
import (
"context"
"time"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
)
type LiveStateStore interface {
UpsertSession(ctx context.Context, state LiveSessionState) error
GetSession(ctx context.Context, sessionID string) (*LiveSessionState, error)
DeleteSession(ctx context.Context, sessionID string) error
BindController(ctx context.Context, binding sessioncontracts.ControllerBinding, ttl time.Duration) error
GetControllerBinding(ctx context.Context, sessionID string) (*sessioncontracts.ControllerBinding, error)
ClearControllerBinding(ctx context.Context, sessionID string) error
StoreAttachToken(ctx context.Context, claims sessioncontracts.AttachTokenClaims, ttl time.Duration) error
ConsumeAttachToken(ctx context.Context, token string) (*sessioncontracts.AttachTokenClaims, error)
TouchAttachmentHeartbeat(ctx context.Context, sessionID, attachmentID string, ttl time.Duration) error
UpdateWorkerRoute(ctx context.Context, route WorkerRoute, ttl time.Duration) error
GetWorkerRoute(ctx context.Context, sessionID string) (*WorkerRoute, error)
DeleteWorkerRoute(ctx context.Context, sessionID string) error
}
type LiveSessionState struct {
SessionID string `json:"session_id"`
ResourceID string `json:"resource_id"`
WorkerID string `json:"worker_id"`
State sessioncontracts.State `json:"state"`
ControllerID string `json:"controller_id"`
AttachmentID string `json:"attachment_id"`
TakeoverVersion int `json:"takeover_version"`
RenderQualityProfile string `json:"render_quality_profile,omitempty"`
RenderState string `json:"render_state,omitempty"`
RenderWidth int `json:"render_width,omitempty"`
RenderHeight int `json:"render_height,omitempty"`
RenderFrameSequence int64 `json:"render_frame_sequence,omitempty"`
RenderFrameFormat string `json:"render_frame_format,omitempty"`
RenderFrameData string `json:"render_frame_data,omitempty"`
LastInputCorrelationID string `json:"last_input_correlation_id,omitempty"`
WorkerFrameCapturedAt string `json:"worker_frame_captured_at,omitempty"`
CursorX int `json:"cursor_x,omitempty"`
CursorY int `json:"cursor_y,omitempty"`
CursorVisible bool `json:"cursor_visible,omitempty"`
DirtyRectangles int `json:"dirty_rectangles,omitempty"`
LastRenderAt *time.Time `json:"last_render_at,omitempty"`
ClipboardSequence int64 `json:"clipboard_sequence,omitempty"`
ClipboardText string `json:"clipboard_text,omitempty"`
ClipboardOrigin string `json:"clipboard_origin,omitempty"`
ClipboardContentHash string `json:"clipboard_content_hash,omitempty"`
ClipboardUpdatedAt *time.Time `json:"clipboard_updated_at,omitempty"`
FileDownloadSequence int64 `json:"file_download_sequence,omitempty"`
FileDownloadType string `json:"file_download_type,omitempty"`
FileDownloadPayload map[string]any `json:"file_download_payload,omitempty"`
FileDownloadUpdatedAt *time.Time `json:"file_download_updated_at,omitempty"`
UpdatedAt time.Time `json:"updated_at"`
}
type WorkerRoute struct {
SessionID string `json:"session_id"`
WorkerID string `json:"worker_id"`
LeaseID string `json:"lease_id"`
ControlStream string `json:"control_stream"`
UpdatedAt time.Time `json:"updated_at"`
}
@@ -0,0 +1,132 @@
package sessionbroker
import (
"time"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
)
type AttachmentRole string
const (
AttachmentRoleController AttachmentRole = "controller"
)
type AttachmentState string
const (
AttachmentStateAttaching AttachmentState = "attaching"
AttachmentStateActive AttachmentState = "active"
AttachmentStateDetached AttachmentState = "detached"
AttachmentStateSuperseded AttachmentState = "superseded"
AttachmentStateRevoked AttachmentState = "revoked"
AttachmentStateClosed AttachmentState = "closed"
)
type ResourceTakeoverPolicy string
const (
ResourceTakeoverPolicyTrustedDevice ResourceTakeoverPolicy = "trusted_device"
ResourceTakeoverPolicySameUser ResourceTakeoverPolicy = "same_user"
ResourceTakeoverPolicyAdminOnly ResourceTakeoverPolicy = "admin_only"
)
type ResourceClipboardMode string
const (
ResourceClipboardModeDisabled ResourceClipboardMode = "disabled"
ResourceClipboardModeClientToServer ResourceClipboardMode = "client_to_server"
ResourceClipboardModeServerToClient ResourceClipboardMode = "server_to_client"
ResourceClipboardModeBidirectional ResourceClipboardMode = "bidirectional"
)
type ResourceFileTransferMode string
const (
ResourceFileTransferModeDisabled ResourceFileTransferMode = "disabled"
ResourceFileTransferModeClientToServer ResourceFileTransferMode = "client_to_server"
ResourceFileTransferModeServerToClient ResourceFileTransferMode = "server_to_client"
ResourceFileTransferModeBidirectional ResourceFileTransferMode = "bidirectional"
)
type RemoteSession struct {
ID string
OrganizationID string
ResourceID string
Protocol string
State sessioncontracts.State
WorkerID string
ControllerUserID string
DetachDeadlineAt *time.Time
LastHeartbeatAt *time.Time
TakeoverVersion int
RenderQualityProfile string
Metadata []byte
CreatedAt time.Time
UpdatedAt time.Time
}
type SessionAttachment struct {
ID string
RemoteSessionID string
UserID string
DeviceID string
Role AttachmentRole
State AttachmentState
SupersededBy *string
TakeoverOf *string
AttachedAt *time.Time
DetachedAt *time.Time
LastInputAt *time.Time
Metadata []byte
CreatedAt time.Time
UpdatedAt time.Time
}
type ResourcePolicy struct {
ResourceID string
MaxConcurrentSessions int
TakeoverPolicy ResourceTakeoverPolicy
RequireTrustedDevice bool
DetachGracePeriod time.Duration
ClipboardEnabled bool
ClipboardMode ResourceClipboardMode
FileTransferEnabled bool
FileTransferMode ResourceFileTransferMode
CreatedAt time.Time
UpdatedAt time.Time
}
type ResourceRuntimeSpec struct {
ID string
OrganizationID string
Name string
Address string
Protocol string
SecretRef *string
CertificateVerificationMode string
Metadata []byte
}
type AuditEvent struct {
ID string
ActorUserID *string
ActorDeviceID *string
EventType string
TargetType string
TargetID string
RemoteSessionID *string
Payload []byte
CreatedAt time.Time
}
const (
AuditEventSessionStarted = "session_started"
AuditEventSessionAttached = "session_attached"
AuditEventSessionDetached = "session_detached"
AuditEventSessionTakenOver = "session_taken_over"
AuditEventSessionTerminated = "session_terminated"
AuditEventSessionFailed = "session_failed"
AuditEventSecretAccessed = "resource_secret_accessed"
AuditEventSecretAccessDenied = "resource_secret_access_denied"
)
@@ -0,0 +1,164 @@
package sessionbroker
import (
"encoding/json"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
)
type Module struct {
service *Service
}
func NewModule(service *Service) *Module {
return &Module{service: service}
}
func (m *Module) Name() string {
return "session-broker"
}
func (m *Module) Service() *Service {
return m.service
}
func (m *Module) RegisterRoutes(router chi.Router) {
router.Route("/sessions", func(r chi.Router) {
r.Get("/", m.listSessions)
r.Post("/", m.startSession)
r.Post("/{sessionID}/attach", m.attachSession)
r.Post("/{sessionID}/detach", m.detachSession)
r.Post("/{sessionID}/takeover", m.takeoverSession)
r.Post("/{sessionID}/terminate", m.terminateSession)
r.Post("/{sessionID}/fail", m.markFailed)
})
}
func (m *Module) listSessions(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id")
if userID == "" {
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
return
}
sessions, err := m.service.ListSessions(r.Context(), userID)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"sessions": sessions})
}
func (m *Module) startSession(w http.ResponseWriter, r *http.Request) {
var cmd StartRemoteSessionCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid start session payload")
return
}
result, err := m.service.StartRemoteSession(r.Context(), cmd)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusCreated, result)
}
func (m *Module) attachSession(w http.ResponseWriter, r *http.Request) {
var cmd AttachToSessionCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid attach session payload")
return
}
cmd.SessionID = chi.URLParam(r, "sessionID")
result, err := m.service.AttachToSession(r.Context(), cmd)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusOK, result)
}
func (m *Module) detachSession(w http.ResponseWriter, r *http.Request) {
var cmd DetachFromSessionCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid detach session payload")
return
}
cmd.SessionID = chi.URLParam(r, "sessionID")
result, err := m.service.DetachFromSession(r.Context(), cmd)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusAccepted, result)
}
func (m *Module) takeoverSession(w http.ResponseWriter, r *http.Request) {
var cmd TakeoverSessionCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid takeover session payload")
return
}
cmd.SessionID = chi.URLParam(r, "sessionID")
result, err := m.service.TakeoverSession(r.Context(), cmd)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusOK, result)
}
func (m *Module) terminateSession(w http.ResponseWriter, r *http.Request) {
var cmd TerminateSessionCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid terminate session payload")
return
}
cmd.SessionID = chi.URLParam(r, "sessionID")
if err := m.service.TerminateSession(r.Context(), cmd); err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{
"status": "terminated",
"message": httpx.NewMessage(
"session.terminated",
"status.session.terminated",
"Session terminated.",
nil,
"",
),
})
}
func (m *Module) markFailed(w http.ResponseWriter, r *http.Request) {
var cmd MarkSessionFailedCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid fail session payload")
return
}
cmd.SessionID = chi.URLParam(r, "sessionID")
if err := m.service.MarkSessionFailed(r.Context(), cmd); err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{
"status": "failed",
"message": httpx.NewMessage(
"session.failed",
"status.session.failed",
"Session marked as failed.",
nil,
"",
),
})
}
@@ -0,0 +1,17 @@
package sessionbroker
import (
"context"
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
)
type WorkerOrchestrator interface {
Reserve(ctx context.Context, request workercontracts.AttachRequest) (*workercontracts.WorkerLease, error)
GetSessionLease(ctx context.Context, sessionID string) (*workercontracts.WorkerLease, error)
ReleaseSessionLease(ctx context.Context, sessionID string) error
PrepareAttachment(ctx context.Context, session RemoteSession, attachment SessionAttachment, runtimeMetadata map[string]any) error
NotifyDetachment(ctx context.Context, session RemoteSession, attachment SessionAttachment) error
TerminateRemoteSession(ctx context.Context, sessionID, attachmentID string) error
ValidateSessionRuntime(ctx context.Context, sessionID, workerID string) (bool, string, error)
}
@@ -0,0 +1,607 @@
package sessionbroker
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/example/remote-access-platform/backend/internal/platform/authority"
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
)
type postgresStore struct {
db postgresplatform.DBTX
authority *authority.Verifier
}
type PostgresTransactor struct {
pool *pgxpool.Pool
authority *authority.Verifier
}
func NewPostgresStore(pool *pgxpool.Pool, verifiers ...*authority.Verifier) Store {
var authorityVerifier *authority.Verifier
if len(verifiers) > 0 {
authorityVerifier = verifiers[0]
}
return &postgresStore{db: pool, authority: authorityVerifier}
}
func NewPostgresTransactor(pool *pgxpool.Pool, verifiers ...*authority.Verifier) *PostgresTransactor {
var authorityVerifier *authority.Verifier
if len(verifiers) > 0 {
authorityVerifier = verifiers[0]
}
return &PostgresTransactor{pool: pool, authority: authorityVerifier}
}
func (t *PostgresTransactor) WithinTransaction(ctx context.Context, fn func(store Store) error) error {
return postgresplatform.WithTransaction(ctx, t.pool, func(tx pgx.Tx) error {
return fn(&postgresStore{db: tx, authority: t.authority})
})
}
func (s *postgresStore) RemoteSessions() RemoteSessionRepository {
return &postgresRemoteSessionRepository{db: s.db}
}
func (s *postgresStore) SessionAttachments() SessionAttachmentRepository {
return &postgresSessionAttachmentRepository{db: s.db}
}
func (s *postgresStore) ResourcePolicies() ResourcePolicyRepository {
return &postgresResourcePolicyRepository{db: s.db}
}
func (s *postgresStore) ResourceRuntime() ResourceRuntimeRepository {
return &postgresResourceRuntimeRepository{db: s.db}
}
func (s *postgresStore) AuditEvents() AuditEventRepository {
return &postgresAuditEventRepository{db: s.db}
}
func (s *postgresStore) Access() AccessRepository {
return &postgresAccessRepository{db: s.db, authority: s.authority}
}
type postgresRemoteSessionRepository struct {
db postgresplatform.DBTX
}
type postgresSessionAttachmentRepository struct {
db postgresplatform.DBTX
}
type postgresResourcePolicyRepository struct {
db postgresplatform.DBTX
}
type postgresResourceRuntimeRepository struct {
db postgresplatform.DBTX
}
type postgresAuditEventRepository struct {
db postgresplatform.DBTX
}
type postgresAccessRepository struct {
db postgresplatform.DBTX
authority *authority.Verifier
}
func (r *postgresRemoteSessionRepository) Create(ctx context.Context, session RemoteSession) error {
const query = `
INSERT INTO remote_sessions (
id, organization_id, resource_id, protocol, state, worker_id, controller_user_id, detach_deadline_at,
last_heartbeat_at, takeover_version, metadata, created_at, updated_at
) VALUES (
$1::uuid, $2::uuid, $3::uuid, $4, $5, NULLIF($6, ''), $7::uuid, $8, $9, $10, $11::jsonb, $12, $13
)
`
if _, err := r.db.Exec(ctx, query,
session.ID,
session.OrganizationID,
session.ResourceID,
session.Protocol,
session.State,
session.WorkerID,
session.ControllerUserID,
session.DetachDeadlineAt,
session.LastHeartbeatAt,
session.TakeoverVersion,
jsonPayload(session.Metadata),
session.CreatedAt,
session.UpdatedAt,
); err != nil {
return fmt.Errorf("create remote session: %w", err)
}
return nil
}
func (r *postgresRemoteSessionRepository) GetByID(ctx context.Context, sessionID string) (*RemoteSession, error) {
return r.getByID(ctx, sessionID, "")
}
func (r *postgresRemoteSessionRepository) GetByIDForUpdate(ctx context.Context, sessionID string) (*RemoteSession, error) {
return r.getByID(ctx, sessionID, " FOR UPDATE")
}
func (r *postgresRemoteSessionRepository) getByID(ctx context.Context, sessionID string, suffix string) (*RemoteSession, error) {
query := `
SELECT id::text, organization_id::text, resource_id::text, protocol, state, COALESCE(worker_id, ''), controller_user_id::text,
detach_deadline_at, last_heartbeat_at, takeover_version, metadata, created_at, updated_at
FROM remote_sessions
WHERE id = $1::uuid` + suffix
remoteSession, err := scanRemoteSession(r.db.QueryRow(ctx, query, sessionID))
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return remoteSession, err
}
func (r *postgresRemoteSessionRepository) ListByController(ctx context.Context, userID string) ([]RemoteSession, error) {
const query = `
SELECT id::text, organization_id::text, resource_id::text, protocol, state, COALESCE(worker_id, ''), controller_user_id::text,
detach_deadline_at, last_heartbeat_at, takeover_version, metadata, created_at, updated_at
FROM remote_sessions
WHERE controller_user_id = $1::uuid
ORDER BY updated_at DESC
`
rows, err := r.db.Query(ctx, query, userID)
if err != nil {
return nil, fmt.Errorf("list remote sessions: %w", err)
}
defer rows.Close()
var sessions []RemoteSession
for rows.Next() {
item, err := scanRemoteSession(rows)
if err != nil {
return nil, err
}
sessions = append(sessions, *item)
}
return sessions, rows.Err()
}
func (r *postgresRemoteSessionRepository) CountLiveByResource(ctx context.Context, resourceID string) (int, error) {
const query = `
SELECT COUNT(*)
FROM remote_sessions
WHERE resource_id = $1::uuid AND state IN ('starting', 'active', 'detached', 'reconnecting')
`
var count int
if err := r.db.QueryRow(ctx, query, resourceID).Scan(&count); err != nil {
return 0, fmt.Errorf("count live remote sessions: %w", err)
}
return count, nil
}
func (r *postgresRemoteSessionRepository) ListDetachedExpired(ctx context.Context, before time.Time, limit int) ([]RemoteSession, error) {
const query = `
SELECT id::text, organization_id::text, resource_id::text, protocol, state, COALESCE(worker_id, ''), controller_user_id::text,
detach_deadline_at, last_heartbeat_at, takeover_version, metadata, created_at, updated_at
FROM remote_sessions
WHERE state = 'detached' AND detach_deadline_at IS NOT NULL AND detach_deadline_at <= $1
ORDER BY detach_deadline_at ASC
LIMIT $2
`
rows, err := r.db.Query(ctx, query, before, limit)
if err != nil {
return nil, fmt.Errorf("list detached expired sessions: %w", err)
}
defer rows.Close()
var sessions []RemoteSession
for rows.Next() {
item, err := scanRemoteSession(rows)
if err != nil {
return nil, err
}
sessions = append(sessions, *item)
}
return sessions, rows.Err()
}
func (r *postgresRemoteSessionRepository) UpdateState(ctx context.Context, params UpdateRemoteSessionStateParams) error {
const query = `
UPDATE remote_sessions
SET state = $2,
worker_id = NULLIF($3, ''),
detach_deadline_at = $4,
last_heartbeat_at = $5,
takeover_version = $6,
updated_at = $7
WHERE id = $1::uuid
`
if _, err := r.db.Exec(ctx, query,
params.RemoteSessionID,
params.State,
params.WorkerID,
params.DetachDeadlineAt,
params.LastHeartbeatAt,
params.TakeoverVersion,
params.UpdatedAt,
); err != nil {
return fmt.Errorf("update remote session state: %w", err)
}
return nil
}
func (r *postgresSessionAttachmentRepository) Create(ctx context.Context, attachment SessionAttachment) error {
const query = `
INSERT INTO session_attachments (
id, remote_session_id, user_id, device_id, role, state, superseded_by,
takeover_of, attached_at, detached_at, last_input_at, metadata, created_at, updated_at
) VALUES (
$1::uuid, $2::uuid, $3::uuid, $4::uuid, $5, $6, NULLIF($7, '')::uuid,
NULLIF($8, '')::uuid, $9, $10, $11, $12::jsonb, $13, $14
)
`
if _, err := r.db.Exec(ctx, query,
attachment.ID,
attachment.RemoteSessionID,
attachment.UserID,
attachment.DeviceID,
attachment.Role,
attachment.State,
stringValue(attachment.SupersededBy),
stringValue(attachment.TakeoverOf),
attachment.AttachedAt,
attachment.DetachedAt,
attachment.LastInputAt,
jsonPayload(attachment.Metadata),
attachment.CreatedAt,
attachment.UpdatedAt,
); err != nil {
return fmt.Errorf("create session attachment: %w", err)
}
return nil
}
func (r *postgresSessionAttachmentRepository) GetByID(ctx context.Context, attachmentID string) (*SessionAttachment, error) {
return r.getByID(ctx, attachmentID, "")
}
func (r *postgresSessionAttachmentRepository) GetByIDForUpdate(ctx context.Context, attachmentID string) (*SessionAttachment, error) {
return r.getByID(ctx, attachmentID, " FOR UPDATE")
}
func (r *postgresSessionAttachmentRepository) getByID(ctx context.Context, attachmentID string, suffix string) (*SessionAttachment, error) {
query := `
SELECT id::text, remote_session_id::text, user_id::text, device_id::text, role, state,
superseded_by::text, takeover_of::text, attached_at, detached_at, last_input_at, metadata, created_at, updated_at
FROM session_attachments
WHERE id = $1::uuid` + suffix
attachment, err := scanSessionAttachment(r.db.QueryRow(ctx, query, attachmentID))
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return attachment, err
}
func (r *postgresSessionAttachmentRepository) ListByRemoteSession(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error) {
return r.listByRemoteSession(ctx, remoteSessionID, "")
}
func (r *postgresSessionAttachmentRepository) ListActiveByRemoteSessionForUpdate(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error) {
return r.listByRemoteSession(ctx, remoteSessionID, " AND state IN ('attaching', 'active', 'reconnecting') FOR UPDATE")
}
func (r *postgresSessionAttachmentRepository) listByRemoteSession(ctx context.Context, remoteSessionID string, suffix string) ([]SessionAttachment, error) {
query := `
SELECT id::text, remote_session_id::text, user_id::text, device_id::text, role, state,
superseded_by::text, takeover_of::text, attached_at, detached_at, last_input_at, metadata, created_at, updated_at
FROM session_attachments
WHERE remote_session_id = $1::uuid` + suffix
rows, err := r.db.Query(ctx, query, remoteSessionID)
if err != nil {
return nil, fmt.Errorf("list session attachments: %w", err)
}
defer rows.Close()
var attachments []SessionAttachment
for rows.Next() {
item, err := scanSessionAttachment(rows)
if err != nil {
return nil, err
}
attachments = append(attachments, *item)
}
return attachments, rows.Err()
}
func (r *postgresSessionAttachmentRepository) UpdateState(ctx context.Context, params UpdateSessionAttachmentStateParams) error {
const query = `
UPDATE session_attachments
SET state = $2,
detached_at = $3,
last_input_at = $4,
updated_at = $5
WHERE id = $1::uuid
`
if _, err := r.db.Exec(ctx, query,
params.AttachmentID,
params.State,
params.DetachedAt,
params.LastInputAt,
params.UpdatedAt,
); err != nil {
return fmt.Errorf("update session attachment state: %w", err)
}
return nil
}
func (r *postgresSessionAttachmentRepository) Supersede(ctx context.Context, params SupersedeAttachmentParams) error {
const query = `
UPDATE session_attachments
SET state = 'superseded',
superseded_by = $2::uuid,
detached_at = $3,
updated_at = $4
WHERE id = $1::uuid
`
if _, err := r.db.Exec(ctx, query,
params.PreviousAttachmentID,
params.NextAttachmentID,
params.DetachedAt,
params.UpdatedAt,
); err != nil {
return fmt.Errorf("supersede attachment: %w", err)
}
return nil
}
func (r *postgresResourcePolicyRepository) GetByResourceID(ctx context.Context, resourceID string) (*ResourcePolicy, error) {
const query = `
SELECT resource_id::text, max_concurrent_sessions, takeover_policy, require_trusted_device,
detach_grace_period_seconds, clipboard_enabled, clipboard_mode, file_transfer_enabled,
COALESCE(file_transfer_mode, 'disabled'), created_at, updated_at
FROM resource_policies
WHERE resource_id = $1::uuid
`
policy, err := scanResourcePolicy(r.db.QueryRow(ctx, query, resourceID))
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return policy, err
}
func (r *postgresResourcePolicyRepository) Upsert(ctx context.Context, policy ResourcePolicy) error {
const query = `
INSERT INTO resource_policies (
resource_id, max_concurrent_sessions, takeover_policy, require_trusted_device,
detach_grace_period_seconds, clipboard_enabled, clipboard_mode, file_transfer_enabled, file_transfer_mode, created_at, updated_at
) VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
ON CONFLICT (resource_id) DO UPDATE SET
max_concurrent_sessions = EXCLUDED.max_concurrent_sessions,
takeover_policy = EXCLUDED.takeover_policy,
require_trusted_device = EXCLUDED.require_trusted_device,
detach_grace_period_seconds = EXCLUDED.detach_grace_period_seconds,
clipboard_enabled = EXCLUDED.clipboard_enabled,
clipboard_mode = EXCLUDED.clipboard_mode,
file_transfer_enabled = EXCLUDED.file_transfer_enabled,
file_transfer_mode = EXCLUDED.file_transfer_mode,
updated_at = EXCLUDED.updated_at
`
clipboardMode := normalizeClipboardMode(policy.ClipboardMode)
fileTransferMode := normalizeFileTransferMode(policy.FileTransferMode)
if _, err := r.db.Exec(ctx, query,
policy.ResourceID,
policy.MaxConcurrentSessions,
policy.TakeoverPolicy,
policy.RequireTrustedDevice,
int(policy.DetachGracePeriod.Seconds()),
clipboardMode != ResourceClipboardModeDisabled,
clipboardMode,
fileTransferAllowsClientToServer(fileTransferMode),
fileTransferMode,
policy.CreatedAt,
policy.UpdatedAt,
); err != nil {
return fmt.Errorf("upsert resource policy: %w", err)
}
return nil
}
func (r *postgresResourceRuntimeRepository) GetByID(ctx context.Context, resourceID string) (*ResourceRuntimeSpec, error) {
const query = `
SELECT id::text, organization_id::text, name, address, protocol, secret_ref, certificate_verification_mode, metadata
FROM resources
WHERE id = $1::uuid
`
item := &ResourceRuntimeSpec{}
var secretRef *string
var metadata []byte
if err := r.db.QueryRow(ctx, query, resourceID).Scan(
&item.ID,
&item.OrganizationID,
&item.Name,
&item.Address,
&item.Protocol,
&secretRef,
&item.CertificateVerificationMode,
&metadata,
); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("get resource runtime spec: %w", err)
}
item.SecretRef = secretRef
item.Metadata = metadata
return item, nil
}
func (r *postgresAuditEventRepository) Create(ctx context.Context, event AuditEvent) error {
const query = `
INSERT INTO audit_events (
id, actor_user_id, actor_device_id, event_type, target_type, target_id,
remote_session_id, payload, created_at
) VALUES (
$1::uuid, NULLIF($2, '')::uuid, NULLIF($3, '')::uuid, $4, $5, $6,
NULLIF($7, '')::uuid, $8::jsonb, $9
)
`
if _, err := r.db.Exec(ctx, query,
event.ID,
stringValue(event.ActorUserID),
stringValue(event.ActorDeviceID),
event.EventType,
event.TargetType,
event.TargetID,
stringValue(event.RemoteSessionID),
jsonPayload(event.Payload),
event.CreatedAt,
); err != nil {
return fmt.Errorf("create audit event: %w", err)
}
return nil
}
func (r *postgresAccessRepository) IsTrustedDevice(ctx context.Context, userID, deviceID string) (bool, error) {
const query = `
SELECT EXISTS(
SELECT 1 FROM devices
WHERE id = $1::uuid AND user_id = $2::uuid AND trust_status = 'trusted' AND revoked_at IS NULL
)
`
var trusted bool
if err := r.db.QueryRow(ctx, query, deviceID, userID).Scan(&trusted); err != nil {
return false, fmt.Errorf("check trusted device: %w", err)
}
return trusted, nil
}
func (r *postgresAccessRepository) GetPlatformRole(ctx context.Context, userID string) (string, error) {
return authority.EffectivePlatformRole(ctx, r.db, r.authority, userID)
}
func (r *postgresAccessRepository) GetOrganizationRole(ctx context.Context, organizationID, userID string) (string, bool, error) {
const query = `
SELECT role_id
FROM organization_memberships
WHERE organization_id = $1::uuid AND user_id = $2::uuid AND status = 'active'
`
var role string
if err := r.db.QueryRow(ctx, query, organizationID, userID).Scan(&role); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return "", false, nil
}
return "", false, fmt.Errorf("get organization role: %w", err)
}
return role, true, nil
}
type scanner interface {
Scan(dest ...any) error
}
func scanRemoteSession(row scanner) (*RemoteSession, error) {
item := &RemoteSession{}
var detachDeadlineAt, lastHeartbeatAt *time.Time
var metadata []byte
if err := row.Scan(
&item.ID,
&item.OrganizationID,
&item.ResourceID,
&item.Protocol,
&item.State,
&item.WorkerID,
&item.ControllerUserID,
&detachDeadlineAt,
&lastHeartbeatAt,
&item.TakeoverVersion,
&metadata,
&item.CreatedAt,
&item.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan remote session: %w", err)
}
item.DetachDeadlineAt = detachDeadlineAt
item.LastHeartbeatAt = lastHeartbeatAt
item.Metadata = metadata
item.RenderQualityProfile = renderQualityProfileFromSessionMetadata(metadata)
return item, nil
}
func scanSessionAttachment(row scanner) (*SessionAttachment, error) {
item := &SessionAttachment{}
var supersededBy, takeoverOf *string
var attachedAt, detachedAt, lastInputAt *time.Time
var metadata []byte
if err := row.Scan(
&item.ID,
&item.RemoteSessionID,
&item.UserID,
&item.DeviceID,
&item.Role,
&item.State,
&supersededBy,
&takeoverOf,
&attachedAt,
&detachedAt,
&lastInputAt,
&metadata,
&item.CreatedAt,
&item.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan session attachment: %w", err)
}
item.SupersededBy = supersededBy
item.TakeoverOf = takeoverOf
item.AttachedAt = attachedAt
item.DetachedAt = detachedAt
item.LastInputAt = lastInputAt
item.Metadata = metadata
return item, nil
}
func scanResourcePolicy(row scanner) (*ResourcePolicy, error) {
item := &ResourcePolicy{}
var detachGraceSeconds int
if err := row.Scan(
&item.ResourceID,
&item.MaxConcurrentSessions,
&item.TakeoverPolicy,
&item.RequireTrustedDevice,
&detachGraceSeconds,
&item.ClipboardEnabled,
&item.ClipboardMode,
&item.FileTransferEnabled,
&item.FileTransferMode,
&item.CreatedAt,
&item.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan resource policy: %w", err)
}
item.DetachGracePeriod = time.Duration(detachGraceSeconds) * time.Second
item.ClipboardMode = normalizeClipboardMode(item.ClipboardMode)
item.ClipboardEnabled = item.ClipboardMode != ResourceClipboardModeDisabled
item.FileTransferMode = normalizeFileTransferMode(item.FileTransferMode)
item.FileTransferEnabled = fileTransferAllowsClientToServer(item.FileTransferMode)
return item, nil
}
func jsonPayload(payload []byte) []byte {
if len(payload) == 0 {
return []byte(`{}`)
}
if json.Valid(payload) {
return payload
}
return []byte(`{}`)
}
func stringValue(value *string) string {
if value == nil {
return ""
}
return *value
}
@@ -0,0 +1,140 @@
package sessionbroker
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
)
type RedisLiveStateStore struct {
client *redis.Client
}
func NewRedisLiveStateStore(client *redis.Client) *RedisLiveStateStore {
return &RedisLiveStateStore{client: client}
}
func (s *RedisLiveStateStore) UpsertSession(ctx context.Context, state LiveSessionState) error {
return s.setJSON(ctx, liveSessionKey(state.SessionID), state, 0)
}
func (s *RedisLiveStateStore) GetSession(ctx context.Context, sessionID string) (*LiveSessionState, error) {
var state LiveSessionState
ok, err := s.getJSON(ctx, liveSessionKey(sessionID), &state)
if err != nil || !ok {
return nil, err
}
return &state, nil
}
func (s *RedisLiveStateStore) DeleteSession(ctx context.Context, sessionID string) error {
return s.client.Del(ctx, liveSessionKey(sessionID)).Err()
}
func (s *RedisLiveStateStore) BindController(ctx context.Context, binding sessioncontracts.ControllerBinding, ttl time.Duration) error {
return s.setJSON(ctx, controllerBindingKey(binding.SessionID), binding, ttl)
}
func (s *RedisLiveStateStore) GetControllerBinding(ctx context.Context, sessionID string) (*sessioncontracts.ControllerBinding, error) {
var binding sessioncontracts.ControllerBinding
ok, err := s.getJSON(ctx, controllerBindingKey(sessionID), &binding)
if err != nil || !ok {
return nil, err
}
return &binding, nil
}
func (s *RedisLiveStateStore) ClearControllerBinding(ctx context.Context, sessionID string) error {
return s.client.Del(ctx, controllerBindingKey(sessionID)).Err()
}
func (s *RedisLiveStateStore) StoreAttachToken(ctx context.Context, claims sessioncontracts.AttachTokenClaims, ttl time.Duration) error {
return s.setJSON(ctx, attachTokenKey(claims.Token), claims, ttl)
}
func (s *RedisLiveStateStore) ConsumeAttachToken(ctx context.Context, token string) (*sessioncontracts.AttachTokenClaims, error) {
key := attachTokenKey(token)
payload, err := s.client.GetDel(ctx, key).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("consume attach token: %w", err)
}
var claims sessioncontracts.AttachTokenClaims
if err := json.Unmarshal([]byte(payload), &claims); err != nil {
return nil, fmt.Errorf("decode attach token: %w", err)
}
return &claims, nil
}
func (s *RedisLiveStateStore) TouchAttachmentHeartbeat(ctx context.Context, sessionID, attachmentID string, ttl time.Duration) error {
return s.client.Set(ctx, attachmentHeartbeatKey(sessionID, attachmentID), time.Now().UTC().Format(time.RFC3339Nano), ttl).Err()
}
func (s *RedisLiveStateStore) UpdateWorkerRoute(ctx context.Context, route WorkerRoute, ttl time.Duration) error {
return s.setJSON(ctx, workerRouteKey(route.SessionID), route, ttl)
}
func (s *RedisLiveStateStore) GetWorkerRoute(ctx context.Context, sessionID string) (*WorkerRoute, error) {
var route WorkerRoute
ok, err := s.getJSON(ctx, workerRouteKey(sessionID), &route)
if err != nil || !ok {
return nil, err
}
return &route, nil
}
func (s *RedisLiveStateStore) DeleteWorkerRoute(ctx context.Context, sessionID string) error {
return s.client.Del(ctx, workerRouteKey(sessionID)).Err()
}
func (s *RedisLiveStateStore) setJSON(ctx context.Context, key string, value any, ttl time.Duration) error {
payload, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("encode redis payload: %w", err)
}
if err := s.client.Set(ctx, key, payload, ttl).Err(); err != nil {
return fmt.Errorf("set redis key %s: %w", key, err)
}
return nil
}
func (s *RedisLiveStateStore) getJSON(ctx context.Context, key string, dest any) (bool, error) {
payload, err := s.client.Get(ctx, key).Result()
if err == redis.Nil {
return false, nil
}
if err != nil {
return false, fmt.Errorf("get redis key %s: %w", key, err)
}
if err := json.Unmarshal([]byte(payload), dest); err != nil {
return false, fmt.Errorf("decode redis key %s: %w", key, err)
}
return true, nil
}
func liveSessionKey(sessionID string) string {
return "live:session:" + sessionID
}
func controllerBindingKey(sessionID string) string {
return "live:session:" + sessionID + ":controller"
}
func attachTokenKey(token string) string {
return "live:attach:" + token
}
func attachmentHeartbeatKey(sessionID, attachmentID string) string {
return "live:session:" + sessionID + ":attachment:" + attachmentID + ":heartbeat"
}
func workerRouteKey(sessionID string) string {
return "live:session:" + sessionID + ":worker-route"
}
@@ -0,0 +1,85 @@
package sessionbroker
import (
"context"
"time"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
)
type RemoteSessionRepository interface {
Create(ctx context.Context, session RemoteSession) error
GetByID(ctx context.Context, sessionID string) (*RemoteSession, error)
GetByIDForUpdate(ctx context.Context, sessionID string) (*RemoteSession, error)
ListByController(ctx context.Context, userID string) ([]RemoteSession, error)
CountLiveByResource(ctx context.Context, resourceID string) (int, error)
ListDetachedExpired(ctx context.Context, before time.Time, limit int) ([]RemoteSession, error)
UpdateState(ctx context.Context, params UpdateRemoteSessionStateParams) error
}
type SessionAttachmentRepository interface {
Create(ctx context.Context, attachment SessionAttachment) error
GetByID(ctx context.Context, attachmentID string) (*SessionAttachment, error)
GetByIDForUpdate(ctx context.Context, attachmentID string) (*SessionAttachment, error)
ListByRemoteSession(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error)
ListActiveByRemoteSessionForUpdate(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error)
UpdateState(ctx context.Context, params UpdateSessionAttachmentStateParams) error
Supersede(ctx context.Context, params SupersedeAttachmentParams) error
}
type ResourcePolicyRepository interface {
GetByResourceID(ctx context.Context, resourceID string) (*ResourcePolicy, error)
Upsert(ctx context.Context, policy ResourcePolicy) error
}
type AuditEventRepository interface {
Create(ctx context.Context, event AuditEvent) error
}
type Store interface {
RemoteSessions() RemoteSessionRepository
SessionAttachments() SessionAttachmentRepository
ResourcePolicies() ResourcePolicyRepository
ResourceRuntime() ResourceRuntimeRepository
AuditEvents() AuditEventRepository
Access() AccessRepository
}
type Transactor interface {
WithinTransaction(ctx context.Context, fn func(store Store) error) error
}
type UpdateRemoteSessionStateParams struct {
RemoteSessionID string
State sessioncontracts.State
WorkerID string
DetachDeadlineAt *time.Time
LastHeartbeatAt *time.Time
TakeoverVersion int
UpdatedAt time.Time
}
type UpdateSessionAttachmentStateParams struct {
AttachmentID string
State AttachmentState
DetachedAt *time.Time
LastInputAt *time.Time
UpdatedAt time.Time
}
type SupersedeAttachmentParams struct {
PreviousAttachmentID string
NextAttachmentID string
DetachedAt time.Time
UpdatedAt time.Time
}
type AccessRepository interface {
IsTrustedDevice(ctx context.Context, userID, deviceID string) (bool, error)
GetPlatformRole(ctx context.Context, userID string) (string, error)
GetOrganizationRole(ctx context.Context, organizationID, userID string) (string, bool, error)
}
type ResourceRuntimeRepository interface {
GetByID(ctx context.Context, resourceID string) (*ResourceRuntimeSpec, error)
}
@@ -0,0 +1,138 @@
package sessionbroker
import (
"context"
"encoding/json"
"errors"
"testing"
"github.com/example/remote-access-platform/backend/internal/platform/config"
"github.com/example/remote-access-platform/backend/internal/platform/module"
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
)
type fakeSecretResolver struct {
response *secrets.ResolvedResourceSecret
err error
request secrets.ResolveResourceSecretRequest
}
func testAppConfig(env string) config.AppConfig {
return config.AppConfig{Name: "rap-api-test", Env: env}
}
func (r *fakeSecretResolver) ResolveForSession(_ context.Context, req secrets.ResolveResourceSecretRequest) (*secrets.ResolvedResourceSecret, error) {
r.request = req
if r.err != nil {
return nil, r.err
}
return r.response, nil
}
func TestRuntimeAssignmentMetadataMergesResolvedSecretWithoutMutatingSessionMetadata(t *testing.T) {
resolver := &fakeSecretResolver{
response: &secrets.ResolvedResourceSecret{
Descriptor: secrets.ResourceSecretDescriptor{Version: 3},
Payload: json.RawMessage(`{"username":"user","password":"secret","domain":"corp"}`),
},
}
service := NewService(module.Dependencies{
Config: module.Config{App: testAppConfig("production")},
}, nil, nil, nil, nil, resolver)
sessionMetadata := mustJSON(t, map[string]any{
"resource": map[string]any{
"id": "resource-1",
"organization_id": "org-1",
"secret_ref": "rap-secret://org/org-1/resources/resource-1/primary",
"metadata": map[string]any{
"rdp_host": "host",
},
},
})
session := RemoteSession{
ID: "session-1",
OrganizationID: "org-1",
ResourceID: "resource-1",
WorkerID: "worker-1",
Metadata: sessionMetadata,
}
metadata, secretRef, version, err := service.runtimeAssignmentMetadata(context.Background(), session, &workercontracts.WorkerLease{LeaseID: "lease-1"})
if err != nil {
t.Fatalf("runtimeAssignmentMetadata returned error: %v", err)
}
if secretRef == "" || version != 3 {
t.Fatalf("expected secret ref and version, got ref=%q version=%d", secretRef, version)
}
resource := metadata["resource"].(map[string]any)
resourceMetadata := resource["metadata"].(map[string]any)
if resourceMetadata["username"] != "user" || resourceMetadata["password"] != "secret" || resourceMetadata["domain"] != "corp" {
t.Fatalf("resolved secret was not merged: %#v", resourceMetadata)
}
var persisted map[string]any
if err := json.Unmarshal(session.Metadata, &persisted); err != nil {
t.Fatalf("decode persisted metadata: %v", err)
}
persistedResource := persisted["resource"].(map[string]any)
persistedMetadata := persistedResource["metadata"].(map[string]any)
if _, ok := persistedMetadata["password"]; ok {
t.Fatalf("session metadata was mutated with plaintext secret")
}
if resolver.request.LeaseID != "lease-1" || resolver.request.WorkerID != "worker-1" {
t.Fatalf("resolver request missed lease/worker proof: %#v", resolver.request)
}
}
func TestRuntimeAssignmentMetadataRequiresResolverInProduction(t *testing.T) {
service := NewService(module.Dependencies{
Config: module.Config{App: testAppConfig("production")},
}, nil, nil, nil, nil)
session := RemoteSession{
ID: "session-1",
OrganizationID: "org-1",
ResourceID: "resource-1",
WorkerID: "worker-1",
Metadata: mustJSON(t, map[string]any{
"resource": map[string]any{
"secret_ref": "rap-secret://org/org-1/resources/resource-1/primary",
},
}),
}
_, _, _, err := service.runtimeAssignmentMetadata(context.Background(), session, &workercontracts.WorkerLease{LeaseID: "lease-1"})
if !errors.Is(err, secrets.ErrSecretEncryptionKeyMissing) {
t.Fatalf("expected missing resolver error, got %v", err)
}
}
func TestRuntimeAssignmentMetadataAllowsDevelopmentMetadataWithoutResolver(t *testing.T) {
service := NewService(module.Dependencies{
Config: module.Config{App: testAppConfig("development")},
}, nil, nil, nil, nil)
session := RemoteSession{
ID: "session-1",
OrganizationID: "org-1",
ResourceID: "resource-1",
WorkerID: "worker-1",
Metadata: mustJSON(t, map[string]any{
"resource": map[string]any{
"secret_ref": "rap-secret://org/org-1/resources/resource-1/primary",
"metadata": map[string]any{
"username": "dev-user",
"password": "dev-password",
},
},
}),
}
metadata, secretRef, _, err := service.runtimeAssignmentMetadata(context.Background(), session, nil)
if err != nil {
t.Fatalf("development metadata should not require resolver: %v", err)
}
if secretRef != "" {
t.Fatalf("development fallback should not audit resolver use, got %q", secretRef)
}
resource := metadata["resource"].(map[string]any)
resourceMetadata := resource["metadata"].(map[string]any)
if resourceMetadata["password"] != "dev-password" {
t.Fatalf("development metadata was not preserved")
}
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,391 @@
package sessionbroker
import (
"context"
"io"
"log/slog"
"testing"
"time"
"github.com/example/remote-access-platform/backend/internal/platform/module"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
)
func TestHandleWorkerConnectedIgnoresTerminalSession(t *testing.T) {
service, store, live, _ := newStaleWorkerEventTestService()
store.remote.sessions["session-1"] = RemoteSession{
ID: "session-1",
State: sessioncontracts.StateTerminated,
WorkerID: "worker-1",
}
if err := service.HandleWorkerConnected(context.Background(), "session-1"); err != nil {
t.Fatalf("HandleWorkerConnected returned error for stale terminal event: %v", err)
}
if got := store.remote.sessions["session-1"].State; got != sessioncontracts.StateTerminated {
t.Fatalf("stale connected event changed terminal state to %q", got)
}
if store.remote.updateCount != 0 {
t.Fatalf("stale connected event updated authoritative session %d times", store.remote.updateCount)
}
if live.upsertCount != 0 {
t.Fatalf("stale connected event recreated live state %d times", live.upsertCount)
}
}
func TestUpdateWorkerRenderTelemetryIgnoresTerminalSession(t *testing.T) {
service, store, live, _ := newStaleWorkerEventTestService()
store.remote.sessions["session-1"] = RemoteSession{
ID: "session-1",
State: sessioncontracts.StateTerminated,
WorkerID: "worker-1",
}
err := service.UpdateWorkerRenderTelemetry(context.Background(), "session-1", map[string]any{
"render_state": "ready",
"width": 1280,
"height": 720,
"frame_sequence": int64(99),
"frame_data": "stale-frame",
})
if err != nil {
t.Fatalf("UpdateWorkerRenderTelemetry returned error for stale terminal event: %v", err)
}
if live.upsertCount != 0 {
t.Fatalf("stale render event recreated live state %d times", live.upsertCount)
}
if live.sessions["session-1"] != nil {
t.Fatalf("stale render event left live state behind: %#v", live.sessions["session-1"])
}
}
func TestMarkSessionFailedTransitionsActiveSession(t *testing.T) {
service, store, live, orchestrator := newStaleWorkerEventTestService()
store.remote.sessions["session-1"] = RemoteSession{
ID: "session-1",
State: sessioncontracts.StateActive,
WorkerID: "worker-1",
TakeoverVersion: 3,
}
store.attachments.items["attachment-1"] = SessionAttachment{
ID: "attachment-1",
RemoteSessionID: "session-1",
State: AttachmentStateActive,
}
live.sessions["session-1"] = &LiveSessionState{SessionID: "session-1", State: sessioncontracts.StateActive}
if err := service.MarkSessionFailed(context.Background(), MarkSessionFailedCommand{SessionID: "session-1", Reason: "worker_lost"}); err != nil {
t.Fatalf("MarkSessionFailed returned error: %v", err)
}
if got := store.remote.sessions["session-1"].State; got != sessioncontracts.StateFailed {
t.Fatalf("expected failed state, got %q", got)
}
if got := store.attachments.items["attachment-1"].State; got != AttachmentStateClosed {
t.Fatalf("expected attachment closed, got %q", got)
}
if store.audit.createCount != 1 {
t.Fatalf("expected one audit event, got %d", store.audit.createCount)
}
if live.sessions["session-1"] != nil {
t.Fatal("expected failed session live state to be deleted")
}
if orchestrator.releaseCount != 1 {
t.Fatalf("expected session lease release, got %d", orchestrator.releaseCount)
}
}
func TestMarkSessionFailedAlreadyFailedIsIdempotent(t *testing.T) {
service, store, _, _ := newStaleWorkerEventTestService()
store.remote.sessions["session-1"] = RemoteSession{
ID: "session-1",
State: sessioncontracts.StateFailed,
WorkerID: "worker-1",
TakeoverVersion: 1,
}
if err := service.MarkSessionFailed(context.Background(), MarkSessionFailedCommand{SessionID: "session-1", Reason: "duplicate_worker_failure"}); err != nil {
t.Fatalf("duplicate MarkSessionFailed returned error: %v", err)
}
if got := store.remote.sessions["session-1"].State; got != sessioncontracts.StateFailed {
t.Fatalf("duplicate terminal event changed state to %q", got)
}
}
func newStaleWorkerEventTestService() (*Service, *staleWorkerEventTestStore, *staleWorkerEventLiveState, *staleWorkerEventOrchestrator) {
store := &staleWorkerEventTestStore{
remote: &staleWorkerEventRemoteSessions{sessions: map[string]RemoteSession{}},
attachments: &staleWorkerEventAttachments{items: map[string]SessionAttachment{}},
policies: &staleWorkerEventPolicies{},
audit: &staleWorkerEventAudit{},
}
live := &staleWorkerEventLiveState{sessions: map[string]*LiveSessionState{}}
orchestrator := &staleWorkerEventOrchestrator{}
service := NewService(module.Dependencies{
Infra: module.Infra{Logger: slog.New(slog.NewTextHandler(io.Discard, nil))},
}, store, staleWorkerEventTransactor{store: store}, live, orchestrator)
service.now = func() time.Time { return time.Unix(100, 0).UTC() }
return service, store, live, orchestrator
}
type staleWorkerEventTransactor struct {
store Store
}
func (t staleWorkerEventTransactor) WithinTransaction(ctx context.Context, fn func(store Store) error) error {
return fn(t.store)
}
type staleWorkerEventTestStore struct {
remote *staleWorkerEventRemoteSessions
attachments *staleWorkerEventAttachments
policies *staleWorkerEventPolicies
audit *staleWorkerEventAudit
}
func (s *staleWorkerEventTestStore) RemoteSessions() RemoteSessionRepository { return s.remote }
func (s *staleWorkerEventTestStore) SessionAttachments() SessionAttachmentRepository {
return s.attachments
}
func (s *staleWorkerEventTestStore) ResourcePolicies() ResourcePolicyRepository { return s.policies }
func (s *staleWorkerEventTestStore) ResourceRuntime() ResourceRuntimeRepository {
return staleWorkerEventResourceRuntime{}
}
func (s *staleWorkerEventTestStore) AuditEvents() AuditEventRepository { return s.audit }
func (s *staleWorkerEventTestStore) Access() AccessRepository { return staleWorkerEventAccess{} }
type staleWorkerEventRemoteSessions struct {
sessions map[string]RemoteSession
updateCount int
}
func (r *staleWorkerEventRemoteSessions) Create(_ context.Context, session RemoteSession) error {
r.sessions[session.ID] = session
return nil
}
func (r *staleWorkerEventRemoteSessions) GetByID(_ context.Context, sessionID string) (*RemoteSession, error) {
session, ok := r.sessions[sessionID]
if !ok {
return nil, nil
}
return &session, nil
}
func (r *staleWorkerEventRemoteSessions) GetByIDForUpdate(ctx context.Context, sessionID string) (*RemoteSession, error) {
return r.GetByID(ctx, sessionID)
}
func (r *staleWorkerEventRemoteSessions) ListByController(_ context.Context, _ string) ([]RemoteSession, error) {
return nil, nil
}
func (r *staleWorkerEventRemoteSessions) CountLiveByResource(_ context.Context, _ string) (int, error) {
return 0, nil
}
func (r *staleWorkerEventRemoteSessions) ListDetachedExpired(_ context.Context, _ time.Time, _ int) ([]RemoteSession, error) {
return nil, nil
}
func (r *staleWorkerEventRemoteSessions) UpdateState(_ context.Context, params UpdateRemoteSessionStateParams) error {
session := r.sessions[params.RemoteSessionID]
session.State = params.State
session.WorkerID = params.WorkerID
session.DetachDeadlineAt = params.DetachDeadlineAt
session.LastHeartbeatAt = params.LastHeartbeatAt
session.TakeoverVersion = params.TakeoverVersion
session.UpdatedAt = params.UpdatedAt
r.sessions[params.RemoteSessionID] = session
r.updateCount++
return nil
}
type staleWorkerEventAttachments struct {
items map[string]SessionAttachment
}
func (r *staleWorkerEventAttachments) Create(_ context.Context, attachment SessionAttachment) error {
r.items[attachment.ID] = attachment
return nil
}
func (r *staleWorkerEventAttachments) GetByID(_ context.Context, attachmentID string) (*SessionAttachment, error) {
attachment, ok := r.items[attachmentID]
if !ok {
return nil, nil
}
return &attachment, nil
}
func (r *staleWorkerEventAttachments) GetByIDForUpdate(ctx context.Context, attachmentID string) (*SessionAttachment, error) {
return r.GetByID(ctx, attachmentID)
}
func (r *staleWorkerEventAttachments) ListByRemoteSession(_ context.Context, remoteSessionID string) ([]SessionAttachment, error) {
attachments := make([]SessionAttachment, 0)
for _, attachment := range r.items {
if attachment.RemoteSessionID == remoteSessionID {
attachments = append(attachments, attachment)
}
}
return attachments, nil
}
func (r *staleWorkerEventAttachments) ListActiveByRemoteSessionForUpdate(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error) {
return r.ListByRemoteSession(ctx, remoteSessionID)
}
func (r *staleWorkerEventAttachments) UpdateState(_ context.Context, params UpdateSessionAttachmentStateParams) error {
attachment := r.items[params.AttachmentID]
attachment.State = params.State
attachment.DetachedAt = params.DetachedAt
attachment.LastInputAt = params.LastInputAt
attachment.UpdatedAt = params.UpdatedAt
r.items[params.AttachmentID] = attachment
return nil
}
func (r *staleWorkerEventAttachments) Supersede(_ context.Context, params SupersedeAttachmentParams) error {
attachment := r.items[params.PreviousAttachmentID]
attachment.State = AttachmentStateSuperseded
attachment.SupersededBy = &params.NextAttachmentID
attachment.DetachedAt = &params.DetachedAt
attachment.UpdatedAt = params.UpdatedAt
r.items[params.PreviousAttachmentID] = attachment
return nil
}
type staleWorkerEventPolicies struct{}
func (r *staleWorkerEventPolicies) GetByResourceID(_ context.Context, _ string) (*ResourcePolicy, error) {
return nil, nil
}
func (r *staleWorkerEventPolicies) Upsert(_ context.Context, _ ResourcePolicy) error {
return nil
}
type staleWorkerEventAudit struct {
createCount int
}
func (r *staleWorkerEventAudit) Create(_ context.Context, _ AuditEvent) error {
r.createCount++
return nil
}
type staleWorkerEventResourceRuntime struct{}
func (staleWorkerEventResourceRuntime) GetByID(_ context.Context, _ string) (*ResourceRuntimeSpec, error) {
return nil, nil
}
type staleWorkerEventAccess struct{}
func (staleWorkerEventAccess) IsTrustedDevice(_ context.Context, _, _ string) (bool, error) {
return false, nil
}
func (staleWorkerEventAccess) GetPlatformRole(_ context.Context, _ string) (string, error) {
return "", nil
}
func (staleWorkerEventAccess) GetOrganizationRole(_ context.Context, _, _ string) (string, bool, error) {
return "", false, nil
}
type staleWorkerEventLiveState struct {
sessions map[string]*LiveSessionState
upsertCount int
}
func (s *staleWorkerEventLiveState) UpsertSession(_ context.Context, state LiveSessionState) error {
copied := state
s.sessions[state.SessionID] = &copied
s.upsertCount++
return nil
}
func (s *staleWorkerEventLiveState) GetSession(_ context.Context, sessionID string) (*LiveSessionState, error) {
state := s.sessions[sessionID]
if state == nil {
return nil, nil
}
copied := *state
return &copied, nil
}
func (s *staleWorkerEventLiveState) DeleteSession(_ context.Context, sessionID string) error {
delete(s.sessions, sessionID)
return nil
}
func (s *staleWorkerEventLiveState) BindController(_ context.Context, _ sessioncontracts.ControllerBinding, _ time.Duration) error {
return nil
}
func (s *staleWorkerEventLiveState) GetControllerBinding(_ context.Context, _ string) (*sessioncontracts.ControllerBinding, error) {
return nil, nil
}
func (s *staleWorkerEventLiveState) ClearControllerBinding(_ context.Context, _ string) error {
return nil
}
func (s *staleWorkerEventLiveState) StoreAttachToken(_ context.Context, _ sessioncontracts.AttachTokenClaims, _ time.Duration) error {
return nil
}
func (s *staleWorkerEventLiveState) ConsumeAttachToken(_ context.Context, _ string) (*sessioncontracts.AttachTokenClaims, error) {
return nil, nil
}
func (s *staleWorkerEventLiveState) TouchAttachmentHeartbeat(_ context.Context, _, _ string, _ time.Duration) error {
return nil
}
func (s *staleWorkerEventLiveState) UpdateWorkerRoute(_ context.Context, _ WorkerRoute, _ time.Duration) error {
return nil
}
func (s *staleWorkerEventLiveState) GetWorkerRoute(_ context.Context, _ string) (*WorkerRoute, error) {
return nil, nil
}
func (s *staleWorkerEventLiveState) DeleteWorkerRoute(_ context.Context, _ string) error {
return nil
}
type staleWorkerEventOrchestrator struct {
releaseCount int
}
func (o *staleWorkerEventOrchestrator) Reserve(_ context.Context, _ workercontracts.AttachRequest) (*workercontracts.WorkerLease, error) {
return nil, nil
}
func (o *staleWorkerEventOrchestrator) GetSessionLease(_ context.Context, _ string) (*workercontracts.WorkerLease, error) {
return nil, nil
}
func (o *staleWorkerEventOrchestrator) ReleaseSessionLease(_ context.Context, _ string) error {
o.releaseCount++
return nil
}
func (o *staleWorkerEventOrchestrator) PrepareAttachment(_ context.Context, _ RemoteSession, _ SessionAttachment, _ map[string]any) error {
return nil
}
func (o *staleWorkerEventOrchestrator) NotifyDetachment(_ context.Context, _ RemoteSession, _ SessionAttachment) error {
return nil
}
func (o *staleWorkerEventOrchestrator) TerminateRemoteSession(_ context.Context, _, _ string) error {
return nil
}
func (o *staleWorkerEventOrchestrator) ValidateSessionRuntime(_ context.Context, _, _ string) (bool, string, error) {
return true, "", nil
}
@@ -0,0 +1,44 @@
package sessionbroker
import (
"fmt"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
)
var allowedTransitions = map[sessioncontracts.State]map[sessioncontracts.State]struct{}{
sessioncontracts.StateStarting: {
sessioncontracts.StateActive: {},
sessioncontracts.StateFailed: {},
sessioncontracts.StateTerminated: {},
},
sessioncontracts.StateActive: {
sessioncontracts.StateDetached: {},
sessioncontracts.StateReconnecting: {},
sessioncontracts.StateFailed: {},
sessioncontracts.StateTerminated: {},
},
sessioncontracts.StateDetached: {
sessioncontracts.StateReconnecting: {},
sessioncontracts.StateTerminated: {},
sessioncontracts.StateFailed: {},
},
sessioncontracts.StateReconnecting: {
sessioncontracts.StateActive: {},
sessioncontracts.StateDetached: {},
sessioncontracts.StateFailed: {},
sessioncontracts.StateTerminated: {},
},
}
func validateTransition(from, to sessioncontracts.State) error {
if from == to {
return nil
}
if allowed, ok := allowedTransitions[from]; ok {
if _, ok := allowed[to]; ok {
return nil
}
}
return fmt.Errorf("invalid session state transition: %s -> %s", from, to)
}
File diff suppressed because it is too large Load Diff
+29
View File
@@ -0,0 +1,29 @@
package worker
type SessionEvent struct {
Type string `json:"type"`
SessionID string `json:"session_id"`
WorkerID string `json:"worker_id"`
Payload map[string]any `json:"payload,omitempty"`
}
const (
SessionEventConnected = "session_connected"
SessionEventHeartbeat = "session_heartbeat"
SessionEventFailed = "session_failed"
SessionEventTerminated = "session_terminated"
SessionEventDisplayReady = "session_display_ready"
SessionEventRenderReady = "session_render_ready"
SessionEventRenderDirty = "session_render_dirty"
SessionEventRenderResized = "session_render_resized"
SessionEventCursorUpdated = "session_cursor_updated"
SessionEventFrame = "session_frame"
SessionEventClipboardText = "session_clipboard_text"
SessionEventFileUploaded = "session_file_upload_completed"
SessionEventFileDownloadAvailable = "session_file_download_available"
SessionEventFileDownloadChunk = "session_file_download_chunk"
SessionEventFileDownloadProgress = "session_file_download_progress"
SessionEventFileDownloadCompleted = "session_file_download_completed"
SessionEventFileDownloadFailed = "session_file_download_failed"
SessionEventFileDownloadBlocked = "session_file_download_blocked"
)
@@ -0,0 +1,52 @@
package worker
import (
"context"
"errors"
"time"
"github.com/example/remote-access-platform/backend/internal/modules/sessionbroker"
)
type LeaseMonitor struct {
service *Service
broker *sessionbroker.Service
interval time.Duration
}
func NewLeaseMonitor(service *Service, broker *sessionbroker.Service, interval time.Duration) *LeaseMonitor {
if interval <= 0 {
interval = 15 * time.Second
}
return &LeaseMonitor{
service: service,
broker: broker,
interval: interval,
}
}
func (m *LeaseMonitor) Run(ctx context.Context) error {
ticker := time.NewTicker(m.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil
case <-ticker.C:
stale, err := m.service.RecoverStaleLeases(ctx)
if err != nil {
return err
}
for _, lease := range stale {
err := m.broker.MarkSessionFailed(ctx, sessionbroker.MarkSessionFailedCommand{
SessionID: lease.SessionID,
Reason: "worker_lease_stale_or_worker_missing",
})
if err != nil && !errors.Is(err, sessionbroker.ErrSessionNotFound) && !errors.Is(err, sessionbroker.ErrSessionNotTerminable) {
return err
}
}
}
}
}
@@ -0,0 +1,153 @@
package worker
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"time"
"github.com/redis/go-redis/v9"
"github.com/example/remote-access-platform/backend/internal/modules/sessionbroker"
)
type EventProcessor struct {
client *redis.Client
broker *sessionbroker.Service
}
func NewEventProcessor(client *redis.Client, broker *sessionbroker.Service) *EventProcessor {
return &EventProcessor{client: client, broker: broker}
}
func (p *EventProcessor) Run(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return nil
default:
}
result, err := p.client.BLPop(ctx, 5*time.Second, "worker:events").Result()
if err == redis.Nil {
continue
}
if err != nil {
if ctx.Err() != nil {
return nil
}
return fmt.Errorf("consume worker event: %w", err)
}
if len(result) != 2 {
continue
}
var event SessionEvent
if err := json.Unmarshal([]byte(result[1]), &event); err != nil {
continue
}
if err := p.handleEvent(ctx, event); err != nil {
return err
}
}
}
func (p *EventProcessor) handleEvent(ctx context.Context, event SessionEvent) error {
switch event.Type {
case SessionEventConnected, SessionEventDisplayReady:
if err := p.broker.HandleWorkerConnected(ctx, event.SessionID); err != nil {
return err
}
if len(event.Payload) > 0 {
if err := p.broker.UpdateWorkerRenderTelemetry(ctx, event.SessionID, event.Payload); err != nil && !errors.Is(err, sessionbroker.ErrSessionNotFound) {
return err
}
}
return nil
case SessionEventHeartbeat:
return p.broker.HandleWorkerHeartbeat(ctx, event.SessionID)
case SessionEventRenderReady, SessionEventRenderDirty, SessionEventRenderResized, SessionEventCursorUpdated, SessionEventFrame:
if len(event.Payload) == 0 {
return nil
}
if correlationID, _ := event.Payload["input_correlation_id"].(string); correlationID != "" {
slog.Info("worker frame event received",
"session_id", event.SessionID,
"worker_id", event.WorkerID,
"frame_sequence", event.Payload["frame_sequence"],
"correlation_id", correlationID,
"worker_frame_captured_at", event.Payload["worker_frame_captured_at"],
"trace_stage", "backend_frame_receive")
}
return p.updateRenderTelemetryWithRetry(ctx, event.SessionID, event.Payload)
case SessionEventClipboardText:
if len(event.Payload) == 0 {
return nil
}
slog.Info("worker clipboard event received",
"session_id", event.SessionID,
"worker_id", event.WorkerID,
"origin", event.Payload["origin"],
"sequence_id", event.Payload["sequence_id"],
"content_hash", event.Payload["content_hash"])
return p.broker.UpdateWorkerClipboardText(ctx, event.SessionID, event.Payload)
case SessionEventFileUploaded:
slog.Info("worker file upload completed",
"session_id", event.SessionID,
"worker_id", event.WorkerID,
"transfer_id", event.Payload["transfer_id"],
"file_name", event.Payload["file_name"],
"file_size", event.Payload["file_size"],
"content_hash", event.Payload["content_hash"],
"storage_path", event.Payload["storage_path"])
return nil
case SessionEventFileDownloadAvailable, SessionEventFileDownloadChunk, SessionEventFileDownloadProgress,
SessionEventFileDownloadCompleted, SessionEventFileDownloadFailed, SessionEventFileDownloadBlocked:
slog.Info("worker file download event received",
"session_id", event.SessionID,
"worker_id", event.WorkerID,
"event_type", event.Type,
"transfer_id", event.Payload["transfer_id"],
"file_id", event.Payload["file_id"],
"file_name", event.Payload["file_name"],
"status", event.Payload["status"])
return p.broker.UpdateWorkerFileDownloadEvent(ctx, event.SessionID, event.Type, event.Payload)
case SessionEventFailed:
reason, _ := event.Payload["reason"].(string)
err := p.broker.MarkSessionFailed(ctx, sessionbroker.MarkSessionFailedCommand{
SessionID: event.SessionID,
Reason: reason,
})
if errors.Is(err, sessionbroker.ErrSessionNotFound) || errors.Is(err, sessionbroker.ErrSessionNotTerminable) {
return nil
}
return err
case SessionEventTerminated:
reason, _ := event.Payload["reason"].(string)
err := p.broker.TerminateSession(ctx, sessionbroker.TerminateSessionCommand{
SessionID: event.SessionID,
Reason: reason,
})
if errors.Is(err, sessionbroker.ErrSessionNotFound) || errors.Is(err, sessionbroker.ErrSessionNotTerminable) {
return nil
}
return err
default:
return nil
}
}
func (p *EventProcessor) updateRenderTelemetryWithRetry(ctx context.Context, sessionID string, payload map[string]any) error {
var lastErr error
for attempt := 0; attempt < 10; attempt++ {
err := p.broker.UpdateWorkerRenderTelemetry(ctx, sessionID, payload)
if err == nil || errors.Is(err, sessionbroker.ErrSessionNotFound) {
return nil
}
lastErr = err
time.Sleep(100 * time.Millisecond)
}
return lastErr
}
@@ -0,0 +1,264 @@
package worker
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"time"
"github.com/redis/go-redis/v9"
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
)
type RedisStore struct {
client *redis.Client
}
func NewRedisStore(client *redis.Client) *RedisStore {
return &RedisStore{client: client}
}
func (s *RedisStore) RegisterWorker(ctx context.Context, registration workercontracts.WorkerRegistration, ttl time.Duration) error {
payload, err := json.Marshal(registration)
if err != nil {
return fmt.Errorf("marshal worker registration: %w", err)
}
pipe := s.client.TxPipeline()
pipe.Set(ctx, workerKey(registration.WorkerID), payload, ttl)
pipe.SAdd(ctx, workerSetKey(), registration.WorkerID)
_, err = pipe.Exec(ctx)
if err != nil {
return fmt.Errorf("register worker: %w", err)
}
return nil
}
func (s *RedisStore) TouchWorkerHeartbeat(ctx context.Context, heartbeat workercontracts.WorkerHeartbeat, ttl time.Duration) error {
registration, err := s.GetWorker(ctx, heartbeat.WorkerID)
if err != nil {
return err
}
if registration == nil {
registration = &workercontracts.WorkerRegistration{
WorkerID: heartbeat.WorkerID,
Protocol: workercontracts.ProtocolRDP,
}
}
registration.Status = heartbeat.Status
registration.LastHeartbeatAt = heartbeat.LastHeartbeatAt
return s.RegisterWorker(ctx, *registration, ttl)
}
func (s *RedisStore) ListWorkers(ctx context.Context) ([]workercontracts.WorkerRegistration, error) {
ids, err := s.client.SMembers(ctx, workerSetKey()).Result()
if err != nil {
return nil, fmt.Errorf("list worker ids: %w", err)
}
workers := make([]workercontracts.WorkerRegistration, 0, len(ids))
for _, id := range ids {
worker, err := s.GetWorker(ctx, id)
if err != nil {
return nil, err
}
if worker != nil {
workers = append(workers, *worker)
}
}
return workers, nil
}
func (s *RedisStore) GetWorker(ctx context.Context, workerID string) (*workercontracts.WorkerRegistration, error) {
payload, err := s.client.Get(ctx, workerKey(workerID)).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("get worker: %w", err)
}
var registration workercontracts.WorkerRegistration
if err := json.Unmarshal([]byte(payload), &registration); err != nil {
return nil, fmt.Errorf("decode worker registration: %w", err)
}
return &registration, nil
}
func (s *RedisStore) AcquireLease(ctx context.Context, lease workercontracts.WorkerLease, ttl time.Duration) error {
payload, err := json.Marshal(lease)
if err != nil {
return fmt.Errorf("marshal lease: %w", err)
}
ok, err := s.client.SetNX(ctx, leaseKey(lease.LeaseID), payload, ttl).Result()
if err != nil {
return fmt.Errorf("acquire lease: %w", err)
}
if !ok {
return fmt.Errorf("lease already exists")
}
pipe := s.client.TxPipeline()
pipe.SAdd(ctx, leaseSetKey(), lease.LeaseID)
pipe.Set(ctx, sessionLeaseKey(lease.SessionID), lease.LeaseID, ttl)
_, err = pipe.Exec(ctx)
if err != nil {
return fmt.Errorf("index lease: %w", err)
}
return nil
}
func (s *RedisStore) GetLease(ctx context.Context, leaseID string) (*workercontracts.WorkerLease, error) {
payload, err := s.client.Get(ctx, leaseKey(leaseID)).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("get lease: %w", err)
}
var lease workercontracts.WorkerLease
if err := json.Unmarshal([]byte(payload), &lease); err != nil {
return nil, fmt.Errorf("decode lease: %w", err)
}
return &lease, nil
}
func (s *RedisStore) GetLeaseBySession(ctx context.Context, sessionID string) (*workercontracts.WorkerLease, error) {
leaseID, err := s.client.Get(ctx, sessionLeaseKey(sessionID)).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("get lease by session: %w", err)
}
return s.GetLease(ctx, leaseID)
}
func (s *RedisStore) RenewLease(ctx context.Context, lease workercontracts.WorkerLease, ttl time.Duration) error {
payload, err := json.Marshal(lease)
if err != nil {
return fmt.Errorf("marshal lease renewal: %w", err)
}
pipe := s.client.TxPipeline()
pipe.Set(ctx, leaseKey(lease.LeaseID), payload, ttl)
pipe.Set(ctx, sessionLeaseKey(lease.SessionID), lease.LeaseID, ttl)
_, err = pipe.Exec(ctx)
if err != nil {
return fmt.Errorf("renew lease: %w", err)
}
return nil
}
func (s *RedisStore) ReleaseLease(ctx context.Context, leaseID string) error {
lease, err := s.GetLease(ctx, leaseID)
if err != nil {
return err
}
pipe := s.client.TxPipeline()
pipe.Del(ctx, leaseKey(leaseID))
pipe.SRem(ctx, leaseSetKey(), leaseID)
if lease != nil {
pipe.Del(ctx, sessionLeaseKey(lease.SessionID))
}
_, err = pipe.Exec(ctx)
if err != nil {
return fmt.Errorf("release lease: %w", err)
}
return nil
}
func (s *RedisStore) ListLeases(ctx context.Context) ([]workercontracts.WorkerLease, error) {
ids, err := s.client.SMembers(ctx, leaseSetKey()).Result()
if err != nil {
return nil, fmt.Errorf("list lease ids: %w", err)
}
leases := make([]workercontracts.WorkerLease, 0, len(ids))
for _, id := range ids {
lease, err := s.GetLease(ctx, id)
if err != nil {
return nil, err
}
if lease != nil {
leases = append(leases, *lease)
}
}
return leases, nil
}
func (s *RedisStore) AppendEnvelope(ctx context.Context, envelope workercontracts.RoutedEnvelope) error {
payload, err := json.Marshal(envelope)
if err != nil {
return fmt.Errorf("marshal routed envelope: %w", err)
}
key := workerQueueKey(envelope.SessionID)
if err := s.client.RPush(ctx, key, payload).Err(); err != nil {
return fmt.Errorf("append routed envelope: %w", err)
}
if envelope.Type == "input" {
correlationID, _ := envelope.Payload["correlation_id"].(string)
if correlationID != "" {
if length, err := s.client.LLen(ctx, key).Result(); err == nil {
slog.Info("worker queue length after input append",
"session_id", envelope.SessionID,
"attachment_id", envelope.AttachmentID,
"correlation_id", correlationID,
"queue_key", key,
"queue_length", length,
"trace_stage", "redis_queue_append")
}
}
}
return s.client.Expire(ctx, key, 10*time.Minute).Err()
}
func (s *RedisStore) AppendAssignment(ctx context.Context, workerID string, payload map[string]any) error {
encoded, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal worker assignment: %w", err)
}
if err := s.client.RPush(ctx, workerControlQueueKey(workerID), encoded).Err(); err != nil {
return fmt.Errorf("append worker assignment: %w", err)
}
return s.client.Expire(ctx, workerControlQueueKey(workerID), 10*time.Minute).Err()
}
func (s *RedisStore) AppendEvent(ctx context.Context, payload map[string]any) error {
encoded, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal worker event: %w", err)
}
if err := s.client.RPush(ctx, workerEventsKey(), encoded).Err(); err != nil {
return fmt.Errorf("append worker event: %w", err)
}
return s.client.Expire(ctx, workerEventsKey(), 10*time.Minute).Err()
}
func workerKey(workerID string) string {
return "worker:registration:" + workerID
}
func workerSetKey() string {
return "worker:registrations"
}
func leaseKey(leaseID string) string {
return "worker:lease:" + leaseID
}
func leaseSetKey() string {
return "worker:leases"
}
func sessionLeaseKey(sessionID string) string {
return "worker:session-lease:" + sessionID
}
func workerQueueKey(sessionID string) string {
return "worker:queue:" + sessionID
}
func workerControlQueueKey(workerID string) string {
return "worker:control:" + workerID
}
func workerEventsKey() string {
return "worker:events"
}
+274
View File
@@ -0,0 +1,274 @@
package worker
import (
"context"
"encoding/json"
"errors"
"fmt"
"slices"
"time"
"github.com/google/uuid"
"github.com/example/remote-access-platform/backend/internal/modules/sessionbroker"
"github.com/example/remote-access-platform/backend/internal/platform/module"
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
)
var ErrNoWorkerAvailable = errors.New("no worker available")
type Service struct {
cfg module.Config
store Store
now func() time.Time
}
func NewService(deps module.Dependencies, store Store) *Service {
return &Service{
cfg: deps.Config,
store: store,
now: time.Now,
}
}
func (s *Service) Register(ctx context.Context, registration workercontracts.WorkerRegistration) error {
if registration.WorkerID == "" {
return fmt.Errorf("worker id is required")
}
registration.LastHeartbeatAt = s.now().UTC()
return s.store.RegisterWorker(ctx, registration, s.cfg.Worker.HeartbeatTTL)
}
func (s *Service) Heartbeat(ctx context.Context, heartbeat workercontracts.WorkerHeartbeat) error {
heartbeat.LastHeartbeatAt = s.now().UTC()
return s.store.TouchWorkerHeartbeat(ctx, heartbeat, s.cfg.Worker.HeartbeatTTL)
}
func (s *Service) Reserve(ctx context.Context, request workercontracts.AttachRequest) (*workercontracts.WorkerLease, error) {
registration, err := s.reserveWorker(ctx, workercontracts.ProtocolRDP, request.RequiredCapabilities)
if err != nil {
return nil, err
}
return s.AcquireLease(ctx, registration.WorkerID, request)
}
func (s *Service) reserveWorker(ctx context.Context, protocol workercontracts.Protocol, capabilities []string) (*workercontracts.WorkerRegistration, error) {
workers, err := s.store.ListWorkers(ctx)
if err != nil {
return nil, err
}
now := s.now().UTC()
for _, worker := range workers {
if worker.Protocol != protocol || worker.Status != workercontracts.StatusOnline {
continue
}
if now.Sub(worker.LastHeartbeatAt) > s.cfg.Worker.StaleLeaseGracePeriod+s.cfg.Worker.HeartbeatTTL {
continue
}
if !hasCapabilities(worker.Capabilities, capabilities) {
continue
}
return &worker, nil
}
return nil, ErrNoWorkerAvailable
}
func (s *Service) AcquireLease(ctx context.Context, workerID string, request workercontracts.AttachRequest) (*workercontracts.WorkerLease, error) {
if request.SessionID == "" {
request.SessionID = uuid.NewString()
}
now := s.now().UTC()
lease := workercontracts.WorkerLease{
LeaseID: uuid.NewString(),
WorkerID: workerID,
Protocol: workercontracts.ProtocolRDP,
ResourceID: request.ResourceID,
SessionID: request.SessionID,
Capabilities: request.RequiredCapabilities,
ControlStream: "worker://control/" + workerID,
ExpiresAt: now.Add(s.cfg.Worker.LeaseTTL),
RenderQualityProfile: normalizeRenderQualityProfile(request.RenderQualityProfile),
}
if err := s.store.AcquireLease(ctx, lease, s.cfg.Worker.LeaseTTL); err != nil {
return nil, err
}
return &lease, nil
}
func (s *Service) GetSessionLease(ctx context.Context, sessionID string) (*workercontracts.WorkerLease, error) {
return s.store.GetLeaseBySession(ctx, sessionID)
}
func (s *Service) RenewLease(ctx context.Context, leaseID string) (*workercontracts.WorkerLease, error) {
lease, err := s.store.GetLease(ctx, leaseID)
if err != nil || lease == nil {
return lease, err
}
lease.ExpiresAt = s.now().UTC().Add(s.cfg.Worker.LeaseTTL)
if err := s.store.RenewLease(ctx, *lease, s.cfg.Worker.LeaseTTL); err != nil {
return nil, err
}
return lease, nil
}
func (s *Service) ReleaseLease(ctx context.Context, leaseID string) error {
return s.store.ReleaseLease(ctx, leaseID)
}
func (s *Service) ReleaseSessionLease(ctx context.Context, sessionID string) error {
lease, err := s.store.GetLeaseBySession(ctx, sessionID)
if err != nil || lease == nil {
return err
}
return s.store.ReleaseLease(ctx, lease.LeaseID)
}
func (s *Service) RecoverStaleLeases(ctx context.Context) ([]workercontracts.WorkerLease, error) {
leases, err := s.store.ListLeases(ctx)
if err != nil {
return nil, err
}
var stale []workercontracts.WorkerLease
now := s.now().UTC()
for _, lease := range leases {
registration, err := s.store.GetWorker(ctx, lease.WorkerID)
if err != nil {
return nil, err
}
if registration == nil || now.Sub(registration.LastHeartbeatAt) > s.cfg.Worker.StaleLeaseGracePeriod+s.cfg.Worker.HeartbeatTTL {
if err := s.store.ReleaseLease(ctx, lease.LeaseID); err != nil {
return nil, err
}
stale = append(stale, lease)
}
}
return stale, nil
}
func (s *Service) ValidateSessionRuntime(ctx context.Context, sessionID, workerID string) (bool, string, error) {
lease, err := s.store.GetLeaseBySession(ctx, sessionID)
if err != nil {
return false, "", err
}
if lease == nil {
return false, "worker_lease_missing", nil
}
if workerID != "" && lease.WorkerID != workerID {
_ = s.store.ReleaseLease(ctx, lease.LeaseID)
return false, "worker_binding_mismatch", nil
}
now := s.now().UTC()
if !lease.ExpiresAt.After(now) {
_ = s.store.ReleaseLease(ctx, lease.LeaseID)
return false, "worker_lease_expired", nil
}
registration, err := s.store.GetWorker(ctx, lease.WorkerID)
if err != nil {
return false, "", err
}
if registration == nil {
_ = s.store.ReleaseLease(ctx, lease.LeaseID)
return false, "worker_registration_missing", nil
}
if registration.Status != workercontracts.StatusOnline {
return false, "worker_not_online", nil
}
if now.Sub(registration.LastHeartbeatAt) > s.cfg.Worker.StaleLeaseGracePeriod+s.cfg.Worker.HeartbeatTTL {
_ = s.store.ReleaseLease(ctx, lease.LeaseID)
return false, "worker_heartbeat_stale", nil
}
return true, "", nil
}
func (s *Service) PublishControl(ctx context.Context, envelope workercontracts.RoutedEnvelope) error {
return s.store.AppendEnvelope(ctx, envelope)
}
func (s *Service) PublishInput(ctx context.Context, envelope workercontracts.RoutedEnvelope) error {
return s.store.AppendEnvelope(ctx, envelope)
}
func hasCapabilities(workerCaps, required []string) bool {
for _, capability := range required {
if !slices.Contains(workerCaps, capability) {
return false
}
}
return true
}
func (s *Service) PrepareAttachment(ctx context.Context, session sessionbroker.RemoteSession, attachment sessionbroker.SessionAttachment, runtimeMetadata map[string]any) error {
renderQualityProfile := normalizeRenderQualityProfile(session.RenderQualityProfile)
if renderQualityProfile == "balanced" {
renderQualityProfile = renderQualityProfileFromMetadata(session.Metadata)
}
if runtimeMetadata == nil {
runtimeMetadata = decodeMetadata(session.Metadata)
}
return s.store.AppendAssignment(ctx, session.WorkerID, map[string]any{
"type": "session_assignment",
"session_id": session.ID,
"worker_id": session.WorkerID,
"attachment_id": attachment.ID,
"user_id": attachment.UserID,
"device_id": attachment.DeviceID,
"takeover_of": attachment.TakeoverOf,
"state": session.State,
"render_quality_profile": renderQualityProfile,
"metadata": runtimeMetadata,
})
}
func (s *Service) NotifyDetachment(ctx context.Context, session sessionbroker.RemoteSession, attachment sessionbroker.SessionAttachment) error {
return s.PublishControl(ctx, workercontracts.RoutedEnvelope{
SessionID: session.ID,
AttachmentID: attachment.ID,
Type: "control",
Payload: map[string]any{
"action": "detach",
},
})
}
func (s *Service) TerminateRemoteSession(ctx context.Context, sessionID, attachmentID string) error {
return s.PublishControl(ctx, workercontracts.RoutedEnvelope{
SessionID: sessionID,
AttachmentID: attachmentID,
Type: "control",
Payload: map[string]any{
"action": "terminate",
},
})
}
func decodeMetadata(payload []byte) map[string]any {
var out map[string]any
if len(payload) == 0 {
return map[string]any{}
}
if err := json.Unmarshal(payload, &out); err != nil {
return map[string]any{}
}
return out
}
func normalizeRenderQualityProfile(profile string) string {
switch profile {
case "low_bandwidth", "balanced", "high_quality", "text_priority":
return profile
default:
return "balanced"
}
}
func renderQualityProfileFromMetadata(metadata []byte) string {
decoded := decodeMetadata(metadata)
resource, _ := decoded["resource"].(map[string]any)
if resource == nil {
return "balanced"
}
if profile, ok := resource["render_quality_profile"].(string); ok && profile != "" {
return normalizeRenderQualityProfile(profile)
}
return "balanced"
}
+24
View File
@@ -0,0 +1,24 @@
package worker
import (
"context"
"time"
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
)
type Store interface {
RegisterWorker(ctx context.Context, registration workercontracts.WorkerRegistration, ttl time.Duration) error
TouchWorkerHeartbeat(ctx context.Context, heartbeat workercontracts.WorkerHeartbeat, ttl time.Duration) error
ListWorkers(ctx context.Context) ([]workercontracts.WorkerRegistration, error)
GetWorker(ctx context.Context, workerID string) (*workercontracts.WorkerRegistration, error)
AcquireLease(ctx context.Context, lease workercontracts.WorkerLease, ttl time.Duration) error
GetLease(ctx context.Context, leaseID string) (*workercontracts.WorkerLease, error)
GetLeaseBySession(ctx context.Context, sessionID string) (*workercontracts.WorkerLease, error)
RenewLease(ctx context.Context, lease workercontracts.WorkerLease, ttl time.Duration) error
ReleaseLease(ctx context.Context, leaseID string) error
ListLeases(ctx context.Context) ([]workercontracts.WorkerLease, error)
AppendAssignment(ctx context.Context, workerID string, payload map[string]any) error
AppendEnvelope(ctx context.Context, envelope workercontracts.RoutedEnvelope) error
AppendEvent(ctx context.Context, payload map[string]any) error
}
@@ -0,0 +1,329 @@
package authority
import (
"context"
"crypto/ed25519"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/hex"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/example/remote-access-platform/backend/internal/platform/config"
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
)
const (
ModeStrict = "strict"
ModeLegacy = "legacy"
ActivationSchemaVersion = "rap.installation.activation.v1"
PlatformRoleUser = "user"
PlatformRoleAdmin = "platform_admin"
PlatformRoleRecoveryAdmin = "platform_recovery_admin"
)
var (
ErrInvalidAuthorityMode = errors.New("invalid installation authority mode")
ErrProductRootKeyNeeded = errors.New("product root public key is required")
ErrInvalidActivation = errors.New("invalid installation activation")
ErrInvalidGrant = errors.New("invalid platform role grant")
)
type ActivationPayload struct {
SchemaVersion string `json:"schema_version"`
InstallID string `json:"install_id"`
OwnerEmail string `json:"owner_email"`
PlatformRole string `json:"platform_role"`
IssuedAt time.Time `json:"issued_at"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
Nonce string `json:"nonce,omitempty"`
Environment string `json:"environment,omitempty"`
}
type Verifier struct {
mode string
rootPublicKey ed25519.PublicKey
rootFingerprint string
allowInsecureBootstrap bool
now func() time.Time
}
func NewVerifier(cfg config.InstallationConfig) (*Verifier, error) {
mode := strings.ToLower(strings.TrimSpace(cfg.AuthorityMode))
if mode == "" {
mode = ModeLegacy
}
verifier := &Verifier{
mode: mode,
allowInsecureBootstrap: cfg.AllowInsecureBootstrap,
now: time.Now,
}
switch mode {
case ModeLegacy:
return verifier, nil
case ModeStrict:
publicKey, err := decodeEd25519PublicKey(cfg.ProductRootPublicKeyBase64)
if err != nil {
return nil, err
}
verifier.rootPublicKey = publicKey
fingerprint := sha256.Sum256(publicKey)
verifier.rootFingerprint = hex.EncodeToString(fingerprint[:])
return verifier, nil
default:
return nil, fmt.Errorf("%w: %s", ErrInvalidAuthorityMode, mode)
}
}
func (v *Verifier) Mode() string {
if v == nil || v.mode == "" {
return ModeLegacy
}
return v.mode
}
func (v *Verifier) Strict() bool {
return v != nil && v.mode == ModeStrict
}
func (v *Verifier) AllowInsecureBootstrap() bool {
return v != nil && v.allowInsecureBootstrap
}
func (v *Verifier) RootFingerprint() string {
if v == nil {
return ""
}
return v.rootFingerprint
}
func (v *Verifier) VerifyActivation(payload json.RawMessage, signature string) (ActivationPayload, error) {
if v == nil || !v.Strict() {
return ActivationPayload{}, ErrProductRootKeyNeeded
}
activation, canonical, err := parseActivationPayload(payload)
if err != nil {
return ActivationPayload{}, err
}
if err := activation.validate(v.now().UTC()); err != nil {
return ActivationPayload{}, err
}
if err := v.verifySignature(canonical, signature); err != nil {
return ActivationPayload{}, fmt.Errorf("%w: %v", ErrInvalidActivation, err)
}
return activation, nil
}
func (v *Verifier) VerifyPlatformRoleGrant(payload json.RawMessage, signature, expectedInstallID, expectedEmail, expectedRole string) (ActivationPayload, error) {
activation, err := v.VerifyActivation(payload, signature)
if err != nil {
return ActivationPayload{}, fmt.Errorf("%w: %v", ErrInvalidGrant, err)
}
if activation.InstallID != strings.TrimSpace(expectedInstallID) {
return ActivationPayload{}, fmt.Errorf("%w: install_id mismatch", ErrInvalidGrant)
}
if !strings.EqualFold(activation.OwnerEmail, strings.TrimSpace(expectedEmail)) {
return ActivationPayload{}, fmt.Errorf("%w: owner_email mismatch", ErrInvalidGrant)
}
if activation.PlatformRole != strings.TrimSpace(expectedRole) {
return ActivationPayload{}, fmt.Errorf("%w: platform_role mismatch", ErrInvalidGrant)
}
return activation, nil
}
func CanonicalJSON(raw json.RawMessage) ([]byte, error) {
if len(raw) == 0 {
return nil, fmt.Errorf("%w: empty payload", ErrInvalidActivation)
}
var value any
if err := json.Unmarshal(raw, &value); err != nil {
return nil, fmt.Errorf("%w: invalid json: %v", ErrInvalidActivation, err)
}
canonical, err := json.Marshal(value)
if err != nil {
return nil, fmt.Errorf("%w: canonical json: %v", ErrInvalidActivation, err)
}
return canonical, nil
}
func EffectivePlatformRole(ctx context.Context, db postgresplatform.DBTX, verifier *Verifier, userID string) (string, error) {
userID = strings.TrimSpace(userID)
if userID == "" {
return PlatformRoleUser, nil
}
if verifier == nil || !verifier.Strict() {
return legacyPlatformRole(ctx, db, userID)
}
var email string
if err := db.QueryRow(ctx, `SELECT email FROM users WHERE id = $1::uuid`, userID).Scan(&email); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return PlatformRoleUser, nil
}
return "", fmt.Errorf("get user email for platform grant: %w", err)
}
rows, err := db.Query(ctx, `
SELECT prg.role, prg.install_id, prg.grant_payload, prg.grant_signature
FROM platform_role_grants prg
JOIN installation_authority ia
ON ia.id = 1
AND ia.install_id = prg.install_id
AND ia.authority_state = 'active'
WHERE prg.user_id = $1::uuid
AND prg.revoked_at IS NULL
AND (prg.expires_at IS NULL OR prg.expires_at > NOW())
ORDER BY CASE prg.role
WHEN 'platform_recovery_admin' THEN 0
WHEN 'platform_admin' THEN 1
ELSE 2
END, prg.granted_at DESC
`, userID)
if err != nil {
return "", fmt.Errorf("query platform role grants: %w", err)
}
defer rows.Close()
bestRole := PlatformRoleUser
for rows.Next() {
var role, installID, signature string
var payload []byte
if err := rows.Scan(&role, &installID, &payload, &signature); err != nil {
return "", fmt.Errorf("scan platform role grant: %w", err)
}
if _, err := verifier.VerifyPlatformRoleGrant(json.RawMessage(payload), signature, installID, email, role); err != nil {
continue
}
if role == PlatformRoleRecoveryAdmin {
return role, nil
}
if role == PlatformRoleAdmin {
bestRole = role
}
}
if err := rows.Err(); err != nil {
return "", fmt.Errorf("iterate platform role grants: %w", err)
}
return bestRole, nil
}
func legacyPlatformRole(ctx context.Context, db postgresplatform.DBTX, userID string) (string, error) {
var role string
if err := db.QueryRow(ctx, `SELECT platform_role FROM users WHERE id = $1::uuid`, userID).Scan(&role); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return PlatformRoleUser, nil
}
return "", fmt.Errorf("get platform role: %w", err)
}
if role == "" {
return PlatformRoleUser, nil
}
return role, nil
}
func parseActivationPayload(raw json.RawMessage) (ActivationPayload, []byte, error) {
canonical, err := CanonicalJSON(raw)
if err != nil {
return ActivationPayload{}, nil, err
}
var activation ActivationPayload
if err := json.Unmarshal(canonical, &activation); err != nil {
return ActivationPayload{}, nil, fmt.Errorf("%w: decode activation: %v", ErrInvalidActivation, err)
}
activation.SchemaVersion = strings.TrimSpace(activation.SchemaVersion)
activation.InstallID = strings.TrimSpace(activation.InstallID)
activation.OwnerEmail = strings.ToLower(strings.TrimSpace(activation.OwnerEmail))
activation.PlatformRole = strings.TrimSpace(activation.PlatformRole)
activation.Nonce = strings.TrimSpace(activation.Nonce)
activation.Environment = strings.TrimSpace(activation.Environment)
return activation, canonical, nil
}
func (p ActivationPayload) validate(now time.Time) error {
if p.SchemaVersion != ActivationSchemaVersion {
return fmt.Errorf("%w: schema_version must be %s", ErrInvalidActivation, ActivationSchemaVersion)
}
if p.InstallID == "" {
return fmt.Errorf("%w: install_id is required", ErrInvalidActivation)
}
if p.OwnerEmail == "" || !strings.Contains(p.OwnerEmail, "@") {
return fmt.Errorf("%w: owner_email is required", ErrInvalidActivation)
}
switch p.PlatformRole {
case PlatformRoleAdmin, PlatformRoleRecoveryAdmin:
default:
return fmt.Errorf("%w: platform_role must be platform_admin or platform_recovery_admin", ErrInvalidActivation)
}
if p.IssuedAt.IsZero() {
return fmt.Errorf("%w: issued_at is required", ErrInvalidActivation)
}
if p.IssuedAt.After(now.Add(5 * time.Minute)) {
return fmt.Errorf("%w: issued_at is too far in the future", ErrInvalidActivation)
}
if p.ExpiresAt != nil && !p.ExpiresAt.After(now) {
return fmt.Errorf("%w: activation expired", ErrInvalidActivation)
}
return nil
}
func (v *Verifier) verifySignature(payload []byte, signatureText string) error {
signature, err := decodeBase64(strings.TrimSpace(signatureText))
if err != nil {
return fmt.Errorf("signature must be base64 encoded: %w", err)
}
if len(signature) != ed25519.SignatureSize {
return fmt.Errorf("signature must decode to %d bytes", ed25519.SignatureSize)
}
if !ed25519.Verify(v.rootPublicKey, payload, signature) {
return errors.New("signature verification failed")
}
return nil
}
func decodeEd25519PublicKey(value string) (ed25519.PublicKey, error) {
value = strings.TrimSpace(value)
if value == "" {
return nil, ErrProductRootKeyNeeded
}
if block, _ := pem.Decode([]byte(value)); block != nil {
parsed, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parse product root public key PEM: %w", err)
}
publicKey, ok := parsed.(ed25519.PublicKey)
if !ok {
return nil, fmt.Errorf("product root public key PEM must contain an Ed25519 public key")
}
return publicKey, nil
}
decoded, err := decodeBase64(value)
if err != nil {
return nil, fmt.Errorf("product root public key must be base64 encoded: %w", err)
}
if len(decoded) != ed25519.PublicKeySize {
return nil, fmt.Errorf("product root public key must decode to %d bytes", ed25519.PublicKeySize)
}
return ed25519.PublicKey(decoded), nil
}
func decodeBase64(value string) ([]byte, error) {
decoded, err := base64.StdEncoding.DecodeString(value)
if err == nil {
return decoded, nil
}
decoded, rawErr := base64.RawStdEncoding.DecodeString(value)
if rawErr == nil {
return decoded, nil
}
return nil, err
}
@@ -0,0 +1,85 @@
package authority
import (
"crypto/ed25519"
"encoding/base64"
"encoding/json"
"strings"
"testing"
"time"
"github.com/example/remote-access-platform/backend/internal/platform/config"
)
func TestVerifierAcceptsSignedActivation(t *testing.T) {
publicKey, privateKey, err := ed25519.GenerateKey(nil)
if err != nil {
t.Fatalf("generate key: %v", err)
}
verifier, err := NewVerifier(config.InstallationConfig{
AuthorityMode: ModeStrict,
ProductRootPublicKeyBase64: base64.StdEncoding.EncodeToString(publicKey),
})
if err != nil {
t.Fatalf("NewVerifier: %v", err)
}
verifier.now = func() time.Time { return time.Date(2026, 4, 28, 12, 0, 0, 0, time.UTC) }
payload := json.RawMessage(`{
"platform_role":"platform_admin",
"owner_email":"Owner@Example.test",
"install_id":"install-1",
"schema_version":"rap.installation.activation.v1",
"issued_at":"2026-04-28T11:00:00Z",
"expires_at":"2026-04-29T11:00:00Z"
}`)
canonical, err := CanonicalJSON(payload)
if err != nil {
t.Fatalf("CanonicalJSON: %v", err)
}
signature := base64.StdEncoding.EncodeToString(ed25519.Sign(privateKey, canonical))
activation, err := verifier.VerifyActivation(payload, signature)
if err != nil {
t.Fatalf("VerifyActivation: %v", err)
}
if activation.OwnerEmail != "owner@example.test" || activation.PlatformRole != PlatformRoleAdmin {
t.Fatalf("unexpected activation: %+v", activation)
}
if verifier.RootFingerprint() == "" {
t.Fatal("expected root fingerprint")
}
}
func TestVerifierRejectsTamperedActivation(t *testing.T) {
publicKey, privateKey, err := ed25519.GenerateKey(nil)
if err != nil {
t.Fatalf("generate key: %v", err)
}
verifier, err := NewVerifier(config.InstallationConfig{
AuthorityMode: ModeStrict,
ProductRootPublicKeyBase64: base64.StdEncoding.EncodeToString(publicKey),
})
if err != nil {
t.Fatalf("NewVerifier: %v", err)
}
verifier.now = func() time.Time { return time.Date(2026, 4, 28, 12, 0, 0, 0, time.UTC) }
payload := json.RawMessage(`{
"schema_version":"rap.installation.activation.v1",
"install_id":"install-1",
"owner_email":"owner@example.test",
"platform_role":"platform_admin",
"issued_at":"2026-04-28T11:00:00Z"
}`)
canonical, err := CanonicalJSON(payload)
if err != nil {
t.Fatalf("CanonicalJSON: %v", err)
}
signature := base64.StdEncoding.EncodeToString(ed25519.Sign(privateKey, canonical))
tampered := json.RawMessage(strings.Replace(string(payload), "platform_admin", "platform_recovery_admin", 1))
if _, err := verifier.VerifyActivation(tampered, signature); err == nil {
t.Fatal("expected tampered activation to fail")
}
}
@@ -0,0 +1,186 @@
package clusterauth
import (
"crypto/ed25519"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
)
const (
AuthoritySchemaVersion = "rap.cluster_authority.v1"
SignatureSchemaVersion = "rap.cluster_authority.signature.v1"
AlgorithmEd25519 = "ed25519"
)
var (
ErrInvalidKey = errors.New("invalid cluster authority key")
ErrInvalidSignature = errors.New("invalid cluster authority signature")
ErrInvalidPayload = errors.New("invalid cluster authority payload")
)
type KeyPair struct {
PublicKeyB64 string
PrivateKeyB64 string
Fingerprint string
}
type Signature struct {
SchemaVersion string `json:"schema_version"`
Algorithm string `json:"algorithm"`
KeyFingerprint string `json:"key_fingerprint"`
Signature string `json:"signature"`
SignedAt time.Time `json:"signed_at"`
}
func GenerateKeyPair() (KeyPair, error) {
publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return KeyPair{}, err
}
fingerprint := Fingerprint(publicKey)
return KeyPair{
PublicKeyB64: base64.StdEncoding.EncodeToString(publicKey),
PrivateKeyB64: base64.StdEncoding.EncodeToString(privateKey),
Fingerprint: fingerprint,
}, nil
}
func Fingerprint(publicKey ed25519.PublicKey) string {
sum := sha256.Sum256(publicKey)
return "rap-ca-ed25519-" + hex.EncodeToString(sum[:16])
}
func FingerprintFromBase64(publicKeyB64 string) (string, error) {
publicKey, err := DecodePublicKey(publicKeyB64)
if err != nil {
return "", err
}
return Fingerprint(publicKey), nil
}
func SignRaw(privateKeyB64 string, payload json.RawMessage, signedAt time.Time) (Signature, error) {
privateKey, err := DecodePrivateKey(privateKeyB64)
if err != nil {
return Signature{}, err
}
canonical, err := CanonicalJSON(payload)
if err != nil {
return Signature{}, err
}
publicKey, ok := privateKey.Public().(ed25519.PublicKey)
if !ok {
return Signature{}, ErrInvalidKey
}
signature := ed25519.Sign(privateKey, canonical)
return Signature{
SchemaVersion: SignatureSchemaVersion,
Algorithm: AlgorithmEd25519,
KeyFingerprint: Fingerprint(publicKey),
Signature: base64.StdEncoding.EncodeToString(signature),
SignedAt: signedAt.UTC(),
}, nil
}
func SignPayload(privateKeyB64 string, payload any, signedAt time.Time) (json.RawMessage, Signature, error) {
raw, err := json.Marshal(payload)
if err != nil {
return nil, Signature{}, fmt.Errorf("%w: marshal: %v", ErrInvalidPayload, err)
}
signature, err := SignRaw(privateKeyB64, raw, signedAt)
if err != nil {
return nil, Signature{}, err
}
return json.RawMessage(raw), signature, nil
}
func VerifyRaw(publicKeyB64 string, payload json.RawMessage, signature Signature) error {
if signature.SchemaVersion != SignatureSchemaVersion {
return fmt.Errorf("%w: schema_version must be %s", ErrInvalidSignature, SignatureSchemaVersion)
}
if signature.Algorithm != AlgorithmEd25519 {
return fmt.Errorf("%w: algorithm must be %s", ErrInvalidSignature, AlgorithmEd25519)
}
publicKey, err := DecodePublicKey(publicKeyB64)
if err != nil {
return err
}
if signature.KeyFingerprint != Fingerprint(publicKey) {
return fmt.Errorf("%w: key fingerprint mismatch", ErrInvalidSignature)
}
canonical, err := CanonicalJSON(payload)
if err != nil {
return err
}
decodedSignature, err := decodeBase64(strings.TrimSpace(signature.Signature))
if err != nil || len(decodedSignature) != ed25519.SignatureSize {
return fmt.Errorf("%w: signature must be base64 ed25519 signature", ErrInvalidSignature)
}
if !ed25519.Verify(publicKey, canonical, decodedSignature) {
return ErrInvalidSignature
}
return nil
}
func CanonicalJSON(raw json.RawMessage) ([]byte, error) {
if len(raw) == 0 {
return nil, fmt.Errorf("%w: empty payload", ErrInvalidPayload)
}
var value any
if err := json.Unmarshal(raw, &value); err != nil {
return nil, fmt.Errorf("%w: invalid json: %v", ErrInvalidPayload, err)
}
canonical, err := json.Marshal(value)
if err != nil {
return nil, fmt.Errorf("%w: canonical json: %v", ErrInvalidPayload, err)
}
return canonical, nil
}
func HashRaw(raw json.RawMessage) (string, error) {
canonical, err := CanonicalJSON(raw)
if err != nil {
return "", err
}
sum := sha256.Sum256(canonical)
return hex.EncodeToString(sum[:]), nil
}
func DecodePublicKey(value string) (ed25519.PublicKey, error) {
decoded, err := decodeBase64(strings.TrimSpace(value))
if err != nil {
return nil, fmt.Errorf("%w: public key must be base64 encoded", ErrInvalidKey)
}
if len(decoded) != ed25519.PublicKeySize {
return nil, fmt.Errorf("%w: public key must decode to %d bytes", ErrInvalidKey, ed25519.PublicKeySize)
}
return ed25519.PublicKey(decoded), nil
}
func DecodePrivateKey(value string) (ed25519.PrivateKey, error) {
decoded, err := decodeBase64(strings.TrimSpace(value))
if err != nil {
return nil, fmt.Errorf("%w: private key must be base64 encoded", ErrInvalidKey)
}
if len(decoded) != ed25519.PrivateKeySize {
return nil, fmt.Errorf("%w: private key must decode to %d bytes", ErrInvalidKey, ed25519.PrivateKeySize)
}
return ed25519.PrivateKey(decoded), nil
}
func decodeBase64(value string) ([]byte, error) {
if value == "" {
return nil, errors.New("empty base64 value")
}
decoded, err := base64.StdEncoding.DecodeString(value)
if err == nil {
return decoded, nil
}
return base64.RawStdEncoding.DecodeString(value)
}
@@ -0,0 +1,44 @@
package clusterauth
import (
"encoding/json"
"errors"
"testing"
"time"
)
func TestSignAndVerifyRawPayload(t *testing.T) {
keys, err := GenerateKeyPair()
if err != nil {
t.Fatalf("GenerateKeyPair: %v", err)
}
payload := json.RawMessage(`{"cluster_id":"cluster-1","schema_version":"test.v1","value":1}`)
signature, err := SignRaw(keys.PrivateKeyB64, payload, time.Date(2026, 4, 28, 12, 0, 0, 0, time.UTC))
if err != nil {
t.Fatalf("SignRaw: %v", err)
}
if signature.KeyFingerprint != keys.Fingerprint {
t.Fatalf("fingerprint = %q, want %q", signature.KeyFingerprint, keys.Fingerprint)
}
if err := VerifyRaw(keys.PublicKeyB64, payload, signature); err != nil {
t.Fatalf("VerifyRaw: %v", err)
}
}
func TestVerifyRawRejectsTamperedPayload(t *testing.T) {
keys, err := GenerateKeyPair()
if err != nil {
t.Fatalf("GenerateKeyPair: %v", err)
}
payload := json.RawMessage(`{"cluster_id":"cluster-1","schema_version":"test.v1","value":1}`)
signature, err := SignRaw(keys.PrivateKeyB64, payload, time.Date(2026, 4, 28, 12, 0, 0, 0, time.UTC))
if err != nil {
t.Fatalf("SignRaw: %v", err)
}
tampered := json.RawMessage(`{"cluster_id":"cluster-1","schema_version":"test.v1","value":2}`)
if err := VerifyRaw(keys.PublicKeyB64, tampered, signature); !errors.Is(err, ErrInvalidSignature) {
t.Fatalf("err = %v, want ErrInvalidSignature", err)
}
}
+307
View File
@@ -0,0 +1,307 @@
package config
import (
"encoding/base64"
"fmt"
"os"
"strconv"
"strings"
"time"
)
type Config struct {
App AppConfig
HTTP HTTPConfig
Postgres PostgresConfig
Redis RedisConfig
Auth AuthConfig
Installation InstallationConfig
DataPlane DataPlaneConfig
Secret SecretConfig
Session SessionConfig
Worker WorkerConfig
WebSocket WebSocketConfig
}
type AppConfig struct {
Name string
Env string
}
type HTTPConfig struct {
Host string
Port int
ReadTimeout time.Duration
WriteTimeout time.Duration
IdleTimeout time.Duration
ShutdownTimeout time.Duration
}
type PostgresConfig struct {
DSN string
MaxConns int32
MinConns int32
ConnectTimeout time.Duration
}
type RedisConfig struct {
Addr string
Password string
DB int
DialTimeout time.Duration
}
type AuthConfig struct {
AccessTokenTTL time.Duration
RefreshTokenTTL time.Duration
Issuer string
AccessTokenSecret string
RefreshHashSecret string
}
type InstallationConfig struct {
AuthorityMode string
ProductRootPublicKeyBase64 string
ProductRootPublicKeyFile string
AllowInsecureBootstrap bool
}
type DataPlaneConfig struct {
TokenTTL time.Duration
TokenPrivateKeyPEM string
TokenPrivateKeyFile string
BackendGatewayURL string
DirectWorkerWSSURLTemplate string
DirectWorkerJSONRuntime bool
DirectWorkerBinaryRender bool
DirectWorkerTLSTrustMode string
DirectWorkerTLSCARef string
}
type SecretConfig struct {
EncryptionKeyBase64 string
EncryptionKeyFile string
EncryptionKeyID string
}
type SessionConfig struct {
HeartbeatTTL time.Duration
DetachGracePeriod time.Duration
AttachTokenTTL time.Duration
LiveStateTTL time.Duration
RecoveryBatchSize int
}
type WorkerConfig struct {
LeaseTTL time.Duration
HeartbeatTTL time.Duration
StaleLeaseGracePeriod time.Duration
}
type WebSocketConfig struct {
WriteTimeout time.Duration
PingInterval time.Duration
PongWait time.Duration
}
func Load() (Config, error) {
cfg := Config{
App: AppConfig{
Name: getEnv("APP_NAME", "rap-api"),
Env: getEnv("APP_ENV", "development"),
},
HTTP: HTTPConfig{
Host: getEnv("HTTP_HOST", "0.0.0.0"),
Port: getInt("HTTP_PORT", 8080),
ReadTimeout: getDuration("HTTP_READ_TIMEOUT", 15*time.Second),
WriteTimeout: getDuration("HTTP_WRITE_TIMEOUT", 15*time.Second),
IdleTimeout: getDuration("HTTP_IDLE_TIMEOUT", 60*time.Second),
ShutdownTimeout: getDuration("HTTP_SHUTDOWN_TIMEOUT", 10*time.Second),
},
Postgres: PostgresConfig{
DSN: getEnv("POSTGRES_DSN", ""),
MaxConns: int32(getInt("POSTGRES_MAX_CONNS", 20)),
MinConns: int32(getInt("POSTGRES_MIN_CONNS", 2)),
ConnectTimeout: getDuration("POSTGRES_CONNECT_TIMEOUT", 5*time.Second),
},
Redis: RedisConfig{
Addr: getEnv("REDIS_ADDR", "localhost:6379"),
Password: getEnv("REDIS_PASSWORD", ""),
DB: getInt("REDIS_DB", 0),
DialTimeout: getDuration("REDIS_DIAL_TIMEOUT", 5*time.Second),
},
Auth: AuthConfig{
AccessTokenTTL: getDuration("AUTH_ACCESS_TOKEN_TTL", 15*time.Minute),
RefreshTokenTTL: getDuration("AUTH_REFRESH_TOKEN_TTL", 30*24*time.Hour),
Issuer: getEnv("AUTH_ISSUER", "rap-api"),
AccessTokenSecret: getEnv("AUTH_ACCESS_TOKEN_SECRET", ""),
RefreshHashSecret: getEnv("AUTH_REFRESH_HASH_SECRET", ""),
},
Installation: InstallationConfig{
AuthorityMode: getEnv("INSTALLATION_AUTHORITY_MODE", ""),
ProductRootPublicKeyBase64: getEnv("INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_B64", ""),
ProductRootPublicKeyFile: getEnv("INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_FILE", ""),
AllowInsecureBootstrap: getBool("INSTALLATION_INSECURE_BOOTSTRAP_ENABLED", false),
},
DataPlane: DataPlaneConfig{
TokenTTL: getDuration("DATA_PLANE_TOKEN_TTL", 1*time.Minute),
TokenPrivateKeyPEM: getEnv("DATA_PLANE_TOKEN_PRIVATE_KEY_PEM", ""),
TokenPrivateKeyFile: getEnv("DATA_PLANE_TOKEN_PRIVATE_KEY_FILE", ""),
BackendGatewayURL: getEnv("DATA_PLANE_BACKEND_GATEWAY_URL", "/api/v1/gateway/ws"),
DirectWorkerWSSURLTemplate: getEnv("DATA_PLANE_DIRECT_WORKER_WSS_URL_TEMPLATE", ""),
DirectWorkerJSONRuntime: getBool("DATA_PLANE_DIRECT_WORKER_JSON_RUNTIME", false),
DirectWorkerBinaryRender: getBool("DATA_PLANE_DIRECT_WORKER_BINARY_RENDER", false),
DirectWorkerTLSTrustMode: getEnv("DATA_PLANE_DIRECT_WORKER_TLS_TRUST_MODE", "smoke_insecure"),
DirectWorkerTLSCARef: getEnv("DATA_PLANE_DIRECT_WORKER_TLS_CA_REF", ""),
},
Secret: SecretConfig{
EncryptionKeyBase64: getEnv("SECRET_ENCRYPTION_KEY_B64", ""),
EncryptionKeyFile: getEnv("SECRET_ENCRYPTION_KEY_FILE", ""),
EncryptionKeyID: getEnv("SECRET_ENCRYPTION_KEY_ID", "local-v1"),
},
Session: SessionConfig{
HeartbeatTTL: getDuration("SESSION_HEARTBEAT_TTL", 90*time.Second),
DetachGracePeriod: getDuration("SESSION_DETACH_GRACE_PERIOD", 30*time.Minute),
AttachTokenTTL: getDuration("SESSION_ATTACH_TOKEN_TTL", 2*time.Minute),
LiveStateTTL: getDuration("SESSION_LIVE_STATE_TTL", 2*time.Minute),
RecoveryBatchSize: getInt("SESSION_RECOVERY_BATCH_SIZE", 100),
},
Worker: WorkerConfig{
LeaseTTL: getDuration("WORKER_LEASE_TTL", 45*time.Second),
HeartbeatTTL: getDuration("WORKER_HEARTBEAT_TTL", 15*time.Second),
StaleLeaseGracePeriod: getDuration("WORKER_STALE_LEASE_GRACE_PERIOD", 30*time.Second),
},
WebSocket: WebSocketConfig{
WriteTimeout: getDuration("WEBSOCKET_WRITE_TIMEOUT", 10*time.Second),
PingInterval: getDuration("WEBSOCKET_PING_INTERVAL", 20*time.Second),
PongWait: getDuration("WEBSOCKET_PONG_WAIT", 40*time.Second),
},
}
if cfg.Postgres.DSN == "" {
return Config{}, fmt.Errorf("POSTGRES_DSN is required")
}
if cfg.Auth.AccessTokenSecret == "" {
return Config{}, fmt.Errorf("AUTH_ACCESS_TOKEN_SECRET is required")
}
if cfg.Auth.RefreshHashSecret == "" {
return Config{}, fmt.Errorf("AUTH_REFRESH_HASH_SECRET is required")
}
if cfg.Installation.ProductRootPublicKeyBase64 == "" && cfg.Installation.ProductRootPublicKeyFile != "" {
publicKey, err := os.ReadFile(cfg.Installation.ProductRootPublicKeyFile)
if err != nil {
return Config{}, fmt.Errorf("read INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_FILE: %w", err)
}
cfg.Installation.ProductRootPublicKeyBase64 = strings.TrimSpace(string(publicKey))
}
cfg.Installation.AuthorityMode = normalizeInstallationAuthorityMode(cfg.Installation.AuthorityMode, cfg.Installation.ProductRootPublicKeyBase64)
if isProductionEnv(cfg.App.Env) && cfg.Installation.AuthorityMode != "strict" {
return Config{}, fmt.Errorf("INSTALLATION_AUTHORITY_MODE=strict with INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_B64 or file is required in production")
}
if cfg.DataPlane.TokenPrivateKeyPEM == "" && cfg.DataPlane.TokenPrivateKeyFile != "" {
privateKey, err := os.ReadFile(cfg.DataPlane.TokenPrivateKeyFile)
if err != nil {
return Config{}, fmt.Errorf("read DATA_PLANE_TOKEN_PRIVATE_KEY_FILE: %w", err)
}
cfg.DataPlane.TokenPrivateKeyPEM = string(privateKey)
}
if cfg.Secret.EncryptionKeyBase64 == "" && cfg.Secret.EncryptionKeyFile != "" {
secretKey, err := os.ReadFile(cfg.Secret.EncryptionKeyFile)
if err != nil {
return Config{}, fmt.Errorf("read SECRET_ENCRYPTION_KEY_FILE: %w", err)
}
cfg.Secret.EncryptionKeyBase64 = strings.TrimSpace(string(secretKey))
}
if cfg.Secret.EncryptionKeyBase64 != "" {
decoded, err := base64.StdEncoding.DecodeString(cfg.Secret.EncryptionKeyBase64)
if err != nil {
if decodedRaw, rawErr := base64.RawStdEncoding.DecodeString(cfg.Secret.EncryptionKeyBase64); rawErr == nil {
decoded = decodedRaw
} else {
return Config{}, fmt.Errorf("SECRET_ENCRYPTION_KEY_B64 must be base64 encoded: %w", err)
}
}
if len(decoded) != 32 {
return Config{}, fmt.Errorf("SECRET_ENCRYPTION_KEY_B64 must decode to 32 bytes for AES-256-GCM")
}
}
if isProductionEnv(cfg.App.Env) && cfg.Secret.EncryptionKeyBase64 == "" {
return Config{}, fmt.Errorf("SECRET_ENCRYPTION_KEY_B64 or SECRET_ENCRYPTION_KEY_FILE is required in production")
}
return cfg, nil
}
func normalizeInstallationAuthorityMode(mode string, rootPublicKey string) string {
mode = strings.ToLower(strings.TrimSpace(mode))
switch mode {
case "strict", "legacy":
return mode
case "":
if strings.TrimSpace(rootPublicKey) != "" {
return "strict"
}
return "legacy"
default:
return mode
}
}
func isProductionEnv(appEnv string) bool {
switch strings.ToLower(strings.TrimSpace(appEnv)) {
case "production", "prod":
return true
default:
return false
}
}
func getEnv(key, fallback string) string {
if value := os.Getenv(key); value != "" {
return value
}
return fallback
}
func getInt(key string, fallback int) int {
value := os.Getenv(key)
if value == "" {
return fallback
}
parsed, err := strconv.Atoi(value)
if err != nil {
return fallback
}
return parsed
}
func getBool(key string, fallback bool) bool {
value := os.Getenv(key)
if value == "" {
return fallback
}
switch value {
case "1", "true", "TRUE", "yes", "on":
return true
case "0", "false", "FALSE", "no", "off":
return false
default:
return fallback
}
}
func getDuration(key string, fallback time.Duration) time.Duration {
value := os.Getenv(key)
if value == "" {
return fallback
}
parsed, err := time.ParseDuration(value)
if err != nil {
return fallback
}
return parsed
}
@@ -0,0 +1,20 @@
package httpserver
import (
"fmt"
"net/http"
"time"
"github.com/example/remote-access-platform/backend/internal/platform/config"
)
func New(cfg config.HTTPConfig, handler http.Handler) *http.Server {
return &http.Server{
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
Handler: handler,
ReadHeaderTimeout: 5 * time.Second,
ReadTimeout: cfg.ReadTimeout,
WriteTimeout: cfg.WriteTimeout,
IdleTimeout: cfg.IdleTimeout,
}
}
+45
View File
@@ -0,0 +1,45 @@
package httpx
import (
"encoding/json"
"net/http"
)
func WriteJSON(w http.ResponseWriter, status int, payload any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(payload)
}
func WriteError(w http.ResponseWriter, status int, message string) {
traceID := ensureTraceID(w)
WriteJSON(w, status, ErrorResponse{
Error: NewErrorMessage(status, message, nil, traceID),
})
}
func WriteErrorMessage(w http.ResponseWriter, status int, message any) {
traceID := ensureTraceID(w)
switch payload := message.(type) {
case string:
WriteJSON(w, status, ErrorResponse{
Error: NewErrorMessage(status, payload, nil, traceID),
})
case ErrorResponse:
payload.Error.TraceID = traceID
WriteJSON(w, status, payload)
case *ErrorResponse:
if payload == nil {
WriteJSON(w, status, ErrorResponse{
Error: NewErrorMessage(status, "", nil, traceID),
})
return
}
payload.Error.TraceID = traceID
WriteJSON(w, status, payload)
default:
WriteJSON(w, status, ErrorResponse{
Error: NewErrorMessage(status, "Request failed.", nil, traceID),
})
}
}
+131
View File
@@ -0,0 +1,131 @@
package httpx
import (
"net/http"
"strings"
"unicode"
"github.com/google/uuid"
messagecontracts "github.com/example/remote-access-platform/backend/pkg/contracts/message"
)
type ErrorResponse struct {
Error messagecontracts.Message `json:"error"`
}
func NewMessage(code, messageKey, fallbackMessage string, details map[string]any, traceID string) messagecontracts.Message {
if traceID == "" {
traceID = uuid.NewString()
}
if details == nil {
details = map[string]any{}
}
return messagecontracts.Message{
Code: code,
MessageKey: messageKey,
FallbackMessage: fallbackMessage,
Details: details,
TraceID: traceID,
}
}
func NewErrorMessage(status int, fallbackMessage string, details map[string]any, traceID string) messagecontracts.Message {
normalizedFallback, normalizedDetails := normalizeErrorFallback(status, fallbackMessage, details)
code := deriveErrorCode(status, normalizedFallback)
return NewMessage(code, "errors."+code, normalizedFallback, normalizedDetails, traceID)
}
func ensureTraceID(w http.ResponseWriter) string {
traceID := w.Header().Get("X-Trace-Id")
if traceID == "" {
traceID = uuid.NewString()
w.Header().Set("X-Trace-Id", traceID)
}
return traceID
}
func normalizeErrorFallback(status int, fallbackMessage string, details map[string]any) (string, map[string]any) {
if details == nil {
details = map[string]any{}
}
details["http_status"] = status
if status >= http.StatusInternalServerError {
return "An internal server error occurred.", details
}
trimmed := strings.TrimSpace(fallbackMessage)
switch strings.ToLower(trimmed) {
case "forbidden", "access denied":
return "Access denied.", details
}
if field, ok := extractRequiredField(trimmed); ok {
details["field"] = field
}
return trimmed, details
}
func deriveErrorCode(status int, fallbackMessage string) string {
switch strings.ToLower(strings.TrimSpace(fallbackMessage)) {
case "invalid credentials":
return "auth.invalid_credentials"
case "session expired. please sign in again.":
return "auth.session_expired"
case "access denied.":
return "common.access_denied"
}
statusPrefix := map[int]string{
http.StatusBadRequest: "bad_request",
http.StatusUnauthorized: "unauthorized",
http.StatusForbidden: "forbidden",
http.StatusNotFound: "not_found",
http.StatusConflict: "conflict",
http.StatusUnprocessableEntity: "unprocessable_entity",
http.StatusInternalServerError: "internal_server_error",
}[status]
if statusPrefix == "" {
statusPrefix = "http_" + strings.ReplaceAll(http.StatusText(status), " ", "_")
statusPrefix = strings.ToLower(statusPrefix)
}
slug := slugifyMessage(fallbackMessage)
if slug == "" {
slug = "message"
}
if status >= http.StatusInternalServerError {
return "common." + statusPrefix
}
return statusPrefix + "." + slug
}
func slugifyMessage(input string) string {
var builder strings.Builder
lastUnderscore := false
for _, r := range strings.ToLower(strings.TrimSpace(input)) {
if unicode.IsLetter(r) || unicode.IsDigit(r) {
builder.WriteRune(r)
lastUnderscore = false
continue
}
if !lastUnderscore {
builder.WriteRune('_')
lastUnderscore = true
}
}
return strings.Trim(builder.String(), "_")
}
func extractRequiredField(message string) (string, bool) {
const suffix = " is required"
if !strings.HasSuffix(strings.ToLower(message), suffix) {
return "", false
}
field := strings.TrimSpace(message[:len(message)-len(suffix)])
field = strings.ReplaceAll(field, " ", "_")
field = strings.ToLower(field)
return field, field != ""
}
@@ -0,0 +1,17 @@
package logging
import (
"log/slog"
"os"
)
func New(env string) *slog.Logger {
level := slog.LevelInfo
if env == "development" {
level = slog.LevelDebug
}
return slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
Level: level,
}))
}
@@ -0,0 +1,38 @@
package module
import (
"log/slog"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/redis/go-redis/v9"
"github.com/example/remote-access-platform/backend/internal/platform/config"
)
type Dependencies struct {
Config Config
Infra Infra
}
type Config struct {
App config.AppConfig
Auth config.AuthConfig
Installation config.InstallationConfig
DataPlane config.DataPlaneConfig
Secret config.SecretConfig
Session config.SessionConfig
Worker config.WorkerConfig
WebSocket config.WebSocketConfig
}
type Infra struct {
Logger *slog.Logger
DB *pgxpool.Pool
Redis *redis.Client
}
type Module interface {
Name() string
RegisterRoutes(router chi.Router)
}
@@ -0,0 +1,33 @@
package postgres
import (
"context"
"fmt"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/example/remote-access-platform/backend/internal/platform/config"
)
func Open(ctx context.Context, cfg config.PostgresConfig) (*pgxpool.Pool, error) {
poolConfig, err := pgxpool.ParseConfig(cfg.DSN)
if err != nil {
return nil, fmt.Errorf("parse postgres dsn: %w", err)
}
poolConfig.MaxConns = cfg.MaxConns
poolConfig.MinConns = cfg.MinConns
poolConfig.ConnConfig.ConnectTimeout = cfg.ConnectTimeout
pool, err := pgxpool.NewWithConfig(ctx, poolConfig)
if err != nil {
return nil, fmt.Errorf("create postgres pool: %w", err)
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
return nil, fmt.Errorf("ping postgres: %w", err)
}
return pool, nil
}
+34
View File
@@ -0,0 +1,34 @@
package postgres
import (
"context"
"fmt"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
)
type DBTX interface {
Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
}
func WithTransaction(ctx context.Context, pool *pgxpool.Pool, fn func(tx pgx.Tx) error) error {
tx, err := pool.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
return fmt.Errorf("begin transaction: %w", err)
}
if err := fn(tx); err != nil {
_ = tx.Rollback(ctx)
return err
}
if err := tx.Commit(ctx); err != nil {
return fmt.Errorf("commit transaction: %w", err)
}
return nil
}
+26
View File
@@ -0,0 +1,26 @@
package redis
import (
"context"
"fmt"
goredis "github.com/redis/go-redis/v9"
"github.com/example/remote-access-platform/backend/internal/platform/config"
)
func Open(ctx context.Context, cfg config.RedisConfig) (*goredis.Client, error) {
client := goredis.NewClient(&goredis.Options{
Addr: cfg.Addr,
Password: cfg.Password,
DB: cfg.DB,
DialTimeout: cfg.DialTimeout,
})
if err := client.Ping(ctx).Err(); err != nil {
_ = client.Close()
return nil, fmt.Errorf("ping redis: %w", err)
}
return client, nil
}
+220
View File
@@ -0,0 +1,220 @@
package runtime
import (
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"time"
"github.com/go-chi/chi/v5"
chimiddleware "github.com/go-chi/chi/v5/middleware"
"github.com/example/remote-access-platform/backend/internal/modules/auth"
"github.com/example/remote-access-platform/backend/internal/modules/cluster"
"github.com/example/remote-access-platform/backend/internal/modules/identitysource"
"github.com/example/remote-access-platform/backend/internal/modules/node"
"github.com/example/remote-access-platform/backend/internal/modules/nodeagent"
"github.com/example/remote-access-platform/backend/internal/modules/organization"
"github.com/example/remote-access-platform/backend/internal/modules/resource"
"github.com/example/remote-access-platform/backend/internal/modules/sessionbroker"
"github.com/example/remote-access-platform/backend/internal/modules/sessiongateway"
"github.com/example/remote-access-platform/backend/internal/modules/worker"
"github.com/example/remote-access-platform/backend/internal/platform/authority"
"github.com/example/remote-access-platform/backend/internal/platform/config"
"github.com/example/remote-access-platform/backend/internal/platform/httpserver"
"github.com/example/remote-access-platform/backend/internal/platform/logging"
"github.com/example/remote-access-platform/backend/internal/platform/module"
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
redisplatform "github.com/example/remote-access-platform/backend/internal/platform/redis"
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
)
type App struct {
cfg config.Config
logger *slog.Logger
httpServer *http.Server
workers []backgroundRunner
db closeFunc
redis closeFunc
}
type closeFunc func() error
type backgroundRunner func(context.Context) error
func NewApp(ctx context.Context) (*App, error) {
cfg, err := config.Load()
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
logger := logging.New(cfg.App.Env)
db, err := postgresplatform.Open(ctx, cfg.Postgres)
if err != nil {
return nil, err
}
redisClient, err := redisplatform.Open(ctx, cfg.Redis)
if err != nil {
db.Close()
return nil, err
}
authorityVerifier, err := authority.NewVerifier(cfg.Installation)
if err != nil {
redisClient.Close()
db.Close()
return nil, fmt.Errorf("create installation authority verifier: %w", err)
}
deps := module.Dependencies{
Config: module.Config{
App: cfg.App,
Auth: cfg.Auth,
Installation: cfg.Installation,
DataPlane: cfg.DataPlane,
Secret: cfg.Secret,
Session: cfg.Session,
Worker: cfg.Worker,
WebSocket: cfg.WebSocket,
},
Infra: module.Infra{
Logger: logger,
DB: db,
Redis: redisClient,
},
}
workerStore := worker.NewRedisStore(redisClient)
workerService := worker.NewService(deps, workerStore)
authStore := auth.NewPostgresStore(db)
authTx := auth.NewPostgresTransactor(db)
authService := auth.NewService(deps, authStore, authTx, authorityVerifier)
var resourceSecretStore *secrets.ResourceSecretStore
if cfg.Secret.EncryptionKeyBase64 != "" {
secretEncryptor, err := secrets.NewEncryptor(cfg.Secret.EncryptionKeyBase64, cfg.Secret.EncryptionKeyID)
if err != nil {
redisClient.Close()
db.Close()
return nil, fmt.Errorf("create resource secret encryptor: %w", err)
}
resourceSecretStore = secrets.NewResourceSecretStore(db, secretEncryptor)
}
brokerStore := sessionbroker.NewPostgresStore(db, authorityVerifier)
brokerTx := sessionbroker.NewPostgresTransactor(db, authorityVerifier)
liveStateStore := sessionbroker.NewRedisLiveStateStore(redisClient)
brokerService := sessionbroker.NewService(deps, brokerStore, brokerTx, liveStateStore, workerService, resourceSecretStore)
workerEvents := worker.NewEventProcessor(redisClient, brokerService)
leaseMonitor := worker.NewLeaseMonitor(workerService, brokerService, cfg.Worker.StaleLeaseGracePeriod)
brokerModule := sessionbroker.NewModule(brokerService)
authModule := auth.NewModule(deps, authService)
clusterModule := cluster.NewModule(deps, authorityVerifier)
organizationModule := organization.NewModule(deps)
identitySourceModule := identitysource.NewModule(deps)
resourceModule := resource.NewModule(deps, resourceSecretStore)
nodeModule := node.NewModule(deps)
nodeAgentModule := nodeagent.NewModule(deps)
sessionGatewayModule := sessiongateway.NewModule(deps, brokerModule.Service(), workerService)
router := buildRouter(
logger,
authModule,
clusterModule,
organizationModule,
identitySourceModule,
resourceModule,
brokerModule,
nodeModule,
nodeAgentModule,
sessionGatewayModule,
)
return &App{
cfg: cfg,
logger: logger,
httpServer: httpserver.New(cfg.HTTP, router),
workers: []backgroundRunner{workerEvents.Run, leaseMonitor.Run},
db: func() error {
db.Close()
return nil
},
redis: redisClient.Close,
}, nil
}
func (a *App) Run(ctx context.Context) error {
errCh := make(chan error, 1)
go func() {
a.logger.Info("http server starting", "addr", a.httpServer.Addr, "service", a.cfg.App.Name)
if err := a.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
errCh <- err
return
}
errCh <- nil
}()
for _, runner := range a.workers {
runner := runner
go func() {
if err := runner(ctx); err != nil {
errCh <- err
}
}()
}
select {
case <-ctx.Done():
a.logger.Info("shutdown signal received")
case err := <-errCh:
if err != nil {
return err
}
return nil
}
shutdownCtx, cancel := context.WithTimeout(context.Background(), a.cfg.HTTP.ShutdownTimeout)
defer cancel()
if err := a.httpServer.Shutdown(shutdownCtx); err != nil {
return fmt.Errorf("shutdown http server: %w", err)
}
if err := a.redis(); err != nil {
return fmt.Errorf("close redis: %w", err)
}
if err := a.db(); err != nil {
return fmt.Errorf("close postgres: %w", err)
}
a.logger.Info("app stopped", "at", time.Now().UTC())
return nil
}
func buildRouter(logger *slog.Logger, modules ...module.Module) http.Handler {
router := chi.NewRouter()
router.Use(chimiddleware.RequestID)
router.Use(chimiddleware.RealIP)
router.Use(chimiddleware.Recoverer)
router.Use(chimiddleware.Timeout(60 * time.Second))
router.Use(chimiddleware.Heartbeat("/healthz"))
router.Get("/readyz", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ready"))
})
router.Route("/api/v1", func(r chi.Router) {
for _, mod := range modules {
logger.Info("register module routes", "module", mod.Name())
mod.RegisterRoutes(r)
}
})
return router
}
@@ -0,0 +1,37 @@
package secrets
import (
"encoding/json"
"fmt"
)
type AssignmentSecretMergeResult struct {
Metadata map[string]any
Keys []string
}
func MergeResourceSecretIntoAssignmentMetadata(metadata map[string]any, payload json.RawMessage) (AssignmentSecretMergeResult, error) {
if metadata == nil {
metadata = map[string]any{}
}
var secretPayload map[string]any
if err := json.Unmarshal(payload, &secretPayload); err != nil {
return AssignmentSecretMergeResult{}, fmt.Errorf("decode resolved resource secret: %w", err)
}
resource, _ := metadata["resource"].(map[string]any)
if resource == nil {
resource = map[string]any{}
metadata["resource"] = resource
}
resourceMetadata, _ := resource["metadata"].(map[string]any)
if resourceMetadata == nil {
resourceMetadata = map[string]any{}
resource["metadata"] = resourceMetadata
}
keys := make([]string, 0, len(secretPayload))
for key, value := range secretPayload {
resourceMetadata[key] = value
keys = append(keys, key)
}
return AssignmentSecretMergeResult{Metadata: metadata, Keys: keys}, nil
}
@@ -0,0 +1,113 @@
package secrets
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"io"
"strings"
)
const AlgorithmAES256GCM = "AES-256-GCM"
var (
ErrSecretEncryptionKeyMissing = errors.New("secret encryption key is not configured")
ErrSecretPayloadInvalid = errors.New("secret payload must be a json object")
)
type Encryptor struct {
aead cipher.AEAD
keyID string
}
type EncryptedPayload struct {
Algorithm string
KeyID string
Nonce []byte
Ciphertext []byte
PayloadSHA256 string
}
func NewEncryptor(masterKeyBase64, keyID string) (*Encryptor, error) {
masterKeyBase64 = strings.TrimSpace(masterKeyBase64)
if masterKeyBase64 == "" {
return nil, ErrSecretEncryptionKeyMissing
}
key, err := base64.StdEncoding.DecodeString(masterKeyBase64)
if err != nil {
if rawKey, rawErr := base64.RawStdEncoding.DecodeString(masterKeyBase64); rawErr == nil {
key = rawKey
} else {
return nil, fmt.Errorf("decode secret encryption key: %w", err)
}
}
if len(key) != 32 {
return nil, fmt.Errorf("secret encryption key must decode to 32 bytes")
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("create secret cipher: %w", err)
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("create secret gcm: %w", err)
}
if strings.TrimSpace(keyID) == "" {
keyID = "local-v1"
}
return &Encryptor{aead: aead, keyID: keyID}, nil
}
func (e *Encryptor) KeyID() string {
if e == nil {
return ""
}
return e.keyID
}
func (e *Encryptor) Encrypt(plaintext, aad []byte) (EncryptedPayload, error) {
if e == nil {
return EncryptedPayload{}, ErrSecretEncryptionKeyMissing
}
nonce := make([]byte, e.aead.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return EncryptedPayload{}, fmt.Errorf("generate secret nonce: %w", err)
}
hash := sha256.Sum256(plaintext)
return EncryptedPayload{
Algorithm: AlgorithmAES256GCM,
KeyID: e.keyID,
Nonce: nonce,
Ciphertext: e.aead.Seal(nil, nonce, plaintext, aad),
PayloadSHA256: hex.EncodeToString(hash[:]),
}, nil
}
func (e *Encryptor) Decrypt(payload EncryptedPayload, aad []byte) ([]byte, error) {
if e == nil {
return nil, ErrSecretEncryptionKeyMissing
}
if payload.Algorithm != "" && payload.Algorithm != AlgorithmAES256GCM {
return nil, fmt.Errorf("unsupported secret algorithm %q", payload.Algorithm)
}
plaintext, err := e.aead.Open(nil, payload.Nonce, payload.Ciphertext, aad)
if err != nil {
return nil, fmt.Errorf("decrypt secret payload: %w", err)
}
return plaintext, nil
}
func ResourceSecretAAD(organizationID, resourceID, secretRef, protocol string) []byte {
return []byte(strings.Join([]string{
"rap-resource-secret-v1",
strings.TrimSpace(organizationID),
strings.TrimSpace(resourceID),
strings.TrimSpace(secretRef),
strings.ToLower(strings.TrimSpace(protocol)),
}, "|"))
}
@@ -0,0 +1,65 @@
package secrets
import (
"encoding/base64"
"encoding/json"
"testing"
)
func TestEncryptorRoundTrip(t *testing.T) {
key := base64.StdEncoding.EncodeToString([]byte("0123456789abcdef0123456789abcdef"))
encryptor, err := NewEncryptor(key, "test-key")
if err != nil {
t.Fatalf("NewEncryptor returned error: %v", err)
}
aad := ResourceSecretAAD("org-1", "resource-1", "rap-secret://test", "rdp")
encrypted, err := encryptor.Encrypt([]byte(`{"username":"user","password":"secret"}`), aad)
if err != nil {
t.Fatalf("Encrypt returned error: %v", err)
}
plaintext, err := encryptor.Decrypt(encrypted, aad)
if err != nil {
t.Fatalf("Decrypt returned error: %v", err)
}
if string(plaintext) != `{"username":"user","password":"secret"}` {
t.Fatalf("unexpected plaintext: %s", plaintext)
}
}
func TestEncryptorRejectsWrongAAD(t *testing.T) {
key := base64.StdEncoding.EncodeToString([]byte("0123456789abcdef0123456789abcdef"))
encryptor, err := NewEncryptor(key, "test-key")
if err != nil {
t.Fatalf("NewEncryptor returned error: %v", err)
}
encrypted, err := encryptor.Encrypt([]byte(`{"password":"secret"}`), ResourceSecretAAD("org-1", "resource-1", "ref", "rdp"))
if err != nil {
t.Fatalf("Encrypt returned error: %v", err)
}
if _, err := encryptor.Decrypt(encrypted, ResourceSecretAAD("org-2", "resource-1", "ref", "rdp")); err == nil {
t.Fatalf("expected decrypt with wrong aad to fail")
}
}
func TestMergeResourceSecretIntoAssignmentMetadata(t *testing.T) {
metadata := map[string]any{
"resource": map[string]any{
"id": "resource-1",
"metadata": map[string]any{
"rdp_host": "host",
},
},
}
merged, err := MergeResourceSecretIntoAssignmentMetadata(metadata, json.RawMessage(`{"username":"user","password":"secret","domain":"corp"}`))
if err != nil {
t.Fatalf("MergeResourceSecretIntoAssignmentMetadata returned error: %v", err)
}
resource := merged.Metadata["resource"].(map[string]any)
resourceMetadata := resource["metadata"].(map[string]any)
if resourceMetadata["rdp_host"] != "host" {
t.Fatalf("existing metadata was not preserved")
}
if resourceMetadata["username"] != "user" || resourceMetadata["password"] != "secret" || resourceMetadata["domain"] != "corp" {
t.Fatalf("secret payload was not merged: %#v", resourceMetadata)
}
}
@@ -0,0 +1,132 @@
package secrets
import (
"encoding/json"
"errors"
"fmt"
"slices"
"sort"
"strings"
)
var (
ErrPlaintextResourceCredentials = errors.New("plaintext resource credentials are not allowed in metadata in production")
ErrMissingResourceSecretRef = errors.New("secret_ref is required for this resource protocol in production")
)
var credentialKeyFragments = []string{
"accesstoken",
"clientsecret",
"credential",
"credentials",
"domain",
"password",
"privatekey",
"refreshtoken",
"secret",
"secrets",
"token",
"user",
"username",
}
var safeReferenceKeys = []string{
"certificateverificationmode",
"renderqualityprofile",
"secretref",
"secretreference",
"vaultref",
}
func ValidateResourceSecretReadiness(protocol string, secretRef *string, metadata json.RawMessage, appEnv string) error {
if !IsProductionEnv(appEnv) {
return nil
}
paths, err := PlaintextCredentialMetadataPaths(metadata)
if err != nil {
return err
}
if len(paths) > 0 {
return fmt.Errorf("%w: %s", ErrPlaintextResourceCredentials, strings.Join(paths, ", "))
}
if ResourceProtocolRequiresSecretRef(protocol) && (secretRef == nil || strings.TrimSpace(*secretRef) == "") {
return ErrMissingResourceSecretRef
}
return nil
}
func IsProductionEnv(appEnv string) bool {
switch strings.ToLower(strings.TrimSpace(appEnv)) {
case "prod", "production":
return true
default:
return false
}
}
func ResourceProtocolRequiresSecretRef(protocol string) bool {
switch strings.ToLower(strings.TrimSpace(protocol)) {
case "rdp", "vnc", "ssh":
return true
default:
return false
}
}
func PlaintextCredentialMetadataPaths(raw json.RawMessage) ([]string, error) {
if len(raw) == 0 {
return nil, nil
}
var value any
if err := json.Unmarshal(raw, &value); err != nil {
return nil, errors.New("metadata must be valid json")
}
metadata, ok := value.(map[string]any)
if !ok {
return nil, errors.New("metadata must be a json object")
}
var paths []string
collectCredentialPaths(metadata, "", &paths)
sort.Strings(paths)
return slices.Compact(paths), nil
}
func collectCredentialPaths(value any, prefix string, paths *[]string) {
switch typed := value.(type) {
case map[string]any:
for key, child := range typed {
path := key
if prefix != "" {
path = prefix + "." + key
}
if isCredentialMetadataKey(key) {
*paths = append(*paths, path)
}
collectCredentialPaths(child, path, paths)
}
case []any:
for index, child := range typed {
collectCredentialPaths(child, fmt.Sprintf("%s[%d]", prefix, index), paths)
}
}
}
func isCredentialMetadataKey(key string) bool {
normalized := normalizeMetadataKey(key)
if slices.Contains(safeReferenceKeys, normalized) {
return false
}
for _, fragment := range credentialKeyFragments {
if normalized == fragment || strings.HasSuffix(normalized, fragment) {
return true
}
}
return false
}
func normalizeMetadataKey(key string) string {
key = strings.ToLower(strings.TrimSpace(key))
replacer := strings.NewReplacer("_", "", "-", "", " ", "", ".", "")
return replacer.Replace(key)
}
@@ -0,0 +1,52 @@
package secrets
import (
"encoding/json"
"errors"
"slices"
"testing"
)
func TestValidateResourceSecretReadinessAllowsPlaintextInDevelopment(t *testing.T) {
metadata := json.RawMessage(`{"username":"m","password":"secret"}`)
if err := ValidateResourceSecretReadiness("rdp", nil, metadata, "development"); err != nil {
t.Fatalf("development metadata should remain allowed for smoke/dev: %v", err)
}
}
func TestValidateResourceSecretReadinessRejectsPlaintextCredentialsInProduction(t *testing.T) {
metadata := json.RawMessage(`{"rdp_host":"host","credentials":{"username":"m","password":"secret"}}`)
err := ValidateResourceSecretReadiness("rdp", stringPtr("vault://org/resource"), metadata, "production")
if !errors.Is(err, ErrPlaintextResourceCredentials) {
t.Fatalf("expected plaintext credential rejection, got %v", err)
}
paths, err := PlaintextCredentialMetadataPaths(metadata)
if err != nil {
t.Fatalf("metadata paths: %v", err)
}
for _, expected := range []string{"credentials", "credentials.password", "credentials.username"} {
if !slices.Contains(paths, expected) {
t.Fatalf("expected sensitive path %q in %v", expected, paths)
}
}
}
func TestValidateResourceSecretReadinessRequiresSecretRefForProductionRDP(t *testing.T) {
metadata := json.RawMessage(`{"rdp_host":"host","rdp_port":3389}`)
err := ValidateResourceSecretReadiness("rdp", nil, metadata, "production")
if !errors.Is(err, ErrMissingResourceSecretRef) {
t.Fatalf("expected missing secret_ref rejection, got %v", err)
}
}
func TestValidateResourceSecretReadinessAllowsProductionSecretRef(t *testing.T) {
metadata := json.RawMessage(`{"rdp_host":"host","rdp_port":3389,"secret_ref":"vault://org/resource"}`)
if err := ValidateResourceSecretReadiness("rdp", stringPtr("vault://org/resource"), metadata, "production"); err != nil {
t.Fatalf("production secret_ref metadata should be accepted: %v", err)
}
}
func stringPtr(value string) *string {
return &value
}
@@ -0,0 +1,259 @@
package secrets
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/jackc/pgx/v5"
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
)
var (
ErrResourceSecretNotFound = errors.New("resource secret not found")
ErrSecretAccessDenied = errors.New("resource secret access denied")
ErrSecretLeaseRequired = errors.New("resource secret resolution requires lease proof")
)
type ResourceSecretStore struct {
db postgresplatform.DBTX
encryptor *Encryptor
now func() time.Time
}
type ResourceSecretResolver interface {
ResolveForSession(ctx context.Context, req ResolveResourceSecretRequest) (*ResolvedResourceSecret, error)
}
type ResourceSecretDescriptor struct {
ID string `json:"id"`
OrganizationID string `json:"organization_id"`
ResourceID string `json:"resource_id"`
SecretRef string `json:"secret_ref"`
Protocol string `json:"protocol"`
Version int `json:"version"`
KeyID string `json:"key_id"`
Algorithm string `json:"algorithm"`
Metadata json.RawMessage `json:"metadata"`
CreatedAt time.Time `json:"created_at"`
RotatedAt *time.Time `json:"rotated_at,omitempty"`
}
type UpsertResourceSecretCommand struct {
OrganizationID string
ResourceID string
Protocol string
SecretRef string
Payload json.RawMessage
Metadata json.RawMessage
ActorUserID string
}
type ResolveResourceSecretRequest struct {
SecretRef string
OrganizationID string
ResourceID string
SessionID string
WorkerID string
LeaseID string
}
type ResolvedResourceSecret struct {
Descriptor ResourceSecretDescriptor
Payload json.RawMessage
}
func NewResourceSecretStore(db postgresplatform.DBTX, encryptor *Encryptor) *ResourceSecretStore {
return &ResourceSecretStore{db: db, encryptor: encryptor, now: time.Now}
}
func (s *ResourceSecretStore) WithDB(db postgresplatform.DBTX) *ResourceSecretStore {
if s == nil {
return nil
}
return &ResourceSecretStore{db: db, encryptor: s.encryptor, now: s.now}
}
func DefaultResourceSecretRef(organizationID, resourceID string) string {
return "rap-secret://org/" + strings.TrimSpace(organizationID) + "/resources/" + strings.TrimSpace(resourceID) + "/primary"
}
func (s *ResourceSecretStore) Upsert(ctx context.Context, cmd UpsertResourceSecretCommand) (*ResourceSecretDescriptor, error) {
if s == nil || s.encryptor == nil {
return nil, ErrSecretEncryptionKeyMissing
}
payload, err := normalizeJSONObject(cmd.Payload)
if err != nil {
return nil, err
}
metadata, err := normalizeJSONObjectAllowEmpty(cmd.Metadata)
if err != nil {
return nil, err
}
secretRef := strings.TrimSpace(cmd.SecretRef)
if secretRef == "" {
secretRef = DefaultResourceSecretRef(cmd.OrganizationID, cmd.ResourceID)
}
protocol := strings.ToLower(strings.TrimSpace(cmd.Protocol))
encrypted, err := s.encryptor.Encrypt(payload, ResourceSecretAAD(cmd.OrganizationID, cmd.ResourceID, secretRef, protocol))
if err != nil {
return nil, err
}
now := s.now().UTC()
const query = `
INSERT INTO resource_secrets (
organization_id, resource_id, secret_ref, protocol, version, key_id,
algorithm, nonce, ciphertext, payload_sha256, metadata, created_by_user_id,
created_at, rotated_at
) VALUES (
$1::uuid, $2::uuid, $3, $4, 1, $5,
$6, $7, $8, $9, $10::jsonb, NULLIF($11, '')::uuid,
$12, NULL
)
ON CONFLICT (resource_id) DO UPDATE SET
secret_ref = EXCLUDED.secret_ref,
protocol = EXCLUDED.protocol,
version = resource_secrets.version + 1,
key_id = EXCLUDED.key_id,
algorithm = EXCLUDED.algorithm,
nonce = EXCLUDED.nonce,
ciphertext = EXCLUDED.ciphertext,
payload_sha256 = EXCLUDED.payload_sha256,
metadata = EXCLUDED.metadata,
created_by_user_id = EXCLUDED.created_by_user_id,
rotated_at = EXCLUDED.created_at
RETURNING id::text, organization_id::text, resource_id::text, secret_ref,
protocol, version, key_id, algorithm, metadata, created_at, rotated_at
`
var descriptor ResourceSecretDescriptor
if err := s.db.QueryRow(ctx, query,
cmd.OrganizationID,
cmd.ResourceID,
secretRef,
protocol,
encrypted.KeyID,
encrypted.Algorithm,
encrypted.Nonce,
encrypted.Ciphertext,
encrypted.PayloadSHA256,
metadata,
cmd.ActorUserID,
now,
).Scan(
&descriptor.ID,
&descriptor.OrganizationID,
&descriptor.ResourceID,
&descriptor.SecretRef,
&descriptor.Protocol,
&descriptor.Version,
&descriptor.KeyID,
&descriptor.Algorithm,
&descriptor.Metadata,
&descriptor.CreatedAt,
&descriptor.RotatedAt,
); err != nil {
return nil, fmt.Errorf("upsert resource secret: %w", err)
}
return &descriptor, nil
}
func (s *ResourceSecretStore) ResolveForSession(ctx context.Context, req ResolveResourceSecretRequest) (*ResolvedResourceSecret, error) {
if s == nil || s.encryptor == nil {
return nil, ErrSecretEncryptionKeyMissing
}
if strings.TrimSpace(req.LeaseID) == "" {
return nil, ErrSecretLeaseRequired
}
const query = `
SELECT sec.id::text, sec.organization_id::text, sec.resource_id::text, sec.secret_ref,
sec.protocol, sec.version, sec.key_id, sec.algorithm, sec.metadata,
sec.created_at, sec.rotated_at, sec.nonce, sec.ciphertext,
rs.organization_id::text, rs.resource_id::text, COALESCE(rs.worker_id, ''), rs.state
FROM resource_secrets sec
JOIN remote_sessions rs ON rs.resource_id = sec.resource_id
WHERE sec.secret_ref = $1 AND rs.id = $2::uuid
`
var descriptor ResourceSecretDescriptor
var nonce, ciphertext []byte
var sessionOrganizationID, sessionResourceID, sessionWorkerID, sessionState string
if err := s.db.QueryRow(ctx, query, req.SecretRef, req.SessionID).Scan(
&descriptor.ID,
&descriptor.OrganizationID,
&descriptor.ResourceID,
&descriptor.SecretRef,
&descriptor.Protocol,
&descriptor.Version,
&descriptor.KeyID,
&descriptor.Algorithm,
&descriptor.Metadata,
&descriptor.CreatedAt,
&descriptor.RotatedAt,
&nonce,
&ciphertext,
&sessionOrganizationID,
&sessionResourceID,
&sessionWorkerID,
&sessionState,
); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, ErrResourceSecretNotFound
}
return nil, fmt.Errorf("resolve resource secret: %w", err)
}
if descriptor.OrganizationID != req.OrganizationID ||
descriptor.ResourceID != req.ResourceID ||
sessionOrganizationID != req.OrganizationID ||
sessionResourceID != req.ResourceID ||
sessionWorkerID != req.WorkerID ||
!secretResolvableSessionState(sessionState) {
return nil, ErrSecretAccessDenied
}
plaintext, err := s.encryptor.Decrypt(EncryptedPayload{
Algorithm: descriptor.Algorithm,
KeyID: descriptor.KeyID,
Nonce: nonce,
Ciphertext: ciphertext,
}, ResourceSecretAAD(descriptor.OrganizationID, descriptor.ResourceID, descriptor.SecretRef, descriptor.Protocol))
if err != nil {
return nil, err
}
return &ResolvedResourceSecret{
Descriptor: descriptor,
Payload: json.RawMessage(plaintext),
}, nil
}
func normalizeJSONObject(raw json.RawMessage) (json.RawMessage, error) {
if len(raw) == 0 || !json.Valid(raw) {
return nil, ErrSecretPayloadInvalid
}
var decoded map[string]any
if err := json.Unmarshal(raw, &decoded); err != nil {
return nil, ErrSecretPayloadInvalid
}
encoded, err := json.Marshal(decoded)
if err != nil {
return nil, err
}
return json.RawMessage(encoded), nil
}
func normalizeJSONObjectAllowEmpty(raw json.RawMessage) (json.RawMessage, error) {
if len(raw) == 0 {
return json.RawMessage(`{}`), nil
}
return normalizeJSONObject(raw)
}
func secretResolvableSessionState(state string) bool {
switch state {
case "starting", "active", "reconnecting":
return true
default:
return false
}
}