Initial project snapshot

This commit is contained in:
2026-04-28 22:29:50 +03:00
commit 8ba0561f4f
365 changed files with 91832 additions and 0 deletions
+19
View File
@@ -0,0 +1,19 @@
FROM golang:1.23-bookworm AS build
WORKDIR /src
COPY go.mod go.sum ./
RUN go mod download
COPY . .
RUN CGO_ENABLED=0 GOOS=linux go build -o /out/rap-api ./cmd/api
FROM debian:bookworm-slim
RUN apt-get update \
&& apt-get install -y --no-install-recommends ca-certificates \
&& rm -rf /var/lib/apt/lists/*
COPY --from=build /out/rap-api /usr/local/bin/rap-api
ENTRYPOINT ["/usr/local/bin/rap-api"]
+270
View File
@@ -0,0 +1,270 @@
# Backend Foundation
Production-oriented Go backend skeleton for the remote access platform.
## Scope included
- configuration loading from environment
- HTTP server bootstrap with graceful shutdown
- PostgreSQL and Redis connectivity wiring
- migrations scaffold
- auth foundation with access/refresh tokens, hashed refresh rotation, trusted devices, and persisted auth sessions
- persistent session storage foundation for remote sessions, attachments, resource policies, and audit events
- session broker orchestration for start, attach, detach, takeover, terminate, failure, and detached-session recovery
- Redis-backed live session state, controller binding, attach tokens, heartbeat keys, worker routing, and reconnect support
- Redis-backed worker registration, lease lifecycle, heartbeat tracking, stale lease recovery, and routing queues
- worker assignment queueing and worker event ingestion for the minimal real RDP worker runtime
- websocket live plane with attach handshake, ping/pong heartbeat, state messages, takeover detection, and transport reconnect flow
- module boundaries for auth, resources, session broker, and websocket gateway
- worker registry scaffold to prepare later RDP worker integration
- per-resource certificate verification policy for RDP connections with `strict` default and explicit `ignore` override
- platform-core v2 foundations for organizations, memberships, identity sources, nodes, and node-agent control plane
- Data Plane v1 contract scaffolding for optional session response candidates/tokens, with current backend gateway behavior preserved as fallback
- production resource secret-readiness guard for rejecting plaintext credential-like metadata and requiring `secret_ref` for RDP/VNC/SSH resources in production mode
- encrypted resource secret storage/resolver MVP for production `secret_ref` usage
## Entry point
Run the API from `cmd/api`.
## Local dev
- backend: `pwsh -File scripts/smoke/run-backend.ps1`
- infra: `pwsh -File scripts/smoke/start-infra.ps1`
- migrations: `pwsh -File scripts/smoke/apply-migrations.ps1`
- worker image build: `docker build --tag rap-rdp-worker:dev --file workers/rdp-worker/Dockerfile workers/rdp-worker`
- end-to-end smoke path: [scripts/smoke/README.md](/\\?\UNC\192.168.220.200\mst\codex\rdp-proxy\scripts\smoke\README.md)
## Configuration
Use `configs/api.example.env` as the starting point for local environment variables.
Resource secret-readiness is controlled by `APP_ENV`:
- in `APP_ENV=production` or `APP_ENV=prod`, RDP/VNC/SSH resources must carry
`secret_ref` and must not include plaintext credential-like fields in
`metadata`
- in development and smoke environments, plaintext metadata remains allowed
until the encrypted secret resolver is implemented
- the production guard is enforced both on resource create/update and on
session start, so legacy plaintext resources cannot be started in production
accidentally
- `SECRET_ENCRYPTION_KEY_B64` or `SECRET_ENCRYPTION_KEY_FILE` supplies the
AES-256-GCM master key for the MVP encrypted store; production mode refuses
to start without one
- `SECRET_ENCRYPTION_KEY_ID` labels the active key version in stored records
- `PUT /api/v1/resources/{resourceID}/secret` creates or rotates a resource
secret and updates `resources.secret_ref`; plaintext is never returned by the
API
- session assignment keeps PostgreSQL metadata safe: `remote_sessions.metadata`
stores `secret_ref`, while resolved credentials are merged only into the
transient worker assignment after session/worker/lease checks
See `docs/architecture/SECURITY_SECRETS_READINESS.md` for the target
secret-reference model and remaining resolver/PKI gaps.
Data Plane v1 contract scaffolding is controlled by:
- `DATA_PLANE_TOKEN_TTL`, default `1m`
- `DATA_PLANE_TOKEN_PRIVATE_KEY_FILE`, optional path to an RSA private key PEM used to sign RS256 data-plane tokens
- `DATA_PLANE_TOKEN_PRIVATE_KEY_PEM`, optional inline RSA private key PEM; used when file path is not configured
- `DATA_PLANE_BACKEND_GATEWAY_URL`, default `/api/v1/gateway/ws`
- `DATA_PLANE_DIRECT_WORKER_WSS_URL_TEMPLATE`, optional; supports `{worker_id}` replacement
- `DATA_PLANE_DIRECT_WORKER_JSON_RUNTIME`, default `false`; advertises
`runtime_transport=json_v1` only after the worker direct JSON bridge is
deployed and verified
- `DATA_PLANE_DIRECT_WORKER_BINARY_RENDER`, default `false`; when the direct
JSON runtime is enabled, advertises `render_transport=binary_v1` so DP-2
clients can request binary render frames over direct worker WSS. Binary
render candidates also advertise `supported_color_modes=["full_color","grayscale"]`
and `default_color_mode="full_color"` for the DP-3A grayscale foundation.
- `DATA_PLANE_DIRECT_WORKER_TLS_TRUST_MODE`, default `smoke_insecure`; allowed
values are `smoke_insecure`, `public_ca`, and `platform_ca`.
- `DATA_PLANE_DIRECT_WORKER_TLS_CA_REF`, optional label for the platform CA or
trust bundle version advertised to clients.
Data-plane tokens are RS256-signed. The backend must hold only the private key;
workers receive only the matching public key for validation. If no private key
is configured, the backend omits the optional `data_plane` offer and the
backend gateway fallback remains unchanged.
If no direct worker WSS URL template is configured, session responses still include the backend gateway fallback candidate only.
If the URL template is configured but `DATA_PLANE_DIRECT_WORKER_JSON_RUNTIME`
is `false`, the direct candidate is still present for contract visibility but is
not marked data-capable; DP-1D Windows clients will skip it and use the backend
gateway fallback.
If `DATA_PLANE_DIRECT_WORKER_BINARY_RENDER` is `false`, direct worker WSS
remains JSON/base64 for render. If it is `true`, only direct worker WSS render
is binary; backend gateway fallback remains JSON/base64.
In production, the backend does not advertise direct worker WSS when
`DATA_PLANE_DIRECT_WORKER_TLS_TRUST_MODE=smoke_insecure`; it keeps the backend
gateway fallback instead. Trusted direct candidates include `tls_trust_mode`,
`production_trusted`, `smoke_only`, and optional `tls_ca_ref` metadata. See
`docs/architecture/DIRECT_WORKER_TLS_PKI.md`.
## Module layout
- `internal/platform` shared runtime, config, infra, and bootstrap concerns
- `internal/modules/auth` auth and trusted-device boundary
- `internal/modules/organization` organization model, org roles, and memberships
- `internal/modules/identitysource` local/LDAP/OIDC identity source model and future mapping foundations
- `internal/modules/resource` remote resource inventory boundary
- `internal/modules/sessionbroker` persistent session lifecycle, orchestration, audit, and Redis live-state boundary
- `internal/modules/sessiongateway` websocket attach/reconnect/takeover transport boundary
- `internal/modules/worker` worker registration, lease coordination, and control-plane routing boundary for future C++ RDP workers
- `internal/modules/node` node inventory, capabilities, enabled services, update policy, and partition state
- `internal/modules/nodeagent` node-agent registration, health, service status, and update/rollback control interface
- `pkg/contracts` cross-module contracts for sessions and worker control
## Backend responsibilities
- PostgreSQL remains the source of truth for auth sessions, devices, remote sessions, attachments, resource policies, and audit events
- Redis is used only for live routing and coordination: attach tokens, controller bindings, live session cache, worker registration, worker leases, heartbeats, and routing queues
- `worker:control:<worker_id>` carries worker assignments, `worker:queue:<session_id>` carries live control/input envelopes, and `worker:events` carries worker-reported lifecycle events back into broker processing
- Session broker owns state transitions and orchestration rules; websocket handlers call broker services instead of talking to postgres repositories directly
- Worker runtime stays behind interfaces and Redis coordination so the backend remains isolated from FreeRDP implementation details while the minimal real RDP worker plugs into the control plane
- RDP certificate verification is configured per resource through `certificate_verification_mode`
- resources are now org-scoped in PostgreSQL and remote sessions persist their owning organization without changing the proven worker/session runtime contracts
- session start/attach/takeover responses may include optional `data_plane` candidates and a short-lived signed data-plane token for DP-1 direct worker WSS migration; existing clients continue to use the current gateway path, and direct realtime use remains gated by explicit candidate metadata
## Authorization model
- `platform_admin` and `platform_recovery_admin` have global access across organizations, resources, and sessions
- in `INSTALLATION_AUTHORITY_MODE=strict`, platform-admin power is effective only
when the user also has a valid signed row in `platform_role_grants`; changing
`users.platform_role` in PostgreSQL alone no longer grants owner access
- first-owner bootstrap is available at
`POST /api/v1/installation/bootstrap-owner` and requires a Product Root
Ed25519 signature over an activation manifest in strict mode
- production (`APP_ENV=production` or `prod`) requires strict installation
authority plus `INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_B64` or
`INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_FILE`
- legacy/dev installs can keep database-role behavior, and insecure first-owner
bootstrap is available only when
`INSTALLATION_INSECURE_BOOTSTRAP_ENABLED=true`
- `org_owner` and `org_admin` can create and update resources inside their organization and can manage any remote session inside that organization
- active non-admin memberships such as `org_operator`, `org_member`, and `org_viewer` are deny-by-default for admin actions; they can only access org-scoped reads and operate on their own session flows where the session broker explicitly allows it
- session start always authorizes the actor against the resource organization before worker reservation
- attach, detach, takeover, and terminate authorize against the owning remote session organization before any state transition is written
- worker-facing events do not bypass this model for user-originated commands; internal worker failure and heartbeat paths remain broker-internal control-plane operations
## Migration safety
- `000005_platform_core_v2` bootstraps a single `default` organization and backfills existing `resources.organization_id` and `remote_sessions.organization_id` into that organization before setting `NOT NULL`
- `000006_default_org_memberships_backfill` safely restores access continuity by inserting missing active memberships for existing users into the `default` organization
- the backfill is idempotent because it only inserts rows missing under the `(organization_id, user_id)` uniqueness constraint
- platform administrators are backfilled as `org_owner` in the default organization, while other existing users are backfilled as `org_member`
- if `000005` fails before the `NOT NULL` step, PostgreSQL rolls back the transaction and leaves pre-v2 rows untouched; if `000006` is rerun, it skips already-created memberships rather than duplicating them
## Platform-Core V2 Notes
- `organizations`, `organization_memberships`, and `organization_roles` establish multi-tenant ownership and basic org-scoped authorization boundaries
- `identity_sources` and `identity_mappings` are foundation-only in this phase; full LDAP/OIDC sync and claim/group ingestion are intentionally deferred
- `nodes`, `node_capabilities`, `node_services`, `node_update_policies`, `node_partition_states`, and `node_agent_update_runs` provide the first control-plane model for node and node-agent lifecycle
- current proven RDP session lifecycle remains preserved: the session broker still orchestrates the same worker/session behavior, but it now records organization ownership via org-scoped resources
- PostgreSQL remains the source of truth for organizations, memberships, org-scoped resources, identity sources, nodes, node-agent state, and session lifecycle state
## Resource Certificate Verification
- `strict` is the default and keeps normal certificate validation enabled in the worker runtime
- `ignore` must be explicitly stored on the resource and allows that one RDP connection to skip certificate validation
- the backend passes this policy through session assignment data; it is not a global backend toggle
## Messaging Model
- HTTP errors now use a structured envelope:
- `error.code`
- `error.message_key`
- `error.fallback_message`
- `error.details`
- `error.trace_id`
- `internal/platform/httpx` owns error normalization and trace-id generation so handlers can keep calling `WriteError(...)` without changing business logic.
- For `5xx` responses, user-facing payloads are normalized to an English generic fallback message while logs and diagnostics can still keep raw internal details elsewhere.
- For `4xx` responses, stable `code` and `message_key` are derived from the current fallback message, so clients can localize without depending on raw English text as the primary contract.
## WebSocket Messaging
- Session gateway envelopes keep the existing `type` and `payload` contract.
- User-facing websocket events now also include `event` with:
- `code`
- `message_key`
- `fallback_message`
- `details`
- `trace_id`
- `session.taken_over`, terminal `session.state`, `transport.closed`, and protocol-level errors now carry this structured event object.
- Existing payload semantics remain intact for compatibility with the already proven session lifecycle.
## Message Rules
- Keep English as the only development language for `fallback_message`, logs, and diagnostics.
- New HTTP handlers should prefer `httpx.WriteError(...)` for user-facing failures instead of hand-building `"error": "..."` JSON.
- New websocket user-facing notifications should populate `TransportEnvelope.Event` with a stable `code` and `message_key`.
- Do not use raw human-readable English text as the primary client contract; it should only remain as fallback text.
- This messaging layer is now runtime-proven against the live Windows smoke flow for invalid-login errors, websocket takeover delivery, websocket state fallback rendering, and worker-death failure handling.
## Clipboard Policy
RDP text clipboard is controlled per resource through `resource_policies.clipboard_mode`.
Allowed values are `disabled`, `client_to_server`, `server_to_client`, and
`bidirectional`; the default is `disabled`. The legacy `clipboard_enabled`
column is retained only for compatibility and migration/backfill, while new
runtime decisions use `clipboard_mode`.
Clipboard enforcement happens in the real data path:
- `sessionbroker.ResourcePolicy.ClipboardMode` is loaded from PostgreSQL and
embedded into the session assignment metadata sent to the worker.
- `sessiongateway.Module.handleEnvelope` blocks client-to-server clipboard
envelopes unless the session is `active` and the policy allows that direction.
- `worker.EventProcessor` sends worker-originated clipboard text through
`sessionbroker.Service.UpdateWorkerClipboardText`, which applies the same
active-state and server-to-client policy checks before updating live state.
- Clipboard messages carry `sequence_id`, `origin`, and `content_hash` so
clients and workers can avoid feedback loops across reattach/takeover paths.
- Redis stores clipboard text only as transient live state for routing to the
active controller; PostgreSQL remains authoritative for policy/session state.
## File Upload Policy
Stage 5.1 introduces client-to-server file upload as a policy-gated RDP
feature. The authoritative policy field is
`resource_policies.file_transfer_mode`; allowed values are `disabled`,
`client_to_server`, `server_to_client`, and `bidirectional`, but only
`client_to_server` behavior is implemented in this stage. The default is
`disabled`. The legacy `file_transfer_enabled` column is retained only as a
derived compatibility flag and must not be treated as the primary policy.
Enforcement is deliberately duplicated in the real data path:
- `resource.Module` exposes `file_transfer_mode` in resource create, update,
list, and read payloads.
- `sessionbroker.Service.StartRemoteSession` embeds `file_transfer_mode` into
assignment metadata and requests the worker `file-transfer` capability only
when client-to-server upload is allowed.
- `sessiongateway.Module.handleFileUploadStart` and
`handleFileUploadChunk` require an active session, current controller,
allowed policy mode, valid UUID `transfer_id`, safe file name, 25 MiB max
file size, and 256 KiB max chunk size before routing chunks to the worker.
- Redis is used only to route bounded upload envelopes to the worker. The file
itself is written by the worker to controlled worker storage; PostgreSQL
remains authoritative for policy and session state.
## File Download Policy
Stage 5.2 adds a runtime-proven server-to-client download path for RDP. The
policy field remains `resource_policies.file_transfer_mode`; `server_to_client`
and `bidirectional` allow download, while `disabled` and `client_to_server`
block it. The default remains `disabled`.
The v1 download model uses only the restricted `RAP_Transfers\ToClient`
drop-zone inside the existing per-session visible transfer directory. Backend
gateway accepts only `file_download.start`, `file_download.ack`, and
`file_download.cancel` from the current controller of an active session and
routes them to the worker after policy validation. Worker-origin
`file_download.*` events are stored only as transient live state for
backend-gateway fallback delivery; PostgreSQL remains authoritative for
session/resource/policy state and must not store file contents.
The direct worker WSS path is also lifecycle-gated: detach returns
`file_download.blocked`, old-controller takeover returns `session.taken_over`,
and worker failure closes the direct transport after PostgreSQL transitions the
session to `failed`.
+24
View File
@@ -0,0 +1,24 @@
package main
import (
"context"
"log"
"os/signal"
"syscall"
"github.com/example/remote-access-platform/backend/internal/platform/runtime"
)
func main() {
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()
app, err := runtime.NewApp(ctx)
if err != nil {
log.Fatalf("bootstrap app: %v", err)
}
if err := app.Run(ctx); err != nil {
log.Fatalf("run app: %v", err)
}
}
+94
View File
@@ -0,0 +1,94 @@
package main
import (
"context"
"flag"
"fmt"
"log"
"net/url"
"os"
"os/signal"
"strings"
"time"
"github.com/gorilla/websocket"
)
func main() {
var (
gatewayURL = flag.String("gateway-url", "ws://127.0.0.1:8080/api/v1/gateway/ws", "websocket gateway url")
attachToken = flag.String("attach-token", "", "short-lived attach token")
duration = flag.Duration("duration", 90*time.Second, "maximum time to keep the client attached")
heartbeatInterval = flag.Duration("heartbeat-interval", 10*time.Second, "client heartbeat interval")
exitOnTakenOver = flag.Bool("exit-on-taken-over", true, "exit successfully after receiving session.taken_over")
)
flag.Parse()
if *attachToken == "" {
log.Fatal("attach-token is required")
}
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()
if *duration > 0 {
var timeoutCancel context.CancelFunc
ctx, timeoutCancel = context.WithTimeout(ctx, *duration)
defer timeoutCancel()
}
targetURL, err := url.Parse(*gatewayURL)
if err != nil {
log.Fatalf("parse gateway url: %v", err)
}
query := targetURL.Query()
query.Set("attach_token", *attachToken)
targetURL.RawQuery = query.Encode()
conn, _, err := websocket.DefaultDialer.DialContext(ctx, targetURL.String(), nil)
if err != nil {
log.Fatalf("dial websocket: %v", err)
}
defer conn.Close()
done := make(chan error, 1)
go func() {
for {
_, payload, err := conn.ReadMessage()
if err != nil {
done <- err
return
}
fmt.Println(string(payload))
if *exitOnTakenOver && containsTakenOver(payload) {
done <- nil
return
}
}
}()
heartbeatTicker := time.NewTicker(*heartbeatInterval)
defer heartbeatTicker.Stop()
for {
select {
case err := <-done:
if err != nil {
log.Fatalf("websocket read: %v", err)
}
return
case <-heartbeatTicker.C:
if err := conn.WriteJSON(map[string]any{"type": "heartbeat"}); err != nil {
log.Fatalf("heartbeat write: %v", err)
}
case <-ctx.Done():
if err := ctx.Err(); err != nil && err != context.Canceled && err != context.DeadlineExceeded {
log.Fatalf("context ended: %v", err)
}
return
}
}
}
func containsTakenOver(payload []byte) bool {
return strings.Contains(string(payload), `"type":"session.taken_over"`)
}
+44
View File
@@ -0,0 +1,44 @@
APP_NAME=rap-api
APP_ENV=development
HTTP_HOST=0.0.0.0
HTTP_PORT=8080
HTTP_READ_TIMEOUT=15s
HTTP_WRITE_TIMEOUT=15s
HTTP_IDLE_TIMEOUT=60s
HTTP_SHUTDOWN_TIMEOUT=10s
POSTGRES_DSN=postgres://rap_user:rap_password@localhost:5432/remote_access_platform?sslmode=disable
POSTGRES_MAX_CONNS=20
POSTGRES_MIN_CONNS=2
POSTGRES_CONNECT_TIMEOUT=5s
REDIS_ADDR=localhost:6379
REDIS_PASSWORD=
REDIS_DB=0
REDIS_DIAL_TIMEOUT=5s
AUTH_ACCESS_TOKEN_TTL=15m
AUTH_REFRESH_TOKEN_TTL=720h
AUTH_ISSUER=rap-api
AUTH_ACCESS_TOKEN_SECRET=change-me-access-secret
AUTH_REFRESH_HASH_SECRET=change-me-refresh-hash-secret
DATA_PLANE_TOKEN_TTL=1m
DATA_PLANE_TOKEN_PRIVATE_KEY_FILE=
DATA_PLANE_TOKEN_PRIVATE_KEY_PEM=
DATA_PLANE_BACKEND_GATEWAY_URL=/api/v1/gateway/ws
DATA_PLANE_DIRECT_WORKER_WSS_URL_TEMPLATE=
DATA_PLANE_DIRECT_WORKER_JSON_RUNTIME=false
DATA_PLANE_DIRECT_WORKER_BINARY_RENDER=false
DATA_PLANE_DIRECT_WORKER_TLS_TRUST_MODE=smoke_insecure
DATA_PLANE_DIRECT_WORKER_TLS_CA_REF=
SECRET_ENCRYPTION_KEY_B64=
SECRET_ENCRYPTION_KEY_FILE=
SECRET_ENCRYPTION_KEY_ID=local-v1
SESSION_HEARTBEAT_TTL=90s
SESSION_DETACH_GRACE_PERIOD=30m
SESSION_ATTACH_TOKEN_TTL=2m
SESSION_LIVE_STATE_TTL=2m
SESSION_RECOVERY_BATCH_SIZE=100
WORKER_LEASE_TTL=45s
WORKER_HEARTBEAT_TTL=15s
WORKER_STALE_LEASE_GRACE_PERIOD=30s
WEBSOCKET_WRITE_TIMEOUT=10s
WEBSOCKET_PING_INTERVAL=20s
WEBSOCKET_PONG_WAIT=40s
+23
View File
@@ -0,0 +1,23 @@
module github.com/example/remote-access-platform/backend
go 1.23.2
require (
github.com/go-chi/chi/v5 v5.2.1
github.com/golang-jwt/jwt/v5 v5.2.2
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/jackc/pgx/v5 v5.7.4
github.com/redis/go-redis/v9 v9.8.0
golang.org/x/crypto v0.37.0
)
require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
golang.org/x/sync v0.13.0 // indirect
golang.org/x/text v0.24.0 // indirect
)
+46
View File
@@ -0,0 +1,46 @@
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8=
github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg=
github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhicI=
github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+19
View File
@@ -0,0 +1,19 @@
package auth
import "errors"
var (
ErrInvalidCredentials = errors.New("invalid credentials")
ErrInvalidRefreshToken = errors.New("invalid refresh token")
ErrAuthSessionRevoked = errors.New("auth session revoked")
ErrDeviceRevoked = errors.New("device revoked")
ErrDeviceNotTrusted = errors.New("device not trusted")
ErrAuthSessionNotFound = errors.New("auth session not found")
ErrTrustedDeviceMissing = errors.New("trusted device not found")
ErrInstallationAlreadyBootstrapped = errors.New("installation is already bootstrapped")
ErrInstallationActivationRequired = errors.New("signed installation activation is required")
ErrInvalidInstallationActivation = errors.New("invalid installation activation")
ErrInsecureBootstrapDisabled = errors.New("insecure installation bootstrap is disabled")
ErrInvalidBootstrapOwner = errors.New("invalid bootstrap owner")
)
+114
View File
@@ -0,0 +1,114 @@
package auth
import (
"encoding/json"
"time"
)
type DeviceTrustStatus string
const (
DeviceTrustStatusPending DeviceTrustStatus = "pending"
DeviceTrustStatusTrusted DeviceTrustStatus = "trusted"
DeviceTrustStatusRevoked DeviceTrustStatus = "revoked"
)
type User struct {
ID string
Email string
PasswordHash string
MFAEnabled bool
CreatedAt time.Time
UpdatedAt time.Time
}
type Device struct {
ID string
UserID string
Fingerprint string
Label string
TrustStatus DeviceTrustStatus
TrustedAt *time.Time
LastSeenAt *time.Time
RevokedAt *time.Time
RevokedReason *string
CreatedAt time.Time
UpdatedAt time.Time
}
type AuthSession struct {
ID string
UserID string
DeviceID string
RefreshTokenHash string
RefreshExpiresAt time.Time
LastSeenAt *time.Time
LastRotatedAt *time.Time
RevokedAt *time.Time
RevokedReason *string
CreatedAt time.Time
UpdatedAt time.Time
}
type LoginCommand struct {
Email string `json:"email"`
Password string `json:"password"`
DeviceFingerprint string `json:"device_fingerprint"`
DeviceLabel string `json:"device_label"`
TrustDevice bool `json:"trust_device"`
}
type RefreshCommand struct {
RefreshToken string `json:"refresh_token"`
}
type BootstrapOwnerCommand struct {
Email string `json:"email"`
Password string `json:"password"`
ActivationPayload json.RawMessage `json:"activation_payload"`
ActivationSignature string `json:"activation_signature"`
}
type RevokeAuthSessionCommand struct {
UserID string `json:"user_id"`
AuthSessionID string `json:"auth_session_id"`
Reason string `json:"reason"`
}
type RevokeDeviceCommand struct {
UserID string `json:"user_id"`
DeviceID string `json:"device_id"`
Reason string `json:"reason"`
}
type TokenPair struct {
AccessToken string `json:"access_token"`
AccessTokenExpiresAt time.Time `json:"access_token_expires_at"`
RefreshToken string `json:"refresh_token"`
RefreshTokenExpiresAt time.Time `json:"refresh_token_expires_at"`
}
type AuthResult struct {
User User `json:"user"`
Device Device `json:"device"`
AuthSession AuthSession `json:"auth_session"`
Tokens TokenPair `json:"tokens"`
}
type InstallationStatus struct {
Bootstrapped bool `json:"bootstrapped"`
AuthorityState string `json:"authority_state"`
InstallID string `json:"install_id,omitempty"`
BootstrappedOwnerEmail string `json:"bootstrapped_owner_email,omitempty"`
BootstrappedAt *time.Time `json:"bootstrapped_at,omitempty"`
AuthorityMode string `json:"authority_mode"`
StrictAuthority bool `json:"strict_authority"`
RootFingerprint string `json:"root_fingerprint,omitempty"`
InsecureBootstrapAllowed bool `json:"insecure_bootstrap_allowed"`
}
type BootstrapOwnerResult struct {
Installation InstallationStatus `json:"installation"`
User User `json:"user"`
PlatformRole string `json:"platform_role"`
}
+173
View File
@@ -0,0 +1,173 @@
package auth
import (
"encoding/json"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
"github.com/example/remote-access-platform/backend/internal/platform/module"
)
type Module struct {
service *Service
}
func NewModule(deps module.Dependencies, service *Service) *Module {
return &Module{service: service}
}
func (m *Module) Name() string {
return "auth"
}
func (m *Module) RegisterRoutes(router chi.Router) {
router.Route("/installation", func(r chi.Router) {
r.Get("/status", m.handleInstallationStatus)
r.Post("/bootstrap-owner", m.handleBootstrapOwner)
})
router.Route("/auth", func(r chi.Router) {
r.Post("/login", m.handleLogin)
r.Post("/refresh", m.handleRefresh)
r.Post("/sessions/revoke", m.handleRevokeAuthSession)
r.Get("/devices", m.handleTrustedDevices)
r.Post("/devices/{deviceID}/revoke", m.handleRevokeTrustedDevice)
})
}
func (m *Module) handleInstallationStatus(w http.ResponseWriter, r *http.Request) {
status, err := m.service.InstallationStatus(r.Context())
if err != nil {
statusCode, message := m.service.MapError(err)
httpx.WriteError(w, statusCode, message)
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"installation": status})
}
func (m *Module) handleBootstrapOwner(w http.ResponseWriter, r *http.Request) {
var cmd BootstrapOwnerCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid installation bootstrap payload")
return
}
result, err := m.service.BootstrapOwner(r.Context(), cmd)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusCreated, result)
}
func (m *Module) handleLogin(w http.ResponseWriter, r *http.Request) {
var cmd LoginCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid login payload")
return
}
result, err := m.service.Login(r.Context(), cmd)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusOK, result)
}
func (m *Module) handleRefresh(w http.ResponseWriter, r *http.Request) {
var cmd RefreshCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid refresh payload")
return
}
result, err := m.service.Refresh(r.Context(), cmd)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusOK, result)
}
func (m *Module) handleRevokeAuthSession(w http.ResponseWriter, r *http.Request) {
var cmd RevokeAuthSessionCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid auth session revoke payload")
return
}
if err := m.service.RevokeAuthSession(r.Context(), cmd); err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{
"status": "revoked",
"message": httpx.NewMessage(
"auth.session.revoked",
"status.auth.session.revoked",
"Auth session revoked.",
nil,
"",
),
})
}
func (m *Module) handleTrustedDevices(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id")
if userID == "" {
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
return
}
devices, err := m.service.ListTrustedDevices(r.Context(), userID)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{
"devices": devices,
})
}
func (m *Module) handleRevokeTrustedDevice(w http.ResponseWriter, r *http.Request) {
var payload struct {
UserID string `json:"user_id"`
Reason string `json:"reason"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid device revoke payload")
return
}
err := m.service.RevokeTrustedDevice(r.Context(), RevokeDeviceCommand{
UserID: payload.UserID,
DeviceID: chi.URLParam(r, "deviceID"),
Reason: payload.Reason,
})
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{
"status": "revoked",
"message": httpx.NewMessage(
"auth.device.revoked",
"status.auth.device.revoked",
"Trusted device revoked.",
nil,
"",
),
})
}
@@ -0,0 +1,525 @@
package auth
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
)
type postgresStore struct {
db postgresplatform.DBTX
}
type PostgresTransactor struct {
pool *pgxpool.Pool
}
func NewPostgresStore(pool *pgxpool.Pool) Store {
return &postgresStore{db: pool}
}
func NewPostgresTransactor(pool *pgxpool.Pool) *PostgresTransactor {
return &PostgresTransactor{pool: pool}
}
func (t *PostgresTransactor) WithinTransaction(ctx context.Context, fn func(store Store) error) error {
return postgresplatform.WithTransaction(ctx, t.pool, func(tx pgx.Tx) error {
return fn(&postgresStore{db: tx})
})
}
func (s *postgresStore) Users() UserRepository {
return &postgresUserRepository{db: s.db}
}
func (s *postgresStore) Devices() DeviceRepository {
return &postgresDeviceRepository{db: s.db}
}
func (s *postgresStore) AuthSessions() AuthSessionRepository {
return &postgresAuthSessionRepository{db: s.db}
}
func (s *postgresStore) Installation() InstallationRepository {
return &postgresInstallationRepository{db: s.db}
}
type postgresUserRepository struct {
db postgresplatform.DBTX
}
type postgresDeviceRepository struct {
db postgresplatform.DBTX
}
type postgresAuthSessionRepository struct {
db postgresplatform.DBTX
}
type postgresInstallationRepository struct {
db postgresplatform.DBTX
}
func (r *postgresUserRepository) GetByEmail(ctx context.Context, email string) (*User, error) {
const query = `
SELECT id::text, email, password_hash, mfa_enabled, created_at, updated_at
FROM users
WHERE email = $1
`
return scanOptionalUser(r.db.QueryRow(ctx, query, email))
}
func (r *postgresUserRepository) GetByID(ctx context.Context, userID string) (*User, error) {
const query = `
SELECT id::text, email, password_hash, mfa_enabled, created_at, updated_at
FROM users
WHERE id = $1::uuid
`
return scanOptionalUser(r.db.QueryRow(ctx, query, userID))
}
func (r *postgresDeviceRepository) Upsert(ctx context.Context, params UpsertDeviceParams) (*Device, error) {
const query = `
INSERT INTO devices (
user_id,
device_fingerprint,
device_label,
trust_status,
trusted_at,
last_seen_at,
created_at,
updated_at
) VALUES (
$1::uuid,
$2,
$3,
CASE WHEN $4 THEN 'trusted' ELSE 'pending' END,
CASE WHEN $4 THEN $5::timestamptz ELSE NULL::timestamptz END,
$5::timestamptz,
$5::timestamptz,
$5::timestamptz
)
ON CONFLICT (user_id, device_fingerprint) DO UPDATE SET
device_label = EXCLUDED.device_label,
last_seen_at = EXCLUDED.last_seen_at,
updated_at = EXCLUDED.updated_at,
trust_status = CASE
WHEN devices.trust_status = 'revoked' THEN devices.trust_status
WHEN devices.trust_status = 'trusted' THEN devices.trust_status
WHEN EXCLUDED.trust_status = 'trusted' THEN 'trusted'
ELSE devices.trust_status
END,
trusted_at = CASE
WHEN devices.trust_status = 'trusted' THEN devices.trusted_at
WHEN EXCLUDED.trust_status = 'trusted' THEN EXCLUDED.trusted_at
ELSE devices.trusted_at
END
RETURNING
id::text, user_id::text, device_fingerprint, COALESCE(device_label, ''),
trust_status, trusted_at, last_seen_at, revoked_at, revoked_reason, created_at, updated_at
`
return scanDevice(r.db.QueryRow(ctx, query,
params.UserID,
params.Fingerprint,
params.Label,
params.TrustRequested,
params.SeenAt,
))
}
func (r *postgresDeviceRepository) GetByIDForUser(ctx context.Context, userID, deviceID string) (*Device, error) {
const query = `
SELECT id::text, user_id::text, device_fingerprint, COALESCE(device_label, ''),
trust_status, trusted_at, last_seen_at, revoked_at, revoked_reason, created_at, updated_at
FROM devices
WHERE id = $1::uuid AND user_id = $2::uuid
`
device, err := scanDevice(r.db.QueryRow(ctx, query, deviceID, userID))
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return device, err
}
func (r *postgresDeviceRepository) ListTrustedByUser(ctx context.Context, userID string) ([]Device, error) {
const query = `
SELECT id::text, user_id::text, device_fingerprint, COALESCE(device_label, ''),
trust_status, trusted_at, last_seen_at, revoked_at, revoked_reason, created_at, updated_at
FROM devices
WHERE user_id = $1::uuid AND trust_status = 'trusted' AND revoked_at IS NULL
ORDER BY created_at DESC
`
rows, err := r.db.Query(ctx, query, userID)
if err != nil {
return nil, fmt.Errorf("query trusted devices: %w", err)
}
defer rows.Close()
var devices []Device
for rows.Next() {
device, err := scanDevice(rows)
if err != nil {
return nil, err
}
devices = append(devices, *device)
}
return devices, rows.Err()
}
func (r *postgresDeviceRepository) Revoke(ctx context.Context, params RevokeDeviceParams) error {
const query = `
UPDATE devices
SET trust_status = 'revoked',
revoked_at = $3,
revoked_reason = $4,
updated_at = $3
WHERE id = $1::uuid AND user_id = $2::uuid
`
if _, err := r.db.Exec(ctx, query, params.DeviceID, params.UserID, params.RevokedAt, params.Reason); err != nil {
return fmt.Errorf("revoke device: %w", err)
}
return nil
}
func (r *postgresAuthSessionRepository) Create(ctx context.Context, session AuthSession) error {
const query = `
INSERT INTO auth_sessions (
id,
user_id,
device_id,
refresh_token_hash,
refresh_expires_at,
last_seen_at,
created_at,
updated_at
) VALUES ($1::uuid, $2::uuid, $3::uuid, $4, $5, $6, $7, $8)
`
if _, err := r.db.Exec(ctx, query,
session.ID,
session.UserID,
session.DeviceID,
session.RefreshTokenHash,
session.RefreshExpiresAt,
session.LastSeenAt,
session.CreatedAt,
session.UpdatedAt,
); err != nil {
return fmt.Errorf("create auth session: %w", err)
}
return nil
}
func (r *postgresAuthSessionRepository) GetByID(ctx context.Context, authSessionID string) (*AuthSession, error) {
return r.getByID(ctx, authSessionID, "")
}
func (r *postgresAuthSessionRepository) GetByIDForUpdate(ctx context.Context, authSessionID string) (*AuthSession, error) {
return r.getByID(ctx, authSessionID, " FOR UPDATE")
}
func (r *postgresAuthSessionRepository) getByID(ctx context.Context, authSessionID string, suffix string) (*AuthSession, error) {
query := `
SELECT id::text, user_id::text, device_id::text, refresh_token_hash, refresh_expires_at,
last_seen_at, last_rotated_at, revoked_at, revoked_reason, created_at, updated_at
FROM auth_sessions
WHERE id = $1::uuid` + suffix
session, err := scanAuthSession(r.db.QueryRow(ctx, query, authSessionID))
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return session, err
}
func (r *postgresAuthSessionRepository) Rotate(ctx context.Context, params RotateAuthSessionParams) error {
const query = `
UPDATE auth_sessions
SET refresh_token_hash = $2,
refresh_expires_at = $3,
last_seen_at = $4,
last_rotated_at = $5,
updated_at = $5
WHERE id = $1::uuid AND revoked_at IS NULL
`
if _, err := r.db.Exec(ctx, query,
params.AuthSessionID,
params.RefreshTokenHash,
params.RefreshExpiresAt,
params.LastSeenAt,
params.LastRotatedAt,
); err != nil {
return fmt.Errorf("rotate auth session: %w", err)
}
return nil
}
func (r *postgresAuthSessionRepository) Touch(ctx context.Context, authSessionID string, seenAt time.Time) error {
const query = `
UPDATE auth_sessions
SET last_seen_at = $2, updated_at = $2
WHERE id = $1::uuid AND revoked_at IS NULL
`
if _, err := r.db.Exec(ctx, query, authSessionID, seenAt); err != nil {
return fmt.Errorf("touch auth session: %w", err)
}
return nil
}
func (r *postgresAuthSessionRepository) Revoke(ctx context.Context, params RevokeAuthSessionParams) error {
const query = `
UPDATE auth_sessions
SET revoked_at = $3,
revoked_reason = $4,
updated_at = $3
WHERE id = $1::uuid AND user_id = $2::uuid AND revoked_at IS NULL
`
if _, err := r.db.Exec(ctx, query, params.AuthSessionID, params.UserID, params.RevokedAt, params.Reason); err != nil {
return fmt.Errorf("revoke auth session: %w", err)
}
return nil
}
func (r *postgresAuthSessionRepository) RevokeByDevice(ctx context.Context, userID, deviceID, reason string, revokedAt time.Time) error {
const query = `
UPDATE auth_sessions
SET revoked_at = $3,
revoked_reason = $4,
updated_at = $3
WHERE user_id = $1::uuid AND device_id = $2::uuid AND revoked_at IS NULL
`
if _, err := r.db.Exec(ctx, query, userID, deviceID, revokedAt, reason); err != nil {
return fmt.Errorf("revoke auth sessions by device: %w", err)
}
return nil
}
func (r *postgresInstallationRepository) GetStatus(ctx context.Context) (*InstallationAuthorityState, error) {
const query = `
SELECT install_id, authority_state, product_root_key_fingerprint, bootstrapped_owner_email, bootstrapped_at
FROM installation_authority
WHERE id = 1
`
status := &InstallationAuthorityState{}
if err := r.db.QueryRow(ctx, query).Scan(
&status.InstallID,
&status.AuthorityState,
&status.ProductRootFingerprint,
&status.BootstrappedOwnerEmail,
&status.BootstrappedAt,
); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return &InstallationAuthorityState{
Bootstrapped: false,
AuthorityState: "unbootstrapped",
}, nil
}
return nil, fmt.Errorf("get installation status: %w", err)
}
status.Bootstrapped = true
return status, nil
}
func (r *postgresInstallationRepository) BootstrapOwner(ctx context.Context, params BootstrapOwnerParams) (*User, error) {
var existingInstallID string
if err := r.db.QueryRow(ctx, `
SELECT install_id
FROM installation_authority
WHERE id = 1
FOR UPDATE
`).Scan(&existingInstallID); err != nil && !errors.Is(err, pgx.ErrNoRows) {
return nil, fmt.Errorf("lock installation authority: %w", err)
} else if err == nil {
return nil, ErrInstallationAlreadyBootstrapped
}
email := strings.ToLower(strings.TrimSpace(params.Email))
now := params.Now.UTC()
user, err := scanOptionalUser(r.db.QueryRow(ctx, `
INSERT INTO users (email, password_hash, mfa_enabled, platform_role, created_at, updated_at)
VALUES ($1, $2, FALSE, $3, $4, $4)
ON CONFLICT (email) DO UPDATE SET
password_hash = EXCLUDED.password_hash,
platform_role = EXCLUDED.platform_role,
updated_at = EXCLUDED.updated_at
RETURNING id::text, email, password_hash, mfa_enabled, created_at, updated_at
`, email, params.PasswordHash, params.Role, now))
if err != nil {
return nil, fmt.Errorf("upsert bootstrap owner: %w", err)
}
if user == nil {
return nil, fmt.Errorf("upsert bootstrap owner returned no user")
}
payload := json.RawMessage(`{}`)
if len(params.ActivationPayload) > 0 {
payload = params.ActivationPayload
}
if _, err := r.db.Exec(ctx, `
INSERT INTO installation_authority (
id,
install_id,
authority_state,
product_root_key_fingerprint,
activation_payload,
activation_signature,
bootstrapped_owner_email,
bootstrapped_at,
created_at,
updated_at
) VALUES (
1,
$1,
'active',
$2,
$3::jsonb,
$4,
$5,
$6,
$6,
$6
)
`, params.InstallID, params.ProductRootKeyFingerprint, []byte(payload), params.ActivationSignature, email, now); err != nil {
return nil, fmt.Errorf("insert installation authority: %w", err)
}
if _, err := r.db.Exec(ctx, `
UPDATE platform_role_grants
SET revoked_at = $4
WHERE user_id = $1::uuid
AND role = $2
AND install_id = $3
AND revoked_at IS NULL
`, user.ID, params.Role, params.InstallID, now); err != nil {
return nil, fmt.Errorf("revoke superseded platform role grants: %w", err)
}
if _, err := r.db.Exec(ctx, `
INSERT INTO platform_role_grants (
user_id,
role,
install_id,
grant_payload,
grant_signature,
grant_source,
granted_at,
expires_at,
metadata
) VALUES (
$1::uuid,
$2,
$3,
$4::jsonb,
$5,
$6,
$7,
$8,
'{"bootstrap_owner":true}'::jsonb
)
`, user.ID, params.Role, params.InstallID, []byte(payload), params.ActivationSignature, params.GrantSource, now, params.ExpiresAt); err != nil {
return nil, fmt.Errorf("insert platform role grant: %w", err)
}
if _, err := r.db.Exec(ctx, `
INSERT INTO organization_memberships (
organization_id,
user_id,
role_id,
status,
invited_by_user_id,
created_at,
updated_at
)
SELECT id, $1::uuid, 'org_owner', 'active', $1::uuid, $2, $2
FROM organizations
WHERE slug = 'default'
ON CONFLICT (organization_id, user_id) DO UPDATE SET
role_id = 'org_owner',
status = 'active',
invited_by_user_id = EXCLUDED.invited_by_user_id,
updated_at = EXCLUDED.updated_at
`, user.ID, now); err != nil {
return nil, fmt.Errorf("upsert default organization owner membership: %w", err)
}
return user, nil
}
type scanner interface {
Scan(dest ...any) error
}
func scanOptionalUser(row scanner) (*User, error) {
user := &User{}
if err := row.Scan(
&user.ID,
&user.Email,
&user.PasswordHash,
&user.MFAEnabled,
&user.CreatedAt,
&user.UpdatedAt,
); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("scan user: %w", err)
}
return user, nil
}
func scanDevice(row scanner) (*Device, error) {
device := &Device{}
var trustedAt, lastSeenAt, revokedAt *time.Time
var revokedReason *string
if err := row.Scan(
&device.ID,
&device.UserID,
&device.Fingerprint,
&device.Label,
&device.TrustStatus,
&trustedAt,
&lastSeenAt,
&revokedAt,
&revokedReason,
&device.CreatedAt,
&device.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan device: %w", err)
}
device.TrustedAt = trustedAt
device.LastSeenAt = lastSeenAt
device.RevokedAt = revokedAt
device.RevokedReason = revokedReason
return device, nil
}
func scanAuthSession(row scanner) (*AuthSession, error) {
session := &AuthSession{}
var lastSeenAt, lastRotatedAt, revokedAt *time.Time
var revokedReason *string
if err := row.Scan(
&session.ID,
&session.UserID,
&session.DeviceID,
&session.RefreshTokenHash,
&session.RefreshExpiresAt,
&lastSeenAt,
&lastRotatedAt,
&revokedAt,
&revokedReason,
&session.CreatedAt,
&session.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan auth session: %w", err)
}
session.LastSeenAt = lastSeenAt
session.LastRotatedAt = lastRotatedAt
session.RevokedAt = revokedAt
session.RevokedReason = revokedReason
return session, nil
}
@@ -0,0 +1,97 @@
package auth
import (
"context"
"encoding/json"
"time"
)
type UserRepository interface {
GetByEmail(ctx context.Context, email string) (*User, error)
GetByID(ctx context.Context, userID string) (*User, error)
}
type DeviceRepository interface {
Upsert(ctx context.Context, params UpsertDeviceParams) (*Device, error)
GetByIDForUser(ctx context.Context, userID, deviceID string) (*Device, error)
ListTrustedByUser(ctx context.Context, userID string) ([]Device, error)
Revoke(ctx context.Context, params RevokeDeviceParams) error
}
type AuthSessionRepository interface {
Create(ctx context.Context, session AuthSession) error
GetByID(ctx context.Context, authSessionID string) (*AuthSession, error)
GetByIDForUpdate(ctx context.Context, authSessionID string) (*AuthSession, error)
Rotate(ctx context.Context, params RotateAuthSessionParams) error
Touch(ctx context.Context, authSessionID string, seenAt time.Time) error
Revoke(ctx context.Context, params RevokeAuthSessionParams) error
RevokeByDevice(ctx context.Context, userID, deviceID, reason string, revokedAt time.Time) error
}
type InstallationRepository interface {
GetStatus(ctx context.Context) (*InstallationAuthorityState, error)
BootstrapOwner(ctx context.Context, params BootstrapOwnerParams) (*User, error)
}
type Store interface {
Users() UserRepository
Devices() DeviceRepository
AuthSessions() AuthSessionRepository
Installation() InstallationRepository
}
type Transactor interface {
WithinTransaction(ctx context.Context, fn func(store Store) error) error
}
type UpsertDeviceParams struct {
UserID string
Fingerprint string
Label string
TrustRequested bool
SeenAt time.Time
}
type RotateAuthSessionParams struct {
AuthSessionID string
RefreshTokenHash string
RefreshExpiresAt time.Time
LastSeenAt time.Time
LastRotatedAt time.Time
}
type RevokeAuthSessionParams struct {
AuthSessionID string
UserID string
Reason string
RevokedAt time.Time
}
type RevokeDeviceParams struct {
UserID string
DeviceID string
Reason string
RevokedAt time.Time
}
type InstallationAuthorityState struct {
Bootstrapped bool
AuthorityState string
InstallID string
ProductRootFingerprint string
BootstrappedOwnerEmail string
BootstrappedAt *time.Time
}
type BootstrapOwnerParams struct {
Email string
PasswordHash string
Role string
InstallID string
ProductRootKeyFingerprint string
ActivationPayload json.RawMessage
ActivationSignature string
GrantSource string
ExpiresAt *time.Time
Now time.Time
}
+440
View File
@@ -0,0 +1,440 @@
package auth
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"github.com/example/remote-access-platform/backend/internal/platform/authority"
"github.com/example/remote-access-platform/backend/internal/platform/module"
)
type Service struct {
cfg module.Config
store Store
transactor Transactor
tokenManager *TokenManager
authority *authority.Verifier
now func() time.Time
}
func NewService(deps module.Dependencies, store Store, transactor Transactor, verifiers ...*authority.Verifier) *Service {
var authorityVerifier *authority.Verifier
if len(verifiers) > 0 {
authorityVerifier = verifiers[0]
} else if verifier, err := authority.NewVerifier(deps.Config.Installation); err == nil {
authorityVerifier = verifier
}
return &Service{
cfg: deps.Config,
store: store,
transactor: transactor,
tokenManager: NewTokenManager(TokenConfig{
Issuer: deps.Config.Auth.Issuer,
AccessTokenSecret: deps.Config.Auth.AccessTokenSecret,
RefreshHashSecret: deps.Config.Auth.RefreshHashSecret,
AccessTokenTTL: deps.Config.Auth.AccessTokenTTL,
RefreshTokenTTL: deps.Config.Auth.RefreshTokenTTL,
}),
authority: authorityVerifier,
now: time.Now,
}
}
func (s *Service) Login(ctx context.Context, cmd LoginCommand) (*AuthResult, error) {
user, err := s.store.Users().GetByEmail(ctx, cmd.Email)
if err != nil {
return nil, err
}
if user == nil {
return nil, ErrInvalidCredentials
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(cmd.Password)); err != nil {
return nil, ErrInvalidCredentials
}
var result AuthResult
now := s.now().UTC()
if err := s.transactor.WithinTransaction(ctx, func(store Store) error {
device, err := store.Devices().Upsert(ctx, UpsertDeviceParams{
UserID: user.ID,
Fingerprint: cmd.DeviceFingerprint,
Label: cmd.DeviceLabel,
TrustRequested: cmd.TrustDevice,
SeenAt: now,
})
if err != nil {
return err
}
if device.TrustStatus == DeviceTrustStatusRevoked {
return ErrDeviceRevoked
}
authSessionID := uuid.NewString()
refreshToken, refreshHash, refreshExpiresAt, err := s.tokenManager.IssueRefreshToken(authSessionID, now)
if err != nil {
return err
}
accessToken, accessExpiresAt, err := s.tokenManager.IssueAccessToken(user.ID, authSessionID, device.ID, now)
if err != nil {
return err
}
session := AuthSession{
ID: authSessionID,
UserID: user.ID,
DeviceID: device.ID,
RefreshTokenHash: refreshHash,
RefreshExpiresAt: refreshExpiresAt,
CreatedAt: now,
UpdatedAt: now,
LastSeenAt: &now,
}
if err := store.AuthSessions().Create(ctx, session); err != nil {
return err
}
result = AuthResult{
User: *user,
Device: *device,
AuthSession: session,
Tokens: TokenPair{
AccessToken: accessToken,
AccessTokenExpiresAt: accessExpiresAt,
RefreshToken: refreshToken,
RefreshTokenExpiresAt: refreshExpiresAt,
},
}
return nil
}); err != nil {
return nil, err
}
return &result, nil
}
func (s *Service) Refresh(ctx context.Context, cmd RefreshCommand) (*AuthResult, error) {
authSessionID, err := s.tokenManager.ParseRefreshToken(cmd.RefreshToken)
if err != nil {
return nil, err
}
var result AuthResult
now := s.now().UTC()
if err := s.transactor.WithinTransaction(ctx, func(store Store) error {
session, err := store.AuthSessions().GetByIDForUpdate(ctx, authSessionID)
if err != nil {
return err
}
if session == nil {
return ErrInvalidRefreshToken
}
if session.RevokedAt != nil {
return ErrAuthSessionRevoked
}
if now.After(session.RefreshExpiresAt) {
if revokeErr := store.AuthSessions().Revoke(ctx, RevokeAuthSessionParams{
AuthSessionID: session.ID,
UserID: session.UserID,
Reason: "refresh_token_expired",
RevokedAt: now,
}); revokeErr != nil {
return revokeErr
}
return ErrInvalidRefreshToken
}
expectedHash := s.tokenManager.HashRefreshToken(cmd.RefreshToken)
if expectedHash != session.RefreshTokenHash {
if revokeErr := store.AuthSessions().Revoke(ctx, RevokeAuthSessionParams{
AuthSessionID: session.ID,
UserID: session.UserID,
Reason: "refresh_rotation_reuse_detected",
RevokedAt: now,
}); revokeErr != nil {
return revokeErr
}
return ErrInvalidRefreshToken
}
user, err := store.Users().GetByID(ctx, session.UserID)
if err != nil {
return err
}
if user == nil {
return ErrInvalidCredentials
}
device, err := store.Devices().GetByIDForUser(ctx, session.UserID, session.DeviceID)
if err != nil {
return err
}
if device == nil {
return ErrTrustedDeviceMissing
}
if device.TrustStatus == DeviceTrustStatusRevoked {
return ErrDeviceRevoked
}
refreshToken, refreshHash, refreshExpiresAt, err := s.tokenManager.IssueRefreshToken(session.ID, now)
if err != nil {
return err
}
accessToken, accessExpiresAt, err := s.tokenManager.IssueAccessToken(user.ID, session.ID, device.ID, now)
if err != nil {
return err
}
if err := store.AuthSessions().Rotate(ctx, RotateAuthSessionParams{
AuthSessionID: session.ID,
RefreshTokenHash: refreshHash,
RefreshExpiresAt: refreshExpiresAt,
LastSeenAt: now,
LastRotatedAt: now,
}); err != nil {
return err
}
result = AuthResult{
User: *user,
Device: *device,
AuthSession: AuthSession{
ID: session.ID,
UserID: session.UserID,
DeviceID: session.DeviceID,
RefreshTokenHash: refreshHash,
RefreshExpiresAt: refreshExpiresAt,
LastSeenAt: &now,
LastRotatedAt: &now,
},
Tokens: TokenPair{
AccessToken: accessToken,
AccessTokenExpiresAt: accessExpiresAt,
RefreshToken: refreshToken,
RefreshTokenExpiresAt: refreshExpiresAt,
},
}
return nil
}); err != nil {
return nil, err
}
return &result, nil
}
func (s *Service) InstallationStatus(ctx context.Context) (*InstallationStatus, error) {
record, err := s.store.Installation().GetStatus(ctx)
if err != nil {
return nil, err
}
return s.installationStatusFromRecord(record), nil
}
func (s *Service) BootstrapOwner(ctx context.Context, cmd BootstrapOwnerCommand) (*BootstrapOwnerResult, error) {
email := strings.ToLower(strings.TrimSpace(cmd.Email))
password := strings.TrimSpace(cmd.Password)
if email == "" || !strings.Contains(email, "@") || len(password) < 12 {
return nil, ErrInvalidBootstrapOwner
}
now := s.now().UTC()
role := authority.PlatformRoleAdmin
installID := ""
grantSource := "installation_activation"
rootFingerprint := ""
activationPayload := cmd.ActivationPayload
activationSignature := strings.TrimSpace(cmd.ActivationSignature)
var expiresAt *time.Time
if s.strictAuthority() {
if len(activationPayload) == 0 || activationSignature == "" {
return nil, ErrInstallationActivationRequired
}
activation, err := s.authority.VerifyActivation(activationPayload, activationSignature)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrInvalidInstallationActivation, err)
}
if !strings.EqualFold(activation.OwnerEmail, email) {
return nil, ErrInvalidInstallationActivation
}
role = activation.PlatformRole
installID = activation.InstallID
expiresAt = activation.ExpiresAt
rootFingerprint = s.authority.RootFingerprint()
} else {
if s.authority == nil || !s.authority.AllowInsecureBootstrap() {
return nil, ErrInsecureBootstrapDisabled
}
installID = uuid.NewString()
grantSource = "dev_insecure"
rootFingerprint = "dev-insecure"
devPayload, err := json.Marshal(authority.ActivationPayload{
SchemaVersion: authority.ActivationSchemaVersion,
InstallID: installID,
OwnerEmail: email,
PlatformRole: role,
IssuedAt: now,
Environment: s.cfg.App.Env,
})
if err != nil {
return nil, err
}
activationPayload = json.RawMessage(devPayload)
activationSignature = "dev-insecure"
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("hash bootstrap owner password: %w", err)
}
var user *User
if err := s.transactor.WithinTransaction(ctx, func(store Store) error {
created, err := store.Installation().BootstrapOwner(ctx, BootstrapOwnerParams{
Email: email,
PasswordHash: string(passwordHash),
Role: role,
InstallID: installID,
ProductRootKeyFingerprint: rootFingerprint,
ActivationPayload: activationPayload,
ActivationSignature: activationSignature,
GrantSource: grantSource,
ExpiresAt: expiresAt,
Now: now,
})
if err != nil {
return err
}
user = created
return nil
}); err != nil {
return nil, err
}
status, err := s.InstallationStatus(ctx)
if err != nil {
return nil, err
}
return &BootstrapOwnerResult{
Installation: *status,
User: *user,
PlatformRole: role,
}, nil
}
func (s *Service) RevokeAuthSession(ctx context.Context, cmd RevokeAuthSessionCommand) error {
return s.transactor.WithinTransaction(ctx, func(store Store) error {
session, err := store.AuthSessions().GetByIDForUpdate(ctx, cmd.AuthSessionID)
if err != nil {
return err
}
if session == nil || session.UserID != cmd.UserID {
return ErrAuthSessionNotFound
}
return store.AuthSessions().Revoke(ctx, RevokeAuthSessionParams{
AuthSessionID: cmd.AuthSessionID,
UserID: cmd.UserID,
Reason: cmd.Reason,
RevokedAt: s.now().UTC(),
})
})
}
func (s *Service) RevokeTrustedDevice(ctx context.Context, cmd RevokeDeviceCommand) error {
return s.transactor.WithinTransaction(ctx, func(store Store) error {
device, err := store.Devices().GetByIDForUser(ctx, cmd.UserID, cmd.DeviceID)
if err != nil {
return err
}
if device == nil {
return ErrTrustedDeviceMissing
}
if device.TrustStatus != DeviceTrustStatusTrusted {
return ErrDeviceNotTrusted
}
now := s.now().UTC()
if err := store.Devices().Revoke(ctx, RevokeDeviceParams{
UserID: cmd.UserID,
DeviceID: cmd.DeviceID,
Reason: cmd.Reason,
RevokedAt: now,
}); err != nil {
return err
}
return store.AuthSessions().RevokeByDevice(ctx, cmd.UserID, cmd.DeviceID, "device_revoked:"+cmd.Reason, now)
})
}
func (s *Service) ListTrustedDevices(ctx context.Context, userID string) ([]Device, error) {
return s.store.Devices().ListTrustedByUser(ctx, userID)
}
func (s *Service) MapError(err error) (int, string) {
switch {
case err == nil:
return 0, ""
case errors.Is(err, ErrInvalidCredentials):
return 401, "invalid credentials"
case errors.Is(err, ErrInvalidRefreshToken):
return 401, "invalid refresh token"
case errors.Is(err, ErrAuthSessionRevoked):
return 401, "auth session revoked"
case errors.Is(err, ErrDeviceRevoked):
return 403, "device revoked"
case errors.Is(err, ErrDeviceNotTrusted):
return 409, "device is not trusted"
case errors.Is(err, ErrAuthSessionNotFound), errors.Is(err, ErrTrustedDeviceMissing):
return 404, err.Error()
case errors.Is(err, ErrInstallationActivationRequired), errors.Is(err, ErrInvalidInstallationActivation), errors.Is(err, ErrInvalidBootstrapOwner):
return 400, err.Error()
case errors.Is(err, ErrInsecureBootstrapDisabled):
return 403, err.Error()
case errors.Is(err, ErrInstallationAlreadyBootstrapped):
return 409, err.Error()
default:
return 500, fmt.Sprintf("internal error: %v", err)
}
}
func (s *Service) installationStatusFromRecord(record *InstallationAuthorityState) *InstallationStatus {
if record == nil {
record = &InstallationAuthorityState{AuthorityState: "unbootstrapped"}
}
mode := authority.ModeLegacy
strict := false
rootFingerprint := ""
insecureAllowed := false
if s.authority != nil {
mode = s.authority.Mode()
strict = s.authority.Strict()
rootFingerprint = s.authority.RootFingerprint()
insecureAllowed = s.authority.AllowInsecureBootstrap()
}
if record.ProductRootFingerprint != "" {
rootFingerprint = record.ProductRootFingerprint
}
return &InstallationStatus{
Bootstrapped: record.Bootstrapped,
AuthorityState: record.AuthorityState,
InstallID: record.InstallID,
BootstrappedOwnerEmail: record.BootstrappedOwnerEmail,
BootstrappedAt: record.BootstrappedAt,
AuthorityMode: mode,
StrictAuthority: strict,
RootFingerprint: rootFingerprint,
InsecureBootstrapAllowed: insecureAllowed,
}
}
func (s *Service) strictAuthority() bool {
return s.authority != nil && s.authority.Strict()
}
+95
View File
@@ -0,0 +1,95 @@
package auth
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
)
type TokenManager struct {
issuer string
accessSecret []byte
refreshHashSecret []byte
accessTTL time.Duration
refreshTTL time.Duration
}
type AccessClaims struct {
AuthSessionID string `json:"sid"`
DeviceID string `json:"did"`
jwt.RegisteredClaims
}
func NewTokenManager(cfg TokenConfig) *TokenManager {
return &TokenManager{
issuer: cfg.Issuer,
accessSecret: []byte(cfg.AccessTokenSecret),
refreshHashSecret: []byte(cfg.RefreshHashSecret),
accessTTL: cfg.AccessTokenTTL,
refreshTTL: cfg.RefreshTokenTTL,
}
}
type TokenConfig struct {
Issuer string
AccessTokenSecret string
RefreshHashSecret string
AccessTokenTTL time.Duration
RefreshTokenTTL time.Duration
}
func (m *TokenManager) IssueAccessToken(userID, authSessionID, deviceID string, now time.Time) (string, time.Time, error) {
expiresAt := now.Add(m.accessTTL)
claims := AccessClaims{
AuthSessionID: authSessionID,
DeviceID: deviceID,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: m.issuer,
Subject: userID,
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signed, err := token.SignedString(m.accessSecret)
if err != nil {
return "", time.Time{}, fmt.Errorf("sign access token: %w", err)
}
return signed, expiresAt, nil
}
func (m *TokenManager) IssueRefreshToken(authSessionID string, now time.Time) (raw string, hash string, expiresAt time.Time, err error) {
secret := make([]byte, 32)
if _, err = rand.Read(secret); err != nil {
return "", "", time.Time{}, fmt.Errorf("read random refresh secret: %w", err)
}
encodedSecret := base64.RawURLEncoding.EncodeToString(secret)
raw = authSessionID + "." + encodedSecret
hash = m.HashRefreshToken(raw)
expiresAt = now.Add(m.refreshTTL)
return raw, hash, expiresAt, nil
}
func (m *TokenManager) HashRefreshToken(token string) string {
mac := hmac.New(sha256.New, m.refreshHashSecret)
_, _ = mac.Write([]byte(token))
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
}
func (m *TokenManager) ParseRefreshToken(token string) (string, error) {
sessionID, _, ok := strings.Cut(token, ".")
if !ok || sessionID == "" {
return "", ErrInvalidRefreshToken
}
return sessionID, nil
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,34 @@
package cluster
import (
"encoding/json"
"testing"
)
func TestMeshLatestObservationKeySeparatesRouteHealthByRoute(t *testing.T) {
key := meshLatestObservationKey(json.RawMessage(`{
"observation_type":"synthetic_route_health",
"route_id":"route-1"
}`))
if key != "synthetic_route_health:route-1" {
t.Fatalf("key = %q", key)
}
}
func TestMeshLatestObservationKeySeparatesConnectionManagerMode(t *testing.T) {
key := meshLatestObservationKey(json.RawMessage(`{
"observation_type":"peer_connection_manager",
"transport_mode":"relay_control",
"relay_node_id":"node-r"
}`))
if key != "peer_connection_manager:relay_control:node-r" {
t.Fatalf("key = %q", key)
}
}
func TestMeshLatestObservationKeyDefaults(t *testing.T) {
key := meshLatestObservationKey(json.RawMessage(`{}`))
if key != "default" {
t.Fatalf("key = %q", key)
}
}
@@ -0,0 +1,91 @@
package cluster
import (
"context"
"encoding/json"
"time"
)
type Repository interface {
GetPlatformRole(ctx context.Context, userID string) (string, error)
ListClusters(ctx context.Context) ([]Cluster, error)
GetCluster(ctx context.Context, clusterID string) (Cluster, error)
CreateCluster(ctx context.Context, input CreateClusterInput) (Cluster, error)
UpdateCluster(ctx context.Context, input UpdateClusterInput) (Cluster, error)
GetClusterAuthority(ctx context.Context, clusterID string) (ClusterAuthorityKey, error)
EnsureClusterAuthority(ctx context.Context, clusterID string, actorUserID *string) (ClusterAuthorityKey, error)
ListClusterNodes(ctx context.Context, clusterID string) ([]ClusterNode, error)
ListNodeGroups(ctx context.Context, clusterID string) ([]ClusterNodeGroup, error)
CreateNodeGroup(ctx context.Context, input CreateNodeGroupInput) (ClusterNodeGroup, error)
AssignNodeToGroup(ctx context.Context, input AssignNodeGroupInput) (ClusterNode, error)
CreateJoinToken(ctx context.Context, input CreateJoinTokenInput, tokenHash string) (NodeJoinToken, error)
SetJoinTokenAuthority(ctx context.Context, clusterID, tokenID string, payload json.RawMessage, signature ClusterSignature) (NodeJoinToken, error)
GetValidJoinTokenByHash(ctx context.Context, clusterID, tokenHash string) (NodeJoinToken, error)
RevokeJoinToken(ctx context.Context, input RevokeJoinTokenInput) (NodeJoinToken, error)
ExpireJoinTokens(ctx context.Context, clusterID string) error
CreateJoinRequest(ctx context.Context, input CreateJoinRequestInput, joinTokenID string) (NodeJoinRequest, error)
GetJoinRequestForBootstrap(ctx context.Context, input GetJoinRequestBootstrapInput) (NodeJoinRequest, error)
ListJoinRequests(ctx context.Context, clusterID string) ([]NodeJoinRequest, error)
ApproveJoinRequest(ctx context.Context, input ApproveJoinRequestInput) (ApprovedJoinRequest, error)
SetJoinRequestApprovalAuthority(ctx context.Context, clusterID, joinRequestID string, payload json.RawMessage, signature ClusterSignature) (NodeJoinRequest, error)
RejectJoinRequest(ctx context.Context, input RejectJoinRequestInput) (NodeJoinRequest, error)
AssignNodeRole(ctx context.Context, input AssignNodeRoleInput) (NodeRoleAssignment, error)
ListNodeRoleAssignments(ctx context.Context, clusterID, nodeID string) ([]NodeRoleAssignment, error)
AttachExistingNodeToCluster(ctx context.Context, input AttachExistingNodeInput) (ClusterNode, error)
RecordHeartbeat(ctx context.Context, input RecordHeartbeatInput) (NodeHeartbeat, error)
ListNodeHeartbeats(ctx context.Context, clusterID, nodeID string, limit int) ([]NodeHeartbeat, error)
RevokeNodeIdentity(ctx context.Context, input RevokeNodeIdentityInput) error
DisableClusterMembership(ctx context.Context, input DisableMembershipInput) error
UpsertFabricTestingFlag(ctx context.Context, input UpsertFabricTestingFlagInput) (FabricTestingFlag, error)
ListFabricTestingFlags(ctx context.Context) ([]FabricTestingFlag, error)
GetEffectiveNodeTestingFlags(ctx context.Context, clusterID, nodeID string) (EffectiveNodeTestingFlags, error)
RecordNodeTelemetry(ctx context.Context, input RecordNodeTelemetryInput) (NodeTelemetryObservation, error)
ListNodeTelemetry(ctx context.Context, clusterID, nodeID string, limit int) ([]NodeTelemetryObservation, error)
SetDesiredWorkload(ctx context.Context, input SetDesiredWorkloadInput) (NodeWorkloadDesiredState, error)
ListDesiredWorkloads(ctx context.Context, clusterID, nodeID string) ([]NodeWorkloadDesiredState, error)
ReportWorkloadStatus(ctx context.Context, input ReportWorkloadStatusInput) (NodeWorkloadStatus, error)
ListLatestWorkloadStatuses(ctx context.Context, clusterID, nodeID string) ([]NodeWorkloadStatus, error)
ReportMeshLink(ctx context.Context, input ReportMeshLinkInput) (MeshLinkObservation, error)
ListMeshLinks(ctx context.Context, clusterID string) ([]MeshLinkObservation, error)
CreateRouteIntent(ctx context.Context, input CreateRouteIntentInput) (MeshRouteIntent, error)
ListRouteIntents(ctx context.Context, clusterID string) ([]MeshRouteIntent, error)
ListQoSPolicies(ctx context.Context, clusterID string) ([]MeshQoSPolicy, error)
ListFabricEntryPoints(ctx context.Context, clusterID string) ([]FabricEntryPoint, error)
CreateFabricEntryPoint(ctx context.Context, input CreateFabricEntryPointInput) (FabricEntryPoint, error)
SetFabricEntryPointNode(ctx context.Context, input SetFabricEntryPointNodeInput) (FabricEntryPointNode, error)
ListFabricEntryPointNodes(ctx context.Context, clusterID, entryPointID string) ([]FabricEntryPointNode, error)
ListFabricEgressPools(ctx context.Context, clusterID string) ([]FabricEgressPool, error)
CreateFabricEgressPool(ctx context.Context, input CreateFabricEgressPoolInput) (FabricEgressPool, error)
SetFabricEgressPoolNode(ctx context.Context, input SetFabricEgressPoolNodeInput) (FabricEgressPoolNode, error)
ListFabricEgressPoolNodes(ctx context.Context, clusterID, egressPoolID string) ([]FabricEgressPoolNode, error)
GetClusterAuthorityState(ctx context.Context, clusterID string) (ClusterAuthorityState, error)
UpdateClusterAuthorityState(ctx context.Context, input UpdateClusterAuthorityInput) (ClusterAuthorityState, error)
ListClusterAdminSummaries(ctx context.Context) ([]ClusterAdminSummary, error)
CreateVPNConnection(ctx context.Context, input CreateVPNConnectionInput) (VPNConnection, error)
ListVPNConnections(ctx context.Context, clusterID string) ([]VPNConnection, error)
GetVPNConnection(ctx context.Context, clusterID, vpnConnectionID string) (VPNConnection, error)
UpdateVPNConnectionDesiredState(ctx context.Context, input UpdateVPNConnectionDesiredStateInput) (VPNConnection, error)
UpsertVPNConnectionRoutePolicy(ctx context.Context, input UpsertVPNConnectionRoutePolicyInput) (VPNConnectionRoutePolicy, error)
ListVPNConnectionRoutePolicies(ctx context.Context, clusterID, vpnConnectionID string) ([]VPNConnectionRoutePolicy, error)
SetVPNConnectionAllowedNodes(ctx context.Context, input SetVPNConnectionAllowedNodesInput) ([]VPNConnectionAllowedNode, error)
ListVPNConnectionAllowedNodes(ctx context.Context, clusterID, vpnConnectionID string) ([]VPNConnectionAllowedNode, error)
AcquireVPNConnectionLease(ctx context.Context, input AcquireVPNConnectionLeaseInput, expiresAt time.Time, fencingToken string) (VPNConnectionLease, error)
RenewVPNConnectionLease(ctx context.Context, input RenewVPNConnectionLeaseInput, expiresAt time.Time) (VPNConnectionLease, error)
ReleaseVPNConnectionLease(ctx context.Context, input ReleaseVPNConnectionLeaseInput) (VPNConnectionLease, error)
FenceVPNConnectionLease(ctx context.Context, input FenceVPNConnectionLeaseInput) (VPNConnectionLease, error)
GetActiveVPNConnectionLease(ctx context.Context, clusterID, vpnConnectionID string) (VPNConnectionLease, error)
CheckVPNLeaseOwnerEligibility(ctx context.Context, clusterID, vpnConnectionID, ownerNodeID string) (VPNLeaseOwnerEligibility, error)
ExpireStaleVPNConnectionLeases(ctx context.Context, clusterID string, now time.Time) ([]VPNConnectionLease, error)
ListNodeVPNAssignments(ctx context.Context, clusterID, nodeID string) ([]NodeVPNAssignment, error)
ReportNodeVPNAssignmentStatus(ctx context.Context, input ReportNodeVPNAssignmentStatusInput) (NodeVPNAssignmentStatus, error)
RecordAudit(ctx context.Context, event ClusterAuditEvent) error
ListAuditEvents(ctx context.Context, clusterID string, limit int) ([]ClusterAuditEvent, error)
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+43
View File
@@ -0,0 +1,43 @@
package cluster
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
"strings"
"time"
)
const joinTokenHashPrefix = "sha256:"
func generateJoinToken() (string, error) {
var random [32]byte
if _, err := rand.Read(random[:]); err != nil {
return "", err
}
return "rap_join_" + base64.RawURLEncoding.EncodeToString(random[:]), nil
}
func hashJoinToken(token string) (string, error) {
trimmed := strings.TrimSpace(token)
if trimmed == "" {
return "", errors.New("join token is required")
}
sum := sha256.Sum256([]byte(trimmed))
return joinTokenHashPrefix + hex.EncodeToString(sum[:]), nil
}
func isPlatformAdminRole(role string) bool {
return role == PlatformRoleAdmin || role == PlatformRoleRecoveryAdmin
}
func isAllowedNodeRole(role string) bool {
_, ok := allowedNodeRoles[role]
return ok
}
func defaultJoinTokenExpiry(now time.Time) time.Time {
return now.Add(30 * time.Minute)
}
@@ -0,0 +1,344 @@
package identitysource
import (
"context"
"encoding/json"
"errors"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
"github.com/example/remote-access-platform/backend/internal/platform/module"
)
type Module struct {
db *pgxpool.Pool
}
type IdentitySource struct {
ID string `json:"id"`
OrganizationID string `json:"organization_id"`
Kind string `json:"kind"`
Name string `json:"name"`
Status string `json:"status"`
Config json.RawMessage `json:"config"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type IdentityMapping struct {
ID string `json:"id"`
IdentitySourceID string `json:"identity_source_id"`
MappingType string `json:"mapping_type"`
ExternalSelector json.RawMessage `json:"external_selector"`
InternalTarget json.RawMessage `json:"internal_target"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type upsertIdentitySourceRequest struct {
ActorUserID string `json:"actor_user_id"`
OrganizationID string `json:"organization_id"`
Kind string `json:"kind"`
Name string `json:"name"`
Status string `json:"status"`
Config json.RawMessage `json:"config"`
IdentityMappings []struct {
MappingType string `json:"mapping_type"`
ExternalSelector json.RawMessage `json:"external_selector"`
InternalTarget json.RawMessage `json:"internal_target"`
} `json:"identity_mappings"`
}
func NewModule(deps module.Dependencies) *Module {
return &Module{db: deps.Infra.DB}
}
func (m *Module) Name() string {
return "identitysource"
}
func (m *Module) RegisterRoutes(router chi.Router) {
router.Route("/identity-sources", func(r chi.Router) {
r.Get("/", m.listIdentitySources)
r.Post("/", m.createIdentitySource)
r.Get("/{identitySourceID}", m.getIdentitySource)
r.Put("/{identitySourceID}", m.updateIdentitySource)
})
}
func (m *Module) listIdentitySources(w http.ResponseWriter, r *http.Request) {
orgID := r.URL.Query().Get("organization_id")
if orgID == "" {
httpx.WriteError(w, http.StatusBadRequest, "organization_id is required")
return
}
rows, err := m.db.Query(r.Context(), `
SELECT id, organization_id, kind, name, status, config, created_at, updated_at
FROM identity_sources
WHERE organization_id = $1
ORDER BY created_at DESC
`, orgID)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer rows.Close()
var items []IdentitySource
for rows.Next() {
item, err := scanIdentitySource(rows)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
items = append(items, item)
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"identity_sources": items})
}
func (m *Module) getIdentitySource(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "identitySourceID")
item, err := m.getByID(r.Context(), id)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
httpx.WriteError(w, http.StatusNotFound, "identity source not found")
return
}
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
mappings, err := m.listMappings(r.Context(), id)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{
"identity_source": item,
"identity_mappings": mappings,
})
}
func (m *Module) createIdentitySource(w http.ResponseWriter, r *http.Request) {
req, err := decodeRequest(r)
if err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
now := time.Now().UTC()
item := IdentitySource{
ID: uuid.NewString(),
OrganizationID: req.OrganizationID,
Kind: req.Kind,
Name: req.Name,
Status: req.Status,
Config: req.Config,
CreatedAt: now,
UpdatedAt: now,
}
tx, err := m.db.Begin(r.Context())
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer tx.Rollback(r.Context())
if _, err := tx.Exec(r.Context(), `
INSERT INTO identity_sources (id, organization_id, kind, name, status, config, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6::jsonb, $7, $8)
`, item.ID, item.OrganizationID, item.Kind, item.Name, item.Status, []byte(item.Config), item.CreatedAt, item.UpdatedAt); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
mappings, err := upsertMappings(r.Context(), tx, item.ID, req.IdentityMappings)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := tx.Commit(r.Context()); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusCreated, map[string]any{
"identity_source": item,
"identity_mappings": mappings,
})
}
func (m *Module) updateIdentitySource(w http.ResponseWriter, r *http.Request) {
req, err := decodeRequest(r)
if err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
id := chi.URLParam(r, "identitySourceID")
tx, err := m.db.Begin(r.Context())
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer tx.Rollback(r.Context())
tag, err := tx.Exec(r.Context(), `
UPDATE identity_sources
SET organization_id = $2, kind = $3, name = $4, status = $5, config = $6::jsonb, updated_at = $7
WHERE id = $1
`, id, req.OrganizationID, req.Kind, req.Name, req.Status, []byte(req.Config), time.Now().UTC())
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if tag.RowsAffected() == 0 {
httpx.WriteError(w, http.StatusNotFound, "identity source not found")
return
}
if _, err := tx.Exec(r.Context(), `DELETE FROM identity_mappings WHERE identity_source_id = $1`, id); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
mappings, err := upsertMappings(r.Context(), tx, id, req.IdentityMappings)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := tx.Commit(r.Context()); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
item, err := m.getByID(r.Context(), id)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{
"identity_source": item,
"identity_mappings": mappings,
})
}
func decodeRequest(r *http.Request) (*upsertIdentitySourceRequest, error) {
var req upsertIdentitySourceRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.New("invalid identity source payload")
}
if req.ActorUserID == "" || req.OrganizationID == "" || req.Kind == "" || req.Name == "" {
return nil, errors.New("actor_user_id, organization_id, kind, and name are required")
}
if req.Status == "" {
req.Status = "active"
}
if len(req.Config) == 0 {
req.Config = json.RawMessage(`{}`)
}
if !json.Valid(req.Config) {
return nil, errors.New("config must be valid json")
}
for _, mapping := range req.IdentityMappings {
if len(mapping.ExternalSelector) == 0 {
mapping.ExternalSelector = json.RawMessage(`{}`)
}
if len(mapping.InternalTarget) == 0 {
mapping.InternalTarget = json.RawMessage(`{}`)
}
}
return &req, nil
}
func (m *Module) getByID(ctx context.Context, id string) (IdentitySource, error) {
row := m.db.QueryRow(ctx, `
SELECT id, organization_id, kind, name, status, config, created_at, updated_at
FROM identity_sources
WHERE id = $1
`, id)
return scanIdentitySource(row)
}
func (m *Module) listMappings(ctx context.Context, sourceID string) ([]IdentityMapping, error) {
rows, err := m.db.Query(ctx, `
SELECT id, identity_source_id, mapping_type, external_selector, internal_target, created_at, updated_at
FROM identity_mappings
WHERE identity_source_id = $1
ORDER BY created_at ASC
`, sourceID)
if err != nil {
return nil, err
}
defer rows.Close()
var mappings []IdentityMapping
for rows.Next() {
item, err := scanIdentityMapping(rows)
if err != nil {
return nil, err
}
mappings = append(mappings, item)
}
return mappings, rows.Err()
}
func upsertMappings(ctx context.Context, tx pgx.Tx, sourceID string, requested []struct {
MappingType string `json:"mapping_type"`
ExternalSelector json.RawMessage `json:"external_selector"`
InternalTarget json.RawMessage `json:"internal_target"`
}) ([]IdentityMapping, error) {
now := time.Now().UTC()
items := make([]IdentityMapping, 0, len(requested))
for _, mapping := range requested {
external := mapping.ExternalSelector
if len(external) == 0 {
external = json.RawMessage(`{}`)
}
internal := mapping.InternalTarget
if len(internal) == 0 {
internal = json.RawMessage(`{}`)
}
item := IdentityMapping{
ID: uuid.NewString(),
IdentitySourceID: sourceID,
MappingType: mapping.MappingType,
ExternalSelector: external,
InternalTarget: internal,
CreatedAt: now,
UpdatedAt: now,
}
if _, err := tx.Exec(ctx, `
INSERT INTO identity_mappings (
id, identity_source_id, mapping_type, external_selector, internal_target, created_at, updated_at
) VALUES ($1, $2, $3, $4::jsonb, $5::jsonb, $6, $7)
`, item.ID, item.IdentitySourceID, item.MappingType, []byte(item.ExternalSelector), []byte(item.InternalTarget), item.CreatedAt, item.UpdatedAt); err != nil {
return nil, err
}
items = append(items, item)
}
return items, nil
}
type rowScanner interface {
Scan(dest ...any) error
}
func scanIdentitySource(row rowScanner) (IdentitySource, error) {
var item IdentitySource
if err := row.Scan(&item.ID, &item.OrganizationID, &item.Kind, &item.Name, &item.Status, &item.Config, &item.CreatedAt, &item.UpdatedAt); err != nil {
return IdentitySource{}, err
}
if len(item.Config) == 0 {
item.Config = json.RawMessage(`{}`)
}
return item, nil
}
func scanIdentityMapping(row rowScanner) (IdentityMapping, error) {
var item IdentityMapping
if err := row.Scan(&item.ID, &item.IdentitySourceID, &item.MappingType, &item.ExternalSelector, &item.InternalTarget, &item.CreatedAt, &item.UpdatedAt); err != nil {
return IdentityMapping{}, err
}
if len(item.ExternalSelector) == 0 {
item.ExternalSelector = json.RawMessage(`{}`)
}
if len(item.InternalTarget) == 0 {
item.InternalTarget = json.RawMessage(`{}`)
}
return item, nil
}
+458
View File
@@ -0,0 +1,458 @@
package node
import (
"context"
"encoding/json"
"errors"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
"github.com/example/remote-access-platform/backend/internal/platform/module"
)
type Module struct {
db *pgxpool.Pool
}
type Node struct {
ID string `json:"id"`
OwnerOrganizationID *string `json:"owner_organization_id,omitempty"`
NodeKey string `json:"node_key"`
Name string `json:"name"`
OwnershipType string `json:"ownership_type"`
RegistrationStatus string `json:"registration_status"`
HealthStatus string `json:"health_status"`
VersionState string `json:"version_state"`
PartitionState string `json:"partition_state"`
DesiredVersion *string `json:"desired_version,omitempty"`
ReportedVersion *string `json:"reported_version,omitempty"`
LastSeenAt *time.Time `json:"last_seen_at,omitempty"`
Metadata json.RawMessage `json:"metadata"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type NodeCapability struct {
NodeID string `json:"node_id"`
Capability string `json:"capability"`
Value json.RawMessage `json:"value"`
UpdatedAt time.Time `json:"updated_at"`
}
type NodeService struct {
NodeID string `json:"node_id"`
ServiceType string `json:"service_type"`
Enabled bool `json:"enabled"`
DesiredState string `json:"desired_state"`
ReportedState string `json:"reported_state"`
LastReportedAt *time.Time `json:"last_reported_at,omitempty"`
Metadata json.RawMessage `json:"metadata"`
UpdatedAt time.Time `json:"updated_at"`
}
type NodeUpdatePolicy struct {
NodeID string `json:"node_id"`
Mode string `json:"mode"`
Channel string `json:"channel"`
MaintenanceWindow json.RawMessage `json:"maintenance_window"`
Canary bool `json:"canary"`
AutomaticRollout bool `json:"automatic_rollout"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type NodePartitionState struct {
NodeID string `json:"node_id"`
ClusterState string `json:"cluster_state"`
RecoveryMode string `json:"recovery_mode"`
Notes *string `json:"notes,omitempty"`
UpdatedAt time.Time `json:"updated_at"`
}
type upsertNodeRequest struct {
ActorUserID string `json:"actor_user_id"`
OwnerOrganizationID *string `json:"owner_organization_id"`
NodeKey string `json:"node_key"`
Name string `json:"name"`
OwnershipType string `json:"ownership_type"`
DesiredVersion *string `json:"desired_version"`
Metadata json.RawMessage `json:"metadata"`
Capabilities []struct {
Capability string `json:"capability"`
Value json.RawMessage `json:"value"`
} `json:"capabilities"`
Services []struct {
ServiceType string `json:"service_type"`
Enabled bool `json:"enabled"`
DesiredState string `json:"desired_state"`
Metadata json.RawMessage `json:"metadata"`
} `json:"services"`
UpdatePolicy struct {
Mode string `json:"mode"`
Channel string `json:"channel"`
MaintenanceWindow json.RawMessage `json:"maintenance_window"`
Canary bool `json:"canary"`
AutomaticRollout bool `json:"automatic_rollout"`
} `json:"update_policy"`
PartitionState struct {
ClusterState string `json:"cluster_state"`
RecoveryMode string `json:"recovery_mode"`
Notes *string `json:"notes"`
} `json:"partition_state"`
}
func NewModule(deps module.Dependencies) *Module {
return &Module{db: deps.Infra.DB}
}
func (m *Module) Name() string {
return "node"
}
func (m *Module) RegisterRoutes(router chi.Router) {
router.Route("/nodes", func(r chi.Router) {
r.Get("/", m.listNodes)
r.Post("/", m.createNode)
r.Get("/{nodeID}", m.getNode)
r.Put("/{nodeID}", m.updateNode)
})
}
func (m *Module) listNodes(w http.ResponseWriter, r *http.Request) {
rows, err := m.db.Query(r.Context(), `
SELECT id, owner_organization_id, node_key, name, ownership_type, registration_status, health_status,
version_state, partition_state, desired_version, reported_version, last_seen_at, metadata, created_at, updated_at
FROM nodes
ORDER BY created_at DESC
`)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer rows.Close()
var items []Node
for rows.Next() {
item, err := scanNode(rows)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
items = append(items, item)
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"nodes": items})
}
func (m *Module) getNode(w http.ResponseWriter, r *http.Request) {
nodeID := chi.URLParam(r, "nodeID")
item, err := m.getNodeByID(r.Context(), nodeID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
httpx.WriteError(w, http.StatusNotFound, "node not found")
return
}
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
caps, _ := m.listCapabilities(r.Context(), nodeID)
services, _ := m.listServices(r.Context(), nodeID)
updatePolicy, _ := m.getUpdatePolicy(r.Context(), nodeID)
partitionState, _ := m.getPartitionState(r.Context(), nodeID)
httpx.WriteJSON(w, http.StatusOK, map[string]any{
"node": item,
"capabilities": caps,
"services": services,
"update_policy": updatePolicy,
"partition_state": partitionState,
})
}
func (m *Module) createNode(w http.ResponseWriter, r *http.Request) {
req, err := decodeNodeRequest(r)
if err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
item := Node{
ID: uuid.NewString(),
OwnerOrganizationID: req.OwnerOrganizationID,
NodeKey: req.NodeKey,
Name: req.Name,
OwnershipType: req.OwnershipType,
RegistrationStatus: "pending",
HealthStatus: "unknown",
VersionState: "unknown",
PartitionState: "healthy",
DesiredVersion: req.DesiredVersion,
Metadata: req.Metadata,
CreatedAt: time.Now().UTC(),
UpdatedAt: time.Now().UTC(),
}
if err := m.persistNode(r.Context(), item, req, true); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusCreated, map[string]any{"node": item})
}
func (m *Module) updateNode(w http.ResponseWriter, r *http.Request) {
req, err := decodeNodeRequest(r)
if err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
nodeID := chi.URLParam(r, "nodeID")
item, err := m.getNodeByID(r.Context(), nodeID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
httpx.WriteError(w, http.StatusNotFound, "node not found")
return
}
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
item.OwnerOrganizationID = req.OwnerOrganizationID
item.NodeKey = req.NodeKey
item.Name = req.Name
item.OwnershipType = req.OwnershipType
item.DesiredVersion = req.DesiredVersion
item.Metadata = req.Metadata
item.UpdatedAt = time.Now().UTC()
if err := m.persistNode(r.Context(), item, req, false); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"node": item})
}
func decodeNodeRequest(r *http.Request) (*upsertNodeRequest, error) {
var req upsertNodeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.New("invalid node payload")
}
if req.ActorUserID == "" || req.NodeKey == "" || req.Name == "" || req.OwnershipType == "" {
return nil, errors.New("actor_user_id, node_key, name, and ownership_type are required")
}
if len(req.Metadata) == 0 {
req.Metadata = json.RawMessage(`{}`)
}
if !json.Valid(req.Metadata) {
return nil, errors.New("metadata must be valid json")
}
if req.UpdatePolicy.Mode == "" {
req.UpdatePolicy.Mode = "manual"
}
if req.UpdatePolicy.Channel == "" {
req.UpdatePolicy.Channel = "stable"
}
if len(req.UpdatePolicy.MaintenanceWindow) == 0 {
req.UpdatePolicy.MaintenanceWindow = json.RawMessage(`{}`)
}
if req.PartitionState.ClusterState == "" {
req.PartitionState.ClusterState = "healthy"
}
if req.PartitionState.RecoveryMode == "" {
req.PartitionState.RecoveryMode = "normal"
}
for i := range req.Capabilities {
if len(req.Capabilities[i].Value) == 0 {
req.Capabilities[i].Value = json.RawMessage(`{}`)
}
}
for i := range req.Services {
if req.Services[i].DesiredState == "" {
req.Services[i].DesiredState = "disabled"
}
if len(req.Services[i].Metadata) == 0 {
req.Services[i].Metadata = json.RawMessage(`{}`)
}
}
return &req, nil
}
func (m *Module) persistNode(ctx context.Context, item Node, req *upsertNodeRequest, create bool) error {
tx, err := m.db.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
if create {
_, err = tx.Exec(ctx, `
INSERT INTO nodes (
id, owner_organization_id, node_key, name, ownership_type, registration_status, health_status,
version_state, partition_state, desired_version, reported_version, last_seen_at, metadata, created_at, updated_at
) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13::jsonb,$14,$15)
`, item.ID, item.OwnerOrganizationID, item.NodeKey, item.Name, item.OwnershipType, item.RegistrationStatus, item.HealthStatus, item.VersionState, item.PartitionState, item.DesiredVersion, item.ReportedVersion, item.LastSeenAt, []byte(item.Metadata), item.CreatedAt, item.UpdatedAt)
} else {
_, err = tx.Exec(ctx, `
UPDATE nodes
SET owner_organization_id=$2, node_key=$3, name=$4, ownership_type=$5, desired_version=$6, metadata=$7::jsonb, updated_at=$8
WHERE id=$1
`, item.ID, item.OwnerOrganizationID, item.NodeKey, item.Name, item.OwnershipType, item.DesiredVersion, []byte(item.Metadata), item.UpdatedAt)
}
if err != nil {
return err
}
if _, err := tx.Exec(ctx, `DELETE FROM node_capabilities WHERE node_id = $1`, item.ID); err != nil {
return err
}
for _, capability := range req.Capabilities {
if _, err := tx.Exec(ctx, `
INSERT INTO node_capabilities (node_id, capability, value, updated_at)
VALUES ($1, $2, $3::jsonb, $4)
`, item.ID, capability.Capability, []byte(capability.Value), time.Now().UTC()); err != nil {
return err
}
}
if _, err := tx.Exec(ctx, `DELETE FROM node_services WHERE node_id = $1`, item.ID); err != nil {
return err
}
for _, service := range req.Services {
if _, err := tx.Exec(ctx, `
INSERT INTO node_services (
node_id, service_type, enabled, desired_state, reported_state, last_reported_at, metadata, updated_at
) VALUES ($1, $2, $3, $4, 'unknown', NULL, $5::jsonb, $6)
`, item.ID, service.ServiceType, service.Enabled, service.DesiredState, []byte(service.Metadata), time.Now().UTC()); err != nil {
return err
}
}
if _, err := tx.Exec(ctx, `
INSERT INTO node_update_policies (
node_id, mode, channel, maintenance_window, canary, automatic_rollout, created_at, updated_at
) VALUES ($1,$2,$3,$4::jsonb,$5,$6,$7,$8)
ON CONFLICT (node_id) DO UPDATE SET
mode = EXCLUDED.mode,
channel = EXCLUDED.channel,
maintenance_window = EXCLUDED.maintenance_window,
canary = EXCLUDED.canary,
automatic_rollout = EXCLUDED.automatic_rollout,
updated_at = EXCLUDED.updated_at
`, item.ID, req.UpdatePolicy.Mode, req.UpdatePolicy.Channel, []byte(req.UpdatePolicy.MaintenanceWindow), req.UpdatePolicy.Canary, req.UpdatePolicy.AutomaticRollout, time.Now().UTC(), time.Now().UTC()); err != nil {
return err
}
if _, err := tx.Exec(ctx, `
INSERT INTO node_partition_states (node_id, cluster_state, recovery_mode, notes, updated_at)
VALUES ($1,$2,$3,$4,$5)
ON CONFLICT (node_id) DO UPDATE SET
cluster_state = EXCLUDED.cluster_state,
recovery_mode = EXCLUDED.recovery_mode,
notes = EXCLUDED.notes,
updated_at = EXCLUDED.updated_at
`, item.ID, req.PartitionState.ClusterState, req.PartitionState.RecoveryMode, req.PartitionState.Notes, time.Now().UTC()); err != nil {
return err
}
return tx.Commit(ctx)
}
func (m *Module) getNodeByID(ctx context.Context, nodeID string) (Node, error) {
row := m.db.QueryRow(ctx, `
SELECT id, owner_organization_id, node_key, name, ownership_type, registration_status, health_status,
version_state, partition_state, desired_version, reported_version, last_seen_at, metadata, created_at, updated_at
FROM nodes
WHERE id = $1
`, nodeID)
return scanNode(row)
}
func (m *Module) listCapabilities(ctx context.Context, nodeID string) ([]NodeCapability, error) {
rows, err := m.db.Query(ctx, `SELECT node_id, capability, value, updated_at FROM node_capabilities WHERE node_id = $1 ORDER BY capability`, nodeID)
if err != nil {
return nil, err
}
defer rows.Close()
var out []NodeCapability
for rows.Next() {
var item NodeCapability
if err := rows.Scan(&item.NodeID, &item.Capability, &item.Value, &item.UpdatedAt); err != nil {
return nil, err
}
out = append(out, item)
}
return out, rows.Err()
}
func (m *Module) listServices(ctx context.Context, nodeID string) ([]NodeService, error) {
rows, err := m.db.Query(ctx, `
SELECT node_id, service_type, enabled, desired_state, reported_state, last_reported_at, metadata, updated_at
FROM node_services WHERE node_id = $1 ORDER BY service_type
`, nodeID)
if err != nil {
return nil, err
}
defer rows.Close()
var out []NodeService
for rows.Next() {
var item NodeService
if err := rows.Scan(&item.NodeID, &item.ServiceType, &item.Enabled, &item.DesiredState, &item.ReportedState, &item.LastReportedAt, &item.Metadata, &item.UpdatedAt); err != nil {
return nil, err
}
out = append(out, item)
}
return out, rows.Err()
}
func (m *Module) getUpdatePolicy(ctx context.Context, nodeID string) (*NodeUpdatePolicy, error) {
row := m.db.QueryRow(ctx, `
SELECT node_id, mode, channel, maintenance_window, canary, automatic_rollout, created_at, updated_at
FROM node_update_policies WHERE node_id = $1
`, nodeID)
var item NodeUpdatePolicy
if err := row.Scan(&item.NodeID, &item.Mode, &item.Channel, &item.MaintenanceWindow, &item.Canary, &item.AutomaticRollout, &item.CreatedAt, &item.UpdatedAt); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return nil, err
}
return &item, nil
}
func (m *Module) getPartitionState(ctx context.Context, nodeID string) (*NodePartitionState, error) {
row := m.db.QueryRow(ctx, `
SELECT node_id, cluster_state, recovery_mode, notes, updated_at
FROM node_partition_states WHERE node_id = $1
`, nodeID)
var item NodePartitionState
if err := row.Scan(&item.NodeID, &item.ClusterState, &item.RecoveryMode, &item.Notes, &item.UpdatedAt); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return nil, err
}
return &item, nil
}
type rowScanner interface {
Scan(dest ...any) error
}
func scanNode(row rowScanner) (Node, error) {
var item Node
if err := row.Scan(
&item.ID,
&item.OwnerOrganizationID,
&item.NodeKey,
&item.Name,
&item.OwnershipType,
&item.RegistrationStatus,
&item.HealthStatus,
&item.VersionState,
&item.PartitionState,
&item.DesiredVersion,
&item.ReportedVersion,
&item.LastSeenAt,
&item.Metadata,
&item.CreatedAt,
&item.UpdatedAt,
); err != nil {
return Node{}, err
}
if len(item.Metadata) == 0 {
item.Metadata = json.RawMessage(`{}`)
}
return item, nil
}
@@ -0,0 +1,356 @@
package nodeagent
import (
"encoding/json"
"errors"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
clustermodule "github.com/example/remote-access-platform/backend/internal/modules/cluster"
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
"github.com/example/remote-access-platform/backend/internal/platform/module"
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
)
type Module struct {
db *pgxpool.Pool
cluster *clustermodule.Service
}
func NewModule(deps module.Dependencies) *Module {
clusterStore := clustermodule.NewPostgresStore(deps.Infra.DB)
if deps.Config.Secret.EncryptionKeyBase64 != "" {
if encryptor, err := secrets.NewEncryptor(deps.Config.Secret.EncryptionKeyBase64, deps.Config.Secret.EncryptionKeyID); err == nil {
clusterStore.WithClusterKeyEncryptor(encryptor)
}
}
return &Module{
db: deps.Infra.DB,
cluster: clustermodule.NewService(clusterStore),
}
}
func (m *Module) Name() string {
return "nodeagent"
}
func (m *Module) RegisterRoutes(router chi.Router) {
router.Route("/node-agents", func(r chi.Router) {
r.Post("/enroll", m.enrollAgent)
r.Post("/enrollments/{requestID}/bootstrap", m.bootstrapEnrollment)
r.Post("/register", m.registerAgent)
r.Post("/{nodeID}/health", m.reportHealth)
r.Post("/{nodeID}/services/status", m.reportServiceStatus)
r.Post("/{nodeID}/update-manifest/request", m.requestUpdateManifest)
r.Post("/{nodeID}/update-result", m.acknowledgeUpdateResult)
r.Post("/{nodeID}/rollback-result", m.reportRollbackResult)
r.Get("/{nodeID}/clusters/{clusterID}/vpn-assignments/desired", m.listVPNAssignments)
r.Post("/{nodeID}/clusters/{clusterID}/vpn-assignments/{vpnConnectionID}/status", m.reportVPNAssignmentStatus)
})
}
func (m *Module) enrollAgent(w http.ResponseWriter, r *http.Request) {
var payload struct {
ClusterID string `json:"cluster_id"`
JoinToken string `json:"join_token"`
NodeName string `json:"node_name"`
NodeFingerprint string `json:"node_fingerprint"`
PublicKey string `json:"public_key"`
ReportedCapabilities json.RawMessage `json:"reported_capabilities"`
ReportedFacts json.RawMessage `json:"reported_facts"`
RequestedRoles json.RawMessage `json:"requested_roles"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid agent enrollment payload")
return
}
joinRequest, err := m.cluster.CreateJoinRequest(r.Context(), clustermodule.CreateJoinRequestInput{
ClusterID: payload.ClusterID,
JoinToken: payload.JoinToken,
NodeName: payload.NodeName,
NodeFingerprint: payload.NodeFingerprint,
PublicKey: payload.PublicKey,
ReportedCapabilities: payload.ReportedCapabilities,
ReportedFacts: payload.ReportedFacts,
RequestedRoles: payload.RequestedRoles,
})
if err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{
"status": "pending_approval",
"join_request": joinRequest,
})
}
func (m *Module) bootstrapEnrollment(w http.ResponseWriter, r *http.Request) {
var payload struct {
ClusterID string `json:"cluster_id"`
NodeFingerprint string `json:"node_fingerprint"`
PublicKey string `json:"public_key"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid enrollment bootstrap payload")
return
}
result, err := m.cluster.GetJoinRequestBootstrap(r.Context(), clustermodule.GetJoinRequestBootstrapInput{
ClusterID: payload.ClusterID,
JoinRequestID: chi.URLParam(r, "requestID"),
NodeFingerprint: payload.NodeFingerprint,
PublicKey: payload.PublicKey,
})
if err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, result)
}
func (m *Module) registerAgent(w http.ResponseWriter, r *http.Request) {
var payload struct {
NodeKey string `json:"node_key"`
Name string `json:"name"`
OwnershipType string `json:"ownership_type"`
OwnerOrganizationID *string `json:"owner_organization_id"`
ReportedVersion *string `json:"reported_version"`
Metadata json.RawMessage `json:"metadata"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid agent registration payload")
return
}
if payload.NodeKey == "" || payload.Name == "" || payload.OwnershipType == "" {
httpx.WriteError(w, http.StatusBadRequest, "node_key, name, and ownership_type are required")
return
}
if len(payload.Metadata) == 0 {
payload.Metadata = json.RawMessage(`{}`)
}
now := time.Now().UTC()
nodeID := uuid.NewString()
if err := m.db.QueryRow(r.Context(), `
INSERT INTO nodes (
id, owner_organization_id, node_key, name, ownership_type, registration_status, health_status,
version_state, partition_state, desired_version, reported_version, last_seen_at, metadata, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, 'active', 'unknown', 'unknown', 'healthy', NULL, $6, $7, $8::jsonb, $9, $10)
ON CONFLICT (node_key) DO UPDATE SET
name = EXCLUDED.name,
ownership_type = EXCLUDED.ownership_type,
owner_organization_id = EXCLUDED.owner_organization_id,
registration_status = 'active',
reported_version = EXCLUDED.reported_version,
last_seen_at = EXCLUDED.last_seen_at,
metadata = EXCLUDED.metadata,
updated_at = EXCLUDED.updated_at
RETURNING id
`, nodeID, payload.OwnerOrganizationID, payload.NodeKey, payload.Name, payload.OwnershipType, payload.ReportedVersion, now, []byte(payload.Metadata), now, now).Scan(&nodeID); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{
"node_id": nodeID,
"status": "registered",
"legacy": true,
"warning": "direct node-agent registration is retained for compatibility; production enrollment must use /node-agents/enroll",
})
}
func (m *Module) reportHealth(w http.ResponseWriter, r *http.Request) {
var payload struct {
HealthStatus string `json:"health_status"`
ReportedVersion *string `json:"reported_version"`
Metadata json.RawMessage `json:"metadata"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid node health payload")
return
}
if payload.HealthStatus == "" {
payload.HealthStatus = "unknown"
}
if len(payload.Metadata) == 0 {
payload.Metadata = json.RawMessage(`{}`)
}
if _, err := m.db.Exec(r.Context(), `
UPDATE nodes
SET health_status = $2, reported_version = COALESCE($3, reported_version), last_seen_at = $4, metadata = $5::jsonb, updated_at = $4
WHERE id = $1
`, chi.URLParam(r, "nodeID"), payload.HealthStatus, payload.ReportedVersion, time.Now().UTC(), []byte(payload.Metadata)); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{"status": "accepted"})
}
func (m *Module) reportServiceStatus(w http.ResponseWriter, r *http.Request) {
var payload struct {
Services []struct {
ServiceType string `json:"service_type"`
ReportedState string `json:"reported_state"`
Metadata json.RawMessage `json:"metadata"`
} `json:"services"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid node service status payload")
return
}
now := time.Now().UTC()
for _, service := range payload.Services {
if len(service.Metadata) == 0 {
service.Metadata = json.RawMessage(`{}`)
}
if _, err := m.db.Exec(r.Context(), `
INSERT INTO node_services (
node_id, service_type, enabled, desired_state, reported_state, last_reported_at, metadata, updated_at
) VALUES ($1, $2, FALSE, 'disabled', $3, $4, $5::jsonb, $4)
ON CONFLICT (node_id, service_type) DO UPDATE SET
reported_state = EXCLUDED.reported_state,
last_reported_at = EXCLUDED.last_reported_at,
metadata = EXCLUDED.metadata,
updated_at = EXCLUDED.updated_at
`, chi.URLParam(r, "nodeID"), service.ServiceType, service.ReportedState, now, []byte(service.Metadata)); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
}
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{"status": "accepted"})
}
func (m *Module) listVPNAssignments(w http.ResponseWriter, r *http.Request) {
items, err := m.cluster.ListNodeVPNAssignments(r.Context(), chi.URLParam(r, "clusterID"), chi.URLParam(r, "nodeID"))
if writeClusterServiceError(w, err) {
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{
"vpn_assignments": items,
"runtime_execution_enabled": false,
})
}
func (m *Module) reportVPNAssignmentStatus(w http.ResponseWriter, r *http.Request) {
var payload struct {
ObservedStatus string `json:"observed_status"`
StatusPayload json.RawMessage `json:"status_payload"`
ObservedAt *time.Time `json:"observed_at"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid vpn assignment status payload")
return
}
observedAt := time.Time{}
if payload.ObservedAt != nil {
observedAt = *payload.ObservedAt
}
item, err := m.cluster.ReportNodeVPNAssignmentStatus(r.Context(), clustermodule.ReportNodeVPNAssignmentStatusInput{
ClusterID: chi.URLParam(r, "clusterID"),
NodeID: chi.URLParam(r, "nodeID"),
VPNConnectionID: chi.URLParam(r, "vpnConnectionID"),
ObservedStatus: payload.ObservedStatus,
StatusPayload: payload.StatusPayload,
ObservedAt: observedAt,
})
if writeClusterServiceError(w, err) {
return
}
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{
"vpn_assignment_status": item,
"runtime_execution_enabled": false,
})
}
func (m *Module) requestUpdateManifest(w http.ResponseWriter, r *http.Request) {
nodeID := chi.URLParam(r, "nodeID")
var mode, channel string
var canary, automatic bool
var desiredVersion *string
if err := m.db.QueryRow(r.Context(), `
SELECT n.desired_version, p.mode, p.channel, p.canary, p.automatic_rollout
FROM nodes n
LEFT JOIN node_update_policies p ON p.node_id = n.id
WHERE n.id = $1
`, nodeID).Scan(&desiredVersion, &mode, &channel, &canary, &automatic); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{
"manifest": map[string]any{
"node_id": nodeID,
"desired_version": desiredVersion,
"mode": mode,
"channel": channel,
"canary": canary,
"automatic_rollout": automatic,
},
})
}
func (m *Module) acknowledgeUpdateResult(w http.ResponseWriter, r *http.Request) {
m.recordUpdateRun(w, r, "update")
}
func (m *Module) reportRollbackResult(w http.ResponseWriter, r *http.Request) {
m.recordUpdateRun(w, r, "rollback")
}
func (m *Module) recordUpdateRun(w http.ResponseWriter, r *http.Request, action string) {
var payload struct {
TargetVersion string `json:"target_version"`
Status string `json:"status"`
Payload json.RawMessage `json:"payload"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid update result payload")
return
}
if payload.Status == "" {
payload.Status = "acknowledged"
}
if len(payload.Payload) == 0 {
payload.Payload = json.RawMessage(`{}`)
}
now := time.Now().UTC()
runID := uuid.NewString()
if _, err := m.db.Exec(r.Context(), `
INSERT INTO node_agent_update_runs (
id, node_id, action, target_version, status, requested_at, acknowledged_at, completed_at, payload
) VALUES ($1, $2, $3, $4, $5, $6, $6, CASE WHEN $5 IN ('succeeded', 'failed') THEN $6 ELSE NULL END, $7::jsonb)
`, runID, chi.URLParam(r, "nodeID"), action, payload.TargetVersion, payload.Status, now, []byte(payload.Payload)); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if action == "update" && payload.Status == "succeeded" {
_, _ = m.db.Exec(r.Context(), `
UPDATE nodes
SET reported_version = $2, version_state = 'current', updated_at = $3
WHERE id = $1
`, chi.URLParam(r, "nodeID"), payload.TargetVersion, now)
}
if action == "rollback" && payload.Status == "succeeded" {
_, _ = m.db.Exec(r.Context(), `
UPDATE nodes
SET reported_version = $2, version_state = 'rollback', updated_at = $3
WHERE id = $1
`, chi.URLParam(r, "nodeID"), payload.TargetVersion, now)
}
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{"status": "accepted", "run_id": runID})
}
func writeClusterServiceError(w http.ResponseWriter, err error) bool {
if err == nil {
return false
}
switch {
case errors.Is(err, clustermodule.ErrVPNLeaseOwnerNotAllowed), errors.Is(err, clustermodule.ErrVPNLeaseOwnerRoleRequired):
httpx.WriteError(w, http.StatusForbidden, err.Error())
case errors.Is(err, clustermodule.ErrInvalidPayload):
httpx.WriteError(w, http.StatusBadRequest, err.Error())
default:
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
}
return true
}
@@ -0,0 +1,21 @@
package organization
import "testing"
func TestTenantSafeTopologyExposureDoesNotExposeCoreMesh(t *testing.T) {
value := tenantSafeTopologyExposure()
forbidden := []string{
"core_node_id",
"mesh_route",
"cluster_private_topology",
"certificate_serial",
}
for _, token := range forbidden {
if value == token {
t.Fatalf("topology exposure leaked forbidden token %q", token)
}
}
if value != "tenant_safe_no_core_mesh_topology" {
t.Fatalf("unexpected topology exposure marker: %q", value)
}
}
@@ -0,0 +1,518 @@
package organization
import (
"context"
"encoding/json"
"errors"
"net/http"
"strings"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/example/remote-access-platform/backend/internal/platform/authority"
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
"github.com/example/remote-access-platform/backend/internal/platform/module"
)
const (
RoleOrgOwner = "org_owner"
RoleOrgAdmin = "org_admin"
RoleOrgOperator = "org_operator"
RoleOrgMember = "org_member"
RoleOrgViewer = "org_viewer"
)
type Module struct {
db *pgxpool.Pool
authority *authority.Verifier
}
type Organization struct {
ID string `json:"id"`
Slug string `json:"slug"`
Name string `json:"name"`
Status string `json:"status"`
Metadata json.RawMessage `json:"metadata"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type Membership struct {
ID string `json:"id"`
OrganizationID string `json:"organization_id"`
UserID string `json:"user_id"`
RoleID string `json:"role_id"`
Status string `json:"status"`
InvitedByUser *string `json:"invited_by_user_id,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type AdminSummary struct {
OrganizationID string `json:"organization_id"`
ResourceCount int64 `json:"resource_count"`
ActiveSessionCount int64 `json:"active_session_count"`
ServiceEndpoints []ServiceSummary `json:"service_endpoints"`
ConnectorStatus map[string]any `json:"connector_status"`
RecentAudit []OrgAuditEvent `json:"recent_audit"`
TopologyExposure string `json:"topology_exposure"`
}
type ServiceSummary struct {
Protocol string `json:"protocol"`
Count int64 `json:"count"`
}
type OrgAuditEvent struct {
ID string `json:"id"`
EventType string `json:"event_type"`
TargetType string `json:"target_type"`
TargetID string `json:"target_id"`
Payload json.RawMessage `json:"payload"`
CreatedAt time.Time `json:"created_at"`
}
type createOrganizationRequest struct {
ActorUserID string `json:"actor_user_id"`
Slug string `json:"slug"`
Name string `json:"name"`
Metadata json.RawMessage `json:"metadata"`
}
type addMembershipRequest struct {
ActorUserID string `json:"actor_user_id"`
UserID string `json:"user_id"`
RoleID string `json:"role_id"`
}
func NewModule(deps module.Dependencies) *Module {
authorityVerifier, _ := authority.NewVerifier(deps.Config.Installation)
return &Module{db: deps.Infra.DB, authority: authorityVerifier}
}
func (m *Module) Name() string {
return "organization"
}
func (m *Module) RegisterRoutes(router chi.Router) {
router.Route("/organizations", func(r chi.Router) {
r.Get("/", m.listOrganizations)
r.Post("/", m.createOrganization)
r.Get("/{organizationID}", m.getOrganization)
r.Get("/{organizationID}/admin-summary", m.getAdminSummary)
r.Get("/{organizationID}/memberships", m.listMemberships)
r.Post("/{organizationID}/memberships", m.addMembership)
})
}
func (m *Module) listOrganizations(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id")
if userID == "" {
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
return
}
platformRole, err := m.getPlatformRole(r.Context(), userID)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
var rows pgx.Rows
if isPlatformAdmin(platformRole) {
rows, err = m.db.Query(r.Context(), `
SELECT id, slug, name, status, metadata, created_at, updated_at
FROM organizations
ORDER BY created_at DESC
`)
} else {
rows, err = m.db.Query(r.Context(), `
SELECT o.id, o.slug, o.name, o.status, o.metadata, o.created_at, o.updated_at
FROM organizations o
INNER JOIN organization_memberships om ON om.organization_id = o.id
WHERE om.user_id = $1 AND om.status = 'active'
ORDER BY o.created_at DESC
`, userID)
}
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer rows.Close()
var organizations []Organization
for rows.Next() {
org, err := scanOrganization(rows)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
organizations = append(organizations, org)
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"organizations": organizations})
}
func (m *Module) getOrganization(w http.ResponseWriter, r *http.Request) {
orgID := chi.URLParam(r, "organizationID")
userID := r.URL.Query().Get("user_id")
if userID == "" {
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
return
}
if err := m.ensureOrgAccess(r.Context(), orgID, userID, false); err != nil {
status := http.StatusInternalServerError
if errors.Is(err, pgx.ErrNoRows) || errors.Is(err, errForbidden) {
status = http.StatusForbidden
}
httpx.WriteError(w, status, err.Error())
return
}
org, err := m.getOrganizationByID(r.Context(), orgID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
httpx.WriteError(w, http.StatusNotFound, "organization not found")
return
}
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"organization": org})
}
func (m *Module) getAdminSummary(w http.ResponseWriter, r *http.Request) {
orgID := chi.URLParam(r, "organizationID")
actorUserID := r.URL.Query().Get("actor_user_id")
if actorUserID == "" {
httpx.WriteError(w, http.StatusBadRequest, "actor_user_id is required")
return
}
if err := m.ensureOrgAccess(r.Context(), orgID, actorUserID, true); err != nil {
httpx.WriteError(w, http.StatusForbidden, err.Error())
return
}
summary, err := m.loadAdminSummary(r.Context(), orgID)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"admin_summary": summary})
}
func (m *Module) loadAdminSummary(ctx context.Context, orgID string) (AdminSummary, error) {
var resourceCount int64
if err := m.db.QueryRow(ctx, `
SELECT COUNT(*)
FROM resources
WHERE organization_id = $1::uuid
`, orgID).Scan(&resourceCount); err != nil {
return AdminSummary{}, err
}
var activeSessionCount int64
if err := m.db.QueryRow(ctx, `
SELECT COUNT(*)
FROM remote_sessions
WHERE organization_id = $1::uuid
AND state = 'active'
`, orgID).Scan(&activeSessionCount); err != nil {
return AdminSummary{}, err
}
rows, err := m.db.Query(ctx, `
SELECT protocol, COUNT(*)
FROM resources
WHERE organization_id = $1::uuid
GROUP BY protocol
ORDER BY protocol
`, orgID)
if err != nil {
return AdminSummary{}, err
}
defer rows.Close()
var services []ServiceSummary
for rows.Next() {
var item ServiceSummary
if err := rows.Scan(&item.Protocol, &item.Count); err != nil {
return AdminSummary{}, err
}
services = append(services, item)
}
if err := rows.Err(); err != nil {
return AdminSummary{}, err
}
auditRows, err := m.db.Query(ctx, `
SELECT ae.id::text, ae.event_type, ae.target_type, ae.target_id, ae.payload, ae.created_at
FROM audit_events ae
LEFT JOIN remote_sessions rs ON rs.id = ae.remote_session_id
WHERE rs.organization_id = $1::uuid
ORDER BY ae.created_at DESC
LIMIT 20
`, orgID)
if err != nil {
return AdminSummary{}, err
}
defer auditRows.Close()
var audit []OrgAuditEvent
for auditRows.Next() {
var item OrgAuditEvent
if err := auditRows.Scan(&item.ID, &item.EventType, &item.TargetType, &item.TargetID, &item.Payload, &item.CreatedAt); err != nil {
return AdminSummary{}, err
}
audit = append(audit, item)
}
if err := auditRows.Err(); err != nil {
return AdminSummary{}, err
}
return AdminSummary{
OrganizationID: orgID,
ResourceCount: resourceCount,
ActiveSessionCount: activeSessionCount,
ServiceEndpoints: services,
ConnectorStatus: map[string]any{
"vpn": "not_implemented",
"connector": "not_implemented",
},
RecentAudit: audit,
TopologyExposure: tenantSafeTopologyExposure(),
}, nil
}
func tenantSafeTopologyExposure() string {
return "tenant_safe_no_core_mesh_topology"
}
func (m *Module) createOrganization(w http.ResponseWriter, r *http.Request) {
var req createOrganizationRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid organization payload")
return
}
if req.ActorUserID == "" || req.Name == "" || req.Slug == "" {
httpx.WriteError(w, http.StatusBadRequest, "actor_user_id, slug, and name are required")
return
}
role, err := m.getPlatformRole(r.Context(), req.ActorUserID)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if !isPlatformAdmin(role) {
httpx.WriteError(w, http.StatusForbidden, "platform admin role is required")
return
}
if len(req.Metadata) == 0 {
req.Metadata = json.RawMessage(`{}`)
}
if !json.Valid(req.Metadata) {
httpx.WriteError(w, http.StatusBadRequest, "metadata must be valid json")
return
}
now := time.Now().UTC()
org := Organization{
ID: uuid.NewString(),
Slug: normalizeSlug(req.Slug),
Name: req.Name,
Status: "active",
Metadata: req.Metadata,
CreatedAt: now,
UpdatedAt: now,
}
membership := Membership{
ID: uuid.NewString(),
OrganizationID: org.ID,
UserID: req.ActorUserID,
RoleID: RoleOrgOwner,
Status: "active",
CreatedAt: now,
UpdatedAt: now,
}
tx, err := m.db.Begin(r.Context())
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer tx.Rollback(r.Context())
if _, err := tx.Exec(r.Context(), `
INSERT INTO organizations (id, slug, name, status, metadata, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5::jsonb, $6, $7)
`, org.ID, org.Slug, org.Name, org.Status, []byte(org.Metadata), org.CreatedAt, org.UpdatedAt); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if _, err := tx.Exec(r.Context(), `
INSERT INTO organization_memberships (
id, organization_id, user_id, role_id, status, invited_by_user_id, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
`, membership.ID, membership.OrganizationID, membership.UserID, membership.RoleID, membership.Status, req.ActorUserID, membership.CreatedAt, membership.UpdatedAt); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := tx.Commit(r.Context()); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusCreated, map[string]any{
"organization": org,
"membership": membership,
})
}
func (m *Module) listMemberships(w http.ResponseWriter, r *http.Request) {
orgID := chi.URLParam(r, "organizationID")
userID := r.URL.Query().Get("user_id")
if userID == "" {
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
return
}
if err := m.ensureOrgAccess(r.Context(), orgID, userID, true); err != nil {
httpx.WriteError(w, http.StatusForbidden, err.Error())
return
}
rows, err := m.db.Query(r.Context(), `
SELECT id, organization_id, user_id, role_id, status, invited_by_user_id, created_at, updated_at
FROM organization_memberships
WHERE organization_id = $1
ORDER BY created_at DESC
`, orgID)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer rows.Close()
var memberships []Membership
for rows.Next() {
membership, err := scanMembership(rows)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
memberships = append(memberships, membership)
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"memberships": memberships})
}
func (m *Module) addMembership(w http.ResponseWriter, r *http.Request) {
orgID := chi.URLParam(r, "organizationID")
var req addMembershipRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid membership payload")
return
}
if req.ActorUserID == "" || req.UserID == "" || req.RoleID == "" {
httpx.WriteError(w, http.StatusBadRequest, "actor_user_id, user_id, and role_id are required")
return
}
if err := m.ensureOrgAccess(r.Context(), orgID, req.ActorUserID, true); err != nil {
httpx.WriteError(w, http.StatusForbidden, err.Error())
return
}
now := time.Now().UTC()
membership := Membership{
ID: uuid.NewString(),
OrganizationID: orgID,
UserID: req.UserID,
RoleID: req.RoleID,
Status: "active",
InvitedByUser: &req.ActorUserID,
CreatedAt: now,
UpdatedAt: now,
}
if _, err := m.db.Exec(r.Context(), `
INSERT INTO organization_memberships (
id, organization_id, user_id, role_id, status, invited_by_user_id, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (organization_id, user_id) DO UPDATE SET
role_id = EXCLUDED.role_id,
status = 'active',
invited_by_user_id = EXCLUDED.invited_by_user_id,
updated_at = EXCLUDED.updated_at
`, membership.ID, membership.OrganizationID, membership.UserID, membership.RoleID, membership.Status, membership.InvitedByUser, membership.CreatedAt, membership.UpdatedAt); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusCreated, map[string]any{"membership": membership})
}
var errForbidden = errors.New("forbidden")
func (m *Module) ensureOrgAccess(ctx context.Context, orgID, userID string, adminRequired bool) error {
role, err := m.getPlatformRole(ctx, userID)
if err != nil {
return err
}
if isPlatformAdmin(role) {
return nil
}
query := `
SELECT role_id
FROM organization_memberships
WHERE organization_id = $1 AND user_id = $2 AND status = 'active'
`
var roleID string
if err := m.db.QueryRow(ctx, query, orgID, userID).Scan(&roleID); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return errForbidden
}
return err
}
if adminRequired && roleID != RoleOrgOwner && roleID != RoleOrgAdmin {
return errForbidden
}
return nil
}
func (m *Module) getPlatformRole(ctx context.Context, userID string) (string, error) {
return authority.EffectivePlatformRole(ctx, m.db, m.authority, userID)
}
func isPlatformAdmin(role string) bool {
return role == "platform_admin" || role == "platform_recovery_admin"
}
func (m *Module) getOrganizationByID(ctx context.Context, orgID string) (Organization, error) {
row := m.db.QueryRow(ctx, `
SELECT id, slug, name, status, metadata, created_at, updated_at
FROM organizations
WHERE id = $1
`, orgID)
return scanOrganization(row)
}
type rowScanner interface {
Scan(dest ...any) error
}
func scanOrganization(row rowScanner) (Organization, error) {
var org Organization
if err := row.Scan(&org.ID, &org.Slug, &org.Name, &org.Status, &org.Metadata, &org.CreatedAt, &org.UpdatedAt); err != nil {
return Organization{}, err
}
if len(org.Metadata) == 0 {
org.Metadata = json.RawMessage(`{}`)
}
return org, nil
}
func scanMembership(row rowScanner) (Membership, error) {
var membership Membership
if err := row.Scan(
&membership.ID,
&membership.OrganizationID,
&membership.UserID,
&membership.RoleID,
&membership.Status,
&membership.InvitedByUser,
&membership.CreatedAt,
&membership.UpdatedAt,
); err != nil {
return Membership{}, err
}
return membership, nil
}
func normalizeSlug(in string) string {
return strings.ToLower(strings.TrimSpace(in))
}
+639
View File
@@ -0,0 +1,639 @@
package resource
import (
"context"
"encoding/json"
"errors"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/example/remote-access-platform/backend/internal/platform/authority"
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
"github.com/example/remote-access-platform/backend/internal/platform/module"
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
)
const (
CertificateVerificationModeStrict = "strict"
CertificateVerificationModeIgnore = "ignore"
RenderQualityProfileLowBandwidth = "low_bandwidth"
RenderQualityProfileBalanced = "balanced"
RenderQualityProfileHighQuality = "high_quality"
RenderQualityProfileTextPriority = "text_priority"
ClipboardModeDisabled = "disabled"
ClipboardModeClientToServer = "client_to_server"
ClipboardModeServerToClient = "server_to_client"
ClipboardModeBidirectional = "bidirectional"
FileTransferModeDisabled = "disabled"
FileTransferModeClientToServer = "client_to_server"
FileTransferModeServerToClient = "server_to_client"
FileTransferModeBidirectional = "bidirectional"
)
type Module struct {
db *pgxpool.Pool
appEnv string
secretStore *secrets.ResourceSecretStore
authority *authority.Verifier
}
type Resource struct {
ID string `json:"id"`
OrganizationID string `json:"organization_id"`
Name string `json:"name"`
Address string `json:"address"`
Protocol string `json:"protocol"`
SecretRef *string `json:"secret_ref,omitempty"`
CertificateVerificationMode string `json:"certificate_verification_mode"`
RenderQualityProfile string `json:"render_quality_profile"`
ClipboardMode string `json:"clipboard_mode"`
FileTransferMode string `json:"file_transfer_mode"`
Metadata json.RawMessage `json:"metadata"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type upsertResourceRequest struct {
ActorUserID string `json:"actor_user_id"`
OrganizationID string `json:"organization_id"`
Name string `json:"name"`
Address string `json:"address"`
Protocol string `json:"protocol"`
SecretRef *string `json:"secret_ref"`
CertificateVerificationMode string `json:"certificate_verification_mode"`
RenderQualityProfile string `json:"render_quality_profile"`
ClipboardMode string `json:"clipboard_mode"`
FileTransferMode string `json:"file_transfer_mode"`
Metadata json.RawMessage `json:"metadata"`
}
type upsertResourceSecretRequest struct {
ActorUserID string `json:"actor_user_id"`
Payload json.RawMessage `json:"payload"`
Metadata json.RawMessage `json:"metadata"`
}
func NewModule(deps module.Dependencies, secretStores ...*secrets.ResourceSecretStore) *Module {
var secretStore *secrets.ResourceSecretStore
if len(secretStores) > 0 {
secretStore = secretStores[0]
}
authorityVerifier, _ := authority.NewVerifier(deps.Config.Installation)
return &Module{db: deps.Infra.DB, appEnv: deps.Config.App.Env, secretStore: secretStore, authority: authorityVerifier}
}
func (m *Module) Name() string {
return "resource"
}
func (m *Module) RegisterRoutes(router chi.Router) {
router.Route("/resources", func(r chi.Router) {
r.Get("/", m.listResources)
r.Post("/", m.createResource)
r.Get("/{resourceID}", m.getResource)
r.Put("/{resourceID}", m.updateResource)
r.Put("/{resourceID}/secret", m.upsertResourceSecret)
})
}
func (m *Module) listResources(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id")
orgID := r.URL.Query().Get("organization_id")
if userID == "" {
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
return
}
platformRole, err := m.getPlatformRole(r.Context(), userID)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
query := `
SELECT r.id, r.organization_id, r.name, r.address, r.protocol, r.secret_ref,
r.certificate_verification_mode, r.metadata, r.created_at, r.updated_at,
COALESCE(rp.clipboard_mode, 'disabled') AS clipboard_mode,
COALESCE(rp.file_transfer_mode, 'disabled') AS file_transfer_mode
FROM resources r
LEFT JOIN resource_policies rp ON rp.resource_id = r.id
`
args := make([]any, 0, 2)
if platformRole == "platform_admin" || platformRole == "platform_recovery_admin" {
if orgID != "" {
query += ` WHERE r.organization_id = $1`
args = append(args, orgID)
}
query += ` ORDER BY r.created_at DESC`
} else {
query += `
INNER JOIN organization_memberships om ON om.organization_id = r.organization_id
WHERE om.user_id = $1 AND om.status = 'active'
`
args = append(args, userID)
if orgID != "" {
query += ` AND r.organization_id = $2`
args = append(args, orgID)
}
query += ` ORDER BY r.created_at DESC`
}
rows, err := m.db.Query(r.Context(), query, args...)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer rows.Close()
resources := make([]Resource, 0)
for rows.Next() {
resource, err := scanResource(rows)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
resources = append(resources, resource)
}
if err := rows.Err(); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"resources": resources})
}
func (m *Module) getResource(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id")
if userID == "" {
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
return
}
resource, err := m.getByID(r.Context(), chi.URLParam(r, "resourceID"))
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
httpx.WriteError(w, http.StatusNotFound, "resource not found")
return
}
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := m.ensureResourceAccess(r.Context(), resource.OrganizationID, userID, false); err != nil {
httpx.WriteError(w, http.StatusForbidden, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"resource": resource})
}
func (m *Module) createResource(w http.ResponseWriter, r *http.Request) {
req, err := decodeUpsertRequest(r)
if err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
if err := secrets.ValidateResourceSecretReadiness(req.Protocol, req.SecretRef, req.Metadata, m.appEnv); err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
now := time.Now().UTC()
resource := Resource{
ID: uuid.NewString(),
OrganizationID: req.OrganizationID,
Name: req.Name,
Address: req.Address,
Protocol: req.Protocol,
SecretRef: req.SecretRef,
CertificateVerificationMode: req.CertificateVerificationMode,
RenderQualityProfile: req.RenderQualityProfile,
ClipboardMode: req.ClipboardMode,
FileTransferMode: req.FileTransferMode,
Metadata: req.Metadata,
CreatedAt: now,
UpdatedAt: now,
}
if err := m.ensureResourceAccess(r.Context(), req.OrganizationID, req.ActorUserID, true); err != nil {
httpx.WriteError(w, http.StatusForbidden, err.Error())
return
}
tx, err := m.db.Begin(r.Context())
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer tx.Rollback(r.Context())
if _, err := tx.Exec(r.Context(), `
INSERT INTO resources (
id, organization_id, name, address, protocol, secret_ref, certificate_verification_mode, metadata, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9, $10)
`, resource.ID, resource.OrganizationID, resource.Name, resource.Address, resource.Protocol, resource.SecretRef, resource.CertificateVerificationMode, []byte(resource.Metadata), resource.CreatedAt, resource.UpdatedAt); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := upsertResourcePolicy(r.Context(), tx, resource.ID, resource.ClipboardMode, resource.FileTransferMode, now); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := tx.Commit(r.Context()); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusCreated, map[string]any{"resource": resource})
}
func (m *Module) updateResource(w http.ResponseWriter, r *http.Request) {
req, err := decodeUpsertRequest(r)
if err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
if err := secrets.ValidateResourceSecretReadiness(req.Protocol, req.SecretRef, req.Metadata, m.appEnv); err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
resourceID := chi.URLParam(r, "resourceID")
existing, err := m.getByID(r.Context(), resourceID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
httpx.WriteError(w, http.StatusNotFound, "resource not found")
return
}
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := m.ensureResourceAccess(r.Context(), existing.OrganizationID, req.ActorUserID, true); err != nil {
httpx.WriteError(w, http.StatusForbidden, err.Error())
return
}
if req.OrganizationID != existing.OrganizationID {
if err := m.ensureResourceAccess(r.Context(), req.OrganizationID, req.ActorUserID, true); err != nil {
httpx.WriteError(w, http.StatusForbidden, err.Error())
return
}
}
now := time.Now().UTC()
tx, err := m.db.Begin(r.Context())
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer tx.Rollback(r.Context())
tag, err := tx.Exec(r.Context(), `
UPDATE resources
SET
organization_id = $2,
name = $3,
address = $4,
protocol = $5,
secret_ref = $6,
certificate_verification_mode = $7,
metadata = $8::jsonb,
updated_at = $9
WHERE id = $1
`, resourceID, req.OrganizationID, req.Name, req.Address, req.Protocol, req.SecretRef, req.CertificateVerificationMode, []byte(req.Metadata), now)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if tag.RowsAffected() == 0 {
httpx.WriteError(w, http.StatusNotFound, "resource not found")
return
}
if err := upsertResourcePolicy(r.Context(), tx, resourceID, req.ClipboardMode, req.FileTransferMode, now); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := tx.Commit(r.Context()); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
resource, err := m.getByID(r.Context(), resourceID)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"resource": resource})
}
func (m *Module) upsertResourceSecret(w http.ResponseWriter, r *http.Request) {
if m.secretStore == nil {
httpx.WriteError(w, http.StatusServiceUnavailable, "resource secret encryption is not configured")
return
}
var req upsertResourceSecretRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid resource secret payload")
return
}
if req.ActorUserID == "" {
httpx.WriteError(w, http.StatusBadRequest, "actor_user_id is required")
return
}
resourceID := chi.URLParam(r, "resourceID")
resource, err := m.getByID(r.Context(), resourceID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
httpx.WriteError(w, http.StatusNotFound, "resource not found")
return
}
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := m.ensureResourceAccess(r.Context(), resource.OrganizationID, req.ActorUserID, true); err != nil {
httpx.WriteError(w, http.StatusForbidden, err.Error())
return
}
tx, err := m.db.Begin(r.Context())
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer tx.Rollback(r.Context())
secretStore := m.secretStore.WithDB(tx)
secretRef := secrets.DefaultResourceSecretRef(resource.OrganizationID, resource.ID)
descriptor, err := secretStore.Upsert(r.Context(), secrets.UpsertResourceSecretCommand{
OrganizationID: resource.OrganizationID,
ResourceID: resource.ID,
Protocol: resource.Protocol,
SecretRef: secretRef,
Payload: req.Payload,
Metadata: req.Metadata,
ActorUserID: req.ActorUserID,
})
if err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
if _, err := tx.Exec(r.Context(), `
UPDATE resources
SET secret_ref = $2, updated_at = $3
WHERE id = $1::uuid
`, resource.ID, descriptor.SecretRef, time.Now().UTC()); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := writeAuditEvent(r.Context(), tx, "resource_secret_rotated", req.ActorUserID, "resource_secret", descriptor.SecretRef, map[string]any{
"resource_id": resource.ID,
"organization_id": resource.OrganizationID,
"protocol": resource.Protocol,
"version": descriptor.Version,
"secret_ref": descriptor.SecretRef,
}); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := tx.Commit(r.Context()); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"secret": descriptor})
}
func decodeUpsertRequest(r *http.Request) (*upsertResourceRequest, error) {
var req upsertResourceRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.New("invalid resource payload")
}
if req.Name == "" {
return nil, errors.New("name is required")
}
if req.ActorUserID == "" {
return nil, errors.New("actor_user_id is required")
}
if req.OrganizationID == "" {
return nil, errors.New("organization_id is required")
}
if req.Address == "" {
return nil, errors.New("address is required")
}
if req.Protocol == "" {
req.Protocol = "rdp"
}
mode, err := normalizeCertificateVerificationMode(req.CertificateVerificationMode)
if err != nil {
return nil, err
}
req.CertificateVerificationMode = mode
renderQualityProfile, err := normalizeRenderQualityProfile(req.RenderQualityProfile)
if err != nil {
return nil, err
}
req.RenderQualityProfile = renderQualityProfile
clipboardMode, err := normalizeClipboardMode(req.ClipboardMode)
if err != nil {
return nil, err
}
req.ClipboardMode = clipboardMode
fileTransferMode, err := normalizeFileTransferMode(req.FileTransferMode)
if err != nil {
return nil, err
}
req.FileTransferMode = fileTransferMode
metadata, err := normalizeMetadata(req.Metadata, req.CertificateVerificationMode, req.RenderQualityProfile)
if err != nil {
return nil, err
}
req.Metadata = metadata
return &req, nil
}
func normalizeCertificateVerificationMode(mode string) (string, error) {
switch mode {
case "", CertificateVerificationModeStrict:
return CertificateVerificationModeStrict, nil
case CertificateVerificationModeIgnore:
return CertificateVerificationModeIgnore, nil
default:
return "", errors.New("certificate_verification_mode must be one of: strict, ignore")
}
}
func normalizeClipboardMode(mode string) (string, error) {
switch mode {
case "", ClipboardModeDisabled:
return ClipboardModeDisabled, nil
case ClipboardModeClientToServer, ClipboardModeServerToClient, ClipboardModeBidirectional:
return mode, nil
default:
return "", errors.New("clipboard_mode must be one of: disabled, client_to_server, server_to_client, bidirectional")
}
}
func normalizeFileTransferMode(mode string) (string, error) {
switch mode {
case "", FileTransferModeDisabled:
return FileTransferModeDisabled, nil
case FileTransferModeClientToServer, FileTransferModeServerToClient, FileTransferModeBidirectional:
return mode, nil
default:
return "", errors.New("file_transfer_mode must be one of: disabled, client_to_server, server_to_client, bidirectional")
}
}
func normalizeMetadata(raw json.RawMessage, certificateVerificationMode, renderQualityProfile string) (json.RawMessage, error) {
if len(raw) == 0 {
raw = json.RawMessage(`{}`)
}
if !json.Valid(raw) {
return nil, errors.New("metadata must be valid json")
}
var metadata map[string]any
if err := json.Unmarshal(raw, &metadata); err != nil {
return nil, errors.New("metadata must be a json object")
}
metadata["certificate_verification_mode"] = certificateVerificationMode
metadata["render_quality_profile"] = renderQualityProfile
encoded, err := json.Marshal(metadata)
if err != nil {
return nil, err
}
return json.RawMessage(encoded), nil
}
func (m *Module) getByID(ctx context.Context, resourceID string) (Resource, error) {
row := m.db.QueryRow(ctx, `
SELECT r.id, r.organization_id, r.name, r.address, r.protocol, r.secret_ref,
r.certificate_verification_mode, r.metadata, r.created_at, r.updated_at,
COALESCE(rp.clipboard_mode, 'disabled') AS clipboard_mode,
COALESCE(rp.file_transfer_mode, 'disabled') AS file_transfer_mode
FROM resources r
LEFT JOIN resource_policies rp ON rp.resource_id = r.id
WHERE r.id = $1
`, resourceID)
return scanResource(row)
}
func (m *Module) ensureResourceAccess(ctx context.Context, orgID, userID string, adminRequired bool) error {
role, err := m.getPlatformRole(ctx, userID)
if err != nil {
return err
}
if role == "platform_admin" || role == "platform_recovery_admin" {
return nil
}
var membershipRole string
if err := m.db.QueryRow(ctx, `
SELECT role_id
FROM organization_memberships
WHERE organization_id = $1 AND user_id = $2 AND status = 'active'
`, orgID, userID).Scan(&membershipRole); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return errors.New("forbidden")
}
return err
}
if adminRequired && membershipRole != "org_owner" && membershipRole != "org_admin" {
return errors.New("forbidden")
}
return nil
}
func (m *Module) getPlatformRole(ctx context.Context, userID string) (string, error) {
return authority.EffectivePlatformRole(ctx, m.db, m.authority, userID)
}
type rowScanner interface {
Scan(dest ...any) error
}
func scanResource(row rowScanner) (Resource, error) {
var resource Resource
if err := row.Scan(
&resource.ID,
&resource.OrganizationID,
&resource.Name,
&resource.Address,
&resource.Protocol,
&resource.SecretRef,
&resource.CertificateVerificationMode,
&resource.Metadata,
&resource.CreatedAt,
&resource.UpdatedAt,
&resource.ClipboardMode,
&resource.FileTransferMode,
); err != nil {
return Resource{}, err
}
if len(resource.Metadata) == 0 {
resource.Metadata = json.RawMessage(`{}`)
}
if resource.CertificateVerificationMode == "" {
resource.CertificateVerificationMode = CertificateVerificationModeStrict
}
if resource.RenderQualityProfile == "" {
resource.RenderQualityProfile = renderQualityProfileFromMetadata(resource.Metadata)
}
if resource.ClipboardMode == "" {
resource.ClipboardMode = ClipboardModeDisabled
}
if resource.FileTransferMode == "" {
resource.FileTransferMode = FileTransferModeDisabled
}
return resource, nil
}
func upsertResourcePolicy(ctx context.Context, tx pgx.Tx, resourceID, clipboardMode, fileTransferMode string, now time.Time) error {
clipboardEnabled := clipboardMode != ClipboardModeDisabled
fileTransferEnabled := fileTransferMode == FileTransferModeClientToServer || fileTransferMode == FileTransferModeBidirectional
_, err := tx.Exec(ctx, `
INSERT INTO resource_policies (
resource_id, clipboard_enabled, clipboard_mode, file_transfer_enabled, file_transfer_mode, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $6)
ON CONFLICT (resource_id) DO UPDATE SET
clipboard_enabled = EXCLUDED.clipboard_enabled,
clipboard_mode = EXCLUDED.clipboard_mode,
file_transfer_enabled = EXCLUDED.file_transfer_enabled,
file_transfer_mode = EXCLUDED.file_transfer_mode,
updated_at = EXCLUDED.updated_at
`, resourceID, clipboardEnabled, clipboardMode, fileTransferEnabled, fileTransferMode, now)
return err
}
func writeAuditEvent(ctx context.Context, tx pgx.Tx, eventType, actorUserID, targetType, targetID string, payload map[string]any) error {
encoded, err := json.Marshal(payload)
if err != nil {
return err
}
_, err = tx.Exec(ctx, `
INSERT INTO audit_events (
id, actor_user_id, event_type, target_type, target_id, payload, created_at
) VALUES (
$1::uuid, NULLIF($2, '')::uuid, $3, $4, $5, $6::jsonb, $7
)
`, uuid.NewString(), actorUserID, eventType, targetType, targetID, encoded, time.Now().UTC())
return err
}
func normalizeRenderQualityProfile(profile string) (string, error) {
switch profile {
case "", RenderQualityProfileBalanced:
return RenderQualityProfileBalanced, nil
case RenderQualityProfileLowBandwidth, RenderQualityProfileHighQuality, RenderQualityProfileTextPriority:
return profile, nil
default:
return "", errors.New("render_quality_profile must be one of: low_bandwidth, balanced, high_quality, text_priority")
}
}
func renderQualityProfileFromMetadata(raw json.RawMessage) string {
if len(raw) == 0 {
return RenderQualityProfileBalanced
}
var metadata map[string]any
if err := json.Unmarshal(raw, &metadata); err != nil {
return RenderQualityProfileBalanced
}
if profile, ok := metadata["render_quality_profile"].(string); ok {
switch profile {
case RenderQualityProfileLowBandwidth, RenderQualityProfileBalanced, RenderQualityProfileHighQuality, RenderQualityProfileTextPriority:
return profile
}
}
return RenderQualityProfileBalanced
}
@@ -0,0 +1,219 @@
package sessionbroker
import (
"crypto/rsa"
"fmt"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
)
const (
directWorkerTLSTrustModeSmokeInsecure = "smoke_insecure"
directWorkerTLSTrustModePublicCA = "public_ca"
directWorkerTLSTrustModePlatformCA = "platform_ca"
)
type DataPlaneTokenClaims struct {
SessionID string `json:"session_id"`
AttachmentID string `json:"attachment_id"`
UserID string `json:"user_id"`
OrganizationID string `json:"organization_id"`
ClusterID string `json:"cluster_id,omitempty"`
WorkerID string `json:"worker_id"`
ResourceID string `json:"resource_id"`
AllowedChannels []string `json:"allowed_channels"`
ExpiresAtValue time.Time `json:"expires_at"`
jwt.RegisteredClaims
}
func (s *Service) buildDataPlaneOffer(session RemoteSession, attachment SessionAttachment) (*sessioncontracts.DataPlaneOffer, error) {
if s.cfg.DataPlane.TokenTTL <= 0 || s.cfg.DataPlane.TokenPrivateKeyPEM == "" {
return nil, nil
}
now := s.now().UTC()
expiresAt := now.Add(s.cfg.DataPlane.TokenTTL)
allowedChannels := dataPlaneAllowedChannelsFromSession(session)
jti := uuid.NewString()
claims := DataPlaneTokenClaims{
SessionID: session.ID,
AttachmentID: attachment.ID,
UserID: attachment.UserID,
OrganizationID: session.OrganizationID,
WorkerID: session.WorkerID,
ResourceID: session.ResourceID,
AllowedChannels: allowedChannels,
ExpiresAtValue: expiresAt,
RegisteredClaims: jwt.RegisteredClaims{
ID: jti,
Issuer: s.cfg.Auth.Issuer,
Subject: attachment.UserID,
Audience: jwt.ClaimStrings{"rap-data-plane", "worker:" + session.WorkerID},
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(expiresAt),
},
}
token, err := signDataPlaneToken(claims, s.cfg.DataPlane.TokenPrivateKeyPEM)
if err != nil {
return nil, err
}
candidates := s.buildDataPlaneCandidates(session)
preferred := sessioncontracts.DataPlaneCandidateBackendGateway
if len(candidates) > 0 {
preferred = candidates[0].Type
}
return &sessioncontracts.DataPlaneOffer{
Preferred: preferred,
Token: token,
ExpiresAt: expiresAt,
Candidates: candidates,
}, nil
}
func (s *Service) buildDataPlaneCandidates(session RemoteSession) []sessioncontracts.DataPlaneCandidate {
var candidates []sessioncontracts.DataPlaneCandidate
if directURL := s.directWorkerWSSURL(session.WorkerID); directURL != "" && s.canAdvertiseDirectWorkerWSS() {
metadata := map[string]any(nil)
if s.cfg.DataPlane.DirectWorkerJSONRuntime {
metadata = map[string]any{
"runtime_transport": "json_v1",
"traffic_ready": true,
}
s.addDirectWorkerTLSTrustMetadata(metadata)
if s.cfg.DataPlane.DirectWorkerBinaryRender {
metadata["render_transport"] = "binary_v1"
metadata["binary_render"] = true
metadata["supported_color_modes"] = []string{"full_color", "grayscale"}
metadata["default_color_mode"] = "full_color"
}
}
candidates = append(candidates, sessioncontracts.DataPlaneCandidate{
Type: sessioncontracts.DataPlaneCandidateDirectWorkerWSS,
URL: directURL,
WorkerID: session.WorkerID,
Priority: 10,
Metadata: metadata,
})
}
if s.cfg.DataPlane.BackendGatewayURL != "" {
candidates = append(candidates, sessioncontracts.DataPlaneCandidate{
Type: sessioncontracts.DataPlaneCandidateBackendGateway,
URL: s.cfg.DataPlane.BackendGatewayURL,
Priority: 100,
})
}
return candidates
}
func (s *Service) canAdvertiseDirectWorkerWSS() bool {
trustMode := normalizeDirectWorkerTLSTrustMode(s.cfg.DataPlane.DirectWorkerTLSTrustMode)
return !secrets.IsProductionEnv(s.cfg.App.Env) || directWorkerTLSTrustModeIsProductionTrusted(trustMode)
}
func (s *Service) addDirectWorkerTLSTrustMetadata(metadata map[string]any) {
trustMode := normalizeDirectWorkerTLSTrustMode(s.cfg.DataPlane.DirectWorkerTLSTrustMode)
metadata["tls_trust_mode"] = trustMode
metadata["production_trusted"] = directWorkerTLSTrustModeIsProductionTrusted(trustMode)
metadata["smoke_only"] = trustMode == directWorkerTLSTrustModeSmokeInsecure
if s.cfg.DataPlane.DirectWorkerTLSCARef != "" {
metadata["tls_ca_ref"] = s.cfg.DataPlane.DirectWorkerTLSCARef
}
}
func normalizeDirectWorkerTLSTrustMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case directWorkerTLSTrustModePublicCA:
return directWorkerTLSTrustModePublicCA
case directWorkerTLSTrustModePlatformCA:
return directWorkerTLSTrustModePlatformCA
default:
return directWorkerTLSTrustModeSmokeInsecure
}
}
func directWorkerTLSTrustModeIsProductionTrusted(mode string) bool {
return mode == directWorkerTLSTrustModePublicCA || mode == directWorkerTLSTrustModePlatformCA
}
func (s *Service) directWorkerWSSURL(workerID string) string {
template := strings.TrimSpace(s.cfg.DataPlane.DirectWorkerWSSURLTemplate)
if template == "" || workerID == "" {
return ""
}
return strings.ReplaceAll(template, "{worker_id}", workerID)
}
func signDataPlaneToken(claims DataPlaneTokenClaims, privateKeyPEM string) (string, error) {
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(privateKeyPEM))
if err != nil {
return "", fmt.Errorf("parse data-plane private key: %w", err)
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
signed, err := token.SignedString(privateKey)
if err != nil {
return "", fmt.Errorf("sign data-plane token: %w", err)
}
return signed, nil
}
func parseDataPlaneToken(tokenValue string, publicKey *rsa.PublicKey) (*DataPlaneTokenClaims, error) {
claims := &DataPlaneTokenClaims{}
token, err := jwt.ParseWithClaims(tokenValue, claims, func(token *jwt.Token) (any, error) {
if token.Method != jwt.SigningMethodRS256 {
return nil, fmt.Errorf("unexpected data-plane signing method: %s", token.Header["alg"])
}
return publicKey, nil
})
if err != nil {
return nil, err
}
if !token.Valid {
return nil, fmt.Errorf("data-plane token invalid")
}
return claims, nil
}
func dataPlaneAllowedChannelsFromSession(session RemoteSession) []string {
channels := []string{
sessioncontracts.DataPlaneChannelControl,
sessioncontracts.DataPlaneChannelInput,
sessioncontracts.DataPlaneChannelRender,
sessioncontracts.DataPlaneChannelTelemetry,
}
metadata := decodeJSONMap(session.Metadata)
policy, _ := metadata["policy"].(map[string]any)
if policy != nil {
if mode, _ := policy["clipboard_mode"].(string); mode != "" && mode != string(ResourceClipboardModeDisabled) {
channels = append(channels, sessioncontracts.DataPlaneChannelClipboard)
}
if mode, _ := policy["file_transfer_mode"].(string); fileTransferAllowsClientToServer(ResourceFileTransferMode(mode)) {
channels = append(channels, sessioncontracts.DataPlaneChannelFileUpload)
}
if mode, _ := policy["file_transfer_mode"].(string); fileTransferAllowsServerToClient(ResourceFileTransferMode(mode)) {
channels = append(channels, sessioncontracts.DataPlaneChannelFileDownload)
}
}
return channels
}
func (s *Service) attachDataPlaneOffer(result *SessionControlResult) error {
if result == nil || result.Attachment == nil {
return nil
}
result.GatewayURL = s.cfg.DataPlane.BackendGatewayURL
offer, err := s.buildDataPlaneOffer(result.Session, *result.Attachment)
if err != nil {
return err
}
result.DataPlane = offer
return nil
}
@@ -0,0 +1,357 @@
package sessionbroker
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"slices"
"strings"
"testing"
"time"
"github.com/example/remote-access-platform/backend/internal/platform/config"
"github.com/example/remote-access-platform/backend/internal/platform/module"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
)
func TestDataPlaneTokenScopeValidation(t *testing.T) {
now := time.Now().UTC().Truncate(time.Second)
privateKeyPEM, publicKey := testRS256Key(t)
service := &Service{
cfg: module.Config{
Auth: config.AuthConfig{
Issuer: "rap-api-test",
},
DataPlane: config.DataPlaneConfig{
TokenTTL: time.Minute,
TokenPrivateKeyPEM: privateKeyPEM,
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
},
},
now: func() time.Time { return now },
}
session := RemoteSession{
ID: "session-1",
OrganizationID: "org-1",
ResourceID: "resource-1",
WorkerID: "worker-1",
Metadata: mustJSON(t, map[string]any{"policy": map[string]any{"clipboard_mode": "bidirectional", "file_transfer_mode": "client_to_server"}}),
}
attachment := SessionAttachment{
ID: "attachment-1",
UserID: "user-1",
}
offer, err := service.buildDataPlaneOffer(session, attachment)
if err != nil {
t.Fatalf("buildDataPlaneOffer returned error: %v", err)
}
if offer == nil {
t.Fatal("expected data-plane offer")
}
claims, err := parseDataPlaneToken(offer.Token, publicKey)
if err != nil {
t.Fatalf("parseDataPlaneToken returned error: %v", err)
}
assertEqual(t, claims.SessionID, session.ID, "session_id")
assertEqual(t, claims.AttachmentID, attachment.ID, "attachment_id")
assertEqual(t, claims.UserID, attachment.UserID, "user_id")
assertEqual(t, claims.OrganizationID, session.OrganizationID, "organization_id")
assertEqual(t, claims.WorkerID, session.WorkerID, "worker_id")
assertEqual(t, claims.ResourceID, session.ResourceID, "resource_id")
if claims.ID == "" {
t.Fatal("expected jti")
}
if claims.ExpiresAt == nil || !claims.ExpiresAt.Time.Equal(now.Add(time.Minute)) {
t.Fatalf("unexpected expires_at: %v", claims.ExpiresAt)
}
if !claims.ExpiresAtValue.Equal(now.Add(time.Minute)) {
t.Fatalf("unexpected expires_at claim value: %v", claims.ExpiresAtValue)
}
for _, channel := range []string{
sessioncontracts.DataPlaneChannelControl,
sessioncontracts.DataPlaneChannelInput,
sessioncontracts.DataPlaneChannelRender,
sessioncontracts.DataPlaneChannelTelemetry,
sessioncontracts.DataPlaneChannelClipboard,
sessioncontracts.DataPlaneChannelFileUpload,
} {
if !slices.Contains(claims.AllowedChannels, channel) {
t.Fatalf("expected allowed channel %q in %v", channel, claims.AllowedChannels)
}
}
}
func TestDataPlaneOfferResponseShapeCompatibility(t *testing.T) {
now := time.Now().UTC().Truncate(time.Second)
privateKeyPEM, _ := testRS256Key(t)
service := &Service{
cfg: module.Config{
Auth: config.AuthConfig{Issuer: "rap-api-test"},
DataPlane: config.DataPlaneConfig{
TokenTTL: time.Minute,
TokenPrivateKeyPEM: privateKeyPEM,
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
DirectWorkerJSONRuntime: true,
DirectWorkerTLSTrustMode: "smoke_insecure",
},
},
now: func() time.Time { return now },
}
result := &SessionControlResult{
Session: RemoteSession{
ID: "session-1",
OrganizationID: "org-1",
ResourceID: "resource-1",
WorkerID: "worker-1",
Metadata: mustJSON(t, map[string]any{"policy": map[string]any{"clipboard_mode": "disabled", "file_transfer_mode": "disabled"}}),
},
Attachment: &SessionAttachment{ID: "attachment-1", UserID: "user-1"},
AttachToken: &sessioncontracts.AttachTokenClaims{
Token: "existing-attach-token",
SessionID: "session-1",
AttachmentID: "attachment-1",
UserID: "user-1",
WorkerID: "worker-1",
ExpiresAt: now.Add(2 * time.Minute),
},
}
if err := service.attachDataPlaneOffer(result); err != nil {
t.Fatalf("attachDataPlaneOffer returned error: %v", err)
}
payload, err := json.Marshal(result)
if err != nil {
t.Fatalf("marshal response: %v", err)
}
var decoded map[string]any
if err := json.Unmarshal(payload, &decoded); err != nil {
t.Fatalf("decode response: %v", err)
}
if decoded["session"] == nil || decoded["attachment"] == nil || decoded["attach_token"] == nil {
t.Fatalf("response lost existing fields: %s", payload)
}
if decoded["data_plane"] == nil || decoded["gateway_url"] == nil {
t.Fatalf("response missing data-plane fields: %s", payload)
}
if result.DataPlane == nil {
t.Fatal("expected data-plane offer")
}
if result.DataPlane.Preferred != sessioncontracts.DataPlaneCandidateDirectWorkerWSS {
t.Fatalf("unexpected preferred candidate: %s", result.DataPlane.Preferred)
}
if len(result.DataPlane.Candidates) != 2 {
t.Fatalf("expected direct and fallback candidates, got %d", len(result.DataPlane.Candidates))
}
if result.DataPlane.Candidates[0].URL != "wss://worker-1.worker.example.test/rap/v1/data-plane" {
t.Fatalf("unexpected direct candidate URL: %s", result.DataPlane.Candidates[0].URL)
}
if result.DataPlane.Candidates[0].Metadata["runtime_transport"] != "json_v1" {
t.Fatalf("direct candidate is missing json_v1 runtime metadata: %#v", result.DataPlane.Candidates[0].Metadata)
}
if result.DataPlane.Candidates[0].Metadata["traffic_ready"] != true {
t.Fatalf("direct candidate is missing traffic_ready metadata: %#v", result.DataPlane.Candidates[0].Metadata)
}
if result.DataPlane.Candidates[0].Metadata["smoke_only"] != true {
t.Fatalf("direct candidate should be marked smoke-only by default: %#v", result.DataPlane.Candidates[0].Metadata)
}
if result.DataPlane.Candidates[0].Metadata["production_trusted"] != false {
t.Fatalf("smoke direct candidate must not be production-trusted: %#v", result.DataPlane.Candidates[0].Metadata)
}
if !strings.Contains(result.DataPlane.Candidates[1].URL, "/api/v1/gateway/ws") {
t.Fatalf("unexpected backend candidate URL: %s", result.DataPlane.Candidates[1].URL)
}
}
func TestDataPlaneDirectCandidateMetadataRequiresRuntimeFlag(t *testing.T) {
service := &Service{
cfg: module.Config{
DataPlane: config.DataPlaneConfig{
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
DirectWorkerTLSTrustMode: "smoke_insecure",
},
},
}
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
if len(candidates) != 2 {
t.Fatalf("expected direct and fallback candidates, got %d", len(candidates))
}
if candidates[0].Metadata != nil {
t.Fatalf("direct candidate must not advertise json_v1 before runtime flag is enabled: %#v", candidates[0].Metadata)
}
}
func TestDataPlaneDirectCandidateAdvertisesBinaryRenderOnlyWhenEnabled(t *testing.T) {
service := &Service{
cfg: module.Config{
DataPlane: config.DataPlaneConfig{
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
DirectWorkerJSONRuntime: true,
DirectWorkerBinaryRender: true,
DirectWorkerTLSTrustMode: "platform_ca",
DirectWorkerTLSCARef: "rap-platform-ca:v1",
},
},
}
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
if len(candidates) != 2 {
t.Fatalf("expected direct and fallback candidates, got %d", len(candidates))
}
if candidates[0].Metadata["render_transport"] != "binary_v1" {
t.Fatalf("direct candidate is missing binary render metadata: %#v", candidates[0].Metadata)
}
if candidates[0].Metadata["binary_render"] != true {
t.Fatalf("direct candidate is missing binary_render metadata: %#v", candidates[0].Metadata)
}
if candidates[0].Metadata["default_color_mode"] != "full_color" {
t.Fatalf("direct candidate is missing default_color_mode metadata: %#v", candidates[0].Metadata)
}
if candidates[0].Metadata["production_trusted"] != true || candidates[0].Metadata["tls_trust_mode"] != "platform_ca" {
t.Fatalf("direct candidate is missing production trust metadata: %#v", candidates[0].Metadata)
}
if candidates[0].Metadata["tls_ca_ref"] != "rap-platform-ca:v1" {
t.Fatalf("direct candidate is missing tls_ca_ref metadata: %#v", candidates[0].Metadata)
}
modes, ok := candidates[0].Metadata["supported_color_modes"].([]string)
if !ok || !slices.Contains(modes, "full_color") || !slices.Contains(modes, "grayscale") {
t.Fatalf("direct candidate is missing supported_color_modes metadata: %#v", candidates[0].Metadata)
}
}
func TestDataPlaneDirectCandidateOmittedInProductionWhenSmokeOnly(t *testing.T) {
service := &Service{
cfg: module.Config{
App: config.AppConfig{Env: "production"},
DataPlane: config.DataPlaneConfig{
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
DirectWorkerJSONRuntime: true,
DirectWorkerTLSTrustMode: "smoke_insecure",
},
},
}
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
if len(candidates) != 1 {
t.Fatalf("expected fallback-only candidates in production with smoke TLS, got %d", len(candidates))
}
if candidates[0].Type != sessioncontracts.DataPlaneCandidateBackendGateway {
t.Fatalf("production must not advertise smoke-only direct candidate: %#v", candidates)
}
}
func TestDataPlaneDirectCandidateAdvertisedInProductionWhenTrusted(t *testing.T) {
service := &Service{
cfg: module.Config{
App: config.AppConfig{Env: "production"},
DataPlane: config.DataPlaneConfig{
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
DirectWorkerJSONRuntime: true,
DirectWorkerTLSTrustMode: "public_ca",
},
},
}
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
if len(candidates) != 2 {
t.Fatalf("expected trusted direct and fallback candidates, got %d", len(candidates))
}
if candidates[0].Metadata["production_trusted"] != true || candidates[0].Metadata["tls_trust_mode"] != "public_ca" {
t.Fatalf("trusted production direct candidate metadata mismatch: %#v", candidates[0].Metadata)
}
}
func TestDataPlaneCandidatesFallbackOnlyWhenDirectTemplateMissing(t *testing.T) {
service := &Service{
cfg: module.Config{
DataPlane: config.DataPlaneConfig{
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
},
},
}
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
if len(candidates) != 1 {
t.Fatalf("expected fallback-only candidate list, got %d", len(candidates))
}
if candidates[0].Type != sessioncontracts.DataPlaneCandidateBackendGateway {
t.Fatalf("unexpected candidate type: %s", candidates[0].Type)
}
}
func TestDataPlaneAllowedChannelsRespectRuntimePolicy(t *testing.T) {
cases := []struct {
name string
policy map[string]any
expected []string
blocked []string
}{
{
name: "disabled policies expose only control input render telemetry",
policy: map[string]any{"clipboard_mode": "disabled", "file_transfer_mode": "disabled"},
expected: []string{sessioncontracts.DataPlaneChannelControl, sessioncontracts.DataPlaneChannelInput, sessioncontracts.DataPlaneChannelRender, sessioncontracts.DataPlaneChannelTelemetry},
blocked: []string{sessioncontracts.DataPlaneChannelClipboard, sessioncontracts.DataPlaneChannelFileUpload},
},
{
name: "clipboard policy adds clipboard channel",
policy: map[string]any{"clipboard_mode": "server_to_client", "file_transfer_mode": "disabled"},
expected: []string{sessioncontracts.DataPlaneChannelClipboard},
blocked: []string{sessioncontracts.DataPlaneChannelFileUpload},
},
{
name: "client upload policy adds file upload channel",
policy: map[string]any{"clipboard_mode": "disabled", "file_transfer_mode": "client_to_server"},
expected: []string{sessioncontracts.DataPlaneChannelFileUpload},
blocked: []string{sessioncontracts.DataPlaneChannelClipboard},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
session := RemoteSession{Metadata: mustJSON(t, map[string]any{"policy": tc.policy})}
channels := dataPlaneAllowedChannelsFromSession(session)
for _, channel := range tc.expected {
if !slices.Contains(channels, channel) {
t.Fatalf("expected channel %q in %v", channel, channels)
}
}
for _, channel := range tc.blocked {
if slices.Contains(channels, channel) {
t.Fatalf("did not expect channel %q in %v", channel, channels)
}
}
})
}
}
func mustJSON(t *testing.T, value any) []byte {
t.Helper()
payload, err := json.Marshal(value)
if err != nil {
t.Fatalf("marshal test metadata: %v", err)
}
return payload
}
func testRS256Key(t *testing.T) (string, *rsa.PublicKey) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("generate RSA key: %v", err)
}
encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})
return string(encoded), &privateKey.PublicKey
}
func assertEqual(t *testing.T, got, want, name string) {
t.Helper()
if got != want {
t.Fatalf("unexpected %s: got %q want %q", name, got, want)
}
}
@@ -0,0 +1,15 @@
package sessionbroker
import "errors"
var (
ErrSessionNotFound = errors.New("remote session not found")
ErrAttachmentNotFound = errors.New("session attachment not found")
ErrActiveControllerPresent = errors.New("active controller already present")
ErrTakeoverNotAllowed = errors.New("takeover not allowed")
ErrTrustedDeviceRequired = errors.New("trusted device required")
ErrAccessDenied = errors.New("access denied")
ErrSessionNotAttachable = errors.New("session is not attachable")
ErrSessionNotTerminable = errors.New("session is not terminable")
ErrAttachTokenInvalid = errors.New("attach token invalid or expired")
)
@@ -0,0 +1,65 @@
package sessionbroker
import (
"context"
"time"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
)
type LiveStateStore interface {
UpsertSession(ctx context.Context, state LiveSessionState) error
GetSession(ctx context.Context, sessionID string) (*LiveSessionState, error)
DeleteSession(ctx context.Context, sessionID string) error
BindController(ctx context.Context, binding sessioncontracts.ControllerBinding, ttl time.Duration) error
GetControllerBinding(ctx context.Context, sessionID string) (*sessioncontracts.ControllerBinding, error)
ClearControllerBinding(ctx context.Context, sessionID string) error
StoreAttachToken(ctx context.Context, claims sessioncontracts.AttachTokenClaims, ttl time.Duration) error
ConsumeAttachToken(ctx context.Context, token string) (*sessioncontracts.AttachTokenClaims, error)
TouchAttachmentHeartbeat(ctx context.Context, sessionID, attachmentID string, ttl time.Duration) error
UpdateWorkerRoute(ctx context.Context, route WorkerRoute, ttl time.Duration) error
GetWorkerRoute(ctx context.Context, sessionID string) (*WorkerRoute, error)
DeleteWorkerRoute(ctx context.Context, sessionID string) error
}
type LiveSessionState struct {
SessionID string `json:"session_id"`
ResourceID string `json:"resource_id"`
WorkerID string `json:"worker_id"`
State sessioncontracts.State `json:"state"`
ControllerID string `json:"controller_id"`
AttachmentID string `json:"attachment_id"`
TakeoverVersion int `json:"takeover_version"`
RenderQualityProfile string `json:"render_quality_profile,omitempty"`
RenderState string `json:"render_state,omitempty"`
RenderWidth int `json:"render_width,omitempty"`
RenderHeight int `json:"render_height,omitempty"`
RenderFrameSequence int64 `json:"render_frame_sequence,omitempty"`
RenderFrameFormat string `json:"render_frame_format,omitempty"`
RenderFrameData string `json:"render_frame_data,omitempty"`
LastInputCorrelationID string `json:"last_input_correlation_id,omitempty"`
WorkerFrameCapturedAt string `json:"worker_frame_captured_at,omitempty"`
CursorX int `json:"cursor_x,omitempty"`
CursorY int `json:"cursor_y,omitempty"`
CursorVisible bool `json:"cursor_visible,omitempty"`
DirtyRectangles int `json:"dirty_rectangles,omitempty"`
LastRenderAt *time.Time `json:"last_render_at,omitempty"`
ClipboardSequence int64 `json:"clipboard_sequence,omitempty"`
ClipboardText string `json:"clipboard_text,omitempty"`
ClipboardOrigin string `json:"clipboard_origin,omitempty"`
ClipboardContentHash string `json:"clipboard_content_hash,omitempty"`
ClipboardUpdatedAt *time.Time `json:"clipboard_updated_at,omitempty"`
FileDownloadSequence int64 `json:"file_download_sequence,omitempty"`
FileDownloadType string `json:"file_download_type,omitempty"`
FileDownloadPayload map[string]any `json:"file_download_payload,omitempty"`
FileDownloadUpdatedAt *time.Time `json:"file_download_updated_at,omitempty"`
UpdatedAt time.Time `json:"updated_at"`
}
type WorkerRoute struct {
SessionID string `json:"session_id"`
WorkerID string `json:"worker_id"`
LeaseID string `json:"lease_id"`
ControlStream string `json:"control_stream"`
UpdatedAt time.Time `json:"updated_at"`
}
@@ -0,0 +1,132 @@
package sessionbroker
import (
"time"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
)
type AttachmentRole string
const (
AttachmentRoleController AttachmentRole = "controller"
)
type AttachmentState string
const (
AttachmentStateAttaching AttachmentState = "attaching"
AttachmentStateActive AttachmentState = "active"
AttachmentStateDetached AttachmentState = "detached"
AttachmentStateSuperseded AttachmentState = "superseded"
AttachmentStateRevoked AttachmentState = "revoked"
AttachmentStateClosed AttachmentState = "closed"
)
type ResourceTakeoverPolicy string
const (
ResourceTakeoverPolicyTrustedDevice ResourceTakeoverPolicy = "trusted_device"
ResourceTakeoverPolicySameUser ResourceTakeoverPolicy = "same_user"
ResourceTakeoverPolicyAdminOnly ResourceTakeoverPolicy = "admin_only"
)
type ResourceClipboardMode string
const (
ResourceClipboardModeDisabled ResourceClipboardMode = "disabled"
ResourceClipboardModeClientToServer ResourceClipboardMode = "client_to_server"
ResourceClipboardModeServerToClient ResourceClipboardMode = "server_to_client"
ResourceClipboardModeBidirectional ResourceClipboardMode = "bidirectional"
)
type ResourceFileTransferMode string
const (
ResourceFileTransferModeDisabled ResourceFileTransferMode = "disabled"
ResourceFileTransferModeClientToServer ResourceFileTransferMode = "client_to_server"
ResourceFileTransferModeServerToClient ResourceFileTransferMode = "server_to_client"
ResourceFileTransferModeBidirectional ResourceFileTransferMode = "bidirectional"
)
type RemoteSession struct {
ID string
OrganizationID string
ResourceID string
Protocol string
State sessioncontracts.State
WorkerID string
ControllerUserID string
DetachDeadlineAt *time.Time
LastHeartbeatAt *time.Time
TakeoverVersion int
RenderQualityProfile string
Metadata []byte
CreatedAt time.Time
UpdatedAt time.Time
}
type SessionAttachment struct {
ID string
RemoteSessionID string
UserID string
DeviceID string
Role AttachmentRole
State AttachmentState
SupersededBy *string
TakeoverOf *string
AttachedAt *time.Time
DetachedAt *time.Time
LastInputAt *time.Time
Metadata []byte
CreatedAt time.Time
UpdatedAt time.Time
}
type ResourcePolicy struct {
ResourceID string
MaxConcurrentSessions int
TakeoverPolicy ResourceTakeoverPolicy
RequireTrustedDevice bool
DetachGracePeriod time.Duration
ClipboardEnabled bool
ClipboardMode ResourceClipboardMode
FileTransferEnabled bool
FileTransferMode ResourceFileTransferMode
CreatedAt time.Time
UpdatedAt time.Time
}
type ResourceRuntimeSpec struct {
ID string
OrganizationID string
Name string
Address string
Protocol string
SecretRef *string
CertificateVerificationMode string
Metadata []byte
}
type AuditEvent struct {
ID string
ActorUserID *string
ActorDeviceID *string
EventType string
TargetType string
TargetID string
RemoteSessionID *string
Payload []byte
CreatedAt time.Time
}
const (
AuditEventSessionStarted = "session_started"
AuditEventSessionAttached = "session_attached"
AuditEventSessionDetached = "session_detached"
AuditEventSessionTakenOver = "session_taken_over"
AuditEventSessionTerminated = "session_terminated"
AuditEventSessionFailed = "session_failed"
AuditEventSecretAccessed = "resource_secret_accessed"
AuditEventSecretAccessDenied = "resource_secret_access_denied"
)
@@ -0,0 +1,164 @@
package sessionbroker
import (
"encoding/json"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
)
type Module struct {
service *Service
}
func NewModule(service *Service) *Module {
return &Module{service: service}
}
func (m *Module) Name() string {
return "session-broker"
}
func (m *Module) Service() *Service {
return m.service
}
func (m *Module) RegisterRoutes(router chi.Router) {
router.Route("/sessions", func(r chi.Router) {
r.Get("/", m.listSessions)
r.Post("/", m.startSession)
r.Post("/{sessionID}/attach", m.attachSession)
r.Post("/{sessionID}/detach", m.detachSession)
r.Post("/{sessionID}/takeover", m.takeoverSession)
r.Post("/{sessionID}/terminate", m.terminateSession)
r.Post("/{sessionID}/fail", m.markFailed)
})
}
func (m *Module) listSessions(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id")
if userID == "" {
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
return
}
sessions, err := m.service.ListSessions(r.Context(), userID)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"sessions": sessions})
}
func (m *Module) startSession(w http.ResponseWriter, r *http.Request) {
var cmd StartRemoteSessionCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid start session payload")
return
}
result, err := m.service.StartRemoteSession(r.Context(), cmd)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusCreated, result)
}
func (m *Module) attachSession(w http.ResponseWriter, r *http.Request) {
var cmd AttachToSessionCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid attach session payload")
return
}
cmd.SessionID = chi.URLParam(r, "sessionID")
result, err := m.service.AttachToSession(r.Context(), cmd)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusOK, result)
}
func (m *Module) detachSession(w http.ResponseWriter, r *http.Request) {
var cmd DetachFromSessionCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid detach session payload")
return
}
cmd.SessionID = chi.URLParam(r, "sessionID")
result, err := m.service.DetachFromSession(r.Context(), cmd)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusAccepted, result)
}
func (m *Module) takeoverSession(w http.ResponseWriter, r *http.Request) {
var cmd TakeoverSessionCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid takeover session payload")
return
}
cmd.SessionID = chi.URLParam(r, "sessionID")
result, err := m.service.TakeoverSession(r.Context(), cmd)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusOK, result)
}
func (m *Module) terminateSession(w http.ResponseWriter, r *http.Request) {
var cmd TerminateSessionCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid terminate session payload")
return
}
cmd.SessionID = chi.URLParam(r, "sessionID")
if err := m.service.TerminateSession(r.Context(), cmd); err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{
"status": "terminated",
"message": httpx.NewMessage(
"session.terminated",
"status.session.terminated",
"Session terminated.",
nil,
"",
),
})
}
func (m *Module) markFailed(w http.ResponseWriter, r *http.Request) {
var cmd MarkSessionFailedCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid fail session payload")
return
}
cmd.SessionID = chi.URLParam(r, "sessionID")
if err := m.service.MarkSessionFailed(r.Context(), cmd); err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{
"status": "failed",
"message": httpx.NewMessage(
"session.failed",
"status.session.failed",
"Session marked as failed.",
nil,
"",
),
})
}
@@ -0,0 +1,17 @@
package sessionbroker
import (
"context"
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
)
type WorkerOrchestrator interface {
Reserve(ctx context.Context, request workercontracts.AttachRequest) (*workercontracts.WorkerLease, error)
GetSessionLease(ctx context.Context, sessionID string) (*workercontracts.WorkerLease, error)
ReleaseSessionLease(ctx context.Context, sessionID string) error
PrepareAttachment(ctx context.Context, session RemoteSession, attachment SessionAttachment, runtimeMetadata map[string]any) error
NotifyDetachment(ctx context.Context, session RemoteSession, attachment SessionAttachment) error
TerminateRemoteSession(ctx context.Context, sessionID, attachmentID string) error
ValidateSessionRuntime(ctx context.Context, sessionID, workerID string) (bool, string, error)
}
@@ -0,0 +1,607 @@
package sessionbroker
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/example/remote-access-platform/backend/internal/platform/authority"
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
)
type postgresStore struct {
db postgresplatform.DBTX
authority *authority.Verifier
}
type PostgresTransactor struct {
pool *pgxpool.Pool
authority *authority.Verifier
}
func NewPostgresStore(pool *pgxpool.Pool, verifiers ...*authority.Verifier) Store {
var authorityVerifier *authority.Verifier
if len(verifiers) > 0 {
authorityVerifier = verifiers[0]
}
return &postgresStore{db: pool, authority: authorityVerifier}
}
func NewPostgresTransactor(pool *pgxpool.Pool, verifiers ...*authority.Verifier) *PostgresTransactor {
var authorityVerifier *authority.Verifier
if len(verifiers) > 0 {
authorityVerifier = verifiers[0]
}
return &PostgresTransactor{pool: pool, authority: authorityVerifier}
}
func (t *PostgresTransactor) WithinTransaction(ctx context.Context, fn func(store Store) error) error {
return postgresplatform.WithTransaction(ctx, t.pool, func(tx pgx.Tx) error {
return fn(&postgresStore{db: tx, authority: t.authority})
})
}
func (s *postgresStore) RemoteSessions() RemoteSessionRepository {
return &postgresRemoteSessionRepository{db: s.db}
}
func (s *postgresStore) SessionAttachments() SessionAttachmentRepository {
return &postgresSessionAttachmentRepository{db: s.db}
}
func (s *postgresStore) ResourcePolicies() ResourcePolicyRepository {
return &postgresResourcePolicyRepository{db: s.db}
}
func (s *postgresStore) ResourceRuntime() ResourceRuntimeRepository {
return &postgresResourceRuntimeRepository{db: s.db}
}
func (s *postgresStore) AuditEvents() AuditEventRepository {
return &postgresAuditEventRepository{db: s.db}
}
func (s *postgresStore) Access() AccessRepository {
return &postgresAccessRepository{db: s.db, authority: s.authority}
}
type postgresRemoteSessionRepository struct {
db postgresplatform.DBTX
}
type postgresSessionAttachmentRepository struct {
db postgresplatform.DBTX
}
type postgresResourcePolicyRepository struct {
db postgresplatform.DBTX
}
type postgresResourceRuntimeRepository struct {
db postgresplatform.DBTX
}
type postgresAuditEventRepository struct {
db postgresplatform.DBTX
}
type postgresAccessRepository struct {
db postgresplatform.DBTX
authority *authority.Verifier
}
func (r *postgresRemoteSessionRepository) Create(ctx context.Context, session RemoteSession) error {
const query = `
INSERT INTO remote_sessions (
id, organization_id, resource_id, protocol, state, worker_id, controller_user_id, detach_deadline_at,
last_heartbeat_at, takeover_version, metadata, created_at, updated_at
) VALUES (
$1::uuid, $2::uuid, $3::uuid, $4, $5, NULLIF($6, ''), $7::uuid, $8, $9, $10, $11::jsonb, $12, $13
)
`
if _, err := r.db.Exec(ctx, query,
session.ID,
session.OrganizationID,
session.ResourceID,
session.Protocol,
session.State,
session.WorkerID,
session.ControllerUserID,
session.DetachDeadlineAt,
session.LastHeartbeatAt,
session.TakeoverVersion,
jsonPayload(session.Metadata),
session.CreatedAt,
session.UpdatedAt,
); err != nil {
return fmt.Errorf("create remote session: %w", err)
}
return nil
}
func (r *postgresRemoteSessionRepository) GetByID(ctx context.Context, sessionID string) (*RemoteSession, error) {
return r.getByID(ctx, sessionID, "")
}
func (r *postgresRemoteSessionRepository) GetByIDForUpdate(ctx context.Context, sessionID string) (*RemoteSession, error) {
return r.getByID(ctx, sessionID, " FOR UPDATE")
}
func (r *postgresRemoteSessionRepository) getByID(ctx context.Context, sessionID string, suffix string) (*RemoteSession, error) {
query := `
SELECT id::text, organization_id::text, resource_id::text, protocol, state, COALESCE(worker_id, ''), controller_user_id::text,
detach_deadline_at, last_heartbeat_at, takeover_version, metadata, created_at, updated_at
FROM remote_sessions
WHERE id = $1::uuid` + suffix
remoteSession, err := scanRemoteSession(r.db.QueryRow(ctx, query, sessionID))
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return remoteSession, err
}
func (r *postgresRemoteSessionRepository) ListByController(ctx context.Context, userID string) ([]RemoteSession, error) {
const query = `
SELECT id::text, organization_id::text, resource_id::text, protocol, state, COALESCE(worker_id, ''), controller_user_id::text,
detach_deadline_at, last_heartbeat_at, takeover_version, metadata, created_at, updated_at
FROM remote_sessions
WHERE controller_user_id = $1::uuid
ORDER BY updated_at DESC
`
rows, err := r.db.Query(ctx, query, userID)
if err != nil {
return nil, fmt.Errorf("list remote sessions: %w", err)
}
defer rows.Close()
var sessions []RemoteSession
for rows.Next() {
item, err := scanRemoteSession(rows)
if err != nil {
return nil, err
}
sessions = append(sessions, *item)
}
return sessions, rows.Err()
}
func (r *postgresRemoteSessionRepository) CountLiveByResource(ctx context.Context, resourceID string) (int, error) {
const query = `
SELECT COUNT(*)
FROM remote_sessions
WHERE resource_id = $1::uuid AND state IN ('starting', 'active', 'detached', 'reconnecting')
`
var count int
if err := r.db.QueryRow(ctx, query, resourceID).Scan(&count); err != nil {
return 0, fmt.Errorf("count live remote sessions: %w", err)
}
return count, nil
}
func (r *postgresRemoteSessionRepository) ListDetachedExpired(ctx context.Context, before time.Time, limit int) ([]RemoteSession, error) {
const query = `
SELECT id::text, organization_id::text, resource_id::text, protocol, state, COALESCE(worker_id, ''), controller_user_id::text,
detach_deadline_at, last_heartbeat_at, takeover_version, metadata, created_at, updated_at
FROM remote_sessions
WHERE state = 'detached' AND detach_deadline_at IS NOT NULL AND detach_deadline_at <= $1
ORDER BY detach_deadline_at ASC
LIMIT $2
`
rows, err := r.db.Query(ctx, query, before, limit)
if err != nil {
return nil, fmt.Errorf("list detached expired sessions: %w", err)
}
defer rows.Close()
var sessions []RemoteSession
for rows.Next() {
item, err := scanRemoteSession(rows)
if err != nil {
return nil, err
}
sessions = append(sessions, *item)
}
return sessions, rows.Err()
}
func (r *postgresRemoteSessionRepository) UpdateState(ctx context.Context, params UpdateRemoteSessionStateParams) error {
const query = `
UPDATE remote_sessions
SET state = $2,
worker_id = NULLIF($3, ''),
detach_deadline_at = $4,
last_heartbeat_at = $5,
takeover_version = $6,
updated_at = $7
WHERE id = $1::uuid
`
if _, err := r.db.Exec(ctx, query,
params.RemoteSessionID,
params.State,
params.WorkerID,
params.DetachDeadlineAt,
params.LastHeartbeatAt,
params.TakeoverVersion,
params.UpdatedAt,
); err != nil {
return fmt.Errorf("update remote session state: %w", err)
}
return nil
}
func (r *postgresSessionAttachmentRepository) Create(ctx context.Context, attachment SessionAttachment) error {
const query = `
INSERT INTO session_attachments (
id, remote_session_id, user_id, device_id, role, state, superseded_by,
takeover_of, attached_at, detached_at, last_input_at, metadata, created_at, updated_at
) VALUES (
$1::uuid, $2::uuid, $3::uuid, $4::uuid, $5, $6, NULLIF($7, '')::uuid,
NULLIF($8, '')::uuid, $9, $10, $11, $12::jsonb, $13, $14
)
`
if _, err := r.db.Exec(ctx, query,
attachment.ID,
attachment.RemoteSessionID,
attachment.UserID,
attachment.DeviceID,
attachment.Role,
attachment.State,
stringValue(attachment.SupersededBy),
stringValue(attachment.TakeoverOf),
attachment.AttachedAt,
attachment.DetachedAt,
attachment.LastInputAt,
jsonPayload(attachment.Metadata),
attachment.CreatedAt,
attachment.UpdatedAt,
); err != nil {
return fmt.Errorf("create session attachment: %w", err)
}
return nil
}
func (r *postgresSessionAttachmentRepository) GetByID(ctx context.Context, attachmentID string) (*SessionAttachment, error) {
return r.getByID(ctx, attachmentID, "")
}
func (r *postgresSessionAttachmentRepository) GetByIDForUpdate(ctx context.Context, attachmentID string) (*SessionAttachment, error) {
return r.getByID(ctx, attachmentID, " FOR UPDATE")
}
func (r *postgresSessionAttachmentRepository) getByID(ctx context.Context, attachmentID string, suffix string) (*SessionAttachment, error) {
query := `
SELECT id::text, remote_session_id::text, user_id::text, device_id::text, role, state,
superseded_by::text, takeover_of::text, attached_at, detached_at, last_input_at, metadata, created_at, updated_at
FROM session_attachments
WHERE id = $1::uuid` + suffix
attachment, err := scanSessionAttachment(r.db.QueryRow(ctx, query, attachmentID))
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return attachment, err
}
func (r *postgresSessionAttachmentRepository) ListByRemoteSession(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error) {
return r.listByRemoteSession(ctx, remoteSessionID, "")
}
func (r *postgresSessionAttachmentRepository) ListActiveByRemoteSessionForUpdate(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error) {
return r.listByRemoteSession(ctx, remoteSessionID, " AND state IN ('attaching', 'active', 'reconnecting') FOR UPDATE")
}
func (r *postgresSessionAttachmentRepository) listByRemoteSession(ctx context.Context, remoteSessionID string, suffix string) ([]SessionAttachment, error) {
query := `
SELECT id::text, remote_session_id::text, user_id::text, device_id::text, role, state,
superseded_by::text, takeover_of::text, attached_at, detached_at, last_input_at, metadata, created_at, updated_at
FROM session_attachments
WHERE remote_session_id = $1::uuid` + suffix
rows, err := r.db.Query(ctx, query, remoteSessionID)
if err != nil {
return nil, fmt.Errorf("list session attachments: %w", err)
}
defer rows.Close()
var attachments []SessionAttachment
for rows.Next() {
item, err := scanSessionAttachment(rows)
if err != nil {
return nil, err
}
attachments = append(attachments, *item)
}
return attachments, rows.Err()
}
func (r *postgresSessionAttachmentRepository) UpdateState(ctx context.Context, params UpdateSessionAttachmentStateParams) error {
const query = `
UPDATE session_attachments
SET state = $2,
detached_at = $3,
last_input_at = $4,
updated_at = $5
WHERE id = $1::uuid
`
if _, err := r.db.Exec(ctx, query,
params.AttachmentID,
params.State,
params.DetachedAt,
params.LastInputAt,
params.UpdatedAt,
); err != nil {
return fmt.Errorf("update session attachment state: %w", err)
}
return nil
}
func (r *postgresSessionAttachmentRepository) Supersede(ctx context.Context, params SupersedeAttachmentParams) error {
const query = `
UPDATE session_attachments
SET state = 'superseded',
superseded_by = $2::uuid,
detached_at = $3,
updated_at = $4
WHERE id = $1::uuid
`
if _, err := r.db.Exec(ctx, query,
params.PreviousAttachmentID,
params.NextAttachmentID,
params.DetachedAt,
params.UpdatedAt,
); err != nil {
return fmt.Errorf("supersede attachment: %w", err)
}
return nil
}
func (r *postgresResourcePolicyRepository) GetByResourceID(ctx context.Context, resourceID string) (*ResourcePolicy, error) {
const query = `
SELECT resource_id::text, max_concurrent_sessions, takeover_policy, require_trusted_device,
detach_grace_period_seconds, clipboard_enabled, clipboard_mode, file_transfer_enabled,
COALESCE(file_transfer_mode, 'disabled'), created_at, updated_at
FROM resource_policies
WHERE resource_id = $1::uuid
`
policy, err := scanResourcePolicy(r.db.QueryRow(ctx, query, resourceID))
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return policy, err
}
func (r *postgresResourcePolicyRepository) Upsert(ctx context.Context, policy ResourcePolicy) error {
const query = `
INSERT INTO resource_policies (
resource_id, max_concurrent_sessions, takeover_policy, require_trusted_device,
detach_grace_period_seconds, clipboard_enabled, clipboard_mode, file_transfer_enabled, file_transfer_mode, created_at, updated_at
) VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
ON CONFLICT (resource_id) DO UPDATE SET
max_concurrent_sessions = EXCLUDED.max_concurrent_sessions,
takeover_policy = EXCLUDED.takeover_policy,
require_trusted_device = EXCLUDED.require_trusted_device,
detach_grace_period_seconds = EXCLUDED.detach_grace_period_seconds,
clipboard_enabled = EXCLUDED.clipboard_enabled,
clipboard_mode = EXCLUDED.clipboard_mode,
file_transfer_enabled = EXCLUDED.file_transfer_enabled,
file_transfer_mode = EXCLUDED.file_transfer_mode,
updated_at = EXCLUDED.updated_at
`
clipboardMode := normalizeClipboardMode(policy.ClipboardMode)
fileTransferMode := normalizeFileTransferMode(policy.FileTransferMode)
if _, err := r.db.Exec(ctx, query,
policy.ResourceID,
policy.MaxConcurrentSessions,
policy.TakeoverPolicy,
policy.RequireTrustedDevice,
int(policy.DetachGracePeriod.Seconds()),
clipboardMode != ResourceClipboardModeDisabled,
clipboardMode,
fileTransferAllowsClientToServer(fileTransferMode),
fileTransferMode,
policy.CreatedAt,
policy.UpdatedAt,
); err != nil {
return fmt.Errorf("upsert resource policy: %w", err)
}
return nil
}
func (r *postgresResourceRuntimeRepository) GetByID(ctx context.Context, resourceID string) (*ResourceRuntimeSpec, error) {
const query = `
SELECT id::text, organization_id::text, name, address, protocol, secret_ref, certificate_verification_mode, metadata
FROM resources
WHERE id = $1::uuid
`
item := &ResourceRuntimeSpec{}
var secretRef *string
var metadata []byte
if err := r.db.QueryRow(ctx, query, resourceID).Scan(
&item.ID,
&item.OrganizationID,
&item.Name,
&item.Address,
&item.Protocol,
&secretRef,
&item.CertificateVerificationMode,
&metadata,
); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("get resource runtime spec: %w", err)
}
item.SecretRef = secretRef
item.Metadata = metadata
return item, nil
}
func (r *postgresAuditEventRepository) Create(ctx context.Context, event AuditEvent) error {
const query = `
INSERT INTO audit_events (
id, actor_user_id, actor_device_id, event_type, target_type, target_id,
remote_session_id, payload, created_at
) VALUES (
$1::uuid, NULLIF($2, '')::uuid, NULLIF($3, '')::uuid, $4, $5, $6,
NULLIF($7, '')::uuid, $8::jsonb, $9
)
`
if _, err := r.db.Exec(ctx, query,
event.ID,
stringValue(event.ActorUserID),
stringValue(event.ActorDeviceID),
event.EventType,
event.TargetType,
event.TargetID,
stringValue(event.RemoteSessionID),
jsonPayload(event.Payload),
event.CreatedAt,
); err != nil {
return fmt.Errorf("create audit event: %w", err)
}
return nil
}
func (r *postgresAccessRepository) IsTrustedDevice(ctx context.Context, userID, deviceID string) (bool, error) {
const query = `
SELECT EXISTS(
SELECT 1 FROM devices
WHERE id = $1::uuid AND user_id = $2::uuid AND trust_status = 'trusted' AND revoked_at IS NULL
)
`
var trusted bool
if err := r.db.QueryRow(ctx, query, deviceID, userID).Scan(&trusted); err != nil {
return false, fmt.Errorf("check trusted device: %w", err)
}
return trusted, nil
}
func (r *postgresAccessRepository) GetPlatformRole(ctx context.Context, userID string) (string, error) {
return authority.EffectivePlatformRole(ctx, r.db, r.authority, userID)
}
func (r *postgresAccessRepository) GetOrganizationRole(ctx context.Context, organizationID, userID string) (string, bool, error) {
const query = `
SELECT role_id
FROM organization_memberships
WHERE organization_id = $1::uuid AND user_id = $2::uuid AND status = 'active'
`
var role string
if err := r.db.QueryRow(ctx, query, organizationID, userID).Scan(&role); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return "", false, nil
}
return "", false, fmt.Errorf("get organization role: %w", err)
}
return role, true, nil
}
type scanner interface {
Scan(dest ...any) error
}
func scanRemoteSession(row scanner) (*RemoteSession, error) {
item := &RemoteSession{}
var detachDeadlineAt, lastHeartbeatAt *time.Time
var metadata []byte
if err := row.Scan(
&item.ID,
&item.OrganizationID,
&item.ResourceID,
&item.Protocol,
&item.State,
&item.WorkerID,
&item.ControllerUserID,
&detachDeadlineAt,
&lastHeartbeatAt,
&item.TakeoverVersion,
&metadata,
&item.CreatedAt,
&item.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan remote session: %w", err)
}
item.DetachDeadlineAt = detachDeadlineAt
item.LastHeartbeatAt = lastHeartbeatAt
item.Metadata = metadata
item.RenderQualityProfile = renderQualityProfileFromSessionMetadata(metadata)
return item, nil
}
func scanSessionAttachment(row scanner) (*SessionAttachment, error) {
item := &SessionAttachment{}
var supersededBy, takeoverOf *string
var attachedAt, detachedAt, lastInputAt *time.Time
var metadata []byte
if err := row.Scan(
&item.ID,
&item.RemoteSessionID,
&item.UserID,
&item.DeviceID,
&item.Role,
&item.State,
&supersededBy,
&takeoverOf,
&attachedAt,
&detachedAt,
&lastInputAt,
&metadata,
&item.CreatedAt,
&item.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan session attachment: %w", err)
}
item.SupersededBy = supersededBy
item.TakeoverOf = takeoverOf
item.AttachedAt = attachedAt
item.DetachedAt = detachedAt
item.LastInputAt = lastInputAt
item.Metadata = metadata
return item, nil
}
func scanResourcePolicy(row scanner) (*ResourcePolicy, error) {
item := &ResourcePolicy{}
var detachGraceSeconds int
if err := row.Scan(
&item.ResourceID,
&item.MaxConcurrentSessions,
&item.TakeoverPolicy,
&item.RequireTrustedDevice,
&detachGraceSeconds,
&item.ClipboardEnabled,
&item.ClipboardMode,
&item.FileTransferEnabled,
&item.FileTransferMode,
&item.CreatedAt,
&item.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan resource policy: %w", err)
}
item.DetachGracePeriod = time.Duration(detachGraceSeconds) * time.Second
item.ClipboardMode = normalizeClipboardMode(item.ClipboardMode)
item.ClipboardEnabled = item.ClipboardMode != ResourceClipboardModeDisabled
item.FileTransferMode = normalizeFileTransferMode(item.FileTransferMode)
item.FileTransferEnabled = fileTransferAllowsClientToServer(item.FileTransferMode)
return item, nil
}
func jsonPayload(payload []byte) []byte {
if len(payload) == 0 {
return []byte(`{}`)
}
if json.Valid(payload) {
return payload
}
return []byte(`{}`)
}
func stringValue(value *string) string {
if value == nil {
return ""
}
return *value
}
@@ -0,0 +1,140 @@
package sessionbroker
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
)
type RedisLiveStateStore struct {
client *redis.Client
}
func NewRedisLiveStateStore(client *redis.Client) *RedisLiveStateStore {
return &RedisLiveStateStore{client: client}
}
func (s *RedisLiveStateStore) UpsertSession(ctx context.Context, state LiveSessionState) error {
return s.setJSON(ctx, liveSessionKey(state.SessionID), state, 0)
}
func (s *RedisLiveStateStore) GetSession(ctx context.Context, sessionID string) (*LiveSessionState, error) {
var state LiveSessionState
ok, err := s.getJSON(ctx, liveSessionKey(sessionID), &state)
if err != nil || !ok {
return nil, err
}
return &state, nil
}
func (s *RedisLiveStateStore) DeleteSession(ctx context.Context, sessionID string) error {
return s.client.Del(ctx, liveSessionKey(sessionID)).Err()
}
func (s *RedisLiveStateStore) BindController(ctx context.Context, binding sessioncontracts.ControllerBinding, ttl time.Duration) error {
return s.setJSON(ctx, controllerBindingKey(binding.SessionID), binding, ttl)
}
func (s *RedisLiveStateStore) GetControllerBinding(ctx context.Context, sessionID string) (*sessioncontracts.ControllerBinding, error) {
var binding sessioncontracts.ControllerBinding
ok, err := s.getJSON(ctx, controllerBindingKey(sessionID), &binding)
if err != nil || !ok {
return nil, err
}
return &binding, nil
}
func (s *RedisLiveStateStore) ClearControllerBinding(ctx context.Context, sessionID string) error {
return s.client.Del(ctx, controllerBindingKey(sessionID)).Err()
}
func (s *RedisLiveStateStore) StoreAttachToken(ctx context.Context, claims sessioncontracts.AttachTokenClaims, ttl time.Duration) error {
return s.setJSON(ctx, attachTokenKey(claims.Token), claims, ttl)
}
func (s *RedisLiveStateStore) ConsumeAttachToken(ctx context.Context, token string) (*sessioncontracts.AttachTokenClaims, error) {
key := attachTokenKey(token)
payload, err := s.client.GetDel(ctx, key).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("consume attach token: %w", err)
}
var claims sessioncontracts.AttachTokenClaims
if err := json.Unmarshal([]byte(payload), &claims); err != nil {
return nil, fmt.Errorf("decode attach token: %w", err)
}
return &claims, nil
}
func (s *RedisLiveStateStore) TouchAttachmentHeartbeat(ctx context.Context, sessionID, attachmentID string, ttl time.Duration) error {
return s.client.Set(ctx, attachmentHeartbeatKey(sessionID, attachmentID), time.Now().UTC().Format(time.RFC3339Nano), ttl).Err()
}
func (s *RedisLiveStateStore) UpdateWorkerRoute(ctx context.Context, route WorkerRoute, ttl time.Duration) error {
return s.setJSON(ctx, workerRouteKey(route.SessionID), route, ttl)
}
func (s *RedisLiveStateStore) GetWorkerRoute(ctx context.Context, sessionID string) (*WorkerRoute, error) {
var route WorkerRoute
ok, err := s.getJSON(ctx, workerRouteKey(sessionID), &route)
if err != nil || !ok {
return nil, err
}
return &route, nil
}
func (s *RedisLiveStateStore) DeleteWorkerRoute(ctx context.Context, sessionID string) error {
return s.client.Del(ctx, workerRouteKey(sessionID)).Err()
}
func (s *RedisLiveStateStore) setJSON(ctx context.Context, key string, value any, ttl time.Duration) error {
payload, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("encode redis payload: %w", err)
}
if err := s.client.Set(ctx, key, payload, ttl).Err(); err != nil {
return fmt.Errorf("set redis key %s: %w", key, err)
}
return nil
}
func (s *RedisLiveStateStore) getJSON(ctx context.Context, key string, dest any) (bool, error) {
payload, err := s.client.Get(ctx, key).Result()
if err == redis.Nil {
return false, nil
}
if err != nil {
return false, fmt.Errorf("get redis key %s: %w", key, err)
}
if err := json.Unmarshal([]byte(payload), dest); err != nil {
return false, fmt.Errorf("decode redis key %s: %w", key, err)
}
return true, nil
}
func liveSessionKey(sessionID string) string {
return "live:session:" + sessionID
}
func controllerBindingKey(sessionID string) string {
return "live:session:" + sessionID + ":controller"
}
func attachTokenKey(token string) string {
return "live:attach:" + token
}
func attachmentHeartbeatKey(sessionID, attachmentID string) string {
return "live:session:" + sessionID + ":attachment:" + attachmentID + ":heartbeat"
}
func workerRouteKey(sessionID string) string {
return "live:session:" + sessionID + ":worker-route"
}
@@ -0,0 +1,85 @@
package sessionbroker
import (
"context"
"time"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
)
type RemoteSessionRepository interface {
Create(ctx context.Context, session RemoteSession) error
GetByID(ctx context.Context, sessionID string) (*RemoteSession, error)
GetByIDForUpdate(ctx context.Context, sessionID string) (*RemoteSession, error)
ListByController(ctx context.Context, userID string) ([]RemoteSession, error)
CountLiveByResource(ctx context.Context, resourceID string) (int, error)
ListDetachedExpired(ctx context.Context, before time.Time, limit int) ([]RemoteSession, error)
UpdateState(ctx context.Context, params UpdateRemoteSessionStateParams) error
}
type SessionAttachmentRepository interface {
Create(ctx context.Context, attachment SessionAttachment) error
GetByID(ctx context.Context, attachmentID string) (*SessionAttachment, error)
GetByIDForUpdate(ctx context.Context, attachmentID string) (*SessionAttachment, error)
ListByRemoteSession(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error)
ListActiveByRemoteSessionForUpdate(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error)
UpdateState(ctx context.Context, params UpdateSessionAttachmentStateParams) error
Supersede(ctx context.Context, params SupersedeAttachmentParams) error
}
type ResourcePolicyRepository interface {
GetByResourceID(ctx context.Context, resourceID string) (*ResourcePolicy, error)
Upsert(ctx context.Context, policy ResourcePolicy) error
}
type AuditEventRepository interface {
Create(ctx context.Context, event AuditEvent) error
}
type Store interface {
RemoteSessions() RemoteSessionRepository
SessionAttachments() SessionAttachmentRepository
ResourcePolicies() ResourcePolicyRepository
ResourceRuntime() ResourceRuntimeRepository
AuditEvents() AuditEventRepository
Access() AccessRepository
}
type Transactor interface {
WithinTransaction(ctx context.Context, fn func(store Store) error) error
}
type UpdateRemoteSessionStateParams struct {
RemoteSessionID string
State sessioncontracts.State
WorkerID string
DetachDeadlineAt *time.Time
LastHeartbeatAt *time.Time
TakeoverVersion int
UpdatedAt time.Time
}
type UpdateSessionAttachmentStateParams struct {
AttachmentID string
State AttachmentState
DetachedAt *time.Time
LastInputAt *time.Time
UpdatedAt time.Time
}
type SupersedeAttachmentParams struct {
PreviousAttachmentID string
NextAttachmentID string
DetachedAt time.Time
UpdatedAt time.Time
}
type AccessRepository interface {
IsTrustedDevice(ctx context.Context, userID, deviceID string) (bool, error)
GetPlatformRole(ctx context.Context, userID string) (string, error)
GetOrganizationRole(ctx context.Context, organizationID, userID string) (string, bool, error)
}
type ResourceRuntimeRepository interface {
GetByID(ctx context.Context, resourceID string) (*ResourceRuntimeSpec, error)
}
@@ -0,0 +1,138 @@
package sessionbroker
import (
"context"
"encoding/json"
"errors"
"testing"
"github.com/example/remote-access-platform/backend/internal/platform/config"
"github.com/example/remote-access-platform/backend/internal/platform/module"
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
)
type fakeSecretResolver struct {
response *secrets.ResolvedResourceSecret
err error
request secrets.ResolveResourceSecretRequest
}
func testAppConfig(env string) config.AppConfig {
return config.AppConfig{Name: "rap-api-test", Env: env}
}
func (r *fakeSecretResolver) ResolveForSession(_ context.Context, req secrets.ResolveResourceSecretRequest) (*secrets.ResolvedResourceSecret, error) {
r.request = req
if r.err != nil {
return nil, r.err
}
return r.response, nil
}
func TestRuntimeAssignmentMetadataMergesResolvedSecretWithoutMutatingSessionMetadata(t *testing.T) {
resolver := &fakeSecretResolver{
response: &secrets.ResolvedResourceSecret{
Descriptor: secrets.ResourceSecretDescriptor{Version: 3},
Payload: json.RawMessage(`{"username":"user","password":"secret","domain":"corp"}`),
},
}
service := NewService(module.Dependencies{
Config: module.Config{App: testAppConfig("production")},
}, nil, nil, nil, nil, resolver)
sessionMetadata := mustJSON(t, map[string]any{
"resource": map[string]any{
"id": "resource-1",
"organization_id": "org-1",
"secret_ref": "rap-secret://org/org-1/resources/resource-1/primary",
"metadata": map[string]any{
"rdp_host": "host",
},
},
})
session := RemoteSession{
ID: "session-1",
OrganizationID: "org-1",
ResourceID: "resource-1",
WorkerID: "worker-1",
Metadata: sessionMetadata,
}
metadata, secretRef, version, err := service.runtimeAssignmentMetadata(context.Background(), session, &workercontracts.WorkerLease{LeaseID: "lease-1"})
if err != nil {
t.Fatalf("runtimeAssignmentMetadata returned error: %v", err)
}
if secretRef == "" || version != 3 {
t.Fatalf("expected secret ref and version, got ref=%q version=%d", secretRef, version)
}
resource := metadata["resource"].(map[string]any)
resourceMetadata := resource["metadata"].(map[string]any)
if resourceMetadata["username"] != "user" || resourceMetadata["password"] != "secret" || resourceMetadata["domain"] != "corp" {
t.Fatalf("resolved secret was not merged: %#v", resourceMetadata)
}
var persisted map[string]any
if err := json.Unmarshal(session.Metadata, &persisted); err != nil {
t.Fatalf("decode persisted metadata: %v", err)
}
persistedResource := persisted["resource"].(map[string]any)
persistedMetadata := persistedResource["metadata"].(map[string]any)
if _, ok := persistedMetadata["password"]; ok {
t.Fatalf("session metadata was mutated with plaintext secret")
}
if resolver.request.LeaseID != "lease-1" || resolver.request.WorkerID != "worker-1" {
t.Fatalf("resolver request missed lease/worker proof: %#v", resolver.request)
}
}
func TestRuntimeAssignmentMetadataRequiresResolverInProduction(t *testing.T) {
service := NewService(module.Dependencies{
Config: module.Config{App: testAppConfig("production")},
}, nil, nil, nil, nil)
session := RemoteSession{
ID: "session-1",
OrganizationID: "org-1",
ResourceID: "resource-1",
WorkerID: "worker-1",
Metadata: mustJSON(t, map[string]any{
"resource": map[string]any{
"secret_ref": "rap-secret://org/org-1/resources/resource-1/primary",
},
}),
}
_, _, _, err := service.runtimeAssignmentMetadata(context.Background(), session, &workercontracts.WorkerLease{LeaseID: "lease-1"})
if !errors.Is(err, secrets.ErrSecretEncryptionKeyMissing) {
t.Fatalf("expected missing resolver error, got %v", err)
}
}
func TestRuntimeAssignmentMetadataAllowsDevelopmentMetadataWithoutResolver(t *testing.T) {
service := NewService(module.Dependencies{
Config: module.Config{App: testAppConfig("development")},
}, nil, nil, nil, nil)
session := RemoteSession{
ID: "session-1",
OrganizationID: "org-1",
ResourceID: "resource-1",
WorkerID: "worker-1",
Metadata: mustJSON(t, map[string]any{
"resource": map[string]any{
"secret_ref": "rap-secret://org/org-1/resources/resource-1/primary",
"metadata": map[string]any{
"username": "dev-user",
"password": "dev-password",
},
},
}),
}
metadata, secretRef, _, err := service.runtimeAssignmentMetadata(context.Background(), session, nil)
if err != nil {
t.Fatalf("development metadata should not require resolver: %v", err)
}
if secretRef != "" {
t.Fatalf("development fallback should not audit resolver use, got %q", secretRef)
}
resource := metadata["resource"].(map[string]any)
resourceMetadata := resource["metadata"].(map[string]any)
if resourceMetadata["password"] != "dev-password" {
t.Fatalf("development metadata was not preserved")
}
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,391 @@
package sessionbroker
import (
"context"
"io"
"log/slog"
"testing"
"time"
"github.com/example/remote-access-platform/backend/internal/platform/module"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
)
func TestHandleWorkerConnectedIgnoresTerminalSession(t *testing.T) {
service, store, live, _ := newStaleWorkerEventTestService()
store.remote.sessions["session-1"] = RemoteSession{
ID: "session-1",
State: sessioncontracts.StateTerminated,
WorkerID: "worker-1",
}
if err := service.HandleWorkerConnected(context.Background(), "session-1"); err != nil {
t.Fatalf("HandleWorkerConnected returned error for stale terminal event: %v", err)
}
if got := store.remote.sessions["session-1"].State; got != sessioncontracts.StateTerminated {
t.Fatalf("stale connected event changed terminal state to %q", got)
}
if store.remote.updateCount != 0 {
t.Fatalf("stale connected event updated authoritative session %d times", store.remote.updateCount)
}
if live.upsertCount != 0 {
t.Fatalf("stale connected event recreated live state %d times", live.upsertCount)
}
}
func TestUpdateWorkerRenderTelemetryIgnoresTerminalSession(t *testing.T) {
service, store, live, _ := newStaleWorkerEventTestService()
store.remote.sessions["session-1"] = RemoteSession{
ID: "session-1",
State: sessioncontracts.StateTerminated,
WorkerID: "worker-1",
}
err := service.UpdateWorkerRenderTelemetry(context.Background(), "session-1", map[string]any{
"render_state": "ready",
"width": 1280,
"height": 720,
"frame_sequence": int64(99),
"frame_data": "stale-frame",
})
if err != nil {
t.Fatalf("UpdateWorkerRenderTelemetry returned error for stale terminal event: %v", err)
}
if live.upsertCount != 0 {
t.Fatalf("stale render event recreated live state %d times", live.upsertCount)
}
if live.sessions["session-1"] != nil {
t.Fatalf("stale render event left live state behind: %#v", live.sessions["session-1"])
}
}
func TestMarkSessionFailedTransitionsActiveSession(t *testing.T) {
service, store, live, orchestrator := newStaleWorkerEventTestService()
store.remote.sessions["session-1"] = RemoteSession{
ID: "session-1",
State: sessioncontracts.StateActive,
WorkerID: "worker-1",
TakeoverVersion: 3,
}
store.attachments.items["attachment-1"] = SessionAttachment{
ID: "attachment-1",
RemoteSessionID: "session-1",
State: AttachmentStateActive,
}
live.sessions["session-1"] = &LiveSessionState{SessionID: "session-1", State: sessioncontracts.StateActive}
if err := service.MarkSessionFailed(context.Background(), MarkSessionFailedCommand{SessionID: "session-1", Reason: "worker_lost"}); err != nil {
t.Fatalf("MarkSessionFailed returned error: %v", err)
}
if got := store.remote.sessions["session-1"].State; got != sessioncontracts.StateFailed {
t.Fatalf("expected failed state, got %q", got)
}
if got := store.attachments.items["attachment-1"].State; got != AttachmentStateClosed {
t.Fatalf("expected attachment closed, got %q", got)
}
if store.audit.createCount != 1 {
t.Fatalf("expected one audit event, got %d", store.audit.createCount)
}
if live.sessions["session-1"] != nil {
t.Fatal("expected failed session live state to be deleted")
}
if orchestrator.releaseCount != 1 {
t.Fatalf("expected session lease release, got %d", orchestrator.releaseCount)
}
}
func TestMarkSessionFailedAlreadyFailedIsIdempotent(t *testing.T) {
service, store, _, _ := newStaleWorkerEventTestService()
store.remote.sessions["session-1"] = RemoteSession{
ID: "session-1",
State: sessioncontracts.StateFailed,
WorkerID: "worker-1",
TakeoverVersion: 1,
}
if err := service.MarkSessionFailed(context.Background(), MarkSessionFailedCommand{SessionID: "session-1", Reason: "duplicate_worker_failure"}); err != nil {
t.Fatalf("duplicate MarkSessionFailed returned error: %v", err)
}
if got := store.remote.sessions["session-1"].State; got != sessioncontracts.StateFailed {
t.Fatalf("duplicate terminal event changed state to %q", got)
}
}
func newStaleWorkerEventTestService() (*Service, *staleWorkerEventTestStore, *staleWorkerEventLiveState, *staleWorkerEventOrchestrator) {
store := &staleWorkerEventTestStore{
remote: &staleWorkerEventRemoteSessions{sessions: map[string]RemoteSession{}},
attachments: &staleWorkerEventAttachments{items: map[string]SessionAttachment{}},
policies: &staleWorkerEventPolicies{},
audit: &staleWorkerEventAudit{},
}
live := &staleWorkerEventLiveState{sessions: map[string]*LiveSessionState{}}
orchestrator := &staleWorkerEventOrchestrator{}
service := NewService(module.Dependencies{
Infra: module.Infra{Logger: slog.New(slog.NewTextHandler(io.Discard, nil))},
}, store, staleWorkerEventTransactor{store: store}, live, orchestrator)
service.now = func() time.Time { return time.Unix(100, 0).UTC() }
return service, store, live, orchestrator
}
type staleWorkerEventTransactor struct {
store Store
}
func (t staleWorkerEventTransactor) WithinTransaction(ctx context.Context, fn func(store Store) error) error {
return fn(t.store)
}
type staleWorkerEventTestStore struct {
remote *staleWorkerEventRemoteSessions
attachments *staleWorkerEventAttachments
policies *staleWorkerEventPolicies
audit *staleWorkerEventAudit
}
func (s *staleWorkerEventTestStore) RemoteSessions() RemoteSessionRepository { return s.remote }
func (s *staleWorkerEventTestStore) SessionAttachments() SessionAttachmentRepository {
return s.attachments
}
func (s *staleWorkerEventTestStore) ResourcePolicies() ResourcePolicyRepository { return s.policies }
func (s *staleWorkerEventTestStore) ResourceRuntime() ResourceRuntimeRepository {
return staleWorkerEventResourceRuntime{}
}
func (s *staleWorkerEventTestStore) AuditEvents() AuditEventRepository { return s.audit }
func (s *staleWorkerEventTestStore) Access() AccessRepository { return staleWorkerEventAccess{} }
type staleWorkerEventRemoteSessions struct {
sessions map[string]RemoteSession
updateCount int
}
func (r *staleWorkerEventRemoteSessions) Create(_ context.Context, session RemoteSession) error {
r.sessions[session.ID] = session
return nil
}
func (r *staleWorkerEventRemoteSessions) GetByID(_ context.Context, sessionID string) (*RemoteSession, error) {
session, ok := r.sessions[sessionID]
if !ok {
return nil, nil
}
return &session, nil
}
func (r *staleWorkerEventRemoteSessions) GetByIDForUpdate(ctx context.Context, sessionID string) (*RemoteSession, error) {
return r.GetByID(ctx, sessionID)
}
func (r *staleWorkerEventRemoteSessions) ListByController(_ context.Context, _ string) ([]RemoteSession, error) {
return nil, nil
}
func (r *staleWorkerEventRemoteSessions) CountLiveByResource(_ context.Context, _ string) (int, error) {
return 0, nil
}
func (r *staleWorkerEventRemoteSessions) ListDetachedExpired(_ context.Context, _ time.Time, _ int) ([]RemoteSession, error) {
return nil, nil
}
func (r *staleWorkerEventRemoteSessions) UpdateState(_ context.Context, params UpdateRemoteSessionStateParams) error {
session := r.sessions[params.RemoteSessionID]
session.State = params.State
session.WorkerID = params.WorkerID
session.DetachDeadlineAt = params.DetachDeadlineAt
session.LastHeartbeatAt = params.LastHeartbeatAt
session.TakeoverVersion = params.TakeoverVersion
session.UpdatedAt = params.UpdatedAt
r.sessions[params.RemoteSessionID] = session
r.updateCount++
return nil
}
type staleWorkerEventAttachments struct {
items map[string]SessionAttachment
}
func (r *staleWorkerEventAttachments) Create(_ context.Context, attachment SessionAttachment) error {
r.items[attachment.ID] = attachment
return nil
}
func (r *staleWorkerEventAttachments) GetByID(_ context.Context, attachmentID string) (*SessionAttachment, error) {
attachment, ok := r.items[attachmentID]
if !ok {
return nil, nil
}
return &attachment, nil
}
func (r *staleWorkerEventAttachments) GetByIDForUpdate(ctx context.Context, attachmentID string) (*SessionAttachment, error) {
return r.GetByID(ctx, attachmentID)
}
func (r *staleWorkerEventAttachments) ListByRemoteSession(_ context.Context, remoteSessionID string) ([]SessionAttachment, error) {
attachments := make([]SessionAttachment, 0)
for _, attachment := range r.items {
if attachment.RemoteSessionID == remoteSessionID {
attachments = append(attachments, attachment)
}
}
return attachments, nil
}
func (r *staleWorkerEventAttachments) ListActiveByRemoteSessionForUpdate(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error) {
return r.ListByRemoteSession(ctx, remoteSessionID)
}
func (r *staleWorkerEventAttachments) UpdateState(_ context.Context, params UpdateSessionAttachmentStateParams) error {
attachment := r.items[params.AttachmentID]
attachment.State = params.State
attachment.DetachedAt = params.DetachedAt
attachment.LastInputAt = params.LastInputAt
attachment.UpdatedAt = params.UpdatedAt
r.items[params.AttachmentID] = attachment
return nil
}
func (r *staleWorkerEventAttachments) Supersede(_ context.Context, params SupersedeAttachmentParams) error {
attachment := r.items[params.PreviousAttachmentID]
attachment.State = AttachmentStateSuperseded
attachment.SupersededBy = &params.NextAttachmentID
attachment.DetachedAt = &params.DetachedAt
attachment.UpdatedAt = params.UpdatedAt
r.items[params.PreviousAttachmentID] = attachment
return nil
}
type staleWorkerEventPolicies struct{}
func (r *staleWorkerEventPolicies) GetByResourceID(_ context.Context, _ string) (*ResourcePolicy, error) {
return nil, nil
}
func (r *staleWorkerEventPolicies) Upsert(_ context.Context, _ ResourcePolicy) error {
return nil
}
type staleWorkerEventAudit struct {
createCount int
}
func (r *staleWorkerEventAudit) Create(_ context.Context, _ AuditEvent) error {
r.createCount++
return nil
}
type staleWorkerEventResourceRuntime struct{}
func (staleWorkerEventResourceRuntime) GetByID(_ context.Context, _ string) (*ResourceRuntimeSpec, error) {
return nil, nil
}
type staleWorkerEventAccess struct{}
func (staleWorkerEventAccess) IsTrustedDevice(_ context.Context, _, _ string) (bool, error) {
return false, nil
}
func (staleWorkerEventAccess) GetPlatformRole(_ context.Context, _ string) (string, error) {
return "", nil
}
func (staleWorkerEventAccess) GetOrganizationRole(_ context.Context, _, _ string) (string, bool, error) {
return "", false, nil
}
type staleWorkerEventLiveState struct {
sessions map[string]*LiveSessionState
upsertCount int
}
func (s *staleWorkerEventLiveState) UpsertSession(_ context.Context, state LiveSessionState) error {
copied := state
s.sessions[state.SessionID] = &copied
s.upsertCount++
return nil
}
func (s *staleWorkerEventLiveState) GetSession(_ context.Context, sessionID string) (*LiveSessionState, error) {
state := s.sessions[sessionID]
if state == nil {
return nil, nil
}
copied := *state
return &copied, nil
}
func (s *staleWorkerEventLiveState) DeleteSession(_ context.Context, sessionID string) error {
delete(s.sessions, sessionID)
return nil
}
func (s *staleWorkerEventLiveState) BindController(_ context.Context, _ sessioncontracts.ControllerBinding, _ time.Duration) error {
return nil
}
func (s *staleWorkerEventLiveState) GetControllerBinding(_ context.Context, _ string) (*sessioncontracts.ControllerBinding, error) {
return nil, nil
}
func (s *staleWorkerEventLiveState) ClearControllerBinding(_ context.Context, _ string) error {
return nil
}
func (s *staleWorkerEventLiveState) StoreAttachToken(_ context.Context, _ sessioncontracts.AttachTokenClaims, _ time.Duration) error {
return nil
}
func (s *staleWorkerEventLiveState) ConsumeAttachToken(_ context.Context, _ string) (*sessioncontracts.AttachTokenClaims, error) {
return nil, nil
}
func (s *staleWorkerEventLiveState) TouchAttachmentHeartbeat(_ context.Context, _, _ string, _ time.Duration) error {
return nil
}
func (s *staleWorkerEventLiveState) UpdateWorkerRoute(_ context.Context, _ WorkerRoute, _ time.Duration) error {
return nil
}
func (s *staleWorkerEventLiveState) GetWorkerRoute(_ context.Context, _ string) (*WorkerRoute, error) {
return nil, nil
}
func (s *staleWorkerEventLiveState) DeleteWorkerRoute(_ context.Context, _ string) error {
return nil
}
type staleWorkerEventOrchestrator struct {
releaseCount int
}
func (o *staleWorkerEventOrchestrator) Reserve(_ context.Context, _ workercontracts.AttachRequest) (*workercontracts.WorkerLease, error) {
return nil, nil
}
func (o *staleWorkerEventOrchestrator) GetSessionLease(_ context.Context, _ string) (*workercontracts.WorkerLease, error) {
return nil, nil
}
func (o *staleWorkerEventOrchestrator) ReleaseSessionLease(_ context.Context, _ string) error {
o.releaseCount++
return nil
}
func (o *staleWorkerEventOrchestrator) PrepareAttachment(_ context.Context, _ RemoteSession, _ SessionAttachment, _ map[string]any) error {
return nil
}
func (o *staleWorkerEventOrchestrator) NotifyDetachment(_ context.Context, _ RemoteSession, _ SessionAttachment) error {
return nil
}
func (o *staleWorkerEventOrchestrator) TerminateRemoteSession(_ context.Context, _, _ string) error {
return nil
}
func (o *staleWorkerEventOrchestrator) ValidateSessionRuntime(_ context.Context, _, _ string) (bool, string, error) {
return true, "", nil
}
@@ -0,0 +1,44 @@
package sessionbroker
import (
"fmt"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
)
var allowedTransitions = map[sessioncontracts.State]map[sessioncontracts.State]struct{}{
sessioncontracts.StateStarting: {
sessioncontracts.StateActive: {},
sessioncontracts.StateFailed: {},
sessioncontracts.StateTerminated: {},
},
sessioncontracts.StateActive: {
sessioncontracts.StateDetached: {},
sessioncontracts.StateReconnecting: {},
sessioncontracts.StateFailed: {},
sessioncontracts.StateTerminated: {},
},
sessioncontracts.StateDetached: {
sessioncontracts.StateReconnecting: {},
sessioncontracts.StateTerminated: {},
sessioncontracts.StateFailed: {},
},
sessioncontracts.StateReconnecting: {
sessioncontracts.StateActive: {},
sessioncontracts.StateDetached: {},
sessioncontracts.StateFailed: {},
sessioncontracts.StateTerminated: {},
},
}
func validateTransition(from, to sessioncontracts.State) error {
if from == to {
return nil
}
if allowed, ok := allowedTransitions[from]; ok {
if _, ok := allowed[to]; ok {
return nil
}
}
return fmt.Errorf("invalid session state transition: %s -> %s", from, to)
}
File diff suppressed because it is too large Load Diff
+29
View File
@@ -0,0 +1,29 @@
package worker
type SessionEvent struct {
Type string `json:"type"`
SessionID string `json:"session_id"`
WorkerID string `json:"worker_id"`
Payload map[string]any `json:"payload,omitempty"`
}
const (
SessionEventConnected = "session_connected"
SessionEventHeartbeat = "session_heartbeat"
SessionEventFailed = "session_failed"
SessionEventTerminated = "session_terminated"
SessionEventDisplayReady = "session_display_ready"
SessionEventRenderReady = "session_render_ready"
SessionEventRenderDirty = "session_render_dirty"
SessionEventRenderResized = "session_render_resized"
SessionEventCursorUpdated = "session_cursor_updated"
SessionEventFrame = "session_frame"
SessionEventClipboardText = "session_clipboard_text"
SessionEventFileUploaded = "session_file_upload_completed"
SessionEventFileDownloadAvailable = "session_file_download_available"
SessionEventFileDownloadChunk = "session_file_download_chunk"
SessionEventFileDownloadProgress = "session_file_download_progress"
SessionEventFileDownloadCompleted = "session_file_download_completed"
SessionEventFileDownloadFailed = "session_file_download_failed"
SessionEventFileDownloadBlocked = "session_file_download_blocked"
)
@@ -0,0 +1,52 @@
package worker
import (
"context"
"errors"
"time"
"github.com/example/remote-access-platform/backend/internal/modules/sessionbroker"
)
type LeaseMonitor struct {
service *Service
broker *sessionbroker.Service
interval time.Duration
}
func NewLeaseMonitor(service *Service, broker *sessionbroker.Service, interval time.Duration) *LeaseMonitor {
if interval <= 0 {
interval = 15 * time.Second
}
return &LeaseMonitor{
service: service,
broker: broker,
interval: interval,
}
}
func (m *LeaseMonitor) Run(ctx context.Context) error {
ticker := time.NewTicker(m.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil
case <-ticker.C:
stale, err := m.service.RecoverStaleLeases(ctx)
if err != nil {
return err
}
for _, lease := range stale {
err := m.broker.MarkSessionFailed(ctx, sessionbroker.MarkSessionFailedCommand{
SessionID: lease.SessionID,
Reason: "worker_lease_stale_or_worker_missing",
})
if err != nil && !errors.Is(err, sessionbroker.ErrSessionNotFound) && !errors.Is(err, sessionbroker.ErrSessionNotTerminable) {
return err
}
}
}
}
}
@@ -0,0 +1,153 @@
package worker
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"time"
"github.com/redis/go-redis/v9"
"github.com/example/remote-access-platform/backend/internal/modules/sessionbroker"
)
type EventProcessor struct {
client *redis.Client
broker *sessionbroker.Service
}
func NewEventProcessor(client *redis.Client, broker *sessionbroker.Service) *EventProcessor {
return &EventProcessor{client: client, broker: broker}
}
func (p *EventProcessor) Run(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return nil
default:
}
result, err := p.client.BLPop(ctx, 5*time.Second, "worker:events").Result()
if err == redis.Nil {
continue
}
if err != nil {
if ctx.Err() != nil {
return nil
}
return fmt.Errorf("consume worker event: %w", err)
}
if len(result) != 2 {
continue
}
var event SessionEvent
if err := json.Unmarshal([]byte(result[1]), &event); err != nil {
continue
}
if err := p.handleEvent(ctx, event); err != nil {
return err
}
}
}
func (p *EventProcessor) handleEvent(ctx context.Context, event SessionEvent) error {
switch event.Type {
case SessionEventConnected, SessionEventDisplayReady:
if err := p.broker.HandleWorkerConnected(ctx, event.SessionID); err != nil {
return err
}
if len(event.Payload) > 0 {
if err := p.broker.UpdateWorkerRenderTelemetry(ctx, event.SessionID, event.Payload); err != nil && !errors.Is(err, sessionbroker.ErrSessionNotFound) {
return err
}
}
return nil
case SessionEventHeartbeat:
return p.broker.HandleWorkerHeartbeat(ctx, event.SessionID)
case SessionEventRenderReady, SessionEventRenderDirty, SessionEventRenderResized, SessionEventCursorUpdated, SessionEventFrame:
if len(event.Payload) == 0 {
return nil
}
if correlationID, _ := event.Payload["input_correlation_id"].(string); correlationID != "" {
slog.Info("worker frame event received",
"session_id", event.SessionID,
"worker_id", event.WorkerID,
"frame_sequence", event.Payload["frame_sequence"],
"correlation_id", correlationID,
"worker_frame_captured_at", event.Payload["worker_frame_captured_at"],
"trace_stage", "backend_frame_receive")
}
return p.updateRenderTelemetryWithRetry(ctx, event.SessionID, event.Payload)
case SessionEventClipboardText:
if len(event.Payload) == 0 {
return nil
}
slog.Info("worker clipboard event received",
"session_id", event.SessionID,
"worker_id", event.WorkerID,
"origin", event.Payload["origin"],
"sequence_id", event.Payload["sequence_id"],
"content_hash", event.Payload["content_hash"])
return p.broker.UpdateWorkerClipboardText(ctx, event.SessionID, event.Payload)
case SessionEventFileUploaded:
slog.Info("worker file upload completed",
"session_id", event.SessionID,
"worker_id", event.WorkerID,
"transfer_id", event.Payload["transfer_id"],
"file_name", event.Payload["file_name"],
"file_size", event.Payload["file_size"],
"content_hash", event.Payload["content_hash"],
"storage_path", event.Payload["storage_path"])
return nil
case SessionEventFileDownloadAvailable, SessionEventFileDownloadChunk, SessionEventFileDownloadProgress,
SessionEventFileDownloadCompleted, SessionEventFileDownloadFailed, SessionEventFileDownloadBlocked:
slog.Info("worker file download event received",
"session_id", event.SessionID,
"worker_id", event.WorkerID,
"event_type", event.Type,
"transfer_id", event.Payload["transfer_id"],
"file_id", event.Payload["file_id"],
"file_name", event.Payload["file_name"],
"status", event.Payload["status"])
return p.broker.UpdateWorkerFileDownloadEvent(ctx, event.SessionID, event.Type, event.Payload)
case SessionEventFailed:
reason, _ := event.Payload["reason"].(string)
err := p.broker.MarkSessionFailed(ctx, sessionbroker.MarkSessionFailedCommand{
SessionID: event.SessionID,
Reason: reason,
})
if errors.Is(err, sessionbroker.ErrSessionNotFound) || errors.Is(err, sessionbroker.ErrSessionNotTerminable) {
return nil
}
return err
case SessionEventTerminated:
reason, _ := event.Payload["reason"].(string)
err := p.broker.TerminateSession(ctx, sessionbroker.TerminateSessionCommand{
SessionID: event.SessionID,
Reason: reason,
})
if errors.Is(err, sessionbroker.ErrSessionNotFound) || errors.Is(err, sessionbroker.ErrSessionNotTerminable) {
return nil
}
return err
default:
return nil
}
}
func (p *EventProcessor) updateRenderTelemetryWithRetry(ctx context.Context, sessionID string, payload map[string]any) error {
var lastErr error
for attempt := 0; attempt < 10; attempt++ {
err := p.broker.UpdateWorkerRenderTelemetry(ctx, sessionID, payload)
if err == nil || errors.Is(err, sessionbroker.ErrSessionNotFound) {
return nil
}
lastErr = err
time.Sleep(100 * time.Millisecond)
}
return lastErr
}
@@ -0,0 +1,264 @@
package worker
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"time"
"github.com/redis/go-redis/v9"
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
)
type RedisStore struct {
client *redis.Client
}
func NewRedisStore(client *redis.Client) *RedisStore {
return &RedisStore{client: client}
}
func (s *RedisStore) RegisterWorker(ctx context.Context, registration workercontracts.WorkerRegistration, ttl time.Duration) error {
payload, err := json.Marshal(registration)
if err != nil {
return fmt.Errorf("marshal worker registration: %w", err)
}
pipe := s.client.TxPipeline()
pipe.Set(ctx, workerKey(registration.WorkerID), payload, ttl)
pipe.SAdd(ctx, workerSetKey(), registration.WorkerID)
_, err = pipe.Exec(ctx)
if err != nil {
return fmt.Errorf("register worker: %w", err)
}
return nil
}
func (s *RedisStore) TouchWorkerHeartbeat(ctx context.Context, heartbeat workercontracts.WorkerHeartbeat, ttl time.Duration) error {
registration, err := s.GetWorker(ctx, heartbeat.WorkerID)
if err != nil {
return err
}
if registration == nil {
registration = &workercontracts.WorkerRegistration{
WorkerID: heartbeat.WorkerID,
Protocol: workercontracts.ProtocolRDP,
}
}
registration.Status = heartbeat.Status
registration.LastHeartbeatAt = heartbeat.LastHeartbeatAt
return s.RegisterWorker(ctx, *registration, ttl)
}
func (s *RedisStore) ListWorkers(ctx context.Context) ([]workercontracts.WorkerRegistration, error) {
ids, err := s.client.SMembers(ctx, workerSetKey()).Result()
if err != nil {
return nil, fmt.Errorf("list worker ids: %w", err)
}
workers := make([]workercontracts.WorkerRegistration, 0, len(ids))
for _, id := range ids {
worker, err := s.GetWorker(ctx, id)
if err != nil {
return nil, err
}
if worker != nil {
workers = append(workers, *worker)
}
}
return workers, nil
}
func (s *RedisStore) GetWorker(ctx context.Context, workerID string) (*workercontracts.WorkerRegistration, error) {
payload, err := s.client.Get(ctx, workerKey(workerID)).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("get worker: %w", err)
}
var registration workercontracts.WorkerRegistration
if err := json.Unmarshal([]byte(payload), &registration); err != nil {
return nil, fmt.Errorf("decode worker registration: %w", err)
}
return &registration, nil
}
func (s *RedisStore) AcquireLease(ctx context.Context, lease workercontracts.WorkerLease, ttl time.Duration) error {
payload, err := json.Marshal(lease)
if err != nil {
return fmt.Errorf("marshal lease: %w", err)
}
ok, err := s.client.SetNX(ctx, leaseKey(lease.LeaseID), payload, ttl).Result()
if err != nil {
return fmt.Errorf("acquire lease: %w", err)
}
if !ok {
return fmt.Errorf("lease already exists")
}
pipe := s.client.TxPipeline()
pipe.SAdd(ctx, leaseSetKey(), lease.LeaseID)
pipe.Set(ctx, sessionLeaseKey(lease.SessionID), lease.LeaseID, ttl)
_, err = pipe.Exec(ctx)
if err != nil {
return fmt.Errorf("index lease: %w", err)
}
return nil
}
func (s *RedisStore) GetLease(ctx context.Context, leaseID string) (*workercontracts.WorkerLease, error) {
payload, err := s.client.Get(ctx, leaseKey(leaseID)).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("get lease: %w", err)
}
var lease workercontracts.WorkerLease
if err := json.Unmarshal([]byte(payload), &lease); err != nil {
return nil, fmt.Errorf("decode lease: %w", err)
}
return &lease, nil
}
func (s *RedisStore) GetLeaseBySession(ctx context.Context, sessionID string) (*workercontracts.WorkerLease, error) {
leaseID, err := s.client.Get(ctx, sessionLeaseKey(sessionID)).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("get lease by session: %w", err)
}
return s.GetLease(ctx, leaseID)
}
func (s *RedisStore) RenewLease(ctx context.Context, lease workercontracts.WorkerLease, ttl time.Duration) error {
payload, err := json.Marshal(lease)
if err != nil {
return fmt.Errorf("marshal lease renewal: %w", err)
}
pipe := s.client.TxPipeline()
pipe.Set(ctx, leaseKey(lease.LeaseID), payload, ttl)
pipe.Set(ctx, sessionLeaseKey(lease.SessionID), lease.LeaseID, ttl)
_, err = pipe.Exec(ctx)
if err != nil {
return fmt.Errorf("renew lease: %w", err)
}
return nil
}
func (s *RedisStore) ReleaseLease(ctx context.Context, leaseID string) error {
lease, err := s.GetLease(ctx, leaseID)
if err != nil {
return err
}
pipe := s.client.TxPipeline()
pipe.Del(ctx, leaseKey(leaseID))
pipe.SRem(ctx, leaseSetKey(), leaseID)
if lease != nil {
pipe.Del(ctx, sessionLeaseKey(lease.SessionID))
}
_, err = pipe.Exec(ctx)
if err != nil {
return fmt.Errorf("release lease: %w", err)
}
return nil
}
func (s *RedisStore) ListLeases(ctx context.Context) ([]workercontracts.WorkerLease, error) {
ids, err := s.client.SMembers(ctx, leaseSetKey()).Result()
if err != nil {
return nil, fmt.Errorf("list lease ids: %w", err)
}
leases := make([]workercontracts.WorkerLease, 0, len(ids))
for _, id := range ids {
lease, err := s.GetLease(ctx, id)
if err != nil {
return nil, err
}
if lease != nil {
leases = append(leases, *lease)
}
}
return leases, nil
}
func (s *RedisStore) AppendEnvelope(ctx context.Context, envelope workercontracts.RoutedEnvelope) error {
payload, err := json.Marshal(envelope)
if err != nil {
return fmt.Errorf("marshal routed envelope: %w", err)
}
key := workerQueueKey(envelope.SessionID)
if err := s.client.RPush(ctx, key, payload).Err(); err != nil {
return fmt.Errorf("append routed envelope: %w", err)
}
if envelope.Type == "input" {
correlationID, _ := envelope.Payload["correlation_id"].(string)
if correlationID != "" {
if length, err := s.client.LLen(ctx, key).Result(); err == nil {
slog.Info("worker queue length after input append",
"session_id", envelope.SessionID,
"attachment_id", envelope.AttachmentID,
"correlation_id", correlationID,
"queue_key", key,
"queue_length", length,
"trace_stage", "redis_queue_append")
}
}
}
return s.client.Expire(ctx, key, 10*time.Minute).Err()
}
func (s *RedisStore) AppendAssignment(ctx context.Context, workerID string, payload map[string]any) error {
encoded, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal worker assignment: %w", err)
}
if err := s.client.RPush(ctx, workerControlQueueKey(workerID), encoded).Err(); err != nil {
return fmt.Errorf("append worker assignment: %w", err)
}
return s.client.Expire(ctx, workerControlQueueKey(workerID), 10*time.Minute).Err()
}
func (s *RedisStore) AppendEvent(ctx context.Context, payload map[string]any) error {
encoded, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal worker event: %w", err)
}
if err := s.client.RPush(ctx, workerEventsKey(), encoded).Err(); err != nil {
return fmt.Errorf("append worker event: %w", err)
}
return s.client.Expire(ctx, workerEventsKey(), 10*time.Minute).Err()
}
func workerKey(workerID string) string {
return "worker:registration:" + workerID
}
func workerSetKey() string {
return "worker:registrations"
}
func leaseKey(leaseID string) string {
return "worker:lease:" + leaseID
}
func leaseSetKey() string {
return "worker:leases"
}
func sessionLeaseKey(sessionID string) string {
return "worker:session-lease:" + sessionID
}
func workerQueueKey(sessionID string) string {
return "worker:queue:" + sessionID
}
func workerControlQueueKey(workerID string) string {
return "worker:control:" + workerID
}
func workerEventsKey() string {
return "worker:events"
}
+274
View File
@@ -0,0 +1,274 @@
package worker
import (
"context"
"encoding/json"
"errors"
"fmt"
"slices"
"time"
"github.com/google/uuid"
"github.com/example/remote-access-platform/backend/internal/modules/sessionbroker"
"github.com/example/remote-access-platform/backend/internal/platform/module"
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
)
var ErrNoWorkerAvailable = errors.New("no worker available")
type Service struct {
cfg module.Config
store Store
now func() time.Time
}
func NewService(deps module.Dependencies, store Store) *Service {
return &Service{
cfg: deps.Config,
store: store,
now: time.Now,
}
}
func (s *Service) Register(ctx context.Context, registration workercontracts.WorkerRegistration) error {
if registration.WorkerID == "" {
return fmt.Errorf("worker id is required")
}
registration.LastHeartbeatAt = s.now().UTC()
return s.store.RegisterWorker(ctx, registration, s.cfg.Worker.HeartbeatTTL)
}
func (s *Service) Heartbeat(ctx context.Context, heartbeat workercontracts.WorkerHeartbeat) error {
heartbeat.LastHeartbeatAt = s.now().UTC()
return s.store.TouchWorkerHeartbeat(ctx, heartbeat, s.cfg.Worker.HeartbeatTTL)
}
func (s *Service) Reserve(ctx context.Context, request workercontracts.AttachRequest) (*workercontracts.WorkerLease, error) {
registration, err := s.reserveWorker(ctx, workercontracts.ProtocolRDP, request.RequiredCapabilities)
if err != nil {
return nil, err
}
return s.AcquireLease(ctx, registration.WorkerID, request)
}
func (s *Service) reserveWorker(ctx context.Context, protocol workercontracts.Protocol, capabilities []string) (*workercontracts.WorkerRegistration, error) {
workers, err := s.store.ListWorkers(ctx)
if err != nil {
return nil, err
}
now := s.now().UTC()
for _, worker := range workers {
if worker.Protocol != protocol || worker.Status != workercontracts.StatusOnline {
continue
}
if now.Sub(worker.LastHeartbeatAt) > s.cfg.Worker.StaleLeaseGracePeriod+s.cfg.Worker.HeartbeatTTL {
continue
}
if !hasCapabilities(worker.Capabilities, capabilities) {
continue
}
return &worker, nil
}
return nil, ErrNoWorkerAvailable
}
func (s *Service) AcquireLease(ctx context.Context, workerID string, request workercontracts.AttachRequest) (*workercontracts.WorkerLease, error) {
if request.SessionID == "" {
request.SessionID = uuid.NewString()
}
now := s.now().UTC()
lease := workercontracts.WorkerLease{
LeaseID: uuid.NewString(),
WorkerID: workerID,
Protocol: workercontracts.ProtocolRDP,
ResourceID: request.ResourceID,
SessionID: request.SessionID,
Capabilities: request.RequiredCapabilities,
ControlStream: "worker://control/" + workerID,
ExpiresAt: now.Add(s.cfg.Worker.LeaseTTL),
RenderQualityProfile: normalizeRenderQualityProfile(request.RenderQualityProfile),
}
if err := s.store.AcquireLease(ctx, lease, s.cfg.Worker.LeaseTTL); err != nil {
return nil, err
}
return &lease, nil
}
func (s *Service) GetSessionLease(ctx context.Context, sessionID string) (*workercontracts.WorkerLease, error) {
return s.store.GetLeaseBySession(ctx, sessionID)
}
func (s *Service) RenewLease(ctx context.Context, leaseID string) (*workercontracts.WorkerLease, error) {
lease, err := s.store.GetLease(ctx, leaseID)
if err != nil || lease == nil {
return lease, err
}
lease.ExpiresAt = s.now().UTC().Add(s.cfg.Worker.LeaseTTL)
if err := s.store.RenewLease(ctx, *lease, s.cfg.Worker.LeaseTTL); err != nil {
return nil, err
}
return lease, nil
}
func (s *Service) ReleaseLease(ctx context.Context, leaseID string) error {
return s.store.ReleaseLease(ctx, leaseID)
}
func (s *Service) ReleaseSessionLease(ctx context.Context, sessionID string) error {
lease, err := s.store.GetLeaseBySession(ctx, sessionID)
if err != nil || lease == nil {
return err
}
return s.store.ReleaseLease(ctx, lease.LeaseID)
}
func (s *Service) RecoverStaleLeases(ctx context.Context) ([]workercontracts.WorkerLease, error) {
leases, err := s.store.ListLeases(ctx)
if err != nil {
return nil, err
}
var stale []workercontracts.WorkerLease
now := s.now().UTC()
for _, lease := range leases {
registration, err := s.store.GetWorker(ctx, lease.WorkerID)
if err != nil {
return nil, err
}
if registration == nil || now.Sub(registration.LastHeartbeatAt) > s.cfg.Worker.StaleLeaseGracePeriod+s.cfg.Worker.HeartbeatTTL {
if err := s.store.ReleaseLease(ctx, lease.LeaseID); err != nil {
return nil, err
}
stale = append(stale, lease)
}
}
return stale, nil
}
func (s *Service) ValidateSessionRuntime(ctx context.Context, sessionID, workerID string) (bool, string, error) {
lease, err := s.store.GetLeaseBySession(ctx, sessionID)
if err != nil {
return false, "", err
}
if lease == nil {
return false, "worker_lease_missing", nil
}
if workerID != "" && lease.WorkerID != workerID {
_ = s.store.ReleaseLease(ctx, lease.LeaseID)
return false, "worker_binding_mismatch", nil
}
now := s.now().UTC()
if !lease.ExpiresAt.After(now) {
_ = s.store.ReleaseLease(ctx, lease.LeaseID)
return false, "worker_lease_expired", nil
}
registration, err := s.store.GetWorker(ctx, lease.WorkerID)
if err != nil {
return false, "", err
}
if registration == nil {
_ = s.store.ReleaseLease(ctx, lease.LeaseID)
return false, "worker_registration_missing", nil
}
if registration.Status != workercontracts.StatusOnline {
return false, "worker_not_online", nil
}
if now.Sub(registration.LastHeartbeatAt) > s.cfg.Worker.StaleLeaseGracePeriod+s.cfg.Worker.HeartbeatTTL {
_ = s.store.ReleaseLease(ctx, lease.LeaseID)
return false, "worker_heartbeat_stale", nil
}
return true, "", nil
}
func (s *Service) PublishControl(ctx context.Context, envelope workercontracts.RoutedEnvelope) error {
return s.store.AppendEnvelope(ctx, envelope)
}
func (s *Service) PublishInput(ctx context.Context, envelope workercontracts.RoutedEnvelope) error {
return s.store.AppendEnvelope(ctx, envelope)
}
func hasCapabilities(workerCaps, required []string) bool {
for _, capability := range required {
if !slices.Contains(workerCaps, capability) {
return false
}
}
return true
}
func (s *Service) PrepareAttachment(ctx context.Context, session sessionbroker.RemoteSession, attachment sessionbroker.SessionAttachment, runtimeMetadata map[string]any) error {
renderQualityProfile := normalizeRenderQualityProfile(session.RenderQualityProfile)
if renderQualityProfile == "balanced" {
renderQualityProfile = renderQualityProfileFromMetadata(session.Metadata)
}
if runtimeMetadata == nil {
runtimeMetadata = decodeMetadata(session.Metadata)
}
return s.store.AppendAssignment(ctx, session.WorkerID, map[string]any{
"type": "session_assignment",
"session_id": session.ID,
"worker_id": session.WorkerID,
"attachment_id": attachment.ID,
"user_id": attachment.UserID,
"device_id": attachment.DeviceID,
"takeover_of": attachment.TakeoverOf,
"state": session.State,
"render_quality_profile": renderQualityProfile,
"metadata": runtimeMetadata,
})
}
func (s *Service) NotifyDetachment(ctx context.Context, session sessionbroker.RemoteSession, attachment sessionbroker.SessionAttachment) error {
return s.PublishControl(ctx, workercontracts.RoutedEnvelope{
SessionID: session.ID,
AttachmentID: attachment.ID,
Type: "control",
Payload: map[string]any{
"action": "detach",
},
})
}
func (s *Service) TerminateRemoteSession(ctx context.Context, sessionID, attachmentID string) error {
return s.PublishControl(ctx, workercontracts.RoutedEnvelope{
SessionID: sessionID,
AttachmentID: attachmentID,
Type: "control",
Payload: map[string]any{
"action": "terminate",
},
})
}
func decodeMetadata(payload []byte) map[string]any {
var out map[string]any
if len(payload) == 0 {
return map[string]any{}
}
if err := json.Unmarshal(payload, &out); err != nil {
return map[string]any{}
}
return out
}
func normalizeRenderQualityProfile(profile string) string {
switch profile {
case "low_bandwidth", "balanced", "high_quality", "text_priority":
return profile
default:
return "balanced"
}
}
func renderQualityProfileFromMetadata(metadata []byte) string {
decoded := decodeMetadata(metadata)
resource, _ := decoded["resource"].(map[string]any)
if resource == nil {
return "balanced"
}
if profile, ok := resource["render_quality_profile"].(string); ok && profile != "" {
return normalizeRenderQualityProfile(profile)
}
return "balanced"
}
+24
View File
@@ -0,0 +1,24 @@
package worker
import (
"context"
"time"
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
)
type Store interface {
RegisterWorker(ctx context.Context, registration workercontracts.WorkerRegistration, ttl time.Duration) error
TouchWorkerHeartbeat(ctx context.Context, heartbeat workercontracts.WorkerHeartbeat, ttl time.Duration) error
ListWorkers(ctx context.Context) ([]workercontracts.WorkerRegistration, error)
GetWorker(ctx context.Context, workerID string) (*workercontracts.WorkerRegistration, error)
AcquireLease(ctx context.Context, lease workercontracts.WorkerLease, ttl time.Duration) error
GetLease(ctx context.Context, leaseID string) (*workercontracts.WorkerLease, error)
GetLeaseBySession(ctx context.Context, sessionID string) (*workercontracts.WorkerLease, error)
RenewLease(ctx context.Context, lease workercontracts.WorkerLease, ttl time.Duration) error
ReleaseLease(ctx context.Context, leaseID string) error
ListLeases(ctx context.Context) ([]workercontracts.WorkerLease, error)
AppendAssignment(ctx context.Context, workerID string, payload map[string]any) error
AppendEnvelope(ctx context.Context, envelope workercontracts.RoutedEnvelope) error
AppendEvent(ctx context.Context, payload map[string]any) error
}
@@ -0,0 +1,329 @@
package authority
import (
"context"
"crypto/ed25519"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/hex"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/example/remote-access-platform/backend/internal/platform/config"
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
)
const (
ModeStrict = "strict"
ModeLegacy = "legacy"
ActivationSchemaVersion = "rap.installation.activation.v1"
PlatformRoleUser = "user"
PlatformRoleAdmin = "platform_admin"
PlatformRoleRecoveryAdmin = "platform_recovery_admin"
)
var (
ErrInvalidAuthorityMode = errors.New("invalid installation authority mode")
ErrProductRootKeyNeeded = errors.New("product root public key is required")
ErrInvalidActivation = errors.New("invalid installation activation")
ErrInvalidGrant = errors.New("invalid platform role grant")
)
type ActivationPayload struct {
SchemaVersion string `json:"schema_version"`
InstallID string `json:"install_id"`
OwnerEmail string `json:"owner_email"`
PlatformRole string `json:"platform_role"`
IssuedAt time.Time `json:"issued_at"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
Nonce string `json:"nonce,omitempty"`
Environment string `json:"environment,omitempty"`
}
type Verifier struct {
mode string
rootPublicKey ed25519.PublicKey
rootFingerprint string
allowInsecureBootstrap bool
now func() time.Time
}
func NewVerifier(cfg config.InstallationConfig) (*Verifier, error) {
mode := strings.ToLower(strings.TrimSpace(cfg.AuthorityMode))
if mode == "" {
mode = ModeLegacy
}
verifier := &Verifier{
mode: mode,
allowInsecureBootstrap: cfg.AllowInsecureBootstrap,
now: time.Now,
}
switch mode {
case ModeLegacy:
return verifier, nil
case ModeStrict:
publicKey, err := decodeEd25519PublicKey(cfg.ProductRootPublicKeyBase64)
if err != nil {
return nil, err
}
verifier.rootPublicKey = publicKey
fingerprint := sha256.Sum256(publicKey)
verifier.rootFingerprint = hex.EncodeToString(fingerprint[:])
return verifier, nil
default:
return nil, fmt.Errorf("%w: %s", ErrInvalidAuthorityMode, mode)
}
}
func (v *Verifier) Mode() string {
if v == nil || v.mode == "" {
return ModeLegacy
}
return v.mode
}
func (v *Verifier) Strict() bool {
return v != nil && v.mode == ModeStrict
}
func (v *Verifier) AllowInsecureBootstrap() bool {
return v != nil && v.allowInsecureBootstrap
}
func (v *Verifier) RootFingerprint() string {
if v == nil {
return ""
}
return v.rootFingerprint
}
func (v *Verifier) VerifyActivation(payload json.RawMessage, signature string) (ActivationPayload, error) {
if v == nil || !v.Strict() {
return ActivationPayload{}, ErrProductRootKeyNeeded
}
activation, canonical, err := parseActivationPayload(payload)
if err != nil {
return ActivationPayload{}, err
}
if err := activation.validate(v.now().UTC()); err != nil {
return ActivationPayload{}, err
}
if err := v.verifySignature(canonical, signature); err != nil {
return ActivationPayload{}, fmt.Errorf("%w: %v", ErrInvalidActivation, err)
}
return activation, nil
}
func (v *Verifier) VerifyPlatformRoleGrant(payload json.RawMessage, signature, expectedInstallID, expectedEmail, expectedRole string) (ActivationPayload, error) {
activation, err := v.VerifyActivation(payload, signature)
if err != nil {
return ActivationPayload{}, fmt.Errorf("%w: %v", ErrInvalidGrant, err)
}
if activation.InstallID != strings.TrimSpace(expectedInstallID) {
return ActivationPayload{}, fmt.Errorf("%w: install_id mismatch", ErrInvalidGrant)
}
if !strings.EqualFold(activation.OwnerEmail, strings.TrimSpace(expectedEmail)) {
return ActivationPayload{}, fmt.Errorf("%w: owner_email mismatch", ErrInvalidGrant)
}
if activation.PlatformRole != strings.TrimSpace(expectedRole) {
return ActivationPayload{}, fmt.Errorf("%w: platform_role mismatch", ErrInvalidGrant)
}
return activation, nil
}
func CanonicalJSON(raw json.RawMessage) ([]byte, error) {
if len(raw) == 0 {
return nil, fmt.Errorf("%w: empty payload", ErrInvalidActivation)
}
var value any
if err := json.Unmarshal(raw, &value); err != nil {
return nil, fmt.Errorf("%w: invalid json: %v", ErrInvalidActivation, err)
}
canonical, err := json.Marshal(value)
if err != nil {
return nil, fmt.Errorf("%w: canonical json: %v", ErrInvalidActivation, err)
}
return canonical, nil
}
func EffectivePlatformRole(ctx context.Context, db postgresplatform.DBTX, verifier *Verifier, userID string) (string, error) {
userID = strings.TrimSpace(userID)
if userID == "" {
return PlatformRoleUser, nil
}
if verifier == nil || !verifier.Strict() {
return legacyPlatformRole(ctx, db, userID)
}
var email string
if err := db.QueryRow(ctx, `SELECT email FROM users WHERE id = $1::uuid`, userID).Scan(&email); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return PlatformRoleUser, nil
}
return "", fmt.Errorf("get user email for platform grant: %w", err)
}
rows, err := db.Query(ctx, `
SELECT prg.role, prg.install_id, prg.grant_payload, prg.grant_signature
FROM platform_role_grants prg
JOIN installation_authority ia
ON ia.id = 1
AND ia.install_id = prg.install_id
AND ia.authority_state = 'active'
WHERE prg.user_id = $1::uuid
AND prg.revoked_at IS NULL
AND (prg.expires_at IS NULL OR prg.expires_at > NOW())
ORDER BY CASE prg.role
WHEN 'platform_recovery_admin' THEN 0
WHEN 'platform_admin' THEN 1
ELSE 2
END, prg.granted_at DESC
`, userID)
if err != nil {
return "", fmt.Errorf("query platform role grants: %w", err)
}
defer rows.Close()
bestRole := PlatformRoleUser
for rows.Next() {
var role, installID, signature string
var payload []byte
if err := rows.Scan(&role, &installID, &payload, &signature); err != nil {
return "", fmt.Errorf("scan platform role grant: %w", err)
}
if _, err := verifier.VerifyPlatformRoleGrant(json.RawMessage(payload), signature, installID, email, role); err != nil {
continue
}
if role == PlatformRoleRecoveryAdmin {
return role, nil
}
if role == PlatformRoleAdmin {
bestRole = role
}
}
if err := rows.Err(); err != nil {
return "", fmt.Errorf("iterate platform role grants: %w", err)
}
return bestRole, nil
}
func legacyPlatformRole(ctx context.Context, db postgresplatform.DBTX, userID string) (string, error) {
var role string
if err := db.QueryRow(ctx, `SELECT platform_role FROM users WHERE id = $1::uuid`, userID).Scan(&role); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return PlatformRoleUser, nil
}
return "", fmt.Errorf("get platform role: %w", err)
}
if role == "" {
return PlatformRoleUser, nil
}
return role, nil
}
func parseActivationPayload(raw json.RawMessage) (ActivationPayload, []byte, error) {
canonical, err := CanonicalJSON(raw)
if err != nil {
return ActivationPayload{}, nil, err
}
var activation ActivationPayload
if err := json.Unmarshal(canonical, &activation); err != nil {
return ActivationPayload{}, nil, fmt.Errorf("%w: decode activation: %v", ErrInvalidActivation, err)
}
activation.SchemaVersion = strings.TrimSpace(activation.SchemaVersion)
activation.InstallID = strings.TrimSpace(activation.InstallID)
activation.OwnerEmail = strings.ToLower(strings.TrimSpace(activation.OwnerEmail))
activation.PlatformRole = strings.TrimSpace(activation.PlatformRole)
activation.Nonce = strings.TrimSpace(activation.Nonce)
activation.Environment = strings.TrimSpace(activation.Environment)
return activation, canonical, nil
}
func (p ActivationPayload) validate(now time.Time) error {
if p.SchemaVersion != ActivationSchemaVersion {
return fmt.Errorf("%w: schema_version must be %s", ErrInvalidActivation, ActivationSchemaVersion)
}
if p.InstallID == "" {
return fmt.Errorf("%w: install_id is required", ErrInvalidActivation)
}
if p.OwnerEmail == "" || !strings.Contains(p.OwnerEmail, "@") {
return fmt.Errorf("%w: owner_email is required", ErrInvalidActivation)
}
switch p.PlatformRole {
case PlatformRoleAdmin, PlatformRoleRecoveryAdmin:
default:
return fmt.Errorf("%w: platform_role must be platform_admin or platform_recovery_admin", ErrInvalidActivation)
}
if p.IssuedAt.IsZero() {
return fmt.Errorf("%w: issued_at is required", ErrInvalidActivation)
}
if p.IssuedAt.After(now.Add(5 * time.Minute)) {
return fmt.Errorf("%w: issued_at is too far in the future", ErrInvalidActivation)
}
if p.ExpiresAt != nil && !p.ExpiresAt.After(now) {
return fmt.Errorf("%w: activation expired", ErrInvalidActivation)
}
return nil
}
func (v *Verifier) verifySignature(payload []byte, signatureText string) error {
signature, err := decodeBase64(strings.TrimSpace(signatureText))
if err != nil {
return fmt.Errorf("signature must be base64 encoded: %w", err)
}
if len(signature) != ed25519.SignatureSize {
return fmt.Errorf("signature must decode to %d bytes", ed25519.SignatureSize)
}
if !ed25519.Verify(v.rootPublicKey, payload, signature) {
return errors.New("signature verification failed")
}
return nil
}
func decodeEd25519PublicKey(value string) (ed25519.PublicKey, error) {
value = strings.TrimSpace(value)
if value == "" {
return nil, ErrProductRootKeyNeeded
}
if block, _ := pem.Decode([]byte(value)); block != nil {
parsed, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parse product root public key PEM: %w", err)
}
publicKey, ok := parsed.(ed25519.PublicKey)
if !ok {
return nil, fmt.Errorf("product root public key PEM must contain an Ed25519 public key")
}
return publicKey, nil
}
decoded, err := decodeBase64(value)
if err != nil {
return nil, fmt.Errorf("product root public key must be base64 encoded: %w", err)
}
if len(decoded) != ed25519.PublicKeySize {
return nil, fmt.Errorf("product root public key must decode to %d bytes", ed25519.PublicKeySize)
}
return ed25519.PublicKey(decoded), nil
}
func decodeBase64(value string) ([]byte, error) {
decoded, err := base64.StdEncoding.DecodeString(value)
if err == nil {
return decoded, nil
}
decoded, rawErr := base64.RawStdEncoding.DecodeString(value)
if rawErr == nil {
return decoded, nil
}
return nil, err
}
@@ -0,0 +1,85 @@
package authority
import (
"crypto/ed25519"
"encoding/base64"
"encoding/json"
"strings"
"testing"
"time"
"github.com/example/remote-access-platform/backend/internal/platform/config"
)
func TestVerifierAcceptsSignedActivation(t *testing.T) {
publicKey, privateKey, err := ed25519.GenerateKey(nil)
if err != nil {
t.Fatalf("generate key: %v", err)
}
verifier, err := NewVerifier(config.InstallationConfig{
AuthorityMode: ModeStrict,
ProductRootPublicKeyBase64: base64.StdEncoding.EncodeToString(publicKey),
})
if err != nil {
t.Fatalf("NewVerifier: %v", err)
}
verifier.now = func() time.Time { return time.Date(2026, 4, 28, 12, 0, 0, 0, time.UTC) }
payload := json.RawMessage(`{
"platform_role":"platform_admin",
"owner_email":"Owner@Example.test",
"install_id":"install-1",
"schema_version":"rap.installation.activation.v1",
"issued_at":"2026-04-28T11:00:00Z",
"expires_at":"2026-04-29T11:00:00Z"
}`)
canonical, err := CanonicalJSON(payload)
if err != nil {
t.Fatalf("CanonicalJSON: %v", err)
}
signature := base64.StdEncoding.EncodeToString(ed25519.Sign(privateKey, canonical))
activation, err := verifier.VerifyActivation(payload, signature)
if err != nil {
t.Fatalf("VerifyActivation: %v", err)
}
if activation.OwnerEmail != "owner@example.test" || activation.PlatformRole != PlatformRoleAdmin {
t.Fatalf("unexpected activation: %+v", activation)
}
if verifier.RootFingerprint() == "" {
t.Fatal("expected root fingerprint")
}
}
func TestVerifierRejectsTamperedActivation(t *testing.T) {
publicKey, privateKey, err := ed25519.GenerateKey(nil)
if err != nil {
t.Fatalf("generate key: %v", err)
}
verifier, err := NewVerifier(config.InstallationConfig{
AuthorityMode: ModeStrict,
ProductRootPublicKeyBase64: base64.StdEncoding.EncodeToString(publicKey),
})
if err != nil {
t.Fatalf("NewVerifier: %v", err)
}
verifier.now = func() time.Time { return time.Date(2026, 4, 28, 12, 0, 0, 0, time.UTC) }
payload := json.RawMessage(`{
"schema_version":"rap.installation.activation.v1",
"install_id":"install-1",
"owner_email":"owner@example.test",
"platform_role":"platform_admin",
"issued_at":"2026-04-28T11:00:00Z"
}`)
canonical, err := CanonicalJSON(payload)
if err != nil {
t.Fatalf("CanonicalJSON: %v", err)
}
signature := base64.StdEncoding.EncodeToString(ed25519.Sign(privateKey, canonical))
tampered := json.RawMessage(strings.Replace(string(payload), "platform_admin", "platform_recovery_admin", 1))
if _, err := verifier.VerifyActivation(tampered, signature); err == nil {
t.Fatal("expected tampered activation to fail")
}
}
@@ -0,0 +1,186 @@
package clusterauth
import (
"crypto/ed25519"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
)
const (
AuthoritySchemaVersion = "rap.cluster_authority.v1"
SignatureSchemaVersion = "rap.cluster_authority.signature.v1"
AlgorithmEd25519 = "ed25519"
)
var (
ErrInvalidKey = errors.New("invalid cluster authority key")
ErrInvalidSignature = errors.New("invalid cluster authority signature")
ErrInvalidPayload = errors.New("invalid cluster authority payload")
)
type KeyPair struct {
PublicKeyB64 string
PrivateKeyB64 string
Fingerprint string
}
type Signature struct {
SchemaVersion string `json:"schema_version"`
Algorithm string `json:"algorithm"`
KeyFingerprint string `json:"key_fingerprint"`
Signature string `json:"signature"`
SignedAt time.Time `json:"signed_at"`
}
func GenerateKeyPair() (KeyPair, error) {
publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return KeyPair{}, err
}
fingerprint := Fingerprint(publicKey)
return KeyPair{
PublicKeyB64: base64.StdEncoding.EncodeToString(publicKey),
PrivateKeyB64: base64.StdEncoding.EncodeToString(privateKey),
Fingerprint: fingerprint,
}, nil
}
func Fingerprint(publicKey ed25519.PublicKey) string {
sum := sha256.Sum256(publicKey)
return "rap-ca-ed25519-" + hex.EncodeToString(sum[:16])
}
func FingerprintFromBase64(publicKeyB64 string) (string, error) {
publicKey, err := DecodePublicKey(publicKeyB64)
if err != nil {
return "", err
}
return Fingerprint(publicKey), nil
}
func SignRaw(privateKeyB64 string, payload json.RawMessage, signedAt time.Time) (Signature, error) {
privateKey, err := DecodePrivateKey(privateKeyB64)
if err != nil {
return Signature{}, err
}
canonical, err := CanonicalJSON(payload)
if err != nil {
return Signature{}, err
}
publicKey, ok := privateKey.Public().(ed25519.PublicKey)
if !ok {
return Signature{}, ErrInvalidKey
}
signature := ed25519.Sign(privateKey, canonical)
return Signature{
SchemaVersion: SignatureSchemaVersion,
Algorithm: AlgorithmEd25519,
KeyFingerprint: Fingerprint(publicKey),
Signature: base64.StdEncoding.EncodeToString(signature),
SignedAt: signedAt.UTC(),
}, nil
}
func SignPayload(privateKeyB64 string, payload any, signedAt time.Time) (json.RawMessage, Signature, error) {
raw, err := json.Marshal(payload)
if err != nil {
return nil, Signature{}, fmt.Errorf("%w: marshal: %v", ErrInvalidPayload, err)
}
signature, err := SignRaw(privateKeyB64, raw, signedAt)
if err != nil {
return nil, Signature{}, err
}
return json.RawMessage(raw), signature, nil
}
func VerifyRaw(publicKeyB64 string, payload json.RawMessage, signature Signature) error {
if signature.SchemaVersion != SignatureSchemaVersion {
return fmt.Errorf("%w: schema_version must be %s", ErrInvalidSignature, SignatureSchemaVersion)
}
if signature.Algorithm != AlgorithmEd25519 {
return fmt.Errorf("%w: algorithm must be %s", ErrInvalidSignature, AlgorithmEd25519)
}
publicKey, err := DecodePublicKey(publicKeyB64)
if err != nil {
return err
}
if signature.KeyFingerprint != Fingerprint(publicKey) {
return fmt.Errorf("%w: key fingerprint mismatch", ErrInvalidSignature)
}
canonical, err := CanonicalJSON(payload)
if err != nil {
return err
}
decodedSignature, err := decodeBase64(strings.TrimSpace(signature.Signature))
if err != nil || len(decodedSignature) != ed25519.SignatureSize {
return fmt.Errorf("%w: signature must be base64 ed25519 signature", ErrInvalidSignature)
}
if !ed25519.Verify(publicKey, canonical, decodedSignature) {
return ErrInvalidSignature
}
return nil
}
func CanonicalJSON(raw json.RawMessage) ([]byte, error) {
if len(raw) == 0 {
return nil, fmt.Errorf("%w: empty payload", ErrInvalidPayload)
}
var value any
if err := json.Unmarshal(raw, &value); err != nil {
return nil, fmt.Errorf("%w: invalid json: %v", ErrInvalidPayload, err)
}
canonical, err := json.Marshal(value)
if err != nil {
return nil, fmt.Errorf("%w: canonical json: %v", ErrInvalidPayload, err)
}
return canonical, nil
}
func HashRaw(raw json.RawMessage) (string, error) {
canonical, err := CanonicalJSON(raw)
if err != nil {
return "", err
}
sum := sha256.Sum256(canonical)
return hex.EncodeToString(sum[:]), nil
}
func DecodePublicKey(value string) (ed25519.PublicKey, error) {
decoded, err := decodeBase64(strings.TrimSpace(value))
if err != nil {
return nil, fmt.Errorf("%w: public key must be base64 encoded", ErrInvalidKey)
}
if len(decoded) != ed25519.PublicKeySize {
return nil, fmt.Errorf("%w: public key must decode to %d bytes", ErrInvalidKey, ed25519.PublicKeySize)
}
return ed25519.PublicKey(decoded), nil
}
func DecodePrivateKey(value string) (ed25519.PrivateKey, error) {
decoded, err := decodeBase64(strings.TrimSpace(value))
if err != nil {
return nil, fmt.Errorf("%w: private key must be base64 encoded", ErrInvalidKey)
}
if len(decoded) != ed25519.PrivateKeySize {
return nil, fmt.Errorf("%w: private key must decode to %d bytes", ErrInvalidKey, ed25519.PrivateKeySize)
}
return ed25519.PrivateKey(decoded), nil
}
func decodeBase64(value string) ([]byte, error) {
if value == "" {
return nil, errors.New("empty base64 value")
}
decoded, err := base64.StdEncoding.DecodeString(value)
if err == nil {
return decoded, nil
}
return base64.RawStdEncoding.DecodeString(value)
}
@@ -0,0 +1,44 @@
package clusterauth
import (
"encoding/json"
"errors"
"testing"
"time"
)
func TestSignAndVerifyRawPayload(t *testing.T) {
keys, err := GenerateKeyPair()
if err != nil {
t.Fatalf("GenerateKeyPair: %v", err)
}
payload := json.RawMessage(`{"cluster_id":"cluster-1","schema_version":"test.v1","value":1}`)
signature, err := SignRaw(keys.PrivateKeyB64, payload, time.Date(2026, 4, 28, 12, 0, 0, 0, time.UTC))
if err != nil {
t.Fatalf("SignRaw: %v", err)
}
if signature.KeyFingerprint != keys.Fingerprint {
t.Fatalf("fingerprint = %q, want %q", signature.KeyFingerprint, keys.Fingerprint)
}
if err := VerifyRaw(keys.PublicKeyB64, payload, signature); err != nil {
t.Fatalf("VerifyRaw: %v", err)
}
}
func TestVerifyRawRejectsTamperedPayload(t *testing.T) {
keys, err := GenerateKeyPair()
if err != nil {
t.Fatalf("GenerateKeyPair: %v", err)
}
payload := json.RawMessage(`{"cluster_id":"cluster-1","schema_version":"test.v1","value":1}`)
signature, err := SignRaw(keys.PrivateKeyB64, payload, time.Date(2026, 4, 28, 12, 0, 0, 0, time.UTC))
if err != nil {
t.Fatalf("SignRaw: %v", err)
}
tampered := json.RawMessage(`{"cluster_id":"cluster-1","schema_version":"test.v1","value":2}`)
if err := VerifyRaw(keys.PublicKeyB64, tampered, signature); !errors.Is(err, ErrInvalidSignature) {
t.Fatalf("err = %v, want ErrInvalidSignature", err)
}
}
+307
View File
@@ -0,0 +1,307 @@
package config
import (
"encoding/base64"
"fmt"
"os"
"strconv"
"strings"
"time"
)
type Config struct {
App AppConfig
HTTP HTTPConfig
Postgres PostgresConfig
Redis RedisConfig
Auth AuthConfig
Installation InstallationConfig
DataPlane DataPlaneConfig
Secret SecretConfig
Session SessionConfig
Worker WorkerConfig
WebSocket WebSocketConfig
}
type AppConfig struct {
Name string
Env string
}
type HTTPConfig struct {
Host string
Port int
ReadTimeout time.Duration
WriteTimeout time.Duration
IdleTimeout time.Duration
ShutdownTimeout time.Duration
}
type PostgresConfig struct {
DSN string
MaxConns int32
MinConns int32
ConnectTimeout time.Duration
}
type RedisConfig struct {
Addr string
Password string
DB int
DialTimeout time.Duration
}
type AuthConfig struct {
AccessTokenTTL time.Duration
RefreshTokenTTL time.Duration
Issuer string
AccessTokenSecret string
RefreshHashSecret string
}
type InstallationConfig struct {
AuthorityMode string
ProductRootPublicKeyBase64 string
ProductRootPublicKeyFile string
AllowInsecureBootstrap bool
}
type DataPlaneConfig struct {
TokenTTL time.Duration
TokenPrivateKeyPEM string
TokenPrivateKeyFile string
BackendGatewayURL string
DirectWorkerWSSURLTemplate string
DirectWorkerJSONRuntime bool
DirectWorkerBinaryRender bool
DirectWorkerTLSTrustMode string
DirectWorkerTLSCARef string
}
type SecretConfig struct {
EncryptionKeyBase64 string
EncryptionKeyFile string
EncryptionKeyID string
}
type SessionConfig struct {
HeartbeatTTL time.Duration
DetachGracePeriod time.Duration
AttachTokenTTL time.Duration
LiveStateTTL time.Duration
RecoveryBatchSize int
}
type WorkerConfig struct {
LeaseTTL time.Duration
HeartbeatTTL time.Duration
StaleLeaseGracePeriod time.Duration
}
type WebSocketConfig struct {
WriteTimeout time.Duration
PingInterval time.Duration
PongWait time.Duration
}
func Load() (Config, error) {
cfg := Config{
App: AppConfig{
Name: getEnv("APP_NAME", "rap-api"),
Env: getEnv("APP_ENV", "development"),
},
HTTP: HTTPConfig{
Host: getEnv("HTTP_HOST", "0.0.0.0"),
Port: getInt("HTTP_PORT", 8080),
ReadTimeout: getDuration("HTTP_READ_TIMEOUT", 15*time.Second),
WriteTimeout: getDuration("HTTP_WRITE_TIMEOUT", 15*time.Second),
IdleTimeout: getDuration("HTTP_IDLE_TIMEOUT", 60*time.Second),
ShutdownTimeout: getDuration("HTTP_SHUTDOWN_TIMEOUT", 10*time.Second),
},
Postgres: PostgresConfig{
DSN: getEnv("POSTGRES_DSN", ""),
MaxConns: int32(getInt("POSTGRES_MAX_CONNS", 20)),
MinConns: int32(getInt("POSTGRES_MIN_CONNS", 2)),
ConnectTimeout: getDuration("POSTGRES_CONNECT_TIMEOUT", 5*time.Second),
},
Redis: RedisConfig{
Addr: getEnv("REDIS_ADDR", "localhost:6379"),
Password: getEnv("REDIS_PASSWORD", ""),
DB: getInt("REDIS_DB", 0),
DialTimeout: getDuration("REDIS_DIAL_TIMEOUT", 5*time.Second),
},
Auth: AuthConfig{
AccessTokenTTL: getDuration("AUTH_ACCESS_TOKEN_TTL", 15*time.Minute),
RefreshTokenTTL: getDuration("AUTH_REFRESH_TOKEN_TTL", 30*24*time.Hour),
Issuer: getEnv("AUTH_ISSUER", "rap-api"),
AccessTokenSecret: getEnv("AUTH_ACCESS_TOKEN_SECRET", ""),
RefreshHashSecret: getEnv("AUTH_REFRESH_HASH_SECRET", ""),
},
Installation: InstallationConfig{
AuthorityMode: getEnv("INSTALLATION_AUTHORITY_MODE", ""),
ProductRootPublicKeyBase64: getEnv("INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_B64", ""),
ProductRootPublicKeyFile: getEnv("INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_FILE", ""),
AllowInsecureBootstrap: getBool("INSTALLATION_INSECURE_BOOTSTRAP_ENABLED", false),
},
DataPlane: DataPlaneConfig{
TokenTTL: getDuration("DATA_PLANE_TOKEN_TTL", 1*time.Minute),
TokenPrivateKeyPEM: getEnv("DATA_PLANE_TOKEN_PRIVATE_KEY_PEM", ""),
TokenPrivateKeyFile: getEnv("DATA_PLANE_TOKEN_PRIVATE_KEY_FILE", ""),
BackendGatewayURL: getEnv("DATA_PLANE_BACKEND_GATEWAY_URL", "/api/v1/gateway/ws"),
DirectWorkerWSSURLTemplate: getEnv("DATA_PLANE_DIRECT_WORKER_WSS_URL_TEMPLATE", ""),
DirectWorkerJSONRuntime: getBool("DATA_PLANE_DIRECT_WORKER_JSON_RUNTIME", false),
DirectWorkerBinaryRender: getBool("DATA_PLANE_DIRECT_WORKER_BINARY_RENDER", false),
DirectWorkerTLSTrustMode: getEnv("DATA_PLANE_DIRECT_WORKER_TLS_TRUST_MODE", "smoke_insecure"),
DirectWorkerTLSCARef: getEnv("DATA_PLANE_DIRECT_WORKER_TLS_CA_REF", ""),
},
Secret: SecretConfig{
EncryptionKeyBase64: getEnv("SECRET_ENCRYPTION_KEY_B64", ""),
EncryptionKeyFile: getEnv("SECRET_ENCRYPTION_KEY_FILE", ""),
EncryptionKeyID: getEnv("SECRET_ENCRYPTION_KEY_ID", "local-v1"),
},
Session: SessionConfig{
HeartbeatTTL: getDuration("SESSION_HEARTBEAT_TTL", 90*time.Second),
DetachGracePeriod: getDuration("SESSION_DETACH_GRACE_PERIOD", 30*time.Minute),
AttachTokenTTL: getDuration("SESSION_ATTACH_TOKEN_TTL", 2*time.Minute),
LiveStateTTL: getDuration("SESSION_LIVE_STATE_TTL", 2*time.Minute),
RecoveryBatchSize: getInt("SESSION_RECOVERY_BATCH_SIZE", 100),
},
Worker: WorkerConfig{
LeaseTTL: getDuration("WORKER_LEASE_TTL", 45*time.Second),
HeartbeatTTL: getDuration("WORKER_HEARTBEAT_TTL", 15*time.Second),
StaleLeaseGracePeriod: getDuration("WORKER_STALE_LEASE_GRACE_PERIOD", 30*time.Second),
},
WebSocket: WebSocketConfig{
WriteTimeout: getDuration("WEBSOCKET_WRITE_TIMEOUT", 10*time.Second),
PingInterval: getDuration("WEBSOCKET_PING_INTERVAL", 20*time.Second),
PongWait: getDuration("WEBSOCKET_PONG_WAIT", 40*time.Second),
},
}
if cfg.Postgres.DSN == "" {
return Config{}, fmt.Errorf("POSTGRES_DSN is required")
}
if cfg.Auth.AccessTokenSecret == "" {
return Config{}, fmt.Errorf("AUTH_ACCESS_TOKEN_SECRET is required")
}
if cfg.Auth.RefreshHashSecret == "" {
return Config{}, fmt.Errorf("AUTH_REFRESH_HASH_SECRET is required")
}
if cfg.Installation.ProductRootPublicKeyBase64 == "" && cfg.Installation.ProductRootPublicKeyFile != "" {
publicKey, err := os.ReadFile(cfg.Installation.ProductRootPublicKeyFile)
if err != nil {
return Config{}, fmt.Errorf("read INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_FILE: %w", err)
}
cfg.Installation.ProductRootPublicKeyBase64 = strings.TrimSpace(string(publicKey))
}
cfg.Installation.AuthorityMode = normalizeInstallationAuthorityMode(cfg.Installation.AuthorityMode, cfg.Installation.ProductRootPublicKeyBase64)
if isProductionEnv(cfg.App.Env) && cfg.Installation.AuthorityMode != "strict" {
return Config{}, fmt.Errorf("INSTALLATION_AUTHORITY_MODE=strict with INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_B64 or file is required in production")
}
if cfg.DataPlane.TokenPrivateKeyPEM == "" && cfg.DataPlane.TokenPrivateKeyFile != "" {
privateKey, err := os.ReadFile(cfg.DataPlane.TokenPrivateKeyFile)
if err != nil {
return Config{}, fmt.Errorf("read DATA_PLANE_TOKEN_PRIVATE_KEY_FILE: %w", err)
}
cfg.DataPlane.TokenPrivateKeyPEM = string(privateKey)
}
if cfg.Secret.EncryptionKeyBase64 == "" && cfg.Secret.EncryptionKeyFile != "" {
secretKey, err := os.ReadFile(cfg.Secret.EncryptionKeyFile)
if err != nil {
return Config{}, fmt.Errorf("read SECRET_ENCRYPTION_KEY_FILE: %w", err)
}
cfg.Secret.EncryptionKeyBase64 = strings.TrimSpace(string(secretKey))
}
if cfg.Secret.EncryptionKeyBase64 != "" {
decoded, err := base64.StdEncoding.DecodeString(cfg.Secret.EncryptionKeyBase64)
if err != nil {
if decodedRaw, rawErr := base64.RawStdEncoding.DecodeString(cfg.Secret.EncryptionKeyBase64); rawErr == nil {
decoded = decodedRaw
} else {
return Config{}, fmt.Errorf("SECRET_ENCRYPTION_KEY_B64 must be base64 encoded: %w", err)
}
}
if len(decoded) != 32 {
return Config{}, fmt.Errorf("SECRET_ENCRYPTION_KEY_B64 must decode to 32 bytes for AES-256-GCM")
}
}
if isProductionEnv(cfg.App.Env) && cfg.Secret.EncryptionKeyBase64 == "" {
return Config{}, fmt.Errorf("SECRET_ENCRYPTION_KEY_B64 or SECRET_ENCRYPTION_KEY_FILE is required in production")
}
return cfg, nil
}
func normalizeInstallationAuthorityMode(mode string, rootPublicKey string) string {
mode = strings.ToLower(strings.TrimSpace(mode))
switch mode {
case "strict", "legacy":
return mode
case "":
if strings.TrimSpace(rootPublicKey) != "" {
return "strict"
}
return "legacy"
default:
return mode
}
}
func isProductionEnv(appEnv string) bool {
switch strings.ToLower(strings.TrimSpace(appEnv)) {
case "production", "prod":
return true
default:
return false
}
}
func getEnv(key, fallback string) string {
if value := os.Getenv(key); value != "" {
return value
}
return fallback
}
func getInt(key string, fallback int) int {
value := os.Getenv(key)
if value == "" {
return fallback
}
parsed, err := strconv.Atoi(value)
if err != nil {
return fallback
}
return parsed
}
func getBool(key string, fallback bool) bool {
value := os.Getenv(key)
if value == "" {
return fallback
}
switch value {
case "1", "true", "TRUE", "yes", "on":
return true
case "0", "false", "FALSE", "no", "off":
return false
default:
return fallback
}
}
func getDuration(key string, fallback time.Duration) time.Duration {
value := os.Getenv(key)
if value == "" {
return fallback
}
parsed, err := time.ParseDuration(value)
if err != nil {
return fallback
}
return parsed
}
@@ -0,0 +1,20 @@
package httpserver
import (
"fmt"
"net/http"
"time"
"github.com/example/remote-access-platform/backend/internal/platform/config"
)
func New(cfg config.HTTPConfig, handler http.Handler) *http.Server {
return &http.Server{
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
Handler: handler,
ReadHeaderTimeout: 5 * time.Second,
ReadTimeout: cfg.ReadTimeout,
WriteTimeout: cfg.WriteTimeout,
IdleTimeout: cfg.IdleTimeout,
}
}
+45
View File
@@ -0,0 +1,45 @@
package httpx
import (
"encoding/json"
"net/http"
)
func WriteJSON(w http.ResponseWriter, status int, payload any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(payload)
}
func WriteError(w http.ResponseWriter, status int, message string) {
traceID := ensureTraceID(w)
WriteJSON(w, status, ErrorResponse{
Error: NewErrorMessage(status, message, nil, traceID),
})
}
func WriteErrorMessage(w http.ResponseWriter, status int, message any) {
traceID := ensureTraceID(w)
switch payload := message.(type) {
case string:
WriteJSON(w, status, ErrorResponse{
Error: NewErrorMessage(status, payload, nil, traceID),
})
case ErrorResponse:
payload.Error.TraceID = traceID
WriteJSON(w, status, payload)
case *ErrorResponse:
if payload == nil {
WriteJSON(w, status, ErrorResponse{
Error: NewErrorMessage(status, "", nil, traceID),
})
return
}
payload.Error.TraceID = traceID
WriteJSON(w, status, payload)
default:
WriteJSON(w, status, ErrorResponse{
Error: NewErrorMessage(status, "Request failed.", nil, traceID),
})
}
}
+131
View File
@@ -0,0 +1,131 @@
package httpx
import (
"net/http"
"strings"
"unicode"
"github.com/google/uuid"
messagecontracts "github.com/example/remote-access-platform/backend/pkg/contracts/message"
)
type ErrorResponse struct {
Error messagecontracts.Message `json:"error"`
}
func NewMessage(code, messageKey, fallbackMessage string, details map[string]any, traceID string) messagecontracts.Message {
if traceID == "" {
traceID = uuid.NewString()
}
if details == nil {
details = map[string]any{}
}
return messagecontracts.Message{
Code: code,
MessageKey: messageKey,
FallbackMessage: fallbackMessage,
Details: details,
TraceID: traceID,
}
}
func NewErrorMessage(status int, fallbackMessage string, details map[string]any, traceID string) messagecontracts.Message {
normalizedFallback, normalizedDetails := normalizeErrorFallback(status, fallbackMessage, details)
code := deriveErrorCode(status, normalizedFallback)
return NewMessage(code, "errors."+code, normalizedFallback, normalizedDetails, traceID)
}
func ensureTraceID(w http.ResponseWriter) string {
traceID := w.Header().Get("X-Trace-Id")
if traceID == "" {
traceID = uuid.NewString()
w.Header().Set("X-Trace-Id", traceID)
}
return traceID
}
func normalizeErrorFallback(status int, fallbackMessage string, details map[string]any) (string, map[string]any) {
if details == nil {
details = map[string]any{}
}
details["http_status"] = status
if status >= http.StatusInternalServerError {
return "An internal server error occurred.", details
}
trimmed := strings.TrimSpace(fallbackMessage)
switch strings.ToLower(trimmed) {
case "forbidden", "access denied":
return "Access denied.", details
}
if field, ok := extractRequiredField(trimmed); ok {
details["field"] = field
}
return trimmed, details
}
func deriveErrorCode(status int, fallbackMessage string) string {
switch strings.ToLower(strings.TrimSpace(fallbackMessage)) {
case "invalid credentials":
return "auth.invalid_credentials"
case "session expired. please sign in again.":
return "auth.session_expired"
case "access denied.":
return "common.access_denied"
}
statusPrefix := map[int]string{
http.StatusBadRequest: "bad_request",
http.StatusUnauthorized: "unauthorized",
http.StatusForbidden: "forbidden",
http.StatusNotFound: "not_found",
http.StatusConflict: "conflict",
http.StatusUnprocessableEntity: "unprocessable_entity",
http.StatusInternalServerError: "internal_server_error",
}[status]
if statusPrefix == "" {
statusPrefix = "http_" + strings.ReplaceAll(http.StatusText(status), " ", "_")
statusPrefix = strings.ToLower(statusPrefix)
}
slug := slugifyMessage(fallbackMessage)
if slug == "" {
slug = "message"
}
if status >= http.StatusInternalServerError {
return "common." + statusPrefix
}
return statusPrefix + "." + slug
}
func slugifyMessage(input string) string {
var builder strings.Builder
lastUnderscore := false
for _, r := range strings.ToLower(strings.TrimSpace(input)) {
if unicode.IsLetter(r) || unicode.IsDigit(r) {
builder.WriteRune(r)
lastUnderscore = false
continue
}
if !lastUnderscore {
builder.WriteRune('_')
lastUnderscore = true
}
}
return strings.Trim(builder.String(), "_")
}
func extractRequiredField(message string) (string, bool) {
const suffix = " is required"
if !strings.HasSuffix(strings.ToLower(message), suffix) {
return "", false
}
field := strings.TrimSpace(message[:len(message)-len(suffix)])
field = strings.ReplaceAll(field, " ", "_")
field = strings.ToLower(field)
return field, field != ""
}
@@ -0,0 +1,17 @@
package logging
import (
"log/slog"
"os"
)
func New(env string) *slog.Logger {
level := slog.LevelInfo
if env == "development" {
level = slog.LevelDebug
}
return slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
Level: level,
}))
}
@@ -0,0 +1,38 @@
package module
import (
"log/slog"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/redis/go-redis/v9"
"github.com/example/remote-access-platform/backend/internal/platform/config"
)
type Dependencies struct {
Config Config
Infra Infra
}
type Config struct {
App config.AppConfig
Auth config.AuthConfig
Installation config.InstallationConfig
DataPlane config.DataPlaneConfig
Secret config.SecretConfig
Session config.SessionConfig
Worker config.WorkerConfig
WebSocket config.WebSocketConfig
}
type Infra struct {
Logger *slog.Logger
DB *pgxpool.Pool
Redis *redis.Client
}
type Module interface {
Name() string
RegisterRoutes(router chi.Router)
}
@@ -0,0 +1,33 @@
package postgres
import (
"context"
"fmt"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/example/remote-access-platform/backend/internal/platform/config"
)
func Open(ctx context.Context, cfg config.PostgresConfig) (*pgxpool.Pool, error) {
poolConfig, err := pgxpool.ParseConfig(cfg.DSN)
if err != nil {
return nil, fmt.Errorf("parse postgres dsn: %w", err)
}
poolConfig.MaxConns = cfg.MaxConns
poolConfig.MinConns = cfg.MinConns
poolConfig.ConnConfig.ConnectTimeout = cfg.ConnectTimeout
pool, err := pgxpool.NewWithConfig(ctx, poolConfig)
if err != nil {
return nil, fmt.Errorf("create postgres pool: %w", err)
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
return nil, fmt.Errorf("ping postgres: %w", err)
}
return pool, nil
}
+34
View File
@@ -0,0 +1,34 @@
package postgres
import (
"context"
"fmt"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
)
type DBTX interface {
Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
}
func WithTransaction(ctx context.Context, pool *pgxpool.Pool, fn func(tx pgx.Tx) error) error {
tx, err := pool.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
return fmt.Errorf("begin transaction: %w", err)
}
if err := fn(tx); err != nil {
_ = tx.Rollback(ctx)
return err
}
if err := tx.Commit(ctx); err != nil {
return fmt.Errorf("commit transaction: %w", err)
}
return nil
}
+26
View File
@@ -0,0 +1,26 @@
package redis
import (
"context"
"fmt"
goredis "github.com/redis/go-redis/v9"
"github.com/example/remote-access-platform/backend/internal/platform/config"
)
func Open(ctx context.Context, cfg config.RedisConfig) (*goredis.Client, error) {
client := goredis.NewClient(&goredis.Options{
Addr: cfg.Addr,
Password: cfg.Password,
DB: cfg.DB,
DialTimeout: cfg.DialTimeout,
})
if err := client.Ping(ctx).Err(); err != nil {
_ = client.Close()
return nil, fmt.Errorf("ping redis: %w", err)
}
return client, nil
}
+220
View File
@@ -0,0 +1,220 @@
package runtime
import (
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"time"
"github.com/go-chi/chi/v5"
chimiddleware "github.com/go-chi/chi/v5/middleware"
"github.com/example/remote-access-platform/backend/internal/modules/auth"
"github.com/example/remote-access-platform/backend/internal/modules/cluster"
"github.com/example/remote-access-platform/backend/internal/modules/identitysource"
"github.com/example/remote-access-platform/backend/internal/modules/node"
"github.com/example/remote-access-platform/backend/internal/modules/nodeagent"
"github.com/example/remote-access-platform/backend/internal/modules/organization"
"github.com/example/remote-access-platform/backend/internal/modules/resource"
"github.com/example/remote-access-platform/backend/internal/modules/sessionbroker"
"github.com/example/remote-access-platform/backend/internal/modules/sessiongateway"
"github.com/example/remote-access-platform/backend/internal/modules/worker"
"github.com/example/remote-access-platform/backend/internal/platform/authority"
"github.com/example/remote-access-platform/backend/internal/platform/config"
"github.com/example/remote-access-platform/backend/internal/platform/httpserver"
"github.com/example/remote-access-platform/backend/internal/platform/logging"
"github.com/example/remote-access-platform/backend/internal/platform/module"
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
redisplatform "github.com/example/remote-access-platform/backend/internal/platform/redis"
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
)
type App struct {
cfg config.Config
logger *slog.Logger
httpServer *http.Server
workers []backgroundRunner
db closeFunc
redis closeFunc
}
type closeFunc func() error
type backgroundRunner func(context.Context) error
func NewApp(ctx context.Context) (*App, error) {
cfg, err := config.Load()
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
logger := logging.New(cfg.App.Env)
db, err := postgresplatform.Open(ctx, cfg.Postgres)
if err != nil {
return nil, err
}
redisClient, err := redisplatform.Open(ctx, cfg.Redis)
if err != nil {
db.Close()
return nil, err
}
authorityVerifier, err := authority.NewVerifier(cfg.Installation)
if err != nil {
redisClient.Close()
db.Close()
return nil, fmt.Errorf("create installation authority verifier: %w", err)
}
deps := module.Dependencies{
Config: module.Config{
App: cfg.App,
Auth: cfg.Auth,
Installation: cfg.Installation,
DataPlane: cfg.DataPlane,
Secret: cfg.Secret,
Session: cfg.Session,
Worker: cfg.Worker,
WebSocket: cfg.WebSocket,
},
Infra: module.Infra{
Logger: logger,
DB: db,
Redis: redisClient,
},
}
workerStore := worker.NewRedisStore(redisClient)
workerService := worker.NewService(deps, workerStore)
authStore := auth.NewPostgresStore(db)
authTx := auth.NewPostgresTransactor(db)
authService := auth.NewService(deps, authStore, authTx, authorityVerifier)
var resourceSecretStore *secrets.ResourceSecretStore
if cfg.Secret.EncryptionKeyBase64 != "" {
secretEncryptor, err := secrets.NewEncryptor(cfg.Secret.EncryptionKeyBase64, cfg.Secret.EncryptionKeyID)
if err != nil {
redisClient.Close()
db.Close()
return nil, fmt.Errorf("create resource secret encryptor: %w", err)
}
resourceSecretStore = secrets.NewResourceSecretStore(db, secretEncryptor)
}
brokerStore := sessionbroker.NewPostgresStore(db, authorityVerifier)
brokerTx := sessionbroker.NewPostgresTransactor(db, authorityVerifier)
liveStateStore := sessionbroker.NewRedisLiveStateStore(redisClient)
brokerService := sessionbroker.NewService(deps, brokerStore, brokerTx, liveStateStore, workerService, resourceSecretStore)
workerEvents := worker.NewEventProcessor(redisClient, brokerService)
leaseMonitor := worker.NewLeaseMonitor(workerService, brokerService, cfg.Worker.StaleLeaseGracePeriod)
brokerModule := sessionbroker.NewModule(brokerService)
authModule := auth.NewModule(deps, authService)
clusterModule := cluster.NewModule(deps, authorityVerifier)
organizationModule := organization.NewModule(deps)
identitySourceModule := identitysource.NewModule(deps)
resourceModule := resource.NewModule(deps, resourceSecretStore)
nodeModule := node.NewModule(deps)
nodeAgentModule := nodeagent.NewModule(deps)
sessionGatewayModule := sessiongateway.NewModule(deps, brokerModule.Service(), workerService)
router := buildRouter(
logger,
authModule,
clusterModule,
organizationModule,
identitySourceModule,
resourceModule,
brokerModule,
nodeModule,
nodeAgentModule,
sessionGatewayModule,
)
return &App{
cfg: cfg,
logger: logger,
httpServer: httpserver.New(cfg.HTTP, router),
workers: []backgroundRunner{workerEvents.Run, leaseMonitor.Run},
db: func() error {
db.Close()
return nil
},
redis: redisClient.Close,
}, nil
}
func (a *App) Run(ctx context.Context) error {
errCh := make(chan error, 1)
go func() {
a.logger.Info("http server starting", "addr", a.httpServer.Addr, "service", a.cfg.App.Name)
if err := a.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
errCh <- err
return
}
errCh <- nil
}()
for _, runner := range a.workers {
runner := runner
go func() {
if err := runner(ctx); err != nil {
errCh <- err
}
}()
}
select {
case <-ctx.Done():
a.logger.Info("shutdown signal received")
case err := <-errCh:
if err != nil {
return err
}
return nil
}
shutdownCtx, cancel := context.WithTimeout(context.Background(), a.cfg.HTTP.ShutdownTimeout)
defer cancel()
if err := a.httpServer.Shutdown(shutdownCtx); err != nil {
return fmt.Errorf("shutdown http server: %w", err)
}
if err := a.redis(); err != nil {
return fmt.Errorf("close redis: %w", err)
}
if err := a.db(); err != nil {
return fmt.Errorf("close postgres: %w", err)
}
a.logger.Info("app stopped", "at", time.Now().UTC())
return nil
}
func buildRouter(logger *slog.Logger, modules ...module.Module) http.Handler {
router := chi.NewRouter()
router.Use(chimiddleware.RequestID)
router.Use(chimiddleware.RealIP)
router.Use(chimiddleware.Recoverer)
router.Use(chimiddleware.Timeout(60 * time.Second))
router.Use(chimiddleware.Heartbeat("/healthz"))
router.Get("/readyz", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ready"))
})
router.Route("/api/v1", func(r chi.Router) {
for _, mod := range modules {
logger.Info("register module routes", "module", mod.Name())
mod.RegisterRoutes(r)
}
})
return router
}
@@ -0,0 +1,37 @@
package secrets
import (
"encoding/json"
"fmt"
)
type AssignmentSecretMergeResult struct {
Metadata map[string]any
Keys []string
}
func MergeResourceSecretIntoAssignmentMetadata(metadata map[string]any, payload json.RawMessage) (AssignmentSecretMergeResult, error) {
if metadata == nil {
metadata = map[string]any{}
}
var secretPayload map[string]any
if err := json.Unmarshal(payload, &secretPayload); err != nil {
return AssignmentSecretMergeResult{}, fmt.Errorf("decode resolved resource secret: %w", err)
}
resource, _ := metadata["resource"].(map[string]any)
if resource == nil {
resource = map[string]any{}
metadata["resource"] = resource
}
resourceMetadata, _ := resource["metadata"].(map[string]any)
if resourceMetadata == nil {
resourceMetadata = map[string]any{}
resource["metadata"] = resourceMetadata
}
keys := make([]string, 0, len(secretPayload))
for key, value := range secretPayload {
resourceMetadata[key] = value
keys = append(keys, key)
}
return AssignmentSecretMergeResult{Metadata: metadata, Keys: keys}, nil
}
@@ -0,0 +1,113 @@
package secrets
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"io"
"strings"
)
const AlgorithmAES256GCM = "AES-256-GCM"
var (
ErrSecretEncryptionKeyMissing = errors.New("secret encryption key is not configured")
ErrSecretPayloadInvalid = errors.New("secret payload must be a json object")
)
type Encryptor struct {
aead cipher.AEAD
keyID string
}
type EncryptedPayload struct {
Algorithm string
KeyID string
Nonce []byte
Ciphertext []byte
PayloadSHA256 string
}
func NewEncryptor(masterKeyBase64, keyID string) (*Encryptor, error) {
masterKeyBase64 = strings.TrimSpace(masterKeyBase64)
if masterKeyBase64 == "" {
return nil, ErrSecretEncryptionKeyMissing
}
key, err := base64.StdEncoding.DecodeString(masterKeyBase64)
if err != nil {
if rawKey, rawErr := base64.RawStdEncoding.DecodeString(masterKeyBase64); rawErr == nil {
key = rawKey
} else {
return nil, fmt.Errorf("decode secret encryption key: %w", err)
}
}
if len(key) != 32 {
return nil, fmt.Errorf("secret encryption key must decode to 32 bytes")
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("create secret cipher: %w", err)
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("create secret gcm: %w", err)
}
if strings.TrimSpace(keyID) == "" {
keyID = "local-v1"
}
return &Encryptor{aead: aead, keyID: keyID}, nil
}
func (e *Encryptor) KeyID() string {
if e == nil {
return ""
}
return e.keyID
}
func (e *Encryptor) Encrypt(plaintext, aad []byte) (EncryptedPayload, error) {
if e == nil {
return EncryptedPayload{}, ErrSecretEncryptionKeyMissing
}
nonce := make([]byte, e.aead.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return EncryptedPayload{}, fmt.Errorf("generate secret nonce: %w", err)
}
hash := sha256.Sum256(plaintext)
return EncryptedPayload{
Algorithm: AlgorithmAES256GCM,
KeyID: e.keyID,
Nonce: nonce,
Ciphertext: e.aead.Seal(nil, nonce, plaintext, aad),
PayloadSHA256: hex.EncodeToString(hash[:]),
}, nil
}
func (e *Encryptor) Decrypt(payload EncryptedPayload, aad []byte) ([]byte, error) {
if e == nil {
return nil, ErrSecretEncryptionKeyMissing
}
if payload.Algorithm != "" && payload.Algorithm != AlgorithmAES256GCM {
return nil, fmt.Errorf("unsupported secret algorithm %q", payload.Algorithm)
}
plaintext, err := e.aead.Open(nil, payload.Nonce, payload.Ciphertext, aad)
if err != nil {
return nil, fmt.Errorf("decrypt secret payload: %w", err)
}
return plaintext, nil
}
func ResourceSecretAAD(organizationID, resourceID, secretRef, protocol string) []byte {
return []byte(strings.Join([]string{
"rap-resource-secret-v1",
strings.TrimSpace(organizationID),
strings.TrimSpace(resourceID),
strings.TrimSpace(secretRef),
strings.ToLower(strings.TrimSpace(protocol)),
}, "|"))
}
@@ -0,0 +1,65 @@
package secrets
import (
"encoding/base64"
"encoding/json"
"testing"
)
func TestEncryptorRoundTrip(t *testing.T) {
key := base64.StdEncoding.EncodeToString([]byte("0123456789abcdef0123456789abcdef"))
encryptor, err := NewEncryptor(key, "test-key")
if err != nil {
t.Fatalf("NewEncryptor returned error: %v", err)
}
aad := ResourceSecretAAD("org-1", "resource-1", "rap-secret://test", "rdp")
encrypted, err := encryptor.Encrypt([]byte(`{"username":"user","password":"secret"}`), aad)
if err != nil {
t.Fatalf("Encrypt returned error: %v", err)
}
plaintext, err := encryptor.Decrypt(encrypted, aad)
if err != nil {
t.Fatalf("Decrypt returned error: %v", err)
}
if string(plaintext) != `{"username":"user","password":"secret"}` {
t.Fatalf("unexpected plaintext: %s", plaintext)
}
}
func TestEncryptorRejectsWrongAAD(t *testing.T) {
key := base64.StdEncoding.EncodeToString([]byte("0123456789abcdef0123456789abcdef"))
encryptor, err := NewEncryptor(key, "test-key")
if err != nil {
t.Fatalf("NewEncryptor returned error: %v", err)
}
encrypted, err := encryptor.Encrypt([]byte(`{"password":"secret"}`), ResourceSecretAAD("org-1", "resource-1", "ref", "rdp"))
if err != nil {
t.Fatalf("Encrypt returned error: %v", err)
}
if _, err := encryptor.Decrypt(encrypted, ResourceSecretAAD("org-2", "resource-1", "ref", "rdp")); err == nil {
t.Fatalf("expected decrypt with wrong aad to fail")
}
}
func TestMergeResourceSecretIntoAssignmentMetadata(t *testing.T) {
metadata := map[string]any{
"resource": map[string]any{
"id": "resource-1",
"metadata": map[string]any{
"rdp_host": "host",
},
},
}
merged, err := MergeResourceSecretIntoAssignmentMetadata(metadata, json.RawMessage(`{"username":"user","password":"secret","domain":"corp"}`))
if err != nil {
t.Fatalf("MergeResourceSecretIntoAssignmentMetadata returned error: %v", err)
}
resource := merged.Metadata["resource"].(map[string]any)
resourceMetadata := resource["metadata"].(map[string]any)
if resourceMetadata["rdp_host"] != "host" {
t.Fatalf("existing metadata was not preserved")
}
if resourceMetadata["username"] != "user" || resourceMetadata["password"] != "secret" || resourceMetadata["domain"] != "corp" {
t.Fatalf("secret payload was not merged: %#v", resourceMetadata)
}
}
@@ -0,0 +1,132 @@
package secrets
import (
"encoding/json"
"errors"
"fmt"
"slices"
"sort"
"strings"
)
var (
ErrPlaintextResourceCredentials = errors.New("plaintext resource credentials are not allowed in metadata in production")
ErrMissingResourceSecretRef = errors.New("secret_ref is required for this resource protocol in production")
)
var credentialKeyFragments = []string{
"accesstoken",
"clientsecret",
"credential",
"credentials",
"domain",
"password",
"privatekey",
"refreshtoken",
"secret",
"secrets",
"token",
"user",
"username",
}
var safeReferenceKeys = []string{
"certificateverificationmode",
"renderqualityprofile",
"secretref",
"secretreference",
"vaultref",
}
func ValidateResourceSecretReadiness(protocol string, secretRef *string, metadata json.RawMessage, appEnv string) error {
if !IsProductionEnv(appEnv) {
return nil
}
paths, err := PlaintextCredentialMetadataPaths(metadata)
if err != nil {
return err
}
if len(paths) > 0 {
return fmt.Errorf("%w: %s", ErrPlaintextResourceCredentials, strings.Join(paths, ", "))
}
if ResourceProtocolRequiresSecretRef(protocol) && (secretRef == nil || strings.TrimSpace(*secretRef) == "") {
return ErrMissingResourceSecretRef
}
return nil
}
func IsProductionEnv(appEnv string) bool {
switch strings.ToLower(strings.TrimSpace(appEnv)) {
case "prod", "production":
return true
default:
return false
}
}
func ResourceProtocolRequiresSecretRef(protocol string) bool {
switch strings.ToLower(strings.TrimSpace(protocol)) {
case "rdp", "vnc", "ssh":
return true
default:
return false
}
}
func PlaintextCredentialMetadataPaths(raw json.RawMessage) ([]string, error) {
if len(raw) == 0 {
return nil, nil
}
var value any
if err := json.Unmarshal(raw, &value); err != nil {
return nil, errors.New("metadata must be valid json")
}
metadata, ok := value.(map[string]any)
if !ok {
return nil, errors.New("metadata must be a json object")
}
var paths []string
collectCredentialPaths(metadata, "", &paths)
sort.Strings(paths)
return slices.Compact(paths), nil
}
func collectCredentialPaths(value any, prefix string, paths *[]string) {
switch typed := value.(type) {
case map[string]any:
for key, child := range typed {
path := key
if prefix != "" {
path = prefix + "." + key
}
if isCredentialMetadataKey(key) {
*paths = append(*paths, path)
}
collectCredentialPaths(child, path, paths)
}
case []any:
for index, child := range typed {
collectCredentialPaths(child, fmt.Sprintf("%s[%d]", prefix, index), paths)
}
}
}
func isCredentialMetadataKey(key string) bool {
normalized := normalizeMetadataKey(key)
if slices.Contains(safeReferenceKeys, normalized) {
return false
}
for _, fragment := range credentialKeyFragments {
if normalized == fragment || strings.HasSuffix(normalized, fragment) {
return true
}
}
return false
}
func normalizeMetadataKey(key string) string {
key = strings.ToLower(strings.TrimSpace(key))
replacer := strings.NewReplacer("_", "", "-", "", " ", "", ".", "")
return replacer.Replace(key)
}
@@ -0,0 +1,52 @@
package secrets
import (
"encoding/json"
"errors"
"slices"
"testing"
)
func TestValidateResourceSecretReadinessAllowsPlaintextInDevelopment(t *testing.T) {
metadata := json.RawMessage(`{"username":"m","password":"secret"}`)
if err := ValidateResourceSecretReadiness("rdp", nil, metadata, "development"); err != nil {
t.Fatalf("development metadata should remain allowed for smoke/dev: %v", err)
}
}
func TestValidateResourceSecretReadinessRejectsPlaintextCredentialsInProduction(t *testing.T) {
metadata := json.RawMessage(`{"rdp_host":"host","credentials":{"username":"m","password":"secret"}}`)
err := ValidateResourceSecretReadiness("rdp", stringPtr("vault://org/resource"), metadata, "production")
if !errors.Is(err, ErrPlaintextResourceCredentials) {
t.Fatalf("expected plaintext credential rejection, got %v", err)
}
paths, err := PlaintextCredentialMetadataPaths(metadata)
if err != nil {
t.Fatalf("metadata paths: %v", err)
}
for _, expected := range []string{"credentials", "credentials.password", "credentials.username"} {
if !slices.Contains(paths, expected) {
t.Fatalf("expected sensitive path %q in %v", expected, paths)
}
}
}
func TestValidateResourceSecretReadinessRequiresSecretRefForProductionRDP(t *testing.T) {
metadata := json.RawMessage(`{"rdp_host":"host","rdp_port":3389}`)
err := ValidateResourceSecretReadiness("rdp", nil, metadata, "production")
if !errors.Is(err, ErrMissingResourceSecretRef) {
t.Fatalf("expected missing secret_ref rejection, got %v", err)
}
}
func TestValidateResourceSecretReadinessAllowsProductionSecretRef(t *testing.T) {
metadata := json.RawMessage(`{"rdp_host":"host","rdp_port":3389,"secret_ref":"vault://org/resource"}`)
if err := ValidateResourceSecretReadiness("rdp", stringPtr("vault://org/resource"), metadata, "production"); err != nil {
t.Fatalf("production secret_ref metadata should be accepted: %v", err)
}
}
func stringPtr(value string) *string {
return &value
}
@@ -0,0 +1,259 @@
package secrets
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/jackc/pgx/v5"
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
)
var (
ErrResourceSecretNotFound = errors.New("resource secret not found")
ErrSecretAccessDenied = errors.New("resource secret access denied")
ErrSecretLeaseRequired = errors.New("resource secret resolution requires lease proof")
)
type ResourceSecretStore struct {
db postgresplatform.DBTX
encryptor *Encryptor
now func() time.Time
}
type ResourceSecretResolver interface {
ResolveForSession(ctx context.Context, req ResolveResourceSecretRequest) (*ResolvedResourceSecret, error)
}
type ResourceSecretDescriptor struct {
ID string `json:"id"`
OrganizationID string `json:"organization_id"`
ResourceID string `json:"resource_id"`
SecretRef string `json:"secret_ref"`
Protocol string `json:"protocol"`
Version int `json:"version"`
KeyID string `json:"key_id"`
Algorithm string `json:"algorithm"`
Metadata json.RawMessage `json:"metadata"`
CreatedAt time.Time `json:"created_at"`
RotatedAt *time.Time `json:"rotated_at,omitempty"`
}
type UpsertResourceSecretCommand struct {
OrganizationID string
ResourceID string
Protocol string
SecretRef string
Payload json.RawMessage
Metadata json.RawMessage
ActorUserID string
}
type ResolveResourceSecretRequest struct {
SecretRef string
OrganizationID string
ResourceID string
SessionID string
WorkerID string
LeaseID string
}
type ResolvedResourceSecret struct {
Descriptor ResourceSecretDescriptor
Payload json.RawMessage
}
func NewResourceSecretStore(db postgresplatform.DBTX, encryptor *Encryptor) *ResourceSecretStore {
return &ResourceSecretStore{db: db, encryptor: encryptor, now: time.Now}
}
func (s *ResourceSecretStore) WithDB(db postgresplatform.DBTX) *ResourceSecretStore {
if s == nil {
return nil
}
return &ResourceSecretStore{db: db, encryptor: s.encryptor, now: s.now}
}
func DefaultResourceSecretRef(organizationID, resourceID string) string {
return "rap-secret://org/" + strings.TrimSpace(organizationID) + "/resources/" + strings.TrimSpace(resourceID) + "/primary"
}
func (s *ResourceSecretStore) Upsert(ctx context.Context, cmd UpsertResourceSecretCommand) (*ResourceSecretDescriptor, error) {
if s == nil || s.encryptor == nil {
return nil, ErrSecretEncryptionKeyMissing
}
payload, err := normalizeJSONObject(cmd.Payload)
if err != nil {
return nil, err
}
metadata, err := normalizeJSONObjectAllowEmpty(cmd.Metadata)
if err != nil {
return nil, err
}
secretRef := strings.TrimSpace(cmd.SecretRef)
if secretRef == "" {
secretRef = DefaultResourceSecretRef(cmd.OrganizationID, cmd.ResourceID)
}
protocol := strings.ToLower(strings.TrimSpace(cmd.Protocol))
encrypted, err := s.encryptor.Encrypt(payload, ResourceSecretAAD(cmd.OrganizationID, cmd.ResourceID, secretRef, protocol))
if err != nil {
return nil, err
}
now := s.now().UTC()
const query = `
INSERT INTO resource_secrets (
organization_id, resource_id, secret_ref, protocol, version, key_id,
algorithm, nonce, ciphertext, payload_sha256, metadata, created_by_user_id,
created_at, rotated_at
) VALUES (
$1::uuid, $2::uuid, $3, $4, 1, $5,
$6, $7, $8, $9, $10::jsonb, NULLIF($11, '')::uuid,
$12, NULL
)
ON CONFLICT (resource_id) DO UPDATE SET
secret_ref = EXCLUDED.secret_ref,
protocol = EXCLUDED.protocol,
version = resource_secrets.version + 1,
key_id = EXCLUDED.key_id,
algorithm = EXCLUDED.algorithm,
nonce = EXCLUDED.nonce,
ciphertext = EXCLUDED.ciphertext,
payload_sha256 = EXCLUDED.payload_sha256,
metadata = EXCLUDED.metadata,
created_by_user_id = EXCLUDED.created_by_user_id,
rotated_at = EXCLUDED.created_at
RETURNING id::text, organization_id::text, resource_id::text, secret_ref,
protocol, version, key_id, algorithm, metadata, created_at, rotated_at
`
var descriptor ResourceSecretDescriptor
if err := s.db.QueryRow(ctx, query,
cmd.OrganizationID,
cmd.ResourceID,
secretRef,
protocol,
encrypted.KeyID,
encrypted.Algorithm,
encrypted.Nonce,
encrypted.Ciphertext,
encrypted.PayloadSHA256,
metadata,
cmd.ActorUserID,
now,
).Scan(
&descriptor.ID,
&descriptor.OrganizationID,
&descriptor.ResourceID,
&descriptor.SecretRef,
&descriptor.Protocol,
&descriptor.Version,
&descriptor.KeyID,
&descriptor.Algorithm,
&descriptor.Metadata,
&descriptor.CreatedAt,
&descriptor.RotatedAt,
); err != nil {
return nil, fmt.Errorf("upsert resource secret: %w", err)
}
return &descriptor, nil
}
func (s *ResourceSecretStore) ResolveForSession(ctx context.Context, req ResolveResourceSecretRequest) (*ResolvedResourceSecret, error) {
if s == nil || s.encryptor == nil {
return nil, ErrSecretEncryptionKeyMissing
}
if strings.TrimSpace(req.LeaseID) == "" {
return nil, ErrSecretLeaseRequired
}
const query = `
SELECT sec.id::text, sec.organization_id::text, sec.resource_id::text, sec.secret_ref,
sec.protocol, sec.version, sec.key_id, sec.algorithm, sec.metadata,
sec.created_at, sec.rotated_at, sec.nonce, sec.ciphertext,
rs.organization_id::text, rs.resource_id::text, COALESCE(rs.worker_id, ''), rs.state
FROM resource_secrets sec
JOIN remote_sessions rs ON rs.resource_id = sec.resource_id
WHERE sec.secret_ref = $1 AND rs.id = $2::uuid
`
var descriptor ResourceSecretDescriptor
var nonce, ciphertext []byte
var sessionOrganizationID, sessionResourceID, sessionWorkerID, sessionState string
if err := s.db.QueryRow(ctx, query, req.SecretRef, req.SessionID).Scan(
&descriptor.ID,
&descriptor.OrganizationID,
&descriptor.ResourceID,
&descriptor.SecretRef,
&descriptor.Protocol,
&descriptor.Version,
&descriptor.KeyID,
&descriptor.Algorithm,
&descriptor.Metadata,
&descriptor.CreatedAt,
&descriptor.RotatedAt,
&nonce,
&ciphertext,
&sessionOrganizationID,
&sessionResourceID,
&sessionWorkerID,
&sessionState,
); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, ErrResourceSecretNotFound
}
return nil, fmt.Errorf("resolve resource secret: %w", err)
}
if descriptor.OrganizationID != req.OrganizationID ||
descriptor.ResourceID != req.ResourceID ||
sessionOrganizationID != req.OrganizationID ||
sessionResourceID != req.ResourceID ||
sessionWorkerID != req.WorkerID ||
!secretResolvableSessionState(sessionState) {
return nil, ErrSecretAccessDenied
}
plaintext, err := s.encryptor.Decrypt(EncryptedPayload{
Algorithm: descriptor.Algorithm,
KeyID: descriptor.KeyID,
Nonce: nonce,
Ciphertext: ciphertext,
}, ResourceSecretAAD(descriptor.OrganizationID, descriptor.ResourceID, descriptor.SecretRef, descriptor.Protocol))
if err != nil {
return nil, err
}
return &ResolvedResourceSecret{
Descriptor: descriptor,
Payload: json.RawMessage(plaintext),
}, nil
}
func normalizeJSONObject(raw json.RawMessage) (json.RawMessage, error) {
if len(raw) == 0 || !json.Valid(raw) {
return nil, ErrSecretPayloadInvalid
}
var decoded map[string]any
if err := json.Unmarshal(raw, &decoded); err != nil {
return nil, ErrSecretPayloadInvalid
}
encoded, err := json.Marshal(decoded)
if err != nil {
return nil, err
}
return json.RawMessage(encoded), nil
}
func normalizeJSONObjectAllowEmpty(raw json.RawMessage) (json.RawMessage, error) {
if len(raw) == 0 {
return json.RawMessage(`{}`), nil
}
return normalizeJSONObject(raw)
}
func secretResolvableSessionState(state string) bool {
switch state {
case "starting", "active", "reconnecting":
return true
default:
return false
}
}
+6
View File
@@ -0,0 +1,6 @@
DROP TABLE IF EXISTS audit_logs;
DROP TABLE IF EXISTS secrets;
DROP TABLE IF EXISTS sessions;
DROP TABLE IF EXISTS resources;
DROP TABLE IF EXISTS devices;
DROP TABLE IF EXISTS users;
+65
View File
@@ -0,0 +1,65 @@
CREATE EXTENSION IF NOT EXISTS "pgcrypto";
CREATE TABLE IF NOT EXISTS users (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
email TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
mfa_enabled BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS devices (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
device_fingerprint TEXT NOT NULL,
trusted_at TIMESTAMPTZ,
last_seen_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS resources (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
name TEXT NOT NULL,
address TEXT NOT NULL,
protocol TEXT NOT NULL,
secret_ref TEXT,
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS sessions (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
resource_id UUID NOT NULL REFERENCES resources(id) ON DELETE RESTRICT,
controller_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
worker_id TEXT,
state TEXT NOT NULL,
detached_until TIMESTAMPTZ,
last_heartbeat_at TIMESTAMPTZ,
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS secrets (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
scope TEXT NOT NULL,
encrypted_payload BYTEA NOT NULL,
key_version TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS audit_logs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
actor_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
action TEXT NOT NULL,
target_type TEXT NOT NULL,
target_id TEXT NOT NULL,
payload JSONB NOT NULL DEFAULT '{}'::JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_devices_user_id ON devices(user_id);
CREATE INDEX IF NOT EXISTS idx_sessions_resource_state ON sessions(resource_id, state);
CREATE INDEX IF NOT EXISTS idx_audit_logs_created_at ON audit_logs(created_at DESC);
@@ -0,0 +1,9 @@
DROP TABLE IF EXISTS auth_sessions;
DROP INDEX IF EXISTS idx_devices_user_fingerprint;
ALTER TABLE devices
DROP COLUMN IF EXISTS updated_at,
DROP COLUMN IF EXISTS revoked_reason,
DROP COLUMN IF EXISTS revoked_at,
DROP COLUMN IF EXISTS trust_status,
DROP COLUMN IF EXISTS device_label;
@@ -0,0 +1,32 @@
ALTER TABLE devices
ADD COLUMN IF NOT EXISTS device_label TEXT,
ADD COLUMN IF NOT EXISTS trust_status TEXT NOT NULL DEFAULT 'pending',
ADD COLUMN IF NOT EXISTS revoked_at TIMESTAMPTZ,
ADD COLUMN IF NOT EXISTS revoked_reason TEXT,
ADD COLUMN IF NOT EXISTS updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW();
CREATE UNIQUE INDEX IF NOT EXISTS idx_devices_user_fingerprint
ON devices(user_id, device_fingerprint);
CREATE TABLE IF NOT EXISTS auth_sessions (
id UUID PRIMARY KEY,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
device_id UUID NOT NULL REFERENCES devices(id) ON DELETE RESTRICT,
refresh_token_hash TEXT NOT NULL,
refresh_expires_at TIMESTAMPTZ NOT NULL,
last_seen_at TIMESTAMPTZ,
last_rotated_at TIMESTAMPTZ,
revoked_at TIMESTAMPTZ,
revoked_reason TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_auth_sessions_user_id
ON auth_sessions(user_id);
CREATE INDEX IF NOT EXISTS idx_auth_sessions_device_id
ON auth_sessions(device_id);
CREATE INDEX IF NOT EXISTS idx_auth_sessions_revoked_at
ON auth_sessions(revoked_at);
@@ -0,0 +1,4 @@
DROP TABLE IF EXISTS audit_events;
DROP TABLE IF EXISTS session_attachments;
DROP TABLE IF EXISTS remote_sessions;
DROP TABLE IF EXISTS resource_policies;
@@ -0,0 +1,79 @@
CREATE TABLE IF NOT EXISTS resource_policies (
resource_id UUID PRIMARY KEY REFERENCES resources(id) ON DELETE CASCADE,
max_concurrent_sessions INTEGER NOT NULL DEFAULT 1,
takeover_policy TEXT NOT NULL DEFAULT 'trusted_device',
require_trusted_device BOOLEAN NOT NULL DEFAULT TRUE,
detach_grace_period_seconds INTEGER NOT NULL DEFAULT 1800,
clipboard_enabled BOOLEAN NOT NULL DEFAULT TRUE,
file_transfer_enabled BOOLEAN NOT NULL DEFAULT TRUE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS remote_sessions (
id UUID PRIMARY KEY,
resource_id UUID NOT NULL REFERENCES resources(id) ON DELETE RESTRICT,
protocol TEXT NOT NULL,
state TEXT NOT NULL,
worker_id TEXT,
controller_user_id UUID NOT NULL REFERENCES users(id) ON DELETE RESTRICT,
detach_deadline_at TIMESTAMPTZ,
last_heartbeat_at TIMESTAMPTZ,
takeover_version INTEGER NOT NULL DEFAULT 1,
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_remote_sessions_resource_id
ON remote_sessions(resource_id);
CREATE INDEX IF NOT EXISTS idx_remote_sessions_controller_user_id
ON remote_sessions(controller_user_id);
CREATE INDEX IF NOT EXISTS idx_remote_sessions_state
ON remote_sessions(state);
CREATE TABLE IF NOT EXISTS session_attachments (
id UUID PRIMARY KEY,
remote_session_id UUID NOT NULL REFERENCES remote_sessions(id) ON DELETE CASCADE,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE RESTRICT,
device_id UUID NOT NULL REFERENCES devices(id) ON DELETE RESTRICT,
role TEXT NOT NULL,
state TEXT NOT NULL,
superseded_by UUID REFERENCES session_attachments(id) ON DELETE SET NULL,
takeover_of UUID REFERENCES session_attachments(id) ON DELETE SET NULL,
attached_at TIMESTAMPTZ,
detached_at TIMESTAMPTZ,
last_input_at TIMESTAMPTZ,
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_session_attachments_remote_session_id
ON session_attachments(remote_session_id);
CREATE INDEX IF NOT EXISTS idx_session_attachments_user_id
ON session_attachments(user_id);
CREATE INDEX IF NOT EXISTS idx_session_attachments_state
ON session_attachments(state);
CREATE TABLE IF NOT EXISTS audit_events (
id UUID PRIMARY KEY,
actor_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
actor_device_id UUID REFERENCES devices(id) ON DELETE SET NULL,
event_type TEXT NOT NULL,
target_type TEXT NOT NULL,
target_id TEXT NOT NULL,
remote_session_id UUID REFERENCES remote_sessions(id) ON DELETE SET NULL,
payload JSONB NOT NULL DEFAULT '{}'::JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_audit_events_created_at
ON audit_events(created_at DESC);
CREATE INDEX IF NOT EXISTS idx_audit_events_remote_session_id
ON audit_events(remote_session_id);
@@ -0,0 +1,5 @@
ALTER TABLE resources
DROP CONSTRAINT IF EXISTS resources_certificate_verification_mode_check;
ALTER TABLE resources
DROP COLUMN IF EXISTS certificate_verification_mode;
@@ -0,0 +1,6 @@
ALTER TABLE resources
ADD COLUMN certificate_verification_mode TEXT NOT NULL DEFAULT 'strict';
ALTER TABLE resources
ADD CONSTRAINT resources_certificate_verification_mode_check
CHECK (certificate_verification_mode IN ('strict', 'ignore'));
@@ -0,0 +1,26 @@
DROP TABLE IF EXISTS node_agent_update_runs;
DROP TABLE IF EXISTS node_partition_states;
DROP TABLE IF EXISTS node_update_policies;
DROP TABLE IF EXISTS node_services;
DROP TABLE IF EXISTS node_capabilities;
DROP TABLE IF EXISTS nodes;
DROP TABLE IF EXISTS identity_mappings;
DROP TABLE IF EXISTS identity_sources;
DROP INDEX IF EXISTS idx_remote_sessions_organization_id;
ALTER TABLE remote_sessions
DROP COLUMN IF EXISTS organization_id;
DROP INDEX IF EXISTS idx_resources_organization_id;
ALTER TABLE resources
DROP COLUMN IF EXISTS organization_id;
DROP TABLE IF EXISTS organization_memberships;
DROP TABLE IF EXISTS organization_roles;
DROP TABLE IF EXISTS organizations;
ALTER TABLE users
DROP CONSTRAINT IF EXISTS users_platform_role_check;
ALTER TABLE users
DROP COLUMN IF EXISTS platform_role;
@@ -0,0 +1,219 @@
ALTER TABLE users
ADD COLUMN IF NOT EXISTS platform_role TEXT NOT NULL DEFAULT 'user';
ALTER TABLE users
ADD CONSTRAINT users_platform_role_check
CHECK (platform_role IN ('user', 'platform_admin', 'platform_recovery_admin'));
CREATE TABLE IF NOT EXISTS organizations (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
slug TEXT NOT NULL UNIQUE,
name TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'active',
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT organizations_status_check
CHECK (status IN ('active', 'suspended', 'archived'))
);
CREATE TABLE IF NOT EXISTS organization_roles (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
scope TEXT NOT NULL DEFAULT 'organization',
description TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT organization_roles_scope_check
CHECK (scope IN ('organization'))
);
INSERT INTO organization_roles (id, name, scope, description)
VALUES
('org_owner', 'Organization Owner', 'organization', 'Full organization control including membership administration.'),
('org_admin', 'Organization Admin', 'organization', 'Administrative access within one organization.'),
('org_operator', 'Organization Operator', 'organization', 'Operational access within one organization.'),
('org_member', 'Organization Member', 'organization', 'Standard organization member access.'),
('org_viewer', 'Organization Viewer', 'organization', 'Read-only organization visibility.')
ON CONFLICT (id) DO UPDATE SET
name = EXCLUDED.name,
description = EXCLUDED.description;
CREATE TABLE IF NOT EXISTS organization_memberships (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
organization_id UUID NOT NULL REFERENCES organizations(id) ON DELETE CASCADE,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
role_id TEXT NOT NULL REFERENCES organization_roles(id) ON DELETE RESTRICT,
status TEXT NOT NULL DEFAULT 'active',
invited_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE (organization_id, user_id),
CONSTRAINT organization_memberships_status_check
CHECK (status IN ('active', 'suspended', 'revoked'))
);
CREATE INDEX IF NOT EXISTS idx_organization_memberships_user_id
ON organization_memberships(user_id);
INSERT INTO organizations (slug, name, status, metadata)
VALUES ('default', 'Default Organization', 'active', '{"bootstrap":true}'::jsonb)
ON CONFLICT (slug) DO NOTHING;
ALTER TABLE resources
ADD COLUMN IF NOT EXISTS organization_id UUID REFERENCES organizations(id) ON DELETE RESTRICT;
UPDATE resources
SET organization_id = organizations.id
FROM organizations
WHERE organizations.slug = 'default'
AND resources.organization_id IS NULL;
ALTER TABLE resources
ALTER COLUMN organization_id SET NOT NULL;
CREATE INDEX IF NOT EXISTS idx_resources_organization_id
ON resources(organization_id);
ALTER TABLE remote_sessions
ADD COLUMN IF NOT EXISTS organization_id UUID REFERENCES organizations(id) ON DELETE RESTRICT;
UPDATE remote_sessions
SET organization_id = resources.organization_id
FROM resources
WHERE resources.id = remote_sessions.resource_id
AND remote_sessions.organization_id IS NULL;
ALTER TABLE remote_sessions
ALTER COLUMN organization_id SET NOT NULL;
CREATE INDEX IF NOT EXISTS idx_remote_sessions_organization_id
ON remote_sessions(organization_id);
CREATE TABLE IF NOT EXISTS identity_sources (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
organization_id UUID NOT NULL REFERENCES organizations(id) ON DELETE CASCADE,
kind TEXT NOT NULL,
name TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'active',
config JSONB NOT NULL DEFAULT '{}'::JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT identity_sources_kind_check
CHECK (kind IN ('local', 'ldap_ad', 'oidc')),
CONSTRAINT identity_sources_status_check
CHECK (status IN ('active', 'disabled'))
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_identity_sources_org_name
ON identity_sources(organization_id, name);
CREATE TABLE IF NOT EXISTS identity_mappings (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
identity_source_id UUID NOT NULL REFERENCES identity_sources(id) ON DELETE CASCADE,
mapping_type TEXT NOT NULL,
external_selector JSONB NOT NULL DEFAULT '{}'::JSONB,
internal_target JSONB NOT NULL DEFAULT '{}'::JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT identity_mappings_type_check
CHECK (mapping_type IN ('group_binding', 'claim_binding', 'user_binding'))
);
CREATE INDEX IF NOT EXISTS idx_identity_mappings_source_id
ON identity_mappings(identity_source_id);
CREATE TABLE IF NOT EXISTS nodes (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
owner_organization_id UUID REFERENCES organizations(id) ON DELETE SET NULL,
node_key TEXT NOT NULL UNIQUE,
name TEXT NOT NULL,
ownership_type TEXT NOT NULL,
registration_status TEXT NOT NULL DEFAULT 'pending',
health_status TEXT NOT NULL DEFAULT 'unknown',
version_state TEXT NOT NULL DEFAULT 'unknown',
partition_state TEXT NOT NULL DEFAULT 'healthy',
desired_version TEXT,
reported_version TEXT,
last_seen_at TIMESTAMPTZ,
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT nodes_ownership_type_check
CHECK (ownership_type IN ('platform_managed', 'customer_managed')),
CONSTRAINT nodes_registration_status_check
CHECK (registration_status IN ('pending', 'active', 'disabled', 'revoked')),
CONSTRAINT nodes_health_status_check
CHECK (health_status IN ('unknown', 'healthy', 'warning', 'critical')),
CONSTRAINT nodes_version_state_check
CHECK (version_state IN ('unknown', 'current', 'outdated', 'updating', 'rollback', 'failed')),
CONSTRAINT nodes_partition_state_check
CHECK (partition_state IN ('healthy', 'degraded', 'recovery', 'isolated'))
);
CREATE TABLE IF NOT EXISTS node_capabilities (
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
capability TEXT NOT NULL,
value JSONB NOT NULL DEFAULT '{}'::JSONB,
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (node_id, capability)
);
CREATE TABLE IF NOT EXISTS node_services (
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
service_type TEXT NOT NULL,
enabled BOOLEAN NOT NULL DEFAULT FALSE,
desired_state TEXT NOT NULL DEFAULT 'disabled',
reported_state TEXT NOT NULL DEFAULT 'unknown',
last_reported_at TIMESTAMPTZ,
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (node_id, service_type),
CONSTRAINT node_services_desired_state_check
CHECK (desired_state IN ('enabled', 'disabled', 'drain')),
CONSTRAINT node_services_reported_state_check
CHECK (reported_state IN ('unknown', 'starting', 'running', 'degraded', 'stopped', 'failed'))
);
CREATE TABLE IF NOT EXISTS node_update_policies (
node_id UUID PRIMARY KEY REFERENCES nodes(id) ON DELETE CASCADE,
mode TEXT NOT NULL DEFAULT 'manual',
channel TEXT NOT NULL DEFAULT 'stable',
maintenance_window JSONB NOT NULL DEFAULT '{}'::JSONB,
canary BOOLEAN NOT NULL DEFAULT FALSE,
automatic_rollout BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT node_update_policies_mode_check
CHECK (mode IN ('manual', 'automatic', 'staged', 'canary'))
);
CREATE TABLE IF NOT EXISTS node_partition_states (
node_id UUID PRIMARY KEY REFERENCES nodes(id) ON DELETE CASCADE,
cluster_state TEXT NOT NULL DEFAULT 'healthy',
recovery_mode TEXT NOT NULL DEFAULT 'normal',
notes TEXT,
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT node_partition_states_cluster_state_check
CHECK (cluster_state IN ('healthy', 'degraded', 'recovery', 'isolated')),
CONSTRAINT node_partition_states_recovery_mode_check
CHECK (recovery_mode IN ('normal', 'manual_recovery', 'emergency'))
);
CREATE TABLE IF NOT EXISTS node_agent_update_runs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
action TEXT NOT NULL,
target_version TEXT,
status TEXT NOT NULL DEFAULT 'pending',
requested_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
acknowledged_at TIMESTAMPTZ,
completed_at TIMESTAMPTZ,
payload JSONB NOT NULL DEFAULT '{}'::JSONB,
CONSTRAINT node_agent_update_runs_action_check
CHECK (action IN ('update', 'rollback')),
CONSTRAINT node_agent_update_runs_status_check
CHECK (status IN ('pending', 'acknowledged', 'succeeded', 'failed'))
);
CREATE INDEX IF NOT EXISTS idx_node_agent_update_runs_node_id
ON node_agent_update_runs(node_id, requested_at DESC);
@@ -0,0 +1,7 @@
DELETE FROM organization_memberships
WHERE organization_id IN (
SELECT id
FROM organizations
WHERE slug = 'default'
)
AND invited_by_user_id IS NULL;
@@ -0,0 +1,29 @@
INSERT INTO organization_memberships (
id,
organization_id,
user_id,
role_id,
status,
invited_by_user_id,
created_at,
updated_at
)
SELECT
gen_random_uuid(),
org.id,
u.id,
CASE
WHEN u.platform_role IN ('platform_admin', 'platform_recovery_admin') THEN 'org_owner'
ELSE 'org_member'
END,
'active',
NULL,
NOW(),
NOW()
FROM users u
CROSS JOIN organizations org
LEFT JOIN organization_memberships om
ON om.organization_id = org.id
AND om.user_id = u.id
WHERE org.slug = 'default'
AND om.id IS NULL;
@@ -0,0 +1,5 @@
ALTER TABLE resource_policies
DROP CONSTRAINT IF EXISTS resource_policies_clipboard_mode_check;
ALTER TABLE resource_policies
DROP COLUMN IF EXISTS clipboard_mode;
@@ -0,0 +1,16 @@
ALTER TABLE resource_policies
ADD COLUMN IF NOT EXISTS clipboard_mode TEXT NOT NULL DEFAULT 'disabled';
UPDATE resource_policies
SET clipboard_mode = CASE
WHEN clipboard_enabled THEN 'bidirectional'
ELSE 'disabled'
END
WHERE clipboard_mode = 'disabled';
ALTER TABLE resource_policies
DROP CONSTRAINT IF EXISTS resource_policies_clipboard_mode_check;
ALTER TABLE resource_policies
ADD CONSTRAINT resource_policies_clipboard_mode_check
CHECK (clipboard_mode IN ('disabled', 'client_to_server', 'server_to_client', 'bidirectional'));
@@ -0,0 +1,5 @@
ALTER TABLE resource_policies
DROP CONSTRAINT IF EXISTS resource_policies_file_transfer_mode_check;
ALTER TABLE resource_policies
DROP COLUMN IF EXISTS file_transfer_mode;
@@ -0,0 +1,16 @@
ALTER TABLE resource_policies
ADD COLUMN IF NOT EXISTS file_transfer_mode TEXT NOT NULL DEFAULT 'disabled';
UPDATE resource_policies
SET file_transfer_mode = 'disabled',
file_transfer_enabled = FALSE
WHERE file_transfer_mode IS NULL
OR file_transfer_mode = ''
OR file_transfer_mode = 'disabled';
ALTER TABLE resource_policies
DROP CONSTRAINT IF EXISTS resource_policies_file_transfer_mode_check;
ALTER TABLE resource_policies
ADD CONSTRAINT resource_policies_file_transfer_mode_check
CHECK (file_transfer_mode IN ('disabled', 'client_to_server', 'server_to_client', 'bidirectional'));
@@ -0,0 +1 @@
DROP TABLE IF EXISTS resource_secrets;
@@ -0,0 +1,27 @@
CREATE TABLE IF NOT EXISTS resource_secrets (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
organization_id UUID NOT NULL REFERENCES organizations(id) ON DELETE CASCADE,
resource_id UUID NOT NULL REFERENCES resources(id) ON DELETE CASCADE,
secret_ref TEXT NOT NULL UNIQUE,
protocol TEXT NOT NULL,
version INTEGER NOT NULL DEFAULT 1,
key_id TEXT NOT NULL,
algorithm TEXT NOT NULL DEFAULT 'AES-256-GCM',
nonce BYTEA NOT NULL,
ciphertext BYTEA NOT NULL,
payload_sha256 TEXT NOT NULL,
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
created_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
rotated_at TIMESTAMPTZ,
UNIQUE (resource_id)
);
CREATE INDEX IF NOT EXISTS idx_resource_secrets_organization_id
ON resource_secrets(organization_id);
CREATE INDEX IF NOT EXISTS idx_resource_secrets_resource_id
ON resource_secrets(resource_id);
CREATE INDEX IF NOT EXISTS idx_resource_secrets_secret_ref
ON resource_secrets(secret_ref);
@@ -0,0 +1,9 @@
DROP TABLE IF EXISTS cluster_audit_events;
DROP TABLE IF EXISTS node_latest_heartbeats;
DROP TABLE IF EXISTS node_heartbeats;
DROP TABLE IF EXISTS node_role_assignments;
DROP TABLE IF EXISTS node_join_requests;
DROP TABLE IF EXISTS node_join_tokens;
DROP TABLE IF EXISTS node_identities;
DROP TABLE IF EXISTS cluster_memberships;
DROP TABLE IF EXISTS clusters;
@@ -0,0 +1,186 @@
CREATE TABLE IF NOT EXISTS clusters (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
slug TEXT NOT NULL UNIQUE,
name TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'active',
region TEXT,
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT clusters_status_check
CHECK (status IN ('active', 'disabled', 'archived', 'degraded'))
);
INSERT INTO clusters (slug, name, status, region, metadata)
VALUES ('default', 'Default Cluster', 'active', NULL, '{"bootstrap":true}'::jsonb)
ON CONFLICT (slug) DO NOTHING;
CREATE TABLE IF NOT EXISTS cluster_memberships (
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
membership_status TEXT NOT NULL DEFAULT 'active',
joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
last_seen_at TIMESTAMPTZ,
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
PRIMARY KEY (cluster_id, node_id),
CONSTRAINT cluster_memberships_status_check
CHECK (membership_status IN ('active', 'draining', 'disabled', 'revoked'))
);
INSERT INTO cluster_memberships (cluster_id, node_id, membership_status, joined_at, last_seen_at, metadata)
SELECT c.id, n.id, 'active', COALESCE(n.created_at, NOW()), n.last_seen_at, '{"backfilled":true}'::jsonb
FROM clusters c
CROSS JOIN nodes n
WHERE c.slug = 'default'
ON CONFLICT (cluster_id, node_id) DO NOTHING;
CREATE TABLE IF NOT EXISTS node_identities (
node_id UUID PRIMARY KEY REFERENCES nodes(id) ON DELETE CASCADE,
public_key TEXT NOT NULL,
certificate_serial TEXT,
certificate_not_before TIMESTAMPTZ,
certificate_not_after TIMESTAMPTZ,
identity_status TEXT NOT NULL DEFAULT 'pending',
rotated_at TIMESTAMPTZ,
revoked_at TIMESTAMPTZ,
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT node_identities_status_check
CHECK (identity_status IN ('pending', 'active', 'rotating', 'revoked'))
);
CREATE TABLE IF NOT EXISTS node_join_tokens (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
token_hash TEXT NOT NULL UNIQUE,
scope JSONB NOT NULL DEFAULT '{}'::JSONB,
expires_at TIMESTAMPTZ NOT NULL,
max_uses INTEGER NOT NULL DEFAULT 1,
used_count INTEGER NOT NULL DEFAULT 0,
status TEXT NOT NULL DEFAULT 'active',
created_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
revoked_at TIMESTAMPTZ,
CONSTRAINT node_join_tokens_status_check
CHECK (status IN ('active', 'revoked', 'expired')),
CONSTRAINT node_join_tokens_max_uses_check
CHECK (max_uses > 0),
CONSTRAINT node_join_tokens_used_count_check
CHECK (used_count >= 0)
);
CREATE INDEX IF NOT EXISTS idx_node_join_tokens_cluster_status
ON node_join_tokens(cluster_id, status, expires_at);
CREATE TABLE IF NOT EXISTS node_join_requests (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
join_token_id UUID REFERENCES node_join_tokens(id) ON DELETE SET NULL,
node_name TEXT NOT NULL,
node_fingerprint TEXT NOT NULL,
public_key TEXT NOT NULL,
reported_capabilities JSONB NOT NULL DEFAULT '{}'::JSONB,
reported_facts JSONB NOT NULL DEFAULT '{}'::JSONB,
requested_roles JSONB NOT NULL DEFAULT '[]'::JSONB,
status TEXT NOT NULL DEFAULT 'pending',
reviewed_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
reviewed_at TIMESTAMPTZ,
approved_node_id UUID REFERENCES nodes(id) ON DELETE SET NULL,
rejection_reason TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT node_join_requests_status_check
CHECK (status IN ('pending', 'approved', 'rejected', 'cancelled'))
);
CREATE INDEX IF NOT EXISTS idx_node_join_requests_cluster_status
ON node_join_requests(cluster_id, status, created_at DESC);
CREATE UNIQUE INDEX IF NOT EXISTS idx_node_join_requests_pending_fingerprint
ON node_join_requests(cluster_id, node_fingerprint)
WHERE status = 'pending';
CREATE TABLE IF NOT EXISTS node_role_assignments (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
organization_id UUID REFERENCES organizations(id) ON DELETE CASCADE,
role TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'active',
policy JSONB NOT NULL DEFAULT '{}'::JSONB,
assigned_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
assigned_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
revoked_at TIMESTAMPTZ,
CONSTRAINT node_role_assignments_status_check
CHECK (status IN ('active', 'disabled', 'revoked')),
CONSTRAINT node_role_assignments_role_check
CHECK (role IN (
'entry-node',
'relay-node',
'core-mesh',
'rdp-worker',
'vnc-worker',
'vpn-exit',
'vpn-connector',
'file-storage-cache',
'update-cache',
'video-relay'
))
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_node_role_assignments_unique_active
ON node_role_assignments(cluster_id, node_id, role, COALESCE(organization_id, '00000000-0000-0000-0000-000000000000'::uuid))
WHERE status = 'active';
CREATE INDEX IF NOT EXISTS idx_node_role_assignments_cluster
ON node_role_assignments(cluster_id, role, status);
CREATE TABLE IF NOT EXISTS node_heartbeats (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
health_status TEXT NOT NULL DEFAULT 'unknown',
reported_version TEXT,
capabilities JSONB NOT NULL DEFAULT '{}'::JSONB,
service_states JSONB NOT NULL DEFAULT '{}'::JSONB,
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
observed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT node_heartbeats_health_status_check
CHECK (health_status IN ('unknown', 'healthy', 'warning', 'critical'))
);
CREATE INDEX IF NOT EXISTS idx_node_heartbeats_cluster_node_observed
ON node_heartbeats(cluster_id, node_id, observed_at DESC);
CREATE TABLE IF NOT EXISTS node_latest_heartbeats (
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
heartbeat_id UUID REFERENCES node_heartbeats(id) ON DELETE SET NULL,
health_status TEXT NOT NULL DEFAULT 'unknown',
reported_version TEXT,
capabilities JSONB NOT NULL DEFAULT '{}'::JSONB,
service_states JSONB NOT NULL DEFAULT '{}'::JSONB,
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
observed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (cluster_id, node_id),
CONSTRAINT node_latest_heartbeats_health_status_check
CHECK (health_status IN ('unknown', 'healthy', 'warning', 'critical'))
);
CREATE TABLE IF NOT EXISTS cluster_audit_events (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
cluster_id UUID REFERENCES clusters(id) ON DELETE SET NULL,
actor_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
event_type TEXT NOT NULL,
target_type TEXT NOT NULL,
target_id TEXT,
payload JSONB NOT NULL DEFAULT '{}'::JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_cluster_audit_events_cluster_created
ON cluster_audit_events(cluster_id, created_at DESC);
CREATE INDEX IF NOT EXISTS idx_cluster_audit_events_type_created
ON cluster_audit_events(event_type, created_at DESC);
@@ -0,0 +1,3 @@
DROP TABLE IF EXISTS node_workload_latest_statuses;
DROP TABLE IF EXISTS node_workload_status_reports;
DROP TABLE IF EXISTS node_workload_desired_states;
@@ -0,0 +1,57 @@
CREATE TABLE IF NOT EXISTS node_workload_desired_states (
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
service_type TEXT NOT NULL,
desired_state TEXT NOT NULL DEFAULT 'disabled',
version TEXT,
runtime_mode TEXT NOT NULL DEFAULT 'container',
artifact_ref TEXT,
config JSONB NOT NULL DEFAULT '{}'::JSONB,
environment JSONB NOT NULL DEFAULT '{}'::JSONB,
updated_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (cluster_id, node_id, service_type),
CONSTRAINT node_workload_desired_states_desired_state_check
CHECK (desired_state IN ('enabled', 'disabled', 'drain')),
CONSTRAINT node_workload_desired_states_runtime_mode_check
CHECK (runtime_mode IN ('native', 'container'))
);
CREATE INDEX IF NOT EXISTS idx_node_workload_desired_states_cluster
ON node_workload_desired_states(cluster_id, service_type, desired_state);
CREATE TABLE IF NOT EXISTS node_workload_status_reports (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
service_type TEXT NOT NULL,
reported_state TEXT NOT NULL DEFAULT 'unknown',
runtime_mode TEXT NOT NULL DEFAULT 'container',
version TEXT,
status_payload JSONB NOT NULL DEFAULT '{}'::JSONB,
observed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT node_workload_status_reports_reported_state_check
CHECK (reported_state IN ('unknown', 'starting', 'running', 'degraded', 'stopped', 'failed', 'not_implemented')),
CONSTRAINT node_workload_status_reports_runtime_mode_check
CHECK (runtime_mode IN ('native', 'container'))
);
CREATE INDEX IF NOT EXISTS idx_node_workload_status_reports_node_observed
ON node_workload_status_reports(cluster_id, node_id, observed_at DESC);
CREATE TABLE IF NOT EXISTS node_workload_latest_statuses (
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
service_type TEXT NOT NULL,
status_report_id UUID REFERENCES node_workload_status_reports(id) ON DELETE SET NULL,
reported_state TEXT NOT NULL DEFAULT 'unknown',
runtime_mode TEXT NOT NULL DEFAULT 'container',
version TEXT,
status_payload JSONB NOT NULL DEFAULT '{}'::JSONB,
observed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (cluster_id, node_id, service_type),
CONSTRAINT node_workload_latest_statuses_reported_state_check
CHECK (reported_state IN ('unknown', 'starting', 'running', 'degraded', 'stopped', 'failed', 'not_implemented')),
CONSTRAINT node_workload_latest_statuses_runtime_mode_check
CHECK (runtime_mode IN ('native', 'container'))
);
@@ -0,0 +1,4 @@
DROP TABLE IF EXISTS mesh_qos_policies;
DROP TABLE IF EXISTS mesh_route_intents;
DROP TABLE IF EXISTS mesh_latest_links;
DROP TABLE IF EXISTS mesh_link_observations;
@@ -0,0 +1,94 @@
CREATE TABLE IF NOT EXISTS mesh_link_observations (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
source_node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
target_node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
link_status TEXT NOT NULL DEFAULT 'unknown',
latency_ms INTEGER,
quality_score INTEGER,
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
observed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT mesh_link_observations_status_check
CHECK (link_status IN ('unknown', 'reachable', 'degraded', 'unreachable')),
CONSTRAINT mesh_link_observations_quality_check
CHECK (quality_score IS NULL OR (quality_score >= 0 AND quality_score <= 100))
);
CREATE INDEX IF NOT EXISTS idx_mesh_link_observations_cluster_observed
ON mesh_link_observations(cluster_id, observed_at DESC);
CREATE TABLE IF NOT EXISTS mesh_latest_links (
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
source_node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
target_node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
observation_id UUID REFERENCES mesh_link_observations(id) ON DELETE SET NULL,
link_status TEXT NOT NULL DEFAULT 'unknown',
latency_ms INTEGER,
quality_score INTEGER,
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
observed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (cluster_id, source_node_id, target_node_id),
CONSTRAINT mesh_latest_links_status_check
CHECK (link_status IN ('unknown', 'reachable', 'degraded', 'unreachable')),
CONSTRAINT mesh_latest_links_quality_check
CHECK (quality_score IS NULL OR (quality_score >= 0 AND quality_score <= 100))
);
CREATE TABLE IF NOT EXISTS mesh_route_intents (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
source_selector JSONB NOT NULL DEFAULT '{}'::JSONB,
destination_selector JSONB NOT NULL DEFAULT '{}'::JSONB,
service_class TEXT NOT NULL,
priority INTEGER NOT NULL DEFAULT 100,
status TEXT NOT NULL DEFAULT 'active',
policy JSONB NOT NULL DEFAULT '{}'::JSONB,
created_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT mesh_route_intents_service_class_check
CHECK (service_class IN ('input', 'control', 'render', 'clipboard', 'file_transfer', 'vpn_packets', 'telemetry')),
CONSTRAINT mesh_route_intents_status_check
CHECK (status IN ('active', 'disabled'))
);
CREATE INDEX IF NOT EXISTS idx_mesh_route_intents_cluster_class
ON mesh_route_intents(cluster_id, service_class, status);
CREATE TABLE IF NOT EXISTS mesh_qos_policies (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
service_class TEXT NOT NULL,
priority INTEGER NOT NULL,
reliability_mode TEXT NOT NULL,
drop_policy TEXT NOT NULL,
bandwidth_policy JSONB NOT NULL DEFAULT '{}'::JSONB,
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE (cluster_id, service_class),
CONSTRAINT mesh_qos_policies_service_class_check
CHECK (service_class IN ('input', 'control', 'render', 'clipboard', 'file_transfer', 'vpn_packets', 'telemetry')),
CONSTRAINT mesh_qos_policies_reliability_check
CHECK (reliability_mode IN ('reliable', 'droppable', 'adaptive')),
CONSTRAINT mesh_qos_policies_drop_policy_check
CHECK (drop_policy IN ('never', 'latest_only', 'adaptive'))
);
INSERT INTO mesh_qos_policies (
cluster_id, service_class, priority, reliability_mode, drop_policy, bandwidth_policy, metadata
)
SELECT c.id, defaults.service_class, defaults.priority, defaults.reliability_mode,
defaults.drop_policy, '{}'::jsonb, '{"default":true}'::jsonb
FROM clusters c
CROSS JOIN (
VALUES
('input', 10, 'reliable', 'never'),
('control', 20, 'reliable', 'never'),
('clipboard', 40, 'reliable', 'never'),
('render', 60, 'droppable', 'latest_only'),
('file_transfer', 80, 'reliable', 'never'),
('telemetry', 120, 'adaptive', 'adaptive'),
('vpn_packets', 160, 'adaptive', 'adaptive')
) AS defaults(service_class, priority, reliability_mode, drop_policy)
ON CONFLICT (cluster_id, service_class) DO NOTHING;
@@ -0,0 +1,2 @@
DROP VIEW IF EXISTS cluster_admin_summaries;
DROP TABLE IF EXISTS cluster_authority_states;
@@ -0,0 +1,40 @@
CREATE TABLE IF NOT EXISTS cluster_authority_states (
cluster_id UUID PRIMARY KEY REFERENCES clusters(id) ON DELETE CASCADE,
authority_state TEXT NOT NULL DEFAULT 'authoritative',
mutation_mode TEXT NOT NULL DEFAULT 'normal',
term BIGINT NOT NULL DEFAULT 1,
notes TEXT,
updated_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT cluster_authority_states_authority_check
CHECK (authority_state IN ('authoritative', 'minority', 'isolated', 'recovery')),
CONSTRAINT cluster_authority_states_mutation_check
CHECK (mutation_mode IN ('normal', 'read_only', 'recovery_override'))
);
INSERT INTO cluster_authority_states (cluster_id, authority_state, mutation_mode, term, notes)
SELECT id, 'authoritative', 'normal', 1, 'default authority state'
FROM clusters
ON CONFLICT (cluster_id) DO NOTHING;
CREATE OR REPLACE VIEW cluster_admin_summaries AS
SELECT
c.id AS cluster_id,
c.slug,
c.name,
c.status,
c.region,
COALESCE(cas.authority_state, 'authoritative') AS authority_state,
COALESCE(cas.mutation_mode, 'normal') AS mutation_mode,
COUNT(DISTINCT cm.node_id) AS node_count,
COUNT(DISTINCT CASE WHEN n.health_status = 'healthy' THEN n.id END) AS healthy_node_count,
COUNT(DISTINCT CASE WHEN njr.status = 'pending' THEN njr.id END) AS pending_join_count,
COUNT(DISTINCT nra.id) AS active_role_assignment_count,
MAX(n.last_seen_at) AS last_node_seen_at
FROM clusters c
LEFT JOIN cluster_authority_states cas ON cas.cluster_id = c.id
LEFT JOIN cluster_memberships cm ON cm.cluster_id = c.id
LEFT JOIN nodes n ON n.id = cm.node_id
LEFT JOIN node_join_requests njr ON njr.cluster_id = c.id
LEFT JOIN node_role_assignments nra ON nra.cluster_id = c.id AND nra.status = 'active'
GROUP BY c.id, c.slug, c.name, c.status, c.region, cas.authority_state, cas.mutation_mode;
@@ -0,0 +1,4 @@
DROP TABLE IF EXISTS vpn_connection_leases;
DROP TABLE IF EXISTS vpn_connection_route_policies;
DROP TABLE IF EXISTS vpn_connection_allowed_nodes;
DROP TABLE IF EXISTS vpn_connections;
@@ -0,0 +1,125 @@
CREATE TABLE IF NOT EXISTS vpn_connections (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
organization_id UUID NOT NULL REFERENCES organizations(id) ON DELETE CASCADE,
name TEXT NOT NULL,
target_endpoint JSONB NOT NULL DEFAULT '{}'::JSONB,
protocol_family TEXT NOT NULL DEFAULT 'generic',
credential_ref TEXT,
mode TEXT NOT NULL DEFAULT 'single_active',
desired_state TEXT NOT NULL DEFAULT 'disabled',
allowed_node_policy JSONB NOT NULL DEFAULT '{"mode":"explicit","node_ids":[]}'::JSONB,
routing_usage JSONB NOT NULL DEFAULT '[]'::JSONB,
route_policy JSONB NOT NULL DEFAULT '{}'::JSONB,
qos_policy JSONB NOT NULL DEFAULT '{}'::JSONB,
placement_policy JSONB NOT NULL DEFAULT '{}'::JSONB,
status TEXT NOT NULL DEFAULT 'disabled',
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
created_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
updated_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT vpn_connections_mode_check
CHECK (mode IN ('single_active')),
CONSTRAINT vpn_connections_desired_state_check
CHECK (desired_state IN ('enabled', 'disabled')),
CONSTRAINT vpn_connections_status_check
CHECK (status IN ('disabled', 'enabled', 'connecting', 'active', 'degraded', 'failed')),
CONSTRAINT vpn_connections_protocol_family_check
CHECK (protocol_family IN ('generic', 'wireguard', 'ipsec', 'openvpn'))
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_vpn_connections_cluster_org_name
ON vpn_connections(cluster_id, organization_id, lower(name));
CREATE INDEX IF NOT EXISTS idx_vpn_connections_cluster_org_state
ON vpn_connections(cluster_id, organization_id, desired_state, status);
CREATE TABLE IF NOT EXISTS vpn_connection_allowed_nodes (
vpn_connection_id UUID NOT NULL REFERENCES vpn_connections(id) ON DELETE CASCADE,
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
role_preference TEXT NOT NULL DEFAULT 'candidate',
status TEXT NOT NULL DEFAULT 'active',
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
created_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (vpn_connection_id, node_id),
CONSTRAINT vpn_connection_allowed_nodes_membership_fk
FOREIGN KEY (cluster_id, node_id)
REFERENCES cluster_memberships(cluster_id, node_id)
ON DELETE CASCADE,
CONSTRAINT vpn_connection_allowed_nodes_preference_check
CHECK (role_preference IN ('candidate', 'standby', 'preferred')),
CONSTRAINT vpn_connection_allowed_nodes_status_check
CHECK (status IN ('active', 'disabled'))
);
CREATE INDEX IF NOT EXISTS idx_vpn_connection_allowed_nodes_cluster_node
ON vpn_connection_allowed_nodes(cluster_id, node_id, status);
CREATE TABLE IF NOT EXISTS vpn_connection_route_policies (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
vpn_connection_id UUID NOT NULL REFERENCES vpn_connections(id) ON DELETE CASCADE,
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
organization_id UUID NOT NULL REFERENCES organizations(id) ON DELETE CASCADE,
route_type TEXT NOT NULL,
destination TEXT NOT NULL,
action TEXT NOT NULL DEFAULT 'allow',
service_type TEXT,
priority INTEGER NOT NULL DEFAULT 100,
policy JSONB NOT NULL DEFAULT '{}'::JSONB,
status TEXT NOT NULL DEFAULT 'active',
created_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT vpn_connection_route_policies_route_type_check
CHECK (route_type IN ('cidr', 'dns_suffix', 'service', 'resource')),
CONSTRAINT vpn_connection_route_policies_action_check
CHECK (action IN ('allow', 'deny')),
CONSTRAINT vpn_connection_route_policies_status_check
CHECK (status IN ('active', 'disabled'))
);
CREATE INDEX IF NOT EXISTS idx_vpn_connection_route_policies_connection
ON vpn_connection_route_policies(vpn_connection_id, status, priority);
CREATE INDEX IF NOT EXISTS idx_vpn_connection_route_policies_cluster_org
ON vpn_connection_route_policies(cluster_id, organization_id, route_type, status);
CREATE TABLE IF NOT EXISTS vpn_connection_leases (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
vpn_connection_id UUID NOT NULL REFERENCES vpn_connections(id) ON DELETE CASCADE,
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
owner_node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
lease_generation BIGINT NOT NULL,
fencing_token TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'active',
acquired_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
renewed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
expires_at TIMESTAMPTZ NOT NULL,
released_at TIMESTAMPTZ,
fenced_at TIMESTAMPTZ,
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
CONSTRAINT vpn_connection_leases_membership_fk
FOREIGN KEY (cluster_id, owner_node_id)
REFERENCES cluster_memberships(cluster_id, node_id)
ON DELETE CASCADE,
CONSTRAINT vpn_connection_leases_status_check
CHECK (status IN ('active', 'released', 'expired', 'fenced')),
CONSTRAINT vpn_connection_leases_expiry_check
CHECK (expires_at > acquired_at)
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_vpn_connection_leases_generation
ON vpn_connection_leases(vpn_connection_id, lease_generation);
CREATE UNIQUE INDEX IF NOT EXISTS idx_vpn_connection_leases_single_active
ON vpn_connection_leases(vpn_connection_id)
WHERE status = 'active';
CREATE INDEX IF NOT EXISTS idx_vpn_connection_leases_owner
ON vpn_connection_leases(cluster_id, owner_node_id, status, expires_at);
CREATE INDEX IF NOT EXISTS idx_vpn_connection_leases_connection_time
ON vpn_connection_leases(vpn_connection_id, acquired_at DESC);
@@ -0,0 +1,2 @@
DROP TABLE IF EXISTS vpn_connection_assignment_latest_statuses;
DROP TABLE IF EXISTS vpn_connection_assignment_status_reports;
@@ -0,0 +1,42 @@
CREATE TABLE IF NOT EXISTS vpn_connection_assignment_status_reports (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
vpn_connection_id UUID NOT NULL REFERENCES vpn_connections(id) ON DELETE CASCADE,
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
observed_status TEXT NOT NULL DEFAULT 'unknown',
status_payload JSONB NOT NULL DEFAULT '{}'::JSONB,
observed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT vpn_connection_assignment_status_membership_fk
FOREIGN KEY (cluster_id, node_id)
REFERENCES cluster_memberships(cluster_id, node_id)
ON DELETE CASCADE,
CONSTRAINT vpn_connection_assignment_status_check
CHECK (observed_status IN ('not_started', 'assigned', 'lease_required', 'blocked', 'unknown'))
);
CREATE INDEX IF NOT EXISTS idx_vpn_assignment_status_reports_node
ON vpn_connection_assignment_status_reports(cluster_id, node_id, observed_at DESC);
CREATE INDEX IF NOT EXISTS idx_vpn_assignment_status_reports_connection
ON vpn_connection_assignment_status_reports(vpn_connection_id, observed_at DESC);
CREATE TABLE IF NOT EXISTS vpn_connection_assignment_latest_statuses (
vpn_connection_id UUID NOT NULL REFERENCES vpn_connections(id) ON DELETE CASCADE,
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
report_id UUID NOT NULL REFERENCES vpn_connection_assignment_status_reports(id) ON DELETE CASCADE,
observed_status TEXT NOT NULL,
status_payload JSONB NOT NULL DEFAULT '{}'::JSONB,
observed_at TIMESTAMPTZ NOT NULL,
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (vpn_connection_id, node_id),
CONSTRAINT vpn_connection_assignment_latest_membership_fk
FOREIGN KEY (cluster_id, node_id)
REFERENCES cluster_memberships(cluster_id, node_id)
ON DELETE CASCADE,
CONSTRAINT vpn_connection_assignment_latest_status_check
CHECK (observed_status IN ('not_started', 'assigned', 'lease_required', 'blocked', 'unknown'))
);
CREATE INDEX IF NOT EXISTS idx_vpn_assignment_latest_node
ON vpn_connection_assignment_latest_statuses(cluster_id, node_id, updated_at DESC);
@@ -0,0 +1,5 @@
DROP INDEX IF EXISTS idx_node_telemetry_cluster_node_observed;
DROP TABLE IF EXISTS node_telemetry_observations;
DROP INDEX IF EXISTS idx_fabric_testing_flags_unique_scope;
DROP TABLE IF EXISTS fabric_testing_flags;

Some files were not shown because too many files have changed in this diff Show More