Initial project snapshot
This commit is contained in:
@@ -0,0 +1,19 @@
|
||||
FROM golang:1.23-bookworm AS build
|
||||
|
||||
WORKDIR /src
|
||||
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY . .
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o /out/rap-api ./cmd/api
|
||||
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=build /out/rap-api /usr/local/bin/rap-api
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/rap-api"]
|
||||
@@ -0,0 +1,270 @@
|
||||
# Backend Foundation
|
||||
|
||||
Production-oriented Go backend skeleton for the remote access platform.
|
||||
|
||||
## Scope included
|
||||
|
||||
- configuration loading from environment
|
||||
- HTTP server bootstrap with graceful shutdown
|
||||
- PostgreSQL and Redis connectivity wiring
|
||||
- migrations scaffold
|
||||
- auth foundation with access/refresh tokens, hashed refresh rotation, trusted devices, and persisted auth sessions
|
||||
- persistent session storage foundation for remote sessions, attachments, resource policies, and audit events
|
||||
- session broker orchestration for start, attach, detach, takeover, terminate, failure, and detached-session recovery
|
||||
- Redis-backed live session state, controller binding, attach tokens, heartbeat keys, worker routing, and reconnect support
|
||||
- Redis-backed worker registration, lease lifecycle, heartbeat tracking, stale lease recovery, and routing queues
|
||||
- worker assignment queueing and worker event ingestion for the minimal real RDP worker runtime
|
||||
- websocket live plane with attach handshake, ping/pong heartbeat, state messages, takeover detection, and transport reconnect flow
|
||||
- module boundaries for auth, resources, session broker, and websocket gateway
|
||||
- worker registry scaffold to prepare later RDP worker integration
|
||||
- per-resource certificate verification policy for RDP connections with `strict` default and explicit `ignore` override
|
||||
- platform-core v2 foundations for organizations, memberships, identity sources, nodes, and node-agent control plane
|
||||
- Data Plane v1 contract scaffolding for optional session response candidates/tokens, with current backend gateway behavior preserved as fallback
|
||||
- production resource secret-readiness guard for rejecting plaintext credential-like metadata and requiring `secret_ref` for RDP/VNC/SSH resources in production mode
|
||||
- encrypted resource secret storage/resolver MVP for production `secret_ref` usage
|
||||
|
||||
## Entry point
|
||||
|
||||
Run the API from `cmd/api`.
|
||||
|
||||
## Local dev
|
||||
|
||||
- backend: `pwsh -File scripts/smoke/run-backend.ps1`
|
||||
- infra: `pwsh -File scripts/smoke/start-infra.ps1`
|
||||
- migrations: `pwsh -File scripts/smoke/apply-migrations.ps1`
|
||||
- worker image build: `docker build --tag rap-rdp-worker:dev --file workers/rdp-worker/Dockerfile workers/rdp-worker`
|
||||
- end-to-end smoke path: [scripts/smoke/README.md](/\\?\UNC\192.168.220.200\mst\codex\rdp-proxy\scripts\smoke\README.md)
|
||||
|
||||
## Configuration
|
||||
|
||||
Use `configs/api.example.env` as the starting point for local environment variables.
|
||||
|
||||
Resource secret-readiness is controlled by `APP_ENV`:
|
||||
|
||||
- in `APP_ENV=production` or `APP_ENV=prod`, RDP/VNC/SSH resources must carry
|
||||
`secret_ref` and must not include plaintext credential-like fields in
|
||||
`metadata`
|
||||
- in development and smoke environments, plaintext metadata remains allowed
|
||||
until the encrypted secret resolver is implemented
|
||||
- the production guard is enforced both on resource create/update and on
|
||||
session start, so legacy plaintext resources cannot be started in production
|
||||
accidentally
|
||||
- `SECRET_ENCRYPTION_KEY_B64` or `SECRET_ENCRYPTION_KEY_FILE` supplies the
|
||||
AES-256-GCM master key for the MVP encrypted store; production mode refuses
|
||||
to start without one
|
||||
- `SECRET_ENCRYPTION_KEY_ID` labels the active key version in stored records
|
||||
- `PUT /api/v1/resources/{resourceID}/secret` creates or rotates a resource
|
||||
secret and updates `resources.secret_ref`; plaintext is never returned by the
|
||||
API
|
||||
- session assignment keeps PostgreSQL metadata safe: `remote_sessions.metadata`
|
||||
stores `secret_ref`, while resolved credentials are merged only into the
|
||||
transient worker assignment after session/worker/lease checks
|
||||
|
||||
See `docs/architecture/SECURITY_SECRETS_READINESS.md` for the target
|
||||
secret-reference model and remaining resolver/PKI gaps.
|
||||
|
||||
Data Plane v1 contract scaffolding is controlled by:
|
||||
|
||||
- `DATA_PLANE_TOKEN_TTL`, default `1m`
|
||||
- `DATA_PLANE_TOKEN_PRIVATE_KEY_FILE`, optional path to an RSA private key PEM used to sign RS256 data-plane tokens
|
||||
- `DATA_PLANE_TOKEN_PRIVATE_KEY_PEM`, optional inline RSA private key PEM; used when file path is not configured
|
||||
- `DATA_PLANE_BACKEND_GATEWAY_URL`, default `/api/v1/gateway/ws`
|
||||
- `DATA_PLANE_DIRECT_WORKER_WSS_URL_TEMPLATE`, optional; supports `{worker_id}` replacement
|
||||
- `DATA_PLANE_DIRECT_WORKER_JSON_RUNTIME`, default `false`; advertises
|
||||
`runtime_transport=json_v1` only after the worker direct JSON bridge is
|
||||
deployed and verified
|
||||
- `DATA_PLANE_DIRECT_WORKER_BINARY_RENDER`, default `false`; when the direct
|
||||
JSON runtime is enabled, advertises `render_transport=binary_v1` so DP-2
|
||||
clients can request binary render frames over direct worker WSS. Binary
|
||||
render candidates also advertise `supported_color_modes=["full_color","grayscale"]`
|
||||
and `default_color_mode="full_color"` for the DP-3A grayscale foundation.
|
||||
- `DATA_PLANE_DIRECT_WORKER_TLS_TRUST_MODE`, default `smoke_insecure`; allowed
|
||||
values are `smoke_insecure`, `public_ca`, and `platform_ca`.
|
||||
- `DATA_PLANE_DIRECT_WORKER_TLS_CA_REF`, optional label for the platform CA or
|
||||
trust bundle version advertised to clients.
|
||||
|
||||
Data-plane tokens are RS256-signed. The backend must hold only the private key;
|
||||
workers receive only the matching public key for validation. If no private key
|
||||
is configured, the backend omits the optional `data_plane` offer and the
|
||||
backend gateway fallback remains unchanged.
|
||||
|
||||
If no direct worker WSS URL template is configured, session responses still include the backend gateway fallback candidate only.
|
||||
If the URL template is configured but `DATA_PLANE_DIRECT_WORKER_JSON_RUNTIME`
|
||||
is `false`, the direct candidate is still present for contract visibility but is
|
||||
not marked data-capable; DP-1D Windows clients will skip it and use the backend
|
||||
gateway fallback.
|
||||
If `DATA_PLANE_DIRECT_WORKER_BINARY_RENDER` is `false`, direct worker WSS
|
||||
remains JSON/base64 for render. If it is `true`, only direct worker WSS render
|
||||
is binary; backend gateway fallback remains JSON/base64.
|
||||
In production, the backend does not advertise direct worker WSS when
|
||||
`DATA_PLANE_DIRECT_WORKER_TLS_TRUST_MODE=smoke_insecure`; it keeps the backend
|
||||
gateway fallback instead. Trusted direct candidates include `tls_trust_mode`,
|
||||
`production_trusted`, `smoke_only`, and optional `tls_ca_ref` metadata. See
|
||||
`docs/architecture/DIRECT_WORKER_TLS_PKI.md`.
|
||||
|
||||
## Module layout
|
||||
|
||||
- `internal/platform` shared runtime, config, infra, and bootstrap concerns
|
||||
- `internal/modules/auth` auth and trusted-device boundary
|
||||
- `internal/modules/organization` organization model, org roles, and memberships
|
||||
- `internal/modules/identitysource` local/LDAP/OIDC identity source model and future mapping foundations
|
||||
- `internal/modules/resource` remote resource inventory boundary
|
||||
- `internal/modules/sessionbroker` persistent session lifecycle, orchestration, audit, and Redis live-state boundary
|
||||
- `internal/modules/sessiongateway` websocket attach/reconnect/takeover transport boundary
|
||||
- `internal/modules/worker` worker registration, lease coordination, and control-plane routing boundary for future C++ RDP workers
|
||||
- `internal/modules/node` node inventory, capabilities, enabled services, update policy, and partition state
|
||||
- `internal/modules/nodeagent` node-agent registration, health, service status, and update/rollback control interface
|
||||
- `pkg/contracts` cross-module contracts for sessions and worker control
|
||||
|
||||
## Backend responsibilities
|
||||
|
||||
- PostgreSQL remains the source of truth for auth sessions, devices, remote sessions, attachments, resource policies, and audit events
|
||||
- Redis is used only for live routing and coordination: attach tokens, controller bindings, live session cache, worker registration, worker leases, heartbeats, and routing queues
|
||||
- `worker:control:<worker_id>` carries worker assignments, `worker:queue:<session_id>` carries live control/input envelopes, and `worker:events` carries worker-reported lifecycle events back into broker processing
|
||||
- Session broker owns state transitions and orchestration rules; websocket handlers call broker services instead of talking to postgres repositories directly
|
||||
- Worker runtime stays behind interfaces and Redis coordination so the backend remains isolated from FreeRDP implementation details while the minimal real RDP worker plugs into the control plane
|
||||
- RDP certificate verification is configured per resource through `certificate_verification_mode`
|
||||
- resources are now org-scoped in PostgreSQL and remote sessions persist their owning organization without changing the proven worker/session runtime contracts
|
||||
- session start/attach/takeover responses may include optional `data_plane` candidates and a short-lived signed data-plane token for DP-1 direct worker WSS migration; existing clients continue to use the current gateway path, and direct realtime use remains gated by explicit candidate metadata
|
||||
|
||||
## Authorization model
|
||||
|
||||
- `platform_admin` and `platform_recovery_admin` have global access across organizations, resources, and sessions
|
||||
- in `INSTALLATION_AUTHORITY_MODE=strict`, platform-admin power is effective only
|
||||
when the user also has a valid signed row in `platform_role_grants`; changing
|
||||
`users.platform_role` in PostgreSQL alone no longer grants owner access
|
||||
- first-owner bootstrap is available at
|
||||
`POST /api/v1/installation/bootstrap-owner` and requires a Product Root
|
||||
Ed25519 signature over an activation manifest in strict mode
|
||||
- production (`APP_ENV=production` or `prod`) requires strict installation
|
||||
authority plus `INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_B64` or
|
||||
`INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_FILE`
|
||||
- legacy/dev installs can keep database-role behavior, and insecure first-owner
|
||||
bootstrap is available only when
|
||||
`INSTALLATION_INSECURE_BOOTSTRAP_ENABLED=true`
|
||||
- `org_owner` and `org_admin` can create and update resources inside their organization and can manage any remote session inside that organization
|
||||
- active non-admin memberships such as `org_operator`, `org_member`, and `org_viewer` are deny-by-default for admin actions; they can only access org-scoped reads and operate on their own session flows where the session broker explicitly allows it
|
||||
- session start always authorizes the actor against the resource organization before worker reservation
|
||||
- attach, detach, takeover, and terminate authorize against the owning remote session organization before any state transition is written
|
||||
- worker-facing events do not bypass this model for user-originated commands; internal worker failure and heartbeat paths remain broker-internal control-plane operations
|
||||
|
||||
## Migration safety
|
||||
|
||||
- `000005_platform_core_v2` bootstraps a single `default` organization and backfills existing `resources.organization_id` and `remote_sessions.organization_id` into that organization before setting `NOT NULL`
|
||||
- `000006_default_org_memberships_backfill` safely restores access continuity by inserting missing active memberships for existing users into the `default` organization
|
||||
- the backfill is idempotent because it only inserts rows missing under the `(organization_id, user_id)` uniqueness constraint
|
||||
- platform administrators are backfilled as `org_owner` in the default organization, while other existing users are backfilled as `org_member`
|
||||
- if `000005` fails before the `NOT NULL` step, PostgreSQL rolls back the transaction and leaves pre-v2 rows untouched; if `000006` is rerun, it skips already-created memberships rather than duplicating them
|
||||
|
||||
## Platform-Core V2 Notes
|
||||
|
||||
- `organizations`, `organization_memberships`, and `organization_roles` establish multi-tenant ownership and basic org-scoped authorization boundaries
|
||||
- `identity_sources` and `identity_mappings` are foundation-only in this phase; full LDAP/OIDC sync and claim/group ingestion are intentionally deferred
|
||||
- `nodes`, `node_capabilities`, `node_services`, `node_update_policies`, `node_partition_states`, and `node_agent_update_runs` provide the first control-plane model for node and node-agent lifecycle
|
||||
- current proven RDP session lifecycle remains preserved: the session broker still orchestrates the same worker/session behavior, but it now records organization ownership via org-scoped resources
|
||||
- PostgreSQL remains the source of truth for organizations, memberships, org-scoped resources, identity sources, nodes, node-agent state, and session lifecycle state
|
||||
|
||||
## Resource Certificate Verification
|
||||
|
||||
- `strict` is the default and keeps normal certificate validation enabled in the worker runtime
|
||||
- `ignore` must be explicitly stored on the resource and allows that one RDP connection to skip certificate validation
|
||||
- the backend passes this policy through session assignment data; it is not a global backend toggle
|
||||
|
||||
## Messaging Model
|
||||
|
||||
- HTTP errors now use a structured envelope:
|
||||
- `error.code`
|
||||
- `error.message_key`
|
||||
- `error.fallback_message`
|
||||
- `error.details`
|
||||
- `error.trace_id`
|
||||
- `internal/platform/httpx` owns error normalization and trace-id generation so handlers can keep calling `WriteError(...)` without changing business logic.
|
||||
- For `5xx` responses, user-facing payloads are normalized to an English generic fallback message while logs and diagnostics can still keep raw internal details elsewhere.
|
||||
- For `4xx` responses, stable `code` and `message_key` are derived from the current fallback message, so clients can localize without depending on raw English text as the primary contract.
|
||||
|
||||
## WebSocket Messaging
|
||||
|
||||
- Session gateway envelopes keep the existing `type` and `payload` contract.
|
||||
- User-facing websocket events now also include `event` with:
|
||||
- `code`
|
||||
- `message_key`
|
||||
- `fallback_message`
|
||||
- `details`
|
||||
- `trace_id`
|
||||
- `session.taken_over`, terminal `session.state`, `transport.closed`, and protocol-level errors now carry this structured event object.
|
||||
- Existing payload semantics remain intact for compatibility with the already proven session lifecycle.
|
||||
|
||||
## Message Rules
|
||||
|
||||
- Keep English as the only development language for `fallback_message`, logs, and diagnostics.
|
||||
- New HTTP handlers should prefer `httpx.WriteError(...)` for user-facing failures instead of hand-building `"error": "..."` JSON.
|
||||
- New websocket user-facing notifications should populate `TransportEnvelope.Event` with a stable `code` and `message_key`.
|
||||
- Do not use raw human-readable English text as the primary client contract; it should only remain as fallback text.
|
||||
- This messaging layer is now runtime-proven against the live Windows smoke flow for invalid-login errors, websocket takeover delivery, websocket state fallback rendering, and worker-death failure handling.
|
||||
## Clipboard Policy
|
||||
|
||||
RDP text clipboard is controlled per resource through `resource_policies.clipboard_mode`.
|
||||
Allowed values are `disabled`, `client_to_server`, `server_to_client`, and
|
||||
`bidirectional`; the default is `disabled`. The legacy `clipboard_enabled`
|
||||
column is retained only for compatibility and migration/backfill, while new
|
||||
runtime decisions use `clipboard_mode`.
|
||||
|
||||
Clipboard enforcement happens in the real data path:
|
||||
|
||||
- `sessionbroker.ResourcePolicy.ClipboardMode` is loaded from PostgreSQL and
|
||||
embedded into the session assignment metadata sent to the worker.
|
||||
- `sessiongateway.Module.handleEnvelope` blocks client-to-server clipboard
|
||||
envelopes unless the session is `active` and the policy allows that direction.
|
||||
- `worker.EventProcessor` sends worker-originated clipboard text through
|
||||
`sessionbroker.Service.UpdateWorkerClipboardText`, which applies the same
|
||||
active-state and server-to-client policy checks before updating live state.
|
||||
- Clipboard messages carry `sequence_id`, `origin`, and `content_hash` so
|
||||
clients and workers can avoid feedback loops across reattach/takeover paths.
|
||||
- Redis stores clipboard text only as transient live state for routing to the
|
||||
active controller; PostgreSQL remains authoritative for policy/session state.
|
||||
|
||||
## File Upload Policy
|
||||
|
||||
Stage 5.1 introduces client-to-server file upload as a policy-gated RDP
|
||||
feature. The authoritative policy field is
|
||||
`resource_policies.file_transfer_mode`; allowed values are `disabled`,
|
||||
`client_to_server`, `server_to_client`, and `bidirectional`, but only
|
||||
`client_to_server` behavior is implemented in this stage. The default is
|
||||
`disabled`. The legacy `file_transfer_enabled` column is retained only as a
|
||||
derived compatibility flag and must not be treated as the primary policy.
|
||||
|
||||
Enforcement is deliberately duplicated in the real data path:
|
||||
|
||||
- `resource.Module` exposes `file_transfer_mode` in resource create, update,
|
||||
list, and read payloads.
|
||||
- `sessionbroker.Service.StartRemoteSession` embeds `file_transfer_mode` into
|
||||
assignment metadata and requests the worker `file-transfer` capability only
|
||||
when client-to-server upload is allowed.
|
||||
- `sessiongateway.Module.handleFileUploadStart` and
|
||||
`handleFileUploadChunk` require an active session, current controller,
|
||||
allowed policy mode, valid UUID `transfer_id`, safe file name, 25 MiB max
|
||||
file size, and 256 KiB max chunk size before routing chunks to the worker.
|
||||
- Redis is used only to route bounded upload envelopes to the worker. The file
|
||||
itself is written by the worker to controlled worker storage; PostgreSQL
|
||||
remains authoritative for policy and session state.
|
||||
|
||||
## File Download Policy
|
||||
|
||||
Stage 5.2 adds a runtime-proven server-to-client download path for RDP. The
|
||||
policy field remains `resource_policies.file_transfer_mode`; `server_to_client`
|
||||
and `bidirectional` allow download, while `disabled` and `client_to_server`
|
||||
block it. The default remains `disabled`.
|
||||
|
||||
The v1 download model uses only the restricted `RAP_Transfers\ToClient`
|
||||
drop-zone inside the existing per-session visible transfer directory. Backend
|
||||
gateway accepts only `file_download.start`, `file_download.ack`, and
|
||||
`file_download.cancel` from the current controller of an active session and
|
||||
routes them to the worker after policy validation. Worker-origin
|
||||
`file_download.*` events are stored only as transient live state for
|
||||
backend-gateway fallback delivery; PostgreSQL remains authoritative for
|
||||
session/resource/policy state and must not store file contents.
|
||||
|
||||
The direct worker WSS path is also lifecycle-gated: detach returns
|
||||
`file_download.blocked`, old-controller takeover returns `session.taken_over`,
|
||||
and worker failure closes the direct transport after PostgreSQL transitions the
|
||||
session to `failed`.
|
||||
@@ -0,0 +1,24 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/runtime"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
app, err := runtime.NewApp(ctx)
|
||||
if err != nil {
|
||||
log.Fatalf("bootstrap app: %v", err)
|
||||
}
|
||||
|
||||
if err := app.Run(ctx); err != nil {
|
||||
log.Fatalf("run app: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var (
|
||||
gatewayURL = flag.String("gateway-url", "ws://127.0.0.1:8080/api/v1/gateway/ws", "websocket gateway url")
|
||||
attachToken = flag.String("attach-token", "", "short-lived attach token")
|
||||
duration = flag.Duration("duration", 90*time.Second, "maximum time to keep the client attached")
|
||||
heartbeatInterval = flag.Duration("heartbeat-interval", 10*time.Second, "client heartbeat interval")
|
||||
exitOnTakenOver = flag.Bool("exit-on-taken-over", true, "exit successfully after receiving session.taken_over")
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
if *attachToken == "" {
|
||||
log.Fatal("attach-token is required")
|
||||
}
|
||||
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
||||
defer cancel()
|
||||
if *duration > 0 {
|
||||
var timeoutCancel context.CancelFunc
|
||||
ctx, timeoutCancel = context.WithTimeout(ctx, *duration)
|
||||
defer timeoutCancel()
|
||||
}
|
||||
|
||||
targetURL, err := url.Parse(*gatewayURL)
|
||||
if err != nil {
|
||||
log.Fatalf("parse gateway url: %v", err)
|
||||
}
|
||||
query := targetURL.Query()
|
||||
query.Set("attach_token", *attachToken)
|
||||
targetURL.RawQuery = query.Encode()
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.DialContext(ctx, targetURL.String(), nil)
|
||||
if err != nil {
|
||||
log.Fatalf("dial websocket: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
for {
|
||||
_, payload, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
fmt.Println(string(payload))
|
||||
if *exitOnTakenOver && containsTakenOver(payload) {
|
||||
done <- nil
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
heartbeatTicker := time.NewTicker(*heartbeatInterval)
|
||||
defer heartbeatTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
log.Fatalf("websocket read: %v", err)
|
||||
}
|
||||
return
|
||||
case <-heartbeatTicker.C:
|
||||
if err := conn.WriteJSON(map[string]any{"type": "heartbeat"}); err != nil {
|
||||
log.Fatalf("heartbeat write: %v", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
if err := ctx.Err(); err != nil && err != context.Canceled && err != context.DeadlineExceeded {
|
||||
log.Fatalf("context ended: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func containsTakenOver(payload []byte) bool {
|
||||
return strings.Contains(string(payload), `"type":"session.taken_over"`)
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
APP_NAME=rap-api
|
||||
APP_ENV=development
|
||||
HTTP_HOST=0.0.0.0
|
||||
HTTP_PORT=8080
|
||||
HTTP_READ_TIMEOUT=15s
|
||||
HTTP_WRITE_TIMEOUT=15s
|
||||
HTTP_IDLE_TIMEOUT=60s
|
||||
HTTP_SHUTDOWN_TIMEOUT=10s
|
||||
POSTGRES_DSN=postgres://rap_user:rap_password@localhost:5432/remote_access_platform?sslmode=disable
|
||||
POSTGRES_MAX_CONNS=20
|
||||
POSTGRES_MIN_CONNS=2
|
||||
POSTGRES_CONNECT_TIMEOUT=5s
|
||||
REDIS_ADDR=localhost:6379
|
||||
REDIS_PASSWORD=
|
||||
REDIS_DB=0
|
||||
REDIS_DIAL_TIMEOUT=5s
|
||||
AUTH_ACCESS_TOKEN_TTL=15m
|
||||
AUTH_REFRESH_TOKEN_TTL=720h
|
||||
AUTH_ISSUER=rap-api
|
||||
AUTH_ACCESS_TOKEN_SECRET=change-me-access-secret
|
||||
AUTH_REFRESH_HASH_SECRET=change-me-refresh-hash-secret
|
||||
DATA_PLANE_TOKEN_TTL=1m
|
||||
DATA_PLANE_TOKEN_PRIVATE_KEY_FILE=
|
||||
DATA_PLANE_TOKEN_PRIVATE_KEY_PEM=
|
||||
DATA_PLANE_BACKEND_GATEWAY_URL=/api/v1/gateway/ws
|
||||
DATA_PLANE_DIRECT_WORKER_WSS_URL_TEMPLATE=
|
||||
DATA_PLANE_DIRECT_WORKER_JSON_RUNTIME=false
|
||||
DATA_PLANE_DIRECT_WORKER_BINARY_RENDER=false
|
||||
DATA_PLANE_DIRECT_WORKER_TLS_TRUST_MODE=smoke_insecure
|
||||
DATA_PLANE_DIRECT_WORKER_TLS_CA_REF=
|
||||
SECRET_ENCRYPTION_KEY_B64=
|
||||
SECRET_ENCRYPTION_KEY_FILE=
|
||||
SECRET_ENCRYPTION_KEY_ID=local-v1
|
||||
SESSION_HEARTBEAT_TTL=90s
|
||||
SESSION_DETACH_GRACE_PERIOD=30m
|
||||
SESSION_ATTACH_TOKEN_TTL=2m
|
||||
SESSION_LIVE_STATE_TTL=2m
|
||||
SESSION_RECOVERY_BATCH_SIZE=100
|
||||
WORKER_LEASE_TTL=45s
|
||||
WORKER_HEARTBEAT_TTL=15s
|
||||
WORKER_STALE_LEASE_GRACE_PERIOD=30s
|
||||
WEBSOCKET_WRITE_TIMEOUT=10s
|
||||
WEBSOCKET_PING_INTERVAL=20s
|
||||
WEBSOCKET_PONG_WAIT=40s
|
||||
@@ -0,0 +1,23 @@
|
||||
module github.com/example/remote-access-platform/backend
|
||||
|
||||
go 1.23.2
|
||||
|
||||
require (
|
||||
github.com/go-chi/chi/v5 v5.2.1
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/jackc/pgx/v5 v5.7.4
|
||||
github.com/redis/go-redis/v9 v9.8.0
|
||||
golang.org/x/crypto v0.37.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
golang.org/x/sync v0.13.0 // indirect
|
||||
golang.org/x/text v0.24.0 // indirect
|
||||
)
|
||||
@@ -0,0 +1,46 @@
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8=
|
||||
github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg=
|
||||
github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhicI=
|
||||
github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
|
||||
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
||||
golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
|
||||
golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
|
||||
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
@@ -0,0 +1,19 @@
|
||||
package auth
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrInvalidCredentials = errors.New("invalid credentials")
|
||||
ErrInvalidRefreshToken = errors.New("invalid refresh token")
|
||||
ErrAuthSessionRevoked = errors.New("auth session revoked")
|
||||
ErrDeviceRevoked = errors.New("device revoked")
|
||||
ErrDeviceNotTrusted = errors.New("device not trusted")
|
||||
ErrAuthSessionNotFound = errors.New("auth session not found")
|
||||
ErrTrustedDeviceMissing = errors.New("trusted device not found")
|
||||
|
||||
ErrInstallationAlreadyBootstrapped = errors.New("installation is already bootstrapped")
|
||||
ErrInstallationActivationRequired = errors.New("signed installation activation is required")
|
||||
ErrInvalidInstallationActivation = errors.New("invalid installation activation")
|
||||
ErrInsecureBootstrapDisabled = errors.New("insecure installation bootstrap is disabled")
|
||||
ErrInvalidBootstrapOwner = errors.New("invalid bootstrap owner")
|
||||
)
|
||||
@@ -0,0 +1,114 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DeviceTrustStatus string
|
||||
|
||||
const (
|
||||
DeviceTrustStatusPending DeviceTrustStatus = "pending"
|
||||
DeviceTrustStatusTrusted DeviceTrustStatus = "trusted"
|
||||
DeviceTrustStatusRevoked DeviceTrustStatus = "revoked"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID string
|
||||
Email string
|
||||
PasswordHash string
|
||||
MFAEnabled bool
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type Device struct {
|
||||
ID string
|
||||
UserID string
|
||||
Fingerprint string
|
||||
Label string
|
||||
TrustStatus DeviceTrustStatus
|
||||
TrustedAt *time.Time
|
||||
LastSeenAt *time.Time
|
||||
RevokedAt *time.Time
|
||||
RevokedReason *string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type AuthSession struct {
|
||||
ID string
|
||||
UserID string
|
||||
DeviceID string
|
||||
RefreshTokenHash string
|
||||
RefreshExpiresAt time.Time
|
||||
LastSeenAt *time.Time
|
||||
LastRotatedAt *time.Time
|
||||
RevokedAt *time.Time
|
||||
RevokedReason *string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type LoginCommand struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
DeviceFingerprint string `json:"device_fingerprint"`
|
||||
DeviceLabel string `json:"device_label"`
|
||||
TrustDevice bool `json:"trust_device"`
|
||||
}
|
||||
|
||||
type RefreshCommand struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
type BootstrapOwnerCommand struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
ActivationPayload json.RawMessage `json:"activation_payload"`
|
||||
ActivationSignature string `json:"activation_signature"`
|
||||
}
|
||||
|
||||
type RevokeAuthSessionCommand struct {
|
||||
UserID string `json:"user_id"`
|
||||
AuthSessionID string `json:"auth_session_id"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type RevokeDeviceCommand struct {
|
||||
UserID string `json:"user_id"`
|
||||
DeviceID string `json:"device_id"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type TokenPair struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
AccessTokenExpiresAt time.Time `json:"access_token_expires_at"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
RefreshTokenExpiresAt time.Time `json:"refresh_token_expires_at"`
|
||||
}
|
||||
|
||||
type AuthResult struct {
|
||||
User User `json:"user"`
|
||||
Device Device `json:"device"`
|
||||
AuthSession AuthSession `json:"auth_session"`
|
||||
Tokens TokenPair `json:"tokens"`
|
||||
}
|
||||
|
||||
type InstallationStatus struct {
|
||||
Bootstrapped bool `json:"bootstrapped"`
|
||||
AuthorityState string `json:"authority_state"`
|
||||
InstallID string `json:"install_id,omitempty"`
|
||||
BootstrappedOwnerEmail string `json:"bootstrapped_owner_email,omitempty"`
|
||||
BootstrappedAt *time.Time `json:"bootstrapped_at,omitempty"`
|
||||
AuthorityMode string `json:"authority_mode"`
|
||||
StrictAuthority bool `json:"strict_authority"`
|
||||
RootFingerprint string `json:"root_fingerprint,omitempty"`
|
||||
InsecureBootstrapAllowed bool `json:"insecure_bootstrap_allowed"`
|
||||
}
|
||||
|
||||
type BootstrapOwnerResult struct {
|
||||
Installation InstallationStatus `json:"installation"`
|
||||
User User `json:"user"`
|
||||
PlatformRole string `json:"platform_role"`
|
||||
}
|
||||
@@ -0,0 +1,173 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/module"
|
||||
)
|
||||
|
||||
type Module struct {
|
||||
service *Service
|
||||
}
|
||||
|
||||
func NewModule(deps module.Dependencies, service *Service) *Module {
|
||||
return &Module{service: service}
|
||||
}
|
||||
|
||||
func (m *Module) Name() string {
|
||||
return "auth"
|
||||
}
|
||||
|
||||
func (m *Module) RegisterRoutes(router chi.Router) {
|
||||
router.Route("/installation", func(r chi.Router) {
|
||||
r.Get("/status", m.handleInstallationStatus)
|
||||
r.Post("/bootstrap-owner", m.handleBootstrapOwner)
|
||||
})
|
||||
router.Route("/auth", func(r chi.Router) {
|
||||
r.Post("/login", m.handleLogin)
|
||||
r.Post("/refresh", m.handleRefresh)
|
||||
r.Post("/sessions/revoke", m.handleRevokeAuthSession)
|
||||
r.Get("/devices", m.handleTrustedDevices)
|
||||
r.Post("/devices/{deviceID}/revoke", m.handleRevokeTrustedDevice)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Module) handleInstallationStatus(w http.ResponseWriter, r *http.Request) {
|
||||
status, err := m.service.InstallationStatus(r.Context())
|
||||
if err != nil {
|
||||
statusCode, message := m.service.MapError(err)
|
||||
httpx.WriteError(w, statusCode, message)
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusOK, map[string]any{"installation": status})
|
||||
}
|
||||
|
||||
func (m *Module) handleBootstrapOwner(w http.ResponseWriter, r *http.Request) {
|
||||
var cmd BootstrapOwnerCommand
|
||||
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid installation bootstrap payload")
|
||||
return
|
||||
}
|
||||
result, err := m.service.BootstrapOwner(r.Context(), cmd)
|
||||
if err != nil {
|
||||
status, message := m.service.MapError(err)
|
||||
httpx.WriteError(w, status, message)
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusCreated, result)
|
||||
}
|
||||
|
||||
func (m *Module) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
var cmd LoginCommand
|
||||
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid login payload")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := m.service.Login(r.Context(), cmd)
|
||||
if err != nil {
|
||||
status, message := m.service.MapError(err)
|
||||
httpx.WriteError(w, status, message)
|
||||
return
|
||||
}
|
||||
|
||||
httpx.WriteJSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (m *Module) handleRefresh(w http.ResponseWriter, r *http.Request) {
|
||||
var cmd RefreshCommand
|
||||
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid refresh payload")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := m.service.Refresh(r.Context(), cmd)
|
||||
if err != nil {
|
||||
status, message := m.service.MapError(err)
|
||||
httpx.WriteError(w, status, message)
|
||||
return
|
||||
}
|
||||
|
||||
httpx.WriteJSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (m *Module) handleRevokeAuthSession(w http.ResponseWriter, r *http.Request) {
|
||||
var cmd RevokeAuthSessionCommand
|
||||
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid auth session revoke payload")
|
||||
return
|
||||
}
|
||||
|
||||
if err := m.service.RevokeAuthSession(r.Context(), cmd); err != nil {
|
||||
status, message := m.service.MapError(err)
|
||||
httpx.WriteError(w, status, message)
|
||||
return
|
||||
}
|
||||
|
||||
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{
|
||||
"status": "revoked",
|
||||
"message": httpx.NewMessage(
|
||||
"auth.session.revoked",
|
||||
"status.auth.session.revoked",
|
||||
"Auth session revoked.",
|
||||
nil,
|
||||
"",
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Module) handleTrustedDevices(w http.ResponseWriter, r *http.Request) {
|
||||
userID := r.URL.Query().Get("user_id")
|
||||
if userID == "" {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
devices, err := m.service.ListTrustedDevices(r.Context(), userID)
|
||||
if err != nil {
|
||||
status, message := m.service.MapError(err)
|
||||
httpx.WriteError(w, status, message)
|
||||
return
|
||||
}
|
||||
|
||||
httpx.WriteJSON(w, http.StatusOK, map[string]any{
|
||||
"devices": devices,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Module) handleRevokeTrustedDevice(w http.ResponseWriter, r *http.Request) {
|
||||
var payload struct {
|
||||
UserID string `json:"user_id"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid device revoke payload")
|
||||
return
|
||||
}
|
||||
|
||||
err := m.service.RevokeTrustedDevice(r.Context(), RevokeDeviceCommand{
|
||||
UserID: payload.UserID,
|
||||
DeviceID: chi.URLParam(r, "deviceID"),
|
||||
Reason: payload.Reason,
|
||||
})
|
||||
if err != nil {
|
||||
status, message := m.service.MapError(err)
|
||||
httpx.WriteError(w, status, message)
|
||||
return
|
||||
}
|
||||
|
||||
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{
|
||||
"status": "revoked",
|
||||
"message": httpx.NewMessage(
|
||||
"auth.device.revoked",
|
||||
"status.auth.device.revoked",
|
||||
"Trusted device revoked.",
|
||||
nil,
|
||||
"",
|
||||
),
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,525 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
|
||||
)
|
||||
|
||||
type postgresStore struct {
|
||||
db postgresplatform.DBTX
|
||||
}
|
||||
|
||||
type PostgresTransactor struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewPostgresStore(pool *pgxpool.Pool) Store {
|
||||
return &postgresStore{db: pool}
|
||||
}
|
||||
|
||||
func NewPostgresTransactor(pool *pgxpool.Pool) *PostgresTransactor {
|
||||
return &PostgresTransactor{pool: pool}
|
||||
}
|
||||
|
||||
func (t *PostgresTransactor) WithinTransaction(ctx context.Context, fn func(store Store) error) error {
|
||||
return postgresplatform.WithTransaction(ctx, t.pool, func(tx pgx.Tx) error {
|
||||
return fn(&postgresStore{db: tx})
|
||||
})
|
||||
}
|
||||
|
||||
func (s *postgresStore) Users() UserRepository {
|
||||
return &postgresUserRepository{db: s.db}
|
||||
}
|
||||
|
||||
func (s *postgresStore) Devices() DeviceRepository {
|
||||
return &postgresDeviceRepository{db: s.db}
|
||||
}
|
||||
|
||||
func (s *postgresStore) AuthSessions() AuthSessionRepository {
|
||||
return &postgresAuthSessionRepository{db: s.db}
|
||||
}
|
||||
|
||||
func (s *postgresStore) Installation() InstallationRepository {
|
||||
return &postgresInstallationRepository{db: s.db}
|
||||
}
|
||||
|
||||
type postgresUserRepository struct {
|
||||
db postgresplatform.DBTX
|
||||
}
|
||||
|
||||
type postgresDeviceRepository struct {
|
||||
db postgresplatform.DBTX
|
||||
}
|
||||
|
||||
type postgresAuthSessionRepository struct {
|
||||
db postgresplatform.DBTX
|
||||
}
|
||||
|
||||
type postgresInstallationRepository struct {
|
||||
db postgresplatform.DBTX
|
||||
}
|
||||
|
||||
func (r *postgresUserRepository) GetByEmail(ctx context.Context, email string) (*User, error) {
|
||||
const query = `
|
||||
SELECT id::text, email, password_hash, mfa_enabled, created_at, updated_at
|
||||
FROM users
|
||||
WHERE email = $1
|
||||
`
|
||||
return scanOptionalUser(r.db.QueryRow(ctx, query, email))
|
||||
}
|
||||
|
||||
func (r *postgresUserRepository) GetByID(ctx context.Context, userID string) (*User, error) {
|
||||
const query = `
|
||||
SELECT id::text, email, password_hash, mfa_enabled, created_at, updated_at
|
||||
FROM users
|
||||
WHERE id = $1::uuid
|
||||
`
|
||||
return scanOptionalUser(r.db.QueryRow(ctx, query, userID))
|
||||
}
|
||||
|
||||
func (r *postgresDeviceRepository) Upsert(ctx context.Context, params UpsertDeviceParams) (*Device, error) {
|
||||
const query = `
|
||||
INSERT INTO devices (
|
||||
user_id,
|
||||
device_fingerprint,
|
||||
device_label,
|
||||
trust_status,
|
||||
trusted_at,
|
||||
last_seen_at,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (
|
||||
$1::uuid,
|
||||
$2,
|
||||
$3,
|
||||
CASE WHEN $4 THEN 'trusted' ELSE 'pending' END,
|
||||
CASE WHEN $4 THEN $5::timestamptz ELSE NULL::timestamptz END,
|
||||
$5::timestamptz,
|
||||
$5::timestamptz,
|
||||
$5::timestamptz
|
||||
)
|
||||
ON CONFLICT (user_id, device_fingerprint) DO UPDATE SET
|
||||
device_label = EXCLUDED.device_label,
|
||||
last_seen_at = EXCLUDED.last_seen_at,
|
||||
updated_at = EXCLUDED.updated_at,
|
||||
trust_status = CASE
|
||||
WHEN devices.trust_status = 'revoked' THEN devices.trust_status
|
||||
WHEN devices.trust_status = 'trusted' THEN devices.trust_status
|
||||
WHEN EXCLUDED.trust_status = 'trusted' THEN 'trusted'
|
||||
ELSE devices.trust_status
|
||||
END,
|
||||
trusted_at = CASE
|
||||
WHEN devices.trust_status = 'trusted' THEN devices.trusted_at
|
||||
WHEN EXCLUDED.trust_status = 'trusted' THEN EXCLUDED.trusted_at
|
||||
ELSE devices.trusted_at
|
||||
END
|
||||
RETURNING
|
||||
id::text, user_id::text, device_fingerprint, COALESCE(device_label, ''),
|
||||
trust_status, trusted_at, last_seen_at, revoked_at, revoked_reason, created_at, updated_at
|
||||
`
|
||||
return scanDevice(r.db.QueryRow(ctx, query,
|
||||
params.UserID,
|
||||
params.Fingerprint,
|
||||
params.Label,
|
||||
params.TrustRequested,
|
||||
params.SeenAt,
|
||||
))
|
||||
}
|
||||
|
||||
func (r *postgresDeviceRepository) GetByIDForUser(ctx context.Context, userID, deviceID string) (*Device, error) {
|
||||
const query = `
|
||||
SELECT id::text, user_id::text, device_fingerprint, COALESCE(device_label, ''),
|
||||
trust_status, trusted_at, last_seen_at, revoked_at, revoked_reason, created_at, updated_at
|
||||
FROM devices
|
||||
WHERE id = $1::uuid AND user_id = $2::uuid
|
||||
`
|
||||
device, err := scanDevice(r.db.QueryRow(ctx, query, deviceID, userID))
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return device, err
|
||||
}
|
||||
|
||||
func (r *postgresDeviceRepository) ListTrustedByUser(ctx context.Context, userID string) ([]Device, error) {
|
||||
const query = `
|
||||
SELECT id::text, user_id::text, device_fingerprint, COALESCE(device_label, ''),
|
||||
trust_status, trusted_at, last_seen_at, revoked_at, revoked_reason, created_at, updated_at
|
||||
FROM devices
|
||||
WHERE user_id = $1::uuid AND trust_status = 'trusted' AND revoked_at IS NULL
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
rows, err := r.db.Query(ctx, query, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query trusted devices: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var devices []Device
|
||||
for rows.Next() {
|
||||
device, err := scanDevice(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
devices = append(devices, *device)
|
||||
}
|
||||
return devices, rows.Err()
|
||||
}
|
||||
|
||||
func (r *postgresDeviceRepository) Revoke(ctx context.Context, params RevokeDeviceParams) error {
|
||||
const query = `
|
||||
UPDATE devices
|
||||
SET trust_status = 'revoked',
|
||||
revoked_at = $3,
|
||||
revoked_reason = $4,
|
||||
updated_at = $3
|
||||
WHERE id = $1::uuid AND user_id = $2::uuid
|
||||
`
|
||||
if _, err := r.db.Exec(ctx, query, params.DeviceID, params.UserID, params.RevokedAt, params.Reason); err != nil {
|
||||
return fmt.Errorf("revoke device: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *postgresAuthSessionRepository) Create(ctx context.Context, session AuthSession) error {
|
||||
const query = `
|
||||
INSERT INTO auth_sessions (
|
||||
id,
|
||||
user_id,
|
||||
device_id,
|
||||
refresh_token_hash,
|
||||
refresh_expires_at,
|
||||
last_seen_at,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES ($1::uuid, $2::uuid, $3::uuid, $4, $5, $6, $7, $8)
|
||||
`
|
||||
if _, err := r.db.Exec(ctx, query,
|
||||
session.ID,
|
||||
session.UserID,
|
||||
session.DeviceID,
|
||||
session.RefreshTokenHash,
|
||||
session.RefreshExpiresAt,
|
||||
session.LastSeenAt,
|
||||
session.CreatedAt,
|
||||
session.UpdatedAt,
|
||||
); err != nil {
|
||||
return fmt.Errorf("create auth session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *postgresAuthSessionRepository) GetByID(ctx context.Context, authSessionID string) (*AuthSession, error) {
|
||||
return r.getByID(ctx, authSessionID, "")
|
||||
}
|
||||
|
||||
func (r *postgresAuthSessionRepository) GetByIDForUpdate(ctx context.Context, authSessionID string) (*AuthSession, error) {
|
||||
return r.getByID(ctx, authSessionID, " FOR UPDATE")
|
||||
}
|
||||
|
||||
func (r *postgresAuthSessionRepository) getByID(ctx context.Context, authSessionID string, suffix string) (*AuthSession, error) {
|
||||
query := `
|
||||
SELECT id::text, user_id::text, device_id::text, refresh_token_hash, refresh_expires_at,
|
||||
last_seen_at, last_rotated_at, revoked_at, revoked_reason, created_at, updated_at
|
||||
FROM auth_sessions
|
||||
WHERE id = $1::uuid` + suffix
|
||||
session, err := scanAuthSession(r.db.QueryRow(ctx, query, authSessionID))
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return session, err
|
||||
}
|
||||
|
||||
func (r *postgresAuthSessionRepository) Rotate(ctx context.Context, params RotateAuthSessionParams) error {
|
||||
const query = `
|
||||
UPDATE auth_sessions
|
||||
SET refresh_token_hash = $2,
|
||||
refresh_expires_at = $3,
|
||||
last_seen_at = $4,
|
||||
last_rotated_at = $5,
|
||||
updated_at = $5
|
||||
WHERE id = $1::uuid AND revoked_at IS NULL
|
||||
`
|
||||
if _, err := r.db.Exec(ctx, query,
|
||||
params.AuthSessionID,
|
||||
params.RefreshTokenHash,
|
||||
params.RefreshExpiresAt,
|
||||
params.LastSeenAt,
|
||||
params.LastRotatedAt,
|
||||
); err != nil {
|
||||
return fmt.Errorf("rotate auth session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *postgresAuthSessionRepository) Touch(ctx context.Context, authSessionID string, seenAt time.Time) error {
|
||||
const query = `
|
||||
UPDATE auth_sessions
|
||||
SET last_seen_at = $2, updated_at = $2
|
||||
WHERE id = $1::uuid AND revoked_at IS NULL
|
||||
`
|
||||
if _, err := r.db.Exec(ctx, query, authSessionID, seenAt); err != nil {
|
||||
return fmt.Errorf("touch auth session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *postgresAuthSessionRepository) Revoke(ctx context.Context, params RevokeAuthSessionParams) error {
|
||||
const query = `
|
||||
UPDATE auth_sessions
|
||||
SET revoked_at = $3,
|
||||
revoked_reason = $4,
|
||||
updated_at = $3
|
||||
WHERE id = $1::uuid AND user_id = $2::uuid AND revoked_at IS NULL
|
||||
`
|
||||
if _, err := r.db.Exec(ctx, query, params.AuthSessionID, params.UserID, params.RevokedAt, params.Reason); err != nil {
|
||||
return fmt.Errorf("revoke auth session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *postgresAuthSessionRepository) RevokeByDevice(ctx context.Context, userID, deviceID, reason string, revokedAt time.Time) error {
|
||||
const query = `
|
||||
UPDATE auth_sessions
|
||||
SET revoked_at = $3,
|
||||
revoked_reason = $4,
|
||||
updated_at = $3
|
||||
WHERE user_id = $1::uuid AND device_id = $2::uuid AND revoked_at IS NULL
|
||||
`
|
||||
if _, err := r.db.Exec(ctx, query, userID, deviceID, revokedAt, reason); err != nil {
|
||||
return fmt.Errorf("revoke auth sessions by device: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *postgresInstallationRepository) GetStatus(ctx context.Context) (*InstallationAuthorityState, error) {
|
||||
const query = `
|
||||
SELECT install_id, authority_state, product_root_key_fingerprint, bootstrapped_owner_email, bootstrapped_at
|
||||
FROM installation_authority
|
||||
WHERE id = 1
|
||||
`
|
||||
status := &InstallationAuthorityState{}
|
||||
if err := r.db.QueryRow(ctx, query).Scan(
|
||||
&status.InstallID,
|
||||
&status.AuthorityState,
|
||||
&status.ProductRootFingerprint,
|
||||
&status.BootstrappedOwnerEmail,
|
||||
&status.BootstrappedAt,
|
||||
); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return &InstallationAuthorityState{
|
||||
Bootstrapped: false,
|
||||
AuthorityState: "unbootstrapped",
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("get installation status: %w", err)
|
||||
}
|
||||
status.Bootstrapped = true
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (r *postgresInstallationRepository) BootstrapOwner(ctx context.Context, params BootstrapOwnerParams) (*User, error) {
|
||||
var existingInstallID string
|
||||
if err := r.db.QueryRow(ctx, `
|
||||
SELECT install_id
|
||||
FROM installation_authority
|
||||
WHERE id = 1
|
||||
FOR UPDATE
|
||||
`).Scan(&existingInstallID); err != nil && !errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, fmt.Errorf("lock installation authority: %w", err)
|
||||
} else if err == nil {
|
||||
return nil, ErrInstallationAlreadyBootstrapped
|
||||
}
|
||||
|
||||
email := strings.ToLower(strings.TrimSpace(params.Email))
|
||||
now := params.Now.UTC()
|
||||
user, err := scanOptionalUser(r.db.QueryRow(ctx, `
|
||||
INSERT INTO users (email, password_hash, mfa_enabled, platform_role, created_at, updated_at)
|
||||
VALUES ($1, $2, FALSE, $3, $4, $4)
|
||||
ON CONFLICT (email) DO UPDATE SET
|
||||
password_hash = EXCLUDED.password_hash,
|
||||
platform_role = EXCLUDED.platform_role,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
RETURNING id::text, email, password_hash, mfa_enabled, created_at, updated_at
|
||||
`, email, params.PasswordHash, params.Role, now))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upsert bootstrap owner: %w", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, fmt.Errorf("upsert bootstrap owner returned no user")
|
||||
}
|
||||
|
||||
payload := json.RawMessage(`{}`)
|
||||
if len(params.ActivationPayload) > 0 {
|
||||
payload = params.ActivationPayload
|
||||
}
|
||||
if _, err := r.db.Exec(ctx, `
|
||||
INSERT INTO installation_authority (
|
||||
id,
|
||||
install_id,
|
||||
authority_state,
|
||||
product_root_key_fingerprint,
|
||||
activation_payload,
|
||||
activation_signature,
|
||||
bootstrapped_owner_email,
|
||||
bootstrapped_at,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (
|
||||
1,
|
||||
$1,
|
||||
'active',
|
||||
$2,
|
||||
$3::jsonb,
|
||||
$4,
|
||||
$5,
|
||||
$6,
|
||||
$6,
|
||||
$6
|
||||
)
|
||||
`, params.InstallID, params.ProductRootKeyFingerprint, []byte(payload), params.ActivationSignature, email, now); err != nil {
|
||||
return nil, fmt.Errorf("insert installation authority: %w", err)
|
||||
}
|
||||
|
||||
if _, err := r.db.Exec(ctx, `
|
||||
UPDATE platform_role_grants
|
||||
SET revoked_at = $4
|
||||
WHERE user_id = $1::uuid
|
||||
AND role = $2
|
||||
AND install_id = $3
|
||||
AND revoked_at IS NULL
|
||||
`, user.ID, params.Role, params.InstallID, now); err != nil {
|
||||
return nil, fmt.Errorf("revoke superseded platform role grants: %w", err)
|
||||
}
|
||||
if _, err := r.db.Exec(ctx, `
|
||||
INSERT INTO platform_role_grants (
|
||||
user_id,
|
||||
role,
|
||||
install_id,
|
||||
grant_payload,
|
||||
grant_signature,
|
||||
grant_source,
|
||||
granted_at,
|
||||
expires_at,
|
||||
metadata
|
||||
) VALUES (
|
||||
$1::uuid,
|
||||
$2,
|
||||
$3,
|
||||
$4::jsonb,
|
||||
$5,
|
||||
$6,
|
||||
$7,
|
||||
$8,
|
||||
'{"bootstrap_owner":true}'::jsonb
|
||||
)
|
||||
`, user.ID, params.Role, params.InstallID, []byte(payload), params.ActivationSignature, params.GrantSource, now, params.ExpiresAt); err != nil {
|
||||
return nil, fmt.Errorf("insert platform role grant: %w", err)
|
||||
}
|
||||
|
||||
if _, err := r.db.Exec(ctx, `
|
||||
INSERT INTO organization_memberships (
|
||||
organization_id,
|
||||
user_id,
|
||||
role_id,
|
||||
status,
|
||||
invited_by_user_id,
|
||||
created_at,
|
||||
updated_at
|
||||
)
|
||||
SELECT id, $1::uuid, 'org_owner', 'active', $1::uuid, $2, $2
|
||||
FROM organizations
|
||||
WHERE slug = 'default'
|
||||
ON CONFLICT (organization_id, user_id) DO UPDATE SET
|
||||
role_id = 'org_owner',
|
||||
status = 'active',
|
||||
invited_by_user_id = EXCLUDED.invited_by_user_id,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
`, user.ID, now); err != nil {
|
||||
return nil, fmt.Errorf("upsert default organization owner membership: %w", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
type scanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func scanOptionalUser(row scanner) (*User, error) {
|
||||
user := &User{}
|
||||
if err := row.Scan(
|
||||
&user.ID,
|
||||
&user.Email,
|
||||
&user.PasswordHash,
|
||||
&user.MFAEnabled,
|
||||
&user.CreatedAt,
|
||||
&user.UpdatedAt,
|
||||
); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("scan user: %w", err)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func scanDevice(row scanner) (*Device, error) {
|
||||
device := &Device{}
|
||||
var trustedAt, lastSeenAt, revokedAt *time.Time
|
||||
var revokedReason *string
|
||||
if err := row.Scan(
|
||||
&device.ID,
|
||||
&device.UserID,
|
||||
&device.Fingerprint,
|
||||
&device.Label,
|
||||
&device.TrustStatus,
|
||||
&trustedAt,
|
||||
&lastSeenAt,
|
||||
&revokedAt,
|
||||
&revokedReason,
|
||||
&device.CreatedAt,
|
||||
&device.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan device: %w", err)
|
||||
}
|
||||
device.TrustedAt = trustedAt
|
||||
device.LastSeenAt = lastSeenAt
|
||||
device.RevokedAt = revokedAt
|
||||
device.RevokedReason = revokedReason
|
||||
return device, nil
|
||||
}
|
||||
|
||||
func scanAuthSession(row scanner) (*AuthSession, error) {
|
||||
session := &AuthSession{}
|
||||
var lastSeenAt, lastRotatedAt, revokedAt *time.Time
|
||||
var revokedReason *string
|
||||
if err := row.Scan(
|
||||
&session.ID,
|
||||
&session.UserID,
|
||||
&session.DeviceID,
|
||||
&session.RefreshTokenHash,
|
||||
&session.RefreshExpiresAt,
|
||||
&lastSeenAt,
|
||||
&lastRotatedAt,
|
||||
&revokedAt,
|
||||
&revokedReason,
|
||||
&session.CreatedAt,
|
||||
&session.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan auth session: %w", err)
|
||||
}
|
||||
session.LastSeenAt = lastSeenAt
|
||||
session.LastRotatedAt = lastRotatedAt
|
||||
session.RevokedAt = revokedAt
|
||||
session.RevokedReason = revokedReason
|
||||
return session, nil
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UserRepository interface {
|
||||
GetByEmail(ctx context.Context, email string) (*User, error)
|
||||
GetByID(ctx context.Context, userID string) (*User, error)
|
||||
}
|
||||
|
||||
type DeviceRepository interface {
|
||||
Upsert(ctx context.Context, params UpsertDeviceParams) (*Device, error)
|
||||
GetByIDForUser(ctx context.Context, userID, deviceID string) (*Device, error)
|
||||
ListTrustedByUser(ctx context.Context, userID string) ([]Device, error)
|
||||
Revoke(ctx context.Context, params RevokeDeviceParams) error
|
||||
}
|
||||
|
||||
type AuthSessionRepository interface {
|
||||
Create(ctx context.Context, session AuthSession) error
|
||||
GetByID(ctx context.Context, authSessionID string) (*AuthSession, error)
|
||||
GetByIDForUpdate(ctx context.Context, authSessionID string) (*AuthSession, error)
|
||||
Rotate(ctx context.Context, params RotateAuthSessionParams) error
|
||||
Touch(ctx context.Context, authSessionID string, seenAt time.Time) error
|
||||
Revoke(ctx context.Context, params RevokeAuthSessionParams) error
|
||||
RevokeByDevice(ctx context.Context, userID, deviceID, reason string, revokedAt time.Time) error
|
||||
}
|
||||
|
||||
type InstallationRepository interface {
|
||||
GetStatus(ctx context.Context) (*InstallationAuthorityState, error)
|
||||
BootstrapOwner(ctx context.Context, params BootstrapOwnerParams) (*User, error)
|
||||
}
|
||||
|
||||
type Store interface {
|
||||
Users() UserRepository
|
||||
Devices() DeviceRepository
|
||||
AuthSessions() AuthSessionRepository
|
||||
Installation() InstallationRepository
|
||||
}
|
||||
|
||||
type Transactor interface {
|
||||
WithinTransaction(ctx context.Context, fn func(store Store) error) error
|
||||
}
|
||||
|
||||
type UpsertDeviceParams struct {
|
||||
UserID string
|
||||
Fingerprint string
|
||||
Label string
|
||||
TrustRequested bool
|
||||
SeenAt time.Time
|
||||
}
|
||||
|
||||
type RotateAuthSessionParams struct {
|
||||
AuthSessionID string
|
||||
RefreshTokenHash string
|
||||
RefreshExpiresAt time.Time
|
||||
LastSeenAt time.Time
|
||||
LastRotatedAt time.Time
|
||||
}
|
||||
|
||||
type RevokeAuthSessionParams struct {
|
||||
AuthSessionID string
|
||||
UserID string
|
||||
Reason string
|
||||
RevokedAt time.Time
|
||||
}
|
||||
|
||||
type RevokeDeviceParams struct {
|
||||
UserID string
|
||||
DeviceID string
|
||||
Reason string
|
||||
RevokedAt time.Time
|
||||
}
|
||||
|
||||
type InstallationAuthorityState struct {
|
||||
Bootstrapped bool
|
||||
AuthorityState string
|
||||
InstallID string
|
||||
ProductRootFingerprint string
|
||||
BootstrappedOwnerEmail string
|
||||
BootstrappedAt *time.Time
|
||||
}
|
||||
|
||||
type BootstrapOwnerParams struct {
|
||||
Email string
|
||||
PasswordHash string
|
||||
Role string
|
||||
InstallID string
|
||||
ProductRootKeyFingerprint string
|
||||
ActivationPayload json.RawMessage
|
||||
ActivationSignature string
|
||||
GrantSource string
|
||||
ExpiresAt *time.Time
|
||||
Now time.Time
|
||||
}
|
||||
@@ -0,0 +1,440 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/authority"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/module"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
cfg module.Config
|
||||
store Store
|
||||
transactor Transactor
|
||||
tokenManager *TokenManager
|
||||
authority *authority.Verifier
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
func NewService(deps module.Dependencies, store Store, transactor Transactor, verifiers ...*authority.Verifier) *Service {
|
||||
var authorityVerifier *authority.Verifier
|
||||
if len(verifiers) > 0 {
|
||||
authorityVerifier = verifiers[0]
|
||||
} else if verifier, err := authority.NewVerifier(deps.Config.Installation); err == nil {
|
||||
authorityVerifier = verifier
|
||||
}
|
||||
return &Service{
|
||||
cfg: deps.Config,
|
||||
store: store,
|
||||
transactor: transactor,
|
||||
tokenManager: NewTokenManager(TokenConfig{
|
||||
Issuer: deps.Config.Auth.Issuer,
|
||||
AccessTokenSecret: deps.Config.Auth.AccessTokenSecret,
|
||||
RefreshHashSecret: deps.Config.Auth.RefreshHashSecret,
|
||||
AccessTokenTTL: deps.Config.Auth.AccessTokenTTL,
|
||||
RefreshTokenTTL: deps.Config.Auth.RefreshTokenTTL,
|
||||
}),
|
||||
authority: authorityVerifier,
|
||||
now: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Login(ctx context.Context, cmd LoginCommand) (*AuthResult, error) {
|
||||
user, err := s.store.Users().GetByEmail(ctx, cmd.Email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if user == nil {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(cmd.Password)); err != nil {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
var result AuthResult
|
||||
now := s.now().UTC()
|
||||
|
||||
if err := s.transactor.WithinTransaction(ctx, func(store Store) error {
|
||||
device, err := store.Devices().Upsert(ctx, UpsertDeviceParams{
|
||||
UserID: user.ID,
|
||||
Fingerprint: cmd.DeviceFingerprint,
|
||||
Label: cmd.DeviceLabel,
|
||||
TrustRequested: cmd.TrustDevice,
|
||||
SeenAt: now,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if device.TrustStatus == DeviceTrustStatusRevoked {
|
||||
return ErrDeviceRevoked
|
||||
}
|
||||
|
||||
authSessionID := uuid.NewString()
|
||||
refreshToken, refreshHash, refreshExpiresAt, err := s.tokenManager.IssueRefreshToken(authSessionID, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
accessToken, accessExpiresAt, err := s.tokenManager.IssueAccessToken(user.ID, authSessionID, device.ID, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session := AuthSession{
|
||||
ID: authSessionID,
|
||||
UserID: user.ID,
|
||||
DeviceID: device.ID,
|
||||
RefreshTokenHash: refreshHash,
|
||||
RefreshExpiresAt: refreshExpiresAt,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
LastSeenAt: &now,
|
||||
}
|
||||
if err := store.AuthSessions().Create(ctx, session); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result = AuthResult{
|
||||
User: *user,
|
||||
Device: *device,
|
||||
AuthSession: session,
|
||||
Tokens: TokenPair{
|
||||
AccessToken: accessToken,
|
||||
AccessTokenExpiresAt: accessExpiresAt,
|
||||
RefreshToken: refreshToken,
|
||||
RefreshTokenExpiresAt: refreshExpiresAt,
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (s *Service) Refresh(ctx context.Context, cmd RefreshCommand) (*AuthResult, error) {
|
||||
authSessionID, err := s.tokenManager.ParseRefreshToken(cmd.RefreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result AuthResult
|
||||
now := s.now().UTC()
|
||||
|
||||
if err := s.transactor.WithinTransaction(ctx, func(store Store) error {
|
||||
session, err := store.AuthSessions().GetByIDForUpdate(ctx, authSessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if session == nil {
|
||||
return ErrInvalidRefreshToken
|
||||
}
|
||||
if session.RevokedAt != nil {
|
||||
return ErrAuthSessionRevoked
|
||||
}
|
||||
if now.After(session.RefreshExpiresAt) {
|
||||
if revokeErr := store.AuthSessions().Revoke(ctx, RevokeAuthSessionParams{
|
||||
AuthSessionID: session.ID,
|
||||
UserID: session.UserID,
|
||||
Reason: "refresh_token_expired",
|
||||
RevokedAt: now,
|
||||
}); revokeErr != nil {
|
||||
return revokeErr
|
||||
}
|
||||
return ErrInvalidRefreshToken
|
||||
}
|
||||
|
||||
expectedHash := s.tokenManager.HashRefreshToken(cmd.RefreshToken)
|
||||
if expectedHash != session.RefreshTokenHash {
|
||||
if revokeErr := store.AuthSessions().Revoke(ctx, RevokeAuthSessionParams{
|
||||
AuthSessionID: session.ID,
|
||||
UserID: session.UserID,
|
||||
Reason: "refresh_rotation_reuse_detected",
|
||||
RevokedAt: now,
|
||||
}); revokeErr != nil {
|
||||
return revokeErr
|
||||
}
|
||||
return ErrInvalidRefreshToken
|
||||
}
|
||||
|
||||
user, err := store.Users().GetByID(ctx, session.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if user == nil {
|
||||
return ErrInvalidCredentials
|
||||
}
|
||||
|
||||
device, err := store.Devices().GetByIDForUser(ctx, session.UserID, session.DeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if device == nil {
|
||||
return ErrTrustedDeviceMissing
|
||||
}
|
||||
if device.TrustStatus == DeviceTrustStatusRevoked {
|
||||
return ErrDeviceRevoked
|
||||
}
|
||||
|
||||
refreshToken, refreshHash, refreshExpiresAt, err := s.tokenManager.IssueRefreshToken(session.ID, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
accessToken, accessExpiresAt, err := s.tokenManager.IssueAccessToken(user.ID, session.ID, device.ID, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := store.AuthSessions().Rotate(ctx, RotateAuthSessionParams{
|
||||
AuthSessionID: session.ID,
|
||||
RefreshTokenHash: refreshHash,
|
||||
RefreshExpiresAt: refreshExpiresAt,
|
||||
LastSeenAt: now,
|
||||
LastRotatedAt: now,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result = AuthResult{
|
||||
User: *user,
|
||||
Device: *device,
|
||||
AuthSession: AuthSession{
|
||||
ID: session.ID,
|
||||
UserID: session.UserID,
|
||||
DeviceID: session.DeviceID,
|
||||
RefreshTokenHash: refreshHash,
|
||||
RefreshExpiresAt: refreshExpiresAt,
|
||||
LastSeenAt: &now,
|
||||
LastRotatedAt: &now,
|
||||
},
|
||||
Tokens: TokenPair{
|
||||
AccessToken: accessToken,
|
||||
AccessTokenExpiresAt: accessExpiresAt,
|
||||
RefreshToken: refreshToken,
|
||||
RefreshTokenExpiresAt: refreshExpiresAt,
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (s *Service) InstallationStatus(ctx context.Context) (*InstallationStatus, error) {
|
||||
record, err := s.store.Installation().GetStatus(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.installationStatusFromRecord(record), nil
|
||||
}
|
||||
|
||||
func (s *Service) BootstrapOwner(ctx context.Context, cmd BootstrapOwnerCommand) (*BootstrapOwnerResult, error) {
|
||||
email := strings.ToLower(strings.TrimSpace(cmd.Email))
|
||||
password := strings.TrimSpace(cmd.Password)
|
||||
if email == "" || !strings.Contains(email, "@") || len(password) < 12 {
|
||||
return nil, ErrInvalidBootstrapOwner
|
||||
}
|
||||
|
||||
now := s.now().UTC()
|
||||
role := authority.PlatformRoleAdmin
|
||||
installID := ""
|
||||
grantSource := "installation_activation"
|
||||
rootFingerprint := ""
|
||||
activationPayload := cmd.ActivationPayload
|
||||
activationSignature := strings.TrimSpace(cmd.ActivationSignature)
|
||||
var expiresAt *time.Time
|
||||
|
||||
if s.strictAuthority() {
|
||||
if len(activationPayload) == 0 || activationSignature == "" {
|
||||
return nil, ErrInstallationActivationRequired
|
||||
}
|
||||
activation, err := s.authority.VerifyActivation(activationPayload, activationSignature)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidInstallationActivation, err)
|
||||
}
|
||||
if !strings.EqualFold(activation.OwnerEmail, email) {
|
||||
return nil, ErrInvalidInstallationActivation
|
||||
}
|
||||
role = activation.PlatformRole
|
||||
installID = activation.InstallID
|
||||
expiresAt = activation.ExpiresAt
|
||||
rootFingerprint = s.authority.RootFingerprint()
|
||||
} else {
|
||||
if s.authority == nil || !s.authority.AllowInsecureBootstrap() {
|
||||
return nil, ErrInsecureBootstrapDisabled
|
||||
}
|
||||
installID = uuid.NewString()
|
||||
grantSource = "dev_insecure"
|
||||
rootFingerprint = "dev-insecure"
|
||||
devPayload, err := json.Marshal(authority.ActivationPayload{
|
||||
SchemaVersion: authority.ActivationSchemaVersion,
|
||||
InstallID: installID,
|
||||
OwnerEmail: email,
|
||||
PlatformRole: role,
|
||||
IssuedAt: now,
|
||||
Environment: s.cfg.App.Env,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
activationPayload = json.RawMessage(devPayload)
|
||||
activationSignature = "dev-insecure"
|
||||
}
|
||||
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hash bootstrap owner password: %w", err)
|
||||
}
|
||||
|
||||
var user *User
|
||||
if err := s.transactor.WithinTransaction(ctx, func(store Store) error {
|
||||
created, err := store.Installation().BootstrapOwner(ctx, BootstrapOwnerParams{
|
||||
Email: email,
|
||||
PasswordHash: string(passwordHash),
|
||||
Role: role,
|
||||
InstallID: installID,
|
||||
ProductRootKeyFingerprint: rootFingerprint,
|
||||
ActivationPayload: activationPayload,
|
||||
ActivationSignature: activationSignature,
|
||||
GrantSource: grantSource,
|
||||
ExpiresAt: expiresAt,
|
||||
Now: now,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user = created
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
status, err := s.InstallationStatus(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &BootstrapOwnerResult{
|
||||
Installation: *status,
|
||||
User: *user,
|
||||
PlatformRole: role,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) RevokeAuthSession(ctx context.Context, cmd RevokeAuthSessionCommand) error {
|
||||
return s.transactor.WithinTransaction(ctx, func(store Store) error {
|
||||
session, err := store.AuthSessions().GetByIDForUpdate(ctx, cmd.AuthSessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if session == nil || session.UserID != cmd.UserID {
|
||||
return ErrAuthSessionNotFound
|
||||
}
|
||||
return store.AuthSessions().Revoke(ctx, RevokeAuthSessionParams{
|
||||
AuthSessionID: cmd.AuthSessionID,
|
||||
UserID: cmd.UserID,
|
||||
Reason: cmd.Reason,
|
||||
RevokedAt: s.now().UTC(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) RevokeTrustedDevice(ctx context.Context, cmd RevokeDeviceCommand) error {
|
||||
return s.transactor.WithinTransaction(ctx, func(store Store) error {
|
||||
device, err := store.Devices().GetByIDForUser(ctx, cmd.UserID, cmd.DeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if device == nil {
|
||||
return ErrTrustedDeviceMissing
|
||||
}
|
||||
if device.TrustStatus != DeviceTrustStatusTrusted {
|
||||
return ErrDeviceNotTrusted
|
||||
}
|
||||
|
||||
now := s.now().UTC()
|
||||
if err := store.Devices().Revoke(ctx, RevokeDeviceParams{
|
||||
UserID: cmd.UserID,
|
||||
DeviceID: cmd.DeviceID,
|
||||
Reason: cmd.Reason,
|
||||
RevokedAt: now,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return store.AuthSessions().RevokeByDevice(ctx, cmd.UserID, cmd.DeviceID, "device_revoked:"+cmd.Reason, now)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) ListTrustedDevices(ctx context.Context, userID string) ([]Device, error) {
|
||||
return s.store.Devices().ListTrustedByUser(ctx, userID)
|
||||
}
|
||||
|
||||
func (s *Service) MapError(err error) (int, string) {
|
||||
switch {
|
||||
case err == nil:
|
||||
return 0, ""
|
||||
case errors.Is(err, ErrInvalidCredentials):
|
||||
return 401, "invalid credentials"
|
||||
case errors.Is(err, ErrInvalidRefreshToken):
|
||||
return 401, "invalid refresh token"
|
||||
case errors.Is(err, ErrAuthSessionRevoked):
|
||||
return 401, "auth session revoked"
|
||||
case errors.Is(err, ErrDeviceRevoked):
|
||||
return 403, "device revoked"
|
||||
case errors.Is(err, ErrDeviceNotTrusted):
|
||||
return 409, "device is not trusted"
|
||||
case errors.Is(err, ErrAuthSessionNotFound), errors.Is(err, ErrTrustedDeviceMissing):
|
||||
return 404, err.Error()
|
||||
case errors.Is(err, ErrInstallationActivationRequired), errors.Is(err, ErrInvalidInstallationActivation), errors.Is(err, ErrInvalidBootstrapOwner):
|
||||
return 400, err.Error()
|
||||
case errors.Is(err, ErrInsecureBootstrapDisabled):
|
||||
return 403, err.Error()
|
||||
case errors.Is(err, ErrInstallationAlreadyBootstrapped):
|
||||
return 409, err.Error()
|
||||
default:
|
||||
return 500, fmt.Sprintf("internal error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) installationStatusFromRecord(record *InstallationAuthorityState) *InstallationStatus {
|
||||
if record == nil {
|
||||
record = &InstallationAuthorityState{AuthorityState: "unbootstrapped"}
|
||||
}
|
||||
mode := authority.ModeLegacy
|
||||
strict := false
|
||||
rootFingerprint := ""
|
||||
insecureAllowed := false
|
||||
if s.authority != nil {
|
||||
mode = s.authority.Mode()
|
||||
strict = s.authority.Strict()
|
||||
rootFingerprint = s.authority.RootFingerprint()
|
||||
insecureAllowed = s.authority.AllowInsecureBootstrap()
|
||||
}
|
||||
if record.ProductRootFingerprint != "" {
|
||||
rootFingerprint = record.ProductRootFingerprint
|
||||
}
|
||||
return &InstallationStatus{
|
||||
Bootstrapped: record.Bootstrapped,
|
||||
AuthorityState: record.AuthorityState,
|
||||
InstallID: record.InstallID,
|
||||
BootstrappedOwnerEmail: record.BootstrappedOwnerEmail,
|
||||
BootstrappedAt: record.BootstrappedAt,
|
||||
AuthorityMode: mode,
|
||||
StrictAuthority: strict,
|
||||
RootFingerprint: rootFingerprint,
|
||||
InsecureBootstrapAllowed: insecureAllowed,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) strictAuthority() bool {
|
||||
return s.authority != nil && s.authority.Strict()
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
type TokenManager struct {
|
||||
issuer string
|
||||
accessSecret []byte
|
||||
refreshHashSecret []byte
|
||||
accessTTL time.Duration
|
||||
refreshTTL time.Duration
|
||||
}
|
||||
|
||||
type AccessClaims struct {
|
||||
AuthSessionID string `json:"sid"`
|
||||
DeviceID string `json:"did"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
func NewTokenManager(cfg TokenConfig) *TokenManager {
|
||||
return &TokenManager{
|
||||
issuer: cfg.Issuer,
|
||||
accessSecret: []byte(cfg.AccessTokenSecret),
|
||||
refreshHashSecret: []byte(cfg.RefreshHashSecret),
|
||||
accessTTL: cfg.AccessTokenTTL,
|
||||
refreshTTL: cfg.RefreshTokenTTL,
|
||||
}
|
||||
}
|
||||
|
||||
type TokenConfig struct {
|
||||
Issuer string
|
||||
AccessTokenSecret string
|
||||
RefreshHashSecret string
|
||||
AccessTokenTTL time.Duration
|
||||
RefreshTokenTTL time.Duration
|
||||
}
|
||||
|
||||
func (m *TokenManager) IssueAccessToken(userID, authSessionID, deviceID string, now time.Time) (string, time.Time, error) {
|
||||
expiresAt := now.Add(m.accessTTL)
|
||||
claims := AccessClaims{
|
||||
AuthSessionID: authSessionID,
|
||||
DeviceID: deviceID,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: m.issuer,
|
||||
Subject: userID,
|
||||
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signed, err := token.SignedString(m.accessSecret)
|
||||
if err != nil {
|
||||
return "", time.Time{}, fmt.Errorf("sign access token: %w", err)
|
||||
}
|
||||
|
||||
return signed, expiresAt, nil
|
||||
}
|
||||
|
||||
func (m *TokenManager) IssueRefreshToken(authSessionID string, now time.Time) (raw string, hash string, expiresAt time.Time, err error) {
|
||||
secret := make([]byte, 32)
|
||||
if _, err = rand.Read(secret); err != nil {
|
||||
return "", "", time.Time{}, fmt.Errorf("read random refresh secret: %w", err)
|
||||
}
|
||||
|
||||
encodedSecret := base64.RawURLEncoding.EncodeToString(secret)
|
||||
raw = authSessionID + "." + encodedSecret
|
||||
hash = m.HashRefreshToken(raw)
|
||||
expiresAt = now.Add(m.refreshTTL)
|
||||
return raw, hash, expiresAt, nil
|
||||
}
|
||||
|
||||
func (m *TokenManager) HashRefreshToken(token string) string {
|
||||
mac := hmac.New(sha256.New, m.refreshHashSecret)
|
||||
_, _ = mac.Write([]byte(token))
|
||||
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
|
||||
func (m *TokenManager) ParseRefreshToken(token string) (string, error) {
|
||||
sessionID, _, ok := strings.Cut(token, ".")
|
||||
if !ok || sessionID == "" {
|
||||
return "", ErrInvalidRefreshToken
|
||||
}
|
||||
return sessionID, nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,34 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMeshLatestObservationKeySeparatesRouteHealthByRoute(t *testing.T) {
|
||||
key := meshLatestObservationKey(json.RawMessage(`{
|
||||
"observation_type":"synthetic_route_health",
|
||||
"route_id":"route-1"
|
||||
}`))
|
||||
if key != "synthetic_route_health:route-1" {
|
||||
t.Fatalf("key = %q", key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMeshLatestObservationKeySeparatesConnectionManagerMode(t *testing.T) {
|
||||
key := meshLatestObservationKey(json.RawMessage(`{
|
||||
"observation_type":"peer_connection_manager",
|
||||
"transport_mode":"relay_control",
|
||||
"relay_node_id":"node-r"
|
||||
}`))
|
||||
if key != "peer_connection_manager:relay_control:node-r" {
|
||||
t.Fatalf("key = %q", key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMeshLatestObservationKeyDefaults(t *testing.T) {
|
||||
key := meshLatestObservationKey(json.RawMessage(`{}`))
|
||||
if key != "default" {
|
||||
t.Fatalf("key = %q", key)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
GetPlatformRole(ctx context.Context, userID string) (string, error)
|
||||
|
||||
ListClusters(ctx context.Context) ([]Cluster, error)
|
||||
GetCluster(ctx context.Context, clusterID string) (Cluster, error)
|
||||
CreateCluster(ctx context.Context, input CreateClusterInput) (Cluster, error)
|
||||
UpdateCluster(ctx context.Context, input UpdateClusterInput) (Cluster, error)
|
||||
GetClusterAuthority(ctx context.Context, clusterID string) (ClusterAuthorityKey, error)
|
||||
EnsureClusterAuthority(ctx context.Context, clusterID string, actorUserID *string) (ClusterAuthorityKey, error)
|
||||
|
||||
ListClusterNodes(ctx context.Context, clusterID string) ([]ClusterNode, error)
|
||||
ListNodeGroups(ctx context.Context, clusterID string) ([]ClusterNodeGroup, error)
|
||||
CreateNodeGroup(ctx context.Context, input CreateNodeGroupInput) (ClusterNodeGroup, error)
|
||||
AssignNodeToGroup(ctx context.Context, input AssignNodeGroupInput) (ClusterNode, error)
|
||||
|
||||
CreateJoinToken(ctx context.Context, input CreateJoinTokenInput, tokenHash string) (NodeJoinToken, error)
|
||||
SetJoinTokenAuthority(ctx context.Context, clusterID, tokenID string, payload json.RawMessage, signature ClusterSignature) (NodeJoinToken, error)
|
||||
GetValidJoinTokenByHash(ctx context.Context, clusterID, tokenHash string) (NodeJoinToken, error)
|
||||
RevokeJoinToken(ctx context.Context, input RevokeJoinTokenInput) (NodeJoinToken, error)
|
||||
ExpireJoinTokens(ctx context.Context, clusterID string) error
|
||||
|
||||
CreateJoinRequest(ctx context.Context, input CreateJoinRequestInput, joinTokenID string) (NodeJoinRequest, error)
|
||||
GetJoinRequestForBootstrap(ctx context.Context, input GetJoinRequestBootstrapInput) (NodeJoinRequest, error)
|
||||
ListJoinRequests(ctx context.Context, clusterID string) ([]NodeJoinRequest, error)
|
||||
ApproveJoinRequest(ctx context.Context, input ApproveJoinRequestInput) (ApprovedJoinRequest, error)
|
||||
SetJoinRequestApprovalAuthority(ctx context.Context, clusterID, joinRequestID string, payload json.RawMessage, signature ClusterSignature) (NodeJoinRequest, error)
|
||||
RejectJoinRequest(ctx context.Context, input RejectJoinRequestInput) (NodeJoinRequest, error)
|
||||
|
||||
AssignNodeRole(ctx context.Context, input AssignNodeRoleInput) (NodeRoleAssignment, error)
|
||||
ListNodeRoleAssignments(ctx context.Context, clusterID, nodeID string) ([]NodeRoleAssignment, error)
|
||||
AttachExistingNodeToCluster(ctx context.Context, input AttachExistingNodeInput) (ClusterNode, error)
|
||||
|
||||
RecordHeartbeat(ctx context.Context, input RecordHeartbeatInput) (NodeHeartbeat, error)
|
||||
ListNodeHeartbeats(ctx context.Context, clusterID, nodeID string, limit int) ([]NodeHeartbeat, error)
|
||||
RevokeNodeIdentity(ctx context.Context, input RevokeNodeIdentityInput) error
|
||||
DisableClusterMembership(ctx context.Context, input DisableMembershipInput) error
|
||||
UpsertFabricTestingFlag(ctx context.Context, input UpsertFabricTestingFlagInput) (FabricTestingFlag, error)
|
||||
ListFabricTestingFlags(ctx context.Context) ([]FabricTestingFlag, error)
|
||||
GetEffectiveNodeTestingFlags(ctx context.Context, clusterID, nodeID string) (EffectiveNodeTestingFlags, error)
|
||||
RecordNodeTelemetry(ctx context.Context, input RecordNodeTelemetryInput) (NodeTelemetryObservation, error)
|
||||
ListNodeTelemetry(ctx context.Context, clusterID, nodeID string, limit int) ([]NodeTelemetryObservation, error)
|
||||
SetDesiredWorkload(ctx context.Context, input SetDesiredWorkloadInput) (NodeWorkloadDesiredState, error)
|
||||
ListDesiredWorkloads(ctx context.Context, clusterID, nodeID string) ([]NodeWorkloadDesiredState, error)
|
||||
ReportWorkloadStatus(ctx context.Context, input ReportWorkloadStatusInput) (NodeWorkloadStatus, error)
|
||||
ListLatestWorkloadStatuses(ctx context.Context, clusterID, nodeID string) ([]NodeWorkloadStatus, error)
|
||||
ReportMeshLink(ctx context.Context, input ReportMeshLinkInput) (MeshLinkObservation, error)
|
||||
ListMeshLinks(ctx context.Context, clusterID string) ([]MeshLinkObservation, error)
|
||||
CreateRouteIntent(ctx context.Context, input CreateRouteIntentInput) (MeshRouteIntent, error)
|
||||
ListRouteIntents(ctx context.Context, clusterID string) ([]MeshRouteIntent, error)
|
||||
ListQoSPolicies(ctx context.Context, clusterID string) ([]MeshQoSPolicy, error)
|
||||
ListFabricEntryPoints(ctx context.Context, clusterID string) ([]FabricEntryPoint, error)
|
||||
CreateFabricEntryPoint(ctx context.Context, input CreateFabricEntryPointInput) (FabricEntryPoint, error)
|
||||
SetFabricEntryPointNode(ctx context.Context, input SetFabricEntryPointNodeInput) (FabricEntryPointNode, error)
|
||||
ListFabricEntryPointNodes(ctx context.Context, clusterID, entryPointID string) ([]FabricEntryPointNode, error)
|
||||
ListFabricEgressPools(ctx context.Context, clusterID string) ([]FabricEgressPool, error)
|
||||
CreateFabricEgressPool(ctx context.Context, input CreateFabricEgressPoolInput) (FabricEgressPool, error)
|
||||
SetFabricEgressPoolNode(ctx context.Context, input SetFabricEgressPoolNodeInput) (FabricEgressPoolNode, error)
|
||||
ListFabricEgressPoolNodes(ctx context.Context, clusterID, egressPoolID string) ([]FabricEgressPoolNode, error)
|
||||
GetClusterAuthorityState(ctx context.Context, clusterID string) (ClusterAuthorityState, error)
|
||||
UpdateClusterAuthorityState(ctx context.Context, input UpdateClusterAuthorityInput) (ClusterAuthorityState, error)
|
||||
ListClusterAdminSummaries(ctx context.Context) ([]ClusterAdminSummary, error)
|
||||
|
||||
CreateVPNConnection(ctx context.Context, input CreateVPNConnectionInput) (VPNConnection, error)
|
||||
ListVPNConnections(ctx context.Context, clusterID string) ([]VPNConnection, error)
|
||||
GetVPNConnection(ctx context.Context, clusterID, vpnConnectionID string) (VPNConnection, error)
|
||||
UpdateVPNConnectionDesiredState(ctx context.Context, input UpdateVPNConnectionDesiredStateInput) (VPNConnection, error)
|
||||
UpsertVPNConnectionRoutePolicy(ctx context.Context, input UpsertVPNConnectionRoutePolicyInput) (VPNConnectionRoutePolicy, error)
|
||||
ListVPNConnectionRoutePolicies(ctx context.Context, clusterID, vpnConnectionID string) ([]VPNConnectionRoutePolicy, error)
|
||||
SetVPNConnectionAllowedNodes(ctx context.Context, input SetVPNConnectionAllowedNodesInput) ([]VPNConnectionAllowedNode, error)
|
||||
ListVPNConnectionAllowedNodes(ctx context.Context, clusterID, vpnConnectionID string) ([]VPNConnectionAllowedNode, error)
|
||||
AcquireVPNConnectionLease(ctx context.Context, input AcquireVPNConnectionLeaseInput, expiresAt time.Time, fencingToken string) (VPNConnectionLease, error)
|
||||
RenewVPNConnectionLease(ctx context.Context, input RenewVPNConnectionLeaseInput, expiresAt time.Time) (VPNConnectionLease, error)
|
||||
ReleaseVPNConnectionLease(ctx context.Context, input ReleaseVPNConnectionLeaseInput) (VPNConnectionLease, error)
|
||||
FenceVPNConnectionLease(ctx context.Context, input FenceVPNConnectionLeaseInput) (VPNConnectionLease, error)
|
||||
GetActiveVPNConnectionLease(ctx context.Context, clusterID, vpnConnectionID string) (VPNConnectionLease, error)
|
||||
CheckVPNLeaseOwnerEligibility(ctx context.Context, clusterID, vpnConnectionID, ownerNodeID string) (VPNLeaseOwnerEligibility, error)
|
||||
ExpireStaleVPNConnectionLeases(ctx context.Context, clusterID string, now time.Time) ([]VPNConnectionLease, error)
|
||||
ListNodeVPNAssignments(ctx context.Context, clusterID, nodeID string) ([]NodeVPNAssignment, error)
|
||||
ReportNodeVPNAssignmentStatus(ctx context.Context, input ReportNodeVPNAssignmentStatusInput) (NodeVPNAssignmentStatus, error)
|
||||
|
||||
RecordAudit(ctx context.Context, event ClusterAuditEvent) error
|
||||
ListAuditEvents(ctx context.Context, clusterID string, limit int) ([]ClusterAuditEvent, error)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,43 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const joinTokenHashPrefix = "sha256:"
|
||||
|
||||
func generateJoinToken() (string, error) {
|
||||
var random [32]byte
|
||||
if _, err := rand.Read(random[:]); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return "rap_join_" + base64.RawURLEncoding.EncodeToString(random[:]), nil
|
||||
}
|
||||
|
||||
func hashJoinToken(token string) (string, error) {
|
||||
trimmed := strings.TrimSpace(token)
|
||||
if trimmed == "" {
|
||||
return "", errors.New("join token is required")
|
||||
}
|
||||
sum := sha256.Sum256([]byte(trimmed))
|
||||
return joinTokenHashPrefix + hex.EncodeToString(sum[:]), nil
|
||||
}
|
||||
|
||||
func isPlatformAdminRole(role string) bool {
|
||||
return role == PlatformRoleAdmin || role == PlatformRoleRecoveryAdmin
|
||||
}
|
||||
|
||||
func isAllowedNodeRole(role string) bool {
|
||||
_, ok := allowedNodeRoles[role]
|
||||
return ok
|
||||
}
|
||||
|
||||
func defaultJoinTokenExpiry(now time.Time) time.Time {
|
||||
return now.Add(30 * time.Minute)
|
||||
}
|
||||
@@ -0,0 +1,344 @@
|
||||
package identitysource
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/module"
|
||||
)
|
||||
|
||||
type Module struct {
|
||||
db *pgxpool.Pool
|
||||
}
|
||||
|
||||
type IdentitySource struct {
|
||||
ID string `json:"id"`
|
||||
OrganizationID string `json:"organization_id"`
|
||||
Kind string `json:"kind"`
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
Config json.RawMessage `json:"config"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type IdentityMapping struct {
|
||||
ID string `json:"id"`
|
||||
IdentitySourceID string `json:"identity_source_id"`
|
||||
MappingType string `json:"mapping_type"`
|
||||
ExternalSelector json.RawMessage `json:"external_selector"`
|
||||
InternalTarget json.RawMessage `json:"internal_target"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type upsertIdentitySourceRequest struct {
|
||||
ActorUserID string `json:"actor_user_id"`
|
||||
OrganizationID string `json:"organization_id"`
|
||||
Kind string `json:"kind"`
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
Config json.RawMessage `json:"config"`
|
||||
IdentityMappings []struct {
|
||||
MappingType string `json:"mapping_type"`
|
||||
ExternalSelector json.RawMessage `json:"external_selector"`
|
||||
InternalTarget json.RawMessage `json:"internal_target"`
|
||||
} `json:"identity_mappings"`
|
||||
}
|
||||
|
||||
func NewModule(deps module.Dependencies) *Module {
|
||||
return &Module{db: deps.Infra.DB}
|
||||
}
|
||||
|
||||
func (m *Module) Name() string {
|
||||
return "identitysource"
|
||||
}
|
||||
|
||||
func (m *Module) RegisterRoutes(router chi.Router) {
|
||||
router.Route("/identity-sources", func(r chi.Router) {
|
||||
r.Get("/", m.listIdentitySources)
|
||||
r.Post("/", m.createIdentitySource)
|
||||
r.Get("/{identitySourceID}", m.getIdentitySource)
|
||||
r.Put("/{identitySourceID}", m.updateIdentitySource)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Module) listIdentitySources(w http.ResponseWriter, r *http.Request) {
|
||||
orgID := r.URL.Query().Get("organization_id")
|
||||
if orgID == "" {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "organization_id is required")
|
||||
return
|
||||
}
|
||||
rows, err := m.db.Query(r.Context(), `
|
||||
SELECT id, organization_id, kind, name, status, config, created_at, updated_at
|
||||
FROM identity_sources
|
||||
WHERE organization_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`, orgID)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []IdentitySource
|
||||
for rows.Next() {
|
||||
item, err := scanIdentitySource(rows)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusOK, map[string]any{"identity_sources": items})
|
||||
}
|
||||
|
||||
func (m *Module) getIdentitySource(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "identitySourceID")
|
||||
item, err := m.getByID(r.Context(), id)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
httpx.WriteError(w, http.StatusNotFound, "identity source not found")
|
||||
return
|
||||
}
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
mappings, err := m.listMappings(r.Context(), id)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusOK, map[string]any{
|
||||
"identity_source": item,
|
||||
"identity_mappings": mappings,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Module) createIdentitySource(w http.ResponseWriter, r *http.Request) {
|
||||
req, err := decodeRequest(r)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
item := IdentitySource{
|
||||
ID: uuid.NewString(),
|
||||
OrganizationID: req.OrganizationID,
|
||||
Kind: req.Kind,
|
||||
Name: req.Name,
|
||||
Status: req.Status,
|
||||
Config: req.Config,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
tx, err := m.db.Begin(r.Context())
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
defer tx.Rollback(r.Context())
|
||||
if _, err := tx.Exec(r.Context(), `
|
||||
INSERT INTO identity_sources (id, organization_id, kind, name, status, config, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6::jsonb, $7, $8)
|
||||
`, item.ID, item.OrganizationID, item.Kind, item.Name, item.Status, []byte(item.Config), item.CreatedAt, item.UpdatedAt); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
mappings, err := upsertMappings(r.Context(), tx, item.ID, req.IdentityMappings)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if err := tx.Commit(r.Context()); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusCreated, map[string]any{
|
||||
"identity_source": item,
|
||||
"identity_mappings": mappings,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Module) updateIdentitySource(w http.ResponseWriter, r *http.Request) {
|
||||
req, err := decodeRequest(r)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
id := chi.URLParam(r, "identitySourceID")
|
||||
tx, err := m.db.Begin(r.Context())
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
defer tx.Rollback(r.Context())
|
||||
tag, err := tx.Exec(r.Context(), `
|
||||
UPDATE identity_sources
|
||||
SET organization_id = $2, kind = $3, name = $4, status = $5, config = $6::jsonb, updated_at = $7
|
||||
WHERE id = $1
|
||||
`, id, req.OrganizationID, req.Kind, req.Name, req.Status, []byte(req.Config), time.Now().UTC())
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
httpx.WriteError(w, http.StatusNotFound, "identity source not found")
|
||||
return
|
||||
}
|
||||
if _, err := tx.Exec(r.Context(), `DELETE FROM identity_mappings WHERE identity_source_id = $1`, id); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
mappings, err := upsertMappings(r.Context(), tx, id, req.IdentityMappings)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if err := tx.Commit(r.Context()); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
item, err := m.getByID(r.Context(), id)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusOK, map[string]any{
|
||||
"identity_source": item,
|
||||
"identity_mappings": mappings,
|
||||
})
|
||||
}
|
||||
|
||||
func decodeRequest(r *http.Request) (*upsertIdentitySourceRequest, error) {
|
||||
var req upsertIdentitySourceRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.New("invalid identity source payload")
|
||||
}
|
||||
if req.ActorUserID == "" || req.OrganizationID == "" || req.Kind == "" || req.Name == "" {
|
||||
return nil, errors.New("actor_user_id, organization_id, kind, and name are required")
|
||||
}
|
||||
if req.Status == "" {
|
||||
req.Status = "active"
|
||||
}
|
||||
if len(req.Config) == 0 {
|
||||
req.Config = json.RawMessage(`{}`)
|
||||
}
|
||||
if !json.Valid(req.Config) {
|
||||
return nil, errors.New("config must be valid json")
|
||||
}
|
||||
for _, mapping := range req.IdentityMappings {
|
||||
if len(mapping.ExternalSelector) == 0 {
|
||||
mapping.ExternalSelector = json.RawMessage(`{}`)
|
||||
}
|
||||
if len(mapping.InternalTarget) == 0 {
|
||||
mapping.InternalTarget = json.RawMessage(`{}`)
|
||||
}
|
||||
}
|
||||
return &req, nil
|
||||
}
|
||||
|
||||
func (m *Module) getByID(ctx context.Context, id string) (IdentitySource, error) {
|
||||
row := m.db.QueryRow(ctx, `
|
||||
SELECT id, organization_id, kind, name, status, config, created_at, updated_at
|
||||
FROM identity_sources
|
||||
WHERE id = $1
|
||||
`, id)
|
||||
return scanIdentitySource(row)
|
||||
}
|
||||
|
||||
func (m *Module) listMappings(ctx context.Context, sourceID string) ([]IdentityMapping, error) {
|
||||
rows, err := m.db.Query(ctx, `
|
||||
SELECT id, identity_source_id, mapping_type, external_selector, internal_target, created_at, updated_at
|
||||
FROM identity_mappings
|
||||
WHERE identity_source_id = $1
|
||||
ORDER BY created_at ASC
|
||||
`, sourceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var mappings []IdentityMapping
|
||||
for rows.Next() {
|
||||
item, err := scanIdentityMapping(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mappings = append(mappings, item)
|
||||
}
|
||||
return mappings, rows.Err()
|
||||
}
|
||||
|
||||
func upsertMappings(ctx context.Context, tx pgx.Tx, sourceID string, requested []struct {
|
||||
MappingType string `json:"mapping_type"`
|
||||
ExternalSelector json.RawMessage `json:"external_selector"`
|
||||
InternalTarget json.RawMessage `json:"internal_target"`
|
||||
}) ([]IdentityMapping, error) {
|
||||
now := time.Now().UTC()
|
||||
items := make([]IdentityMapping, 0, len(requested))
|
||||
for _, mapping := range requested {
|
||||
external := mapping.ExternalSelector
|
||||
if len(external) == 0 {
|
||||
external = json.RawMessage(`{}`)
|
||||
}
|
||||
internal := mapping.InternalTarget
|
||||
if len(internal) == 0 {
|
||||
internal = json.RawMessage(`{}`)
|
||||
}
|
||||
item := IdentityMapping{
|
||||
ID: uuid.NewString(),
|
||||
IdentitySourceID: sourceID,
|
||||
MappingType: mapping.MappingType,
|
||||
ExternalSelector: external,
|
||||
InternalTarget: internal,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO identity_mappings (
|
||||
id, identity_source_id, mapping_type, external_selector, internal_target, created_at, updated_at
|
||||
) VALUES ($1, $2, $3, $4::jsonb, $5::jsonb, $6, $7)
|
||||
`, item.ID, item.IdentitySourceID, item.MappingType, []byte(item.ExternalSelector), []byte(item.InternalTarget), item.CreatedAt, item.UpdatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
type rowScanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func scanIdentitySource(row rowScanner) (IdentitySource, error) {
|
||||
var item IdentitySource
|
||||
if err := row.Scan(&item.ID, &item.OrganizationID, &item.Kind, &item.Name, &item.Status, &item.Config, &item.CreatedAt, &item.UpdatedAt); err != nil {
|
||||
return IdentitySource{}, err
|
||||
}
|
||||
if len(item.Config) == 0 {
|
||||
item.Config = json.RawMessage(`{}`)
|
||||
}
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func scanIdentityMapping(row rowScanner) (IdentityMapping, error) {
|
||||
var item IdentityMapping
|
||||
if err := row.Scan(&item.ID, &item.IdentitySourceID, &item.MappingType, &item.ExternalSelector, &item.InternalTarget, &item.CreatedAt, &item.UpdatedAt); err != nil {
|
||||
return IdentityMapping{}, err
|
||||
}
|
||||
if len(item.ExternalSelector) == 0 {
|
||||
item.ExternalSelector = json.RawMessage(`{}`)
|
||||
}
|
||||
if len(item.InternalTarget) == 0 {
|
||||
item.InternalTarget = json.RawMessage(`{}`)
|
||||
}
|
||||
return item, nil
|
||||
}
|
||||
@@ -0,0 +1,458 @@
|
||||
package node
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/module"
|
||||
)
|
||||
|
||||
type Module struct {
|
||||
db *pgxpool.Pool
|
||||
}
|
||||
|
||||
type Node struct {
|
||||
ID string `json:"id"`
|
||||
OwnerOrganizationID *string `json:"owner_organization_id,omitempty"`
|
||||
NodeKey string `json:"node_key"`
|
||||
Name string `json:"name"`
|
||||
OwnershipType string `json:"ownership_type"`
|
||||
RegistrationStatus string `json:"registration_status"`
|
||||
HealthStatus string `json:"health_status"`
|
||||
VersionState string `json:"version_state"`
|
||||
PartitionState string `json:"partition_state"`
|
||||
DesiredVersion *string `json:"desired_version,omitempty"`
|
||||
ReportedVersion *string `json:"reported_version,omitempty"`
|
||||
LastSeenAt *time.Time `json:"last_seen_at,omitempty"`
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type NodeCapability struct {
|
||||
NodeID string `json:"node_id"`
|
||||
Capability string `json:"capability"`
|
||||
Value json.RawMessage `json:"value"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type NodeService struct {
|
||||
NodeID string `json:"node_id"`
|
||||
ServiceType string `json:"service_type"`
|
||||
Enabled bool `json:"enabled"`
|
||||
DesiredState string `json:"desired_state"`
|
||||
ReportedState string `json:"reported_state"`
|
||||
LastReportedAt *time.Time `json:"last_reported_at,omitempty"`
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type NodeUpdatePolicy struct {
|
||||
NodeID string `json:"node_id"`
|
||||
Mode string `json:"mode"`
|
||||
Channel string `json:"channel"`
|
||||
MaintenanceWindow json.RawMessage `json:"maintenance_window"`
|
||||
Canary bool `json:"canary"`
|
||||
AutomaticRollout bool `json:"automatic_rollout"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type NodePartitionState struct {
|
||||
NodeID string `json:"node_id"`
|
||||
ClusterState string `json:"cluster_state"`
|
||||
RecoveryMode string `json:"recovery_mode"`
|
||||
Notes *string `json:"notes,omitempty"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type upsertNodeRequest struct {
|
||||
ActorUserID string `json:"actor_user_id"`
|
||||
OwnerOrganizationID *string `json:"owner_organization_id"`
|
||||
NodeKey string `json:"node_key"`
|
||||
Name string `json:"name"`
|
||||
OwnershipType string `json:"ownership_type"`
|
||||
DesiredVersion *string `json:"desired_version"`
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
Capabilities []struct {
|
||||
Capability string `json:"capability"`
|
||||
Value json.RawMessage `json:"value"`
|
||||
} `json:"capabilities"`
|
||||
Services []struct {
|
||||
ServiceType string `json:"service_type"`
|
||||
Enabled bool `json:"enabled"`
|
||||
DesiredState string `json:"desired_state"`
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
} `json:"services"`
|
||||
UpdatePolicy struct {
|
||||
Mode string `json:"mode"`
|
||||
Channel string `json:"channel"`
|
||||
MaintenanceWindow json.RawMessage `json:"maintenance_window"`
|
||||
Canary bool `json:"canary"`
|
||||
AutomaticRollout bool `json:"automatic_rollout"`
|
||||
} `json:"update_policy"`
|
||||
PartitionState struct {
|
||||
ClusterState string `json:"cluster_state"`
|
||||
RecoveryMode string `json:"recovery_mode"`
|
||||
Notes *string `json:"notes"`
|
||||
} `json:"partition_state"`
|
||||
}
|
||||
|
||||
func NewModule(deps module.Dependencies) *Module {
|
||||
return &Module{db: deps.Infra.DB}
|
||||
}
|
||||
|
||||
func (m *Module) Name() string {
|
||||
return "node"
|
||||
}
|
||||
|
||||
func (m *Module) RegisterRoutes(router chi.Router) {
|
||||
router.Route("/nodes", func(r chi.Router) {
|
||||
r.Get("/", m.listNodes)
|
||||
r.Post("/", m.createNode)
|
||||
r.Get("/{nodeID}", m.getNode)
|
||||
r.Put("/{nodeID}", m.updateNode)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Module) listNodes(w http.ResponseWriter, r *http.Request) {
|
||||
rows, err := m.db.Query(r.Context(), `
|
||||
SELECT id, owner_organization_id, node_key, name, ownership_type, registration_status, health_status,
|
||||
version_state, partition_state, desired_version, reported_version, last_seen_at, metadata, created_at, updated_at
|
||||
FROM nodes
|
||||
ORDER BY created_at DESC
|
||||
`)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []Node
|
||||
for rows.Next() {
|
||||
item, err := scanNode(rows)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusOK, map[string]any{"nodes": items})
|
||||
}
|
||||
|
||||
func (m *Module) getNode(w http.ResponseWriter, r *http.Request) {
|
||||
nodeID := chi.URLParam(r, "nodeID")
|
||||
item, err := m.getNodeByID(r.Context(), nodeID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
httpx.WriteError(w, http.StatusNotFound, "node not found")
|
||||
return
|
||||
}
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
caps, _ := m.listCapabilities(r.Context(), nodeID)
|
||||
services, _ := m.listServices(r.Context(), nodeID)
|
||||
updatePolicy, _ := m.getUpdatePolicy(r.Context(), nodeID)
|
||||
partitionState, _ := m.getPartitionState(r.Context(), nodeID)
|
||||
httpx.WriteJSON(w, http.StatusOK, map[string]any{
|
||||
"node": item,
|
||||
"capabilities": caps,
|
||||
"services": services,
|
||||
"update_policy": updatePolicy,
|
||||
"partition_state": partitionState,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Module) createNode(w http.ResponseWriter, r *http.Request) {
|
||||
req, err := decodeNodeRequest(r)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
item := Node{
|
||||
ID: uuid.NewString(),
|
||||
OwnerOrganizationID: req.OwnerOrganizationID,
|
||||
NodeKey: req.NodeKey,
|
||||
Name: req.Name,
|
||||
OwnershipType: req.OwnershipType,
|
||||
RegistrationStatus: "pending",
|
||||
HealthStatus: "unknown",
|
||||
VersionState: "unknown",
|
||||
PartitionState: "healthy",
|
||||
DesiredVersion: req.DesiredVersion,
|
||||
Metadata: req.Metadata,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
}
|
||||
if err := m.persistNode(r.Context(), item, req, true); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusCreated, map[string]any{"node": item})
|
||||
}
|
||||
|
||||
func (m *Module) updateNode(w http.ResponseWriter, r *http.Request) {
|
||||
req, err := decodeNodeRequest(r)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
nodeID := chi.URLParam(r, "nodeID")
|
||||
item, err := m.getNodeByID(r.Context(), nodeID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
httpx.WriteError(w, http.StatusNotFound, "node not found")
|
||||
return
|
||||
}
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
item.OwnerOrganizationID = req.OwnerOrganizationID
|
||||
item.NodeKey = req.NodeKey
|
||||
item.Name = req.Name
|
||||
item.OwnershipType = req.OwnershipType
|
||||
item.DesiredVersion = req.DesiredVersion
|
||||
item.Metadata = req.Metadata
|
||||
item.UpdatedAt = time.Now().UTC()
|
||||
if err := m.persistNode(r.Context(), item, req, false); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusOK, map[string]any{"node": item})
|
||||
}
|
||||
|
||||
func decodeNodeRequest(r *http.Request) (*upsertNodeRequest, error) {
|
||||
var req upsertNodeRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.New("invalid node payload")
|
||||
}
|
||||
if req.ActorUserID == "" || req.NodeKey == "" || req.Name == "" || req.OwnershipType == "" {
|
||||
return nil, errors.New("actor_user_id, node_key, name, and ownership_type are required")
|
||||
}
|
||||
if len(req.Metadata) == 0 {
|
||||
req.Metadata = json.RawMessage(`{}`)
|
||||
}
|
||||
if !json.Valid(req.Metadata) {
|
||||
return nil, errors.New("metadata must be valid json")
|
||||
}
|
||||
if req.UpdatePolicy.Mode == "" {
|
||||
req.UpdatePolicy.Mode = "manual"
|
||||
}
|
||||
if req.UpdatePolicy.Channel == "" {
|
||||
req.UpdatePolicy.Channel = "stable"
|
||||
}
|
||||
if len(req.UpdatePolicy.MaintenanceWindow) == 0 {
|
||||
req.UpdatePolicy.MaintenanceWindow = json.RawMessage(`{}`)
|
||||
}
|
||||
if req.PartitionState.ClusterState == "" {
|
||||
req.PartitionState.ClusterState = "healthy"
|
||||
}
|
||||
if req.PartitionState.RecoveryMode == "" {
|
||||
req.PartitionState.RecoveryMode = "normal"
|
||||
}
|
||||
for i := range req.Capabilities {
|
||||
if len(req.Capabilities[i].Value) == 0 {
|
||||
req.Capabilities[i].Value = json.RawMessage(`{}`)
|
||||
}
|
||||
}
|
||||
for i := range req.Services {
|
||||
if req.Services[i].DesiredState == "" {
|
||||
req.Services[i].DesiredState = "disabled"
|
||||
}
|
||||
if len(req.Services[i].Metadata) == 0 {
|
||||
req.Services[i].Metadata = json.RawMessage(`{}`)
|
||||
}
|
||||
}
|
||||
return &req, nil
|
||||
}
|
||||
|
||||
func (m *Module) persistNode(ctx context.Context, item Node, req *upsertNodeRequest, create bool) error {
|
||||
tx, err := m.db.Begin(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
if create {
|
||||
_, err = tx.Exec(ctx, `
|
||||
INSERT INTO nodes (
|
||||
id, owner_organization_id, node_key, name, ownership_type, registration_status, health_status,
|
||||
version_state, partition_state, desired_version, reported_version, last_seen_at, metadata, created_at, updated_at
|
||||
) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13::jsonb,$14,$15)
|
||||
`, item.ID, item.OwnerOrganizationID, item.NodeKey, item.Name, item.OwnershipType, item.RegistrationStatus, item.HealthStatus, item.VersionState, item.PartitionState, item.DesiredVersion, item.ReportedVersion, item.LastSeenAt, []byte(item.Metadata), item.CreatedAt, item.UpdatedAt)
|
||||
} else {
|
||||
_, err = tx.Exec(ctx, `
|
||||
UPDATE nodes
|
||||
SET owner_organization_id=$2, node_key=$3, name=$4, ownership_type=$5, desired_version=$6, metadata=$7::jsonb, updated_at=$8
|
||||
WHERE id=$1
|
||||
`, item.ID, item.OwnerOrganizationID, item.NodeKey, item.Name, item.OwnershipType, item.DesiredVersion, []byte(item.Metadata), item.UpdatedAt)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(ctx, `DELETE FROM node_capabilities WHERE node_id = $1`, item.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, capability := range req.Capabilities {
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO node_capabilities (node_id, capability, value, updated_at)
|
||||
VALUES ($1, $2, $3::jsonb, $4)
|
||||
`, item.ID, capability.Capability, []byte(capability.Value), time.Now().UTC()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := tx.Exec(ctx, `DELETE FROM node_services WHERE node_id = $1`, item.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, service := range req.Services {
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO node_services (
|
||||
node_id, service_type, enabled, desired_state, reported_state, last_reported_at, metadata, updated_at
|
||||
) VALUES ($1, $2, $3, $4, 'unknown', NULL, $5::jsonb, $6)
|
||||
`, item.ID, service.ServiceType, service.Enabled, service.DesiredState, []byte(service.Metadata), time.Now().UTC()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO node_update_policies (
|
||||
node_id, mode, channel, maintenance_window, canary, automatic_rollout, created_at, updated_at
|
||||
) VALUES ($1,$2,$3,$4::jsonb,$5,$6,$7,$8)
|
||||
ON CONFLICT (node_id) DO UPDATE SET
|
||||
mode = EXCLUDED.mode,
|
||||
channel = EXCLUDED.channel,
|
||||
maintenance_window = EXCLUDED.maintenance_window,
|
||||
canary = EXCLUDED.canary,
|
||||
automatic_rollout = EXCLUDED.automatic_rollout,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
`, item.ID, req.UpdatePolicy.Mode, req.UpdatePolicy.Channel, []byte(req.UpdatePolicy.MaintenanceWindow), req.UpdatePolicy.Canary, req.UpdatePolicy.AutomaticRollout, time.Now().UTC(), time.Now().UTC()); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO node_partition_states (node_id, cluster_state, recovery_mode, notes, updated_at)
|
||||
VALUES ($1,$2,$3,$4,$5)
|
||||
ON CONFLICT (node_id) DO UPDATE SET
|
||||
cluster_state = EXCLUDED.cluster_state,
|
||||
recovery_mode = EXCLUDED.recovery_mode,
|
||||
notes = EXCLUDED.notes,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
`, item.ID, req.PartitionState.ClusterState, req.PartitionState.RecoveryMode, req.PartitionState.Notes, time.Now().UTC()); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func (m *Module) getNodeByID(ctx context.Context, nodeID string) (Node, error) {
|
||||
row := m.db.QueryRow(ctx, `
|
||||
SELECT id, owner_organization_id, node_key, name, ownership_type, registration_status, health_status,
|
||||
version_state, partition_state, desired_version, reported_version, last_seen_at, metadata, created_at, updated_at
|
||||
FROM nodes
|
||||
WHERE id = $1
|
||||
`, nodeID)
|
||||
return scanNode(row)
|
||||
}
|
||||
|
||||
func (m *Module) listCapabilities(ctx context.Context, nodeID string) ([]NodeCapability, error) {
|
||||
rows, err := m.db.Query(ctx, `SELECT node_id, capability, value, updated_at FROM node_capabilities WHERE node_id = $1 ORDER BY capability`, nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []NodeCapability
|
||||
for rows.Next() {
|
||||
var item NodeCapability
|
||||
if err := rows.Scan(&item.NodeID, &item.Capability, &item.Value, &item.UpdatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, item)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (m *Module) listServices(ctx context.Context, nodeID string) ([]NodeService, error) {
|
||||
rows, err := m.db.Query(ctx, `
|
||||
SELECT node_id, service_type, enabled, desired_state, reported_state, last_reported_at, metadata, updated_at
|
||||
FROM node_services WHERE node_id = $1 ORDER BY service_type
|
||||
`, nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []NodeService
|
||||
for rows.Next() {
|
||||
var item NodeService
|
||||
if err := rows.Scan(&item.NodeID, &item.ServiceType, &item.Enabled, &item.DesiredState, &item.ReportedState, &item.LastReportedAt, &item.Metadata, &item.UpdatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, item)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (m *Module) getUpdatePolicy(ctx context.Context, nodeID string) (*NodeUpdatePolicy, error) {
|
||||
row := m.db.QueryRow(ctx, `
|
||||
SELECT node_id, mode, channel, maintenance_window, canary, automatic_rollout, created_at, updated_at
|
||||
FROM node_update_policies WHERE node_id = $1
|
||||
`, nodeID)
|
||||
var item NodeUpdatePolicy
|
||||
if err := row.Scan(&item.NodeID, &item.Mode, &item.Channel, &item.MaintenanceWindow, &item.Canary, &item.AutomaticRollout, &item.CreatedAt, &item.UpdatedAt); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &item, nil
|
||||
}
|
||||
|
||||
func (m *Module) getPartitionState(ctx context.Context, nodeID string) (*NodePartitionState, error) {
|
||||
row := m.db.QueryRow(ctx, `
|
||||
SELECT node_id, cluster_state, recovery_mode, notes, updated_at
|
||||
FROM node_partition_states WHERE node_id = $1
|
||||
`, nodeID)
|
||||
var item NodePartitionState
|
||||
if err := row.Scan(&item.NodeID, &item.ClusterState, &item.RecoveryMode, &item.Notes, &item.UpdatedAt); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &item, nil
|
||||
}
|
||||
|
||||
type rowScanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func scanNode(row rowScanner) (Node, error) {
|
||||
var item Node
|
||||
if err := row.Scan(
|
||||
&item.ID,
|
||||
&item.OwnerOrganizationID,
|
||||
&item.NodeKey,
|
||||
&item.Name,
|
||||
&item.OwnershipType,
|
||||
&item.RegistrationStatus,
|
||||
&item.HealthStatus,
|
||||
&item.VersionState,
|
||||
&item.PartitionState,
|
||||
&item.DesiredVersion,
|
||||
&item.ReportedVersion,
|
||||
&item.LastSeenAt,
|
||||
&item.Metadata,
|
||||
&item.CreatedAt,
|
||||
&item.UpdatedAt,
|
||||
); err != nil {
|
||||
return Node{}, err
|
||||
}
|
||||
if len(item.Metadata) == 0 {
|
||||
item.Metadata = json.RawMessage(`{}`)
|
||||
}
|
||||
return item, nil
|
||||
}
|
||||
@@ -0,0 +1,356 @@
|
||||
package nodeagent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
clustermodule "github.com/example/remote-access-platform/backend/internal/modules/cluster"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/module"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
|
||||
)
|
||||
|
||||
type Module struct {
|
||||
db *pgxpool.Pool
|
||||
cluster *clustermodule.Service
|
||||
}
|
||||
|
||||
func NewModule(deps module.Dependencies) *Module {
|
||||
clusterStore := clustermodule.NewPostgresStore(deps.Infra.DB)
|
||||
if deps.Config.Secret.EncryptionKeyBase64 != "" {
|
||||
if encryptor, err := secrets.NewEncryptor(deps.Config.Secret.EncryptionKeyBase64, deps.Config.Secret.EncryptionKeyID); err == nil {
|
||||
clusterStore.WithClusterKeyEncryptor(encryptor)
|
||||
}
|
||||
}
|
||||
return &Module{
|
||||
db: deps.Infra.DB,
|
||||
cluster: clustermodule.NewService(clusterStore),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Module) Name() string {
|
||||
return "nodeagent"
|
||||
}
|
||||
|
||||
func (m *Module) RegisterRoutes(router chi.Router) {
|
||||
router.Route("/node-agents", func(r chi.Router) {
|
||||
r.Post("/enroll", m.enrollAgent)
|
||||
r.Post("/enrollments/{requestID}/bootstrap", m.bootstrapEnrollment)
|
||||
r.Post("/register", m.registerAgent)
|
||||
r.Post("/{nodeID}/health", m.reportHealth)
|
||||
r.Post("/{nodeID}/services/status", m.reportServiceStatus)
|
||||
r.Post("/{nodeID}/update-manifest/request", m.requestUpdateManifest)
|
||||
r.Post("/{nodeID}/update-result", m.acknowledgeUpdateResult)
|
||||
r.Post("/{nodeID}/rollback-result", m.reportRollbackResult)
|
||||
r.Get("/{nodeID}/clusters/{clusterID}/vpn-assignments/desired", m.listVPNAssignments)
|
||||
r.Post("/{nodeID}/clusters/{clusterID}/vpn-assignments/{vpnConnectionID}/status", m.reportVPNAssignmentStatus)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Module) enrollAgent(w http.ResponseWriter, r *http.Request) {
|
||||
var payload struct {
|
||||
ClusterID string `json:"cluster_id"`
|
||||
JoinToken string `json:"join_token"`
|
||||
NodeName string `json:"node_name"`
|
||||
NodeFingerprint string `json:"node_fingerprint"`
|
||||
PublicKey string `json:"public_key"`
|
||||
ReportedCapabilities json.RawMessage `json:"reported_capabilities"`
|
||||
ReportedFacts json.RawMessage `json:"reported_facts"`
|
||||
RequestedRoles json.RawMessage `json:"requested_roles"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid agent enrollment payload")
|
||||
return
|
||||
}
|
||||
joinRequest, err := m.cluster.CreateJoinRequest(r.Context(), clustermodule.CreateJoinRequestInput{
|
||||
ClusterID: payload.ClusterID,
|
||||
JoinToken: payload.JoinToken,
|
||||
NodeName: payload.NodeName,
|
||||
NodeFingerprint: payload.NodeFingerprint,
|
||||
PublicKey: payload.PublicKey,
|
||||
ReportedCapabilities: payload.ReportedCapabilities,
|
||||
ReportedFacts: payload.ReportedFacts,
|
||||
RequestedRoles: payload.RequestedRoles,
|
||||
})
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{
|
||||
"status": "pending_approval",
|
||||
"join_request": joinRequest,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Module) bootstrapEnrollment(w http.ResponseWriter, r *http.Request) {
|
||||
var payload struct {
|
||||
ClusterID string `json:"cluster_id"`
|
||||
NodeFingerprint string `json:"node_fingerprint"`
|
||||
PublicKey string `json:"public_key"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid enrollment bootstrap payload")
|
||||
return
|
||||
}
|
||||
result, err := m.cluster.GetJoinRequestBootstrap(r.Context(), clustermodule.GetJoinRequestBootstrapInput{
|
||||
ClusterID: payload.ClusterID,
|
||||
JoinRequestID: chi.URLParam(r, "requestID"),
|
||||
NodeFingerprint: payload.NodeFingerprint,
|
||||
PublicKey: payload.PublicKey,
|
||||
})
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (m *Module) registerAgent(w http.ResponseWriter, r *http.Request) {
|
||||
var payload struct {
|
||||
NodeKey string `json:"node_key"`
|
||||
Name string `json:"name"`
|
||||
OwnershipType string `json:"ownership_type"`
|
||||
OwnerOrganizationID *string `json:"owner_organization_id"`
|
||||
ReportedVersion *string `json:"reported_version"`
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid agent registration payload")
|
||||
return
|
||||
}
|
||||
if payload.NodeKey == "" || payload.Name == "" || payload.OwnershipType == "" {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "node_key, name, and ownership_type are required")
|
||||
return
|
||||
}
|
||||
if len(payload.Metadata) == 0 {
|
||||
payload.Metadata = json.RawMessage(`{}`)
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
nodeID := uuid.NewString()
|
||||
if err := m.db.QueryRow(r.Context(), `
|
||||
INSERT INTO nodes (
|
||||
id, owner_organization_id, node_key, name, ownership_type, registration_status, health_status,
|
||||
version_state, partition_state, desired_version, reported_version, last_seen_at, metadata, created_at, updated_at
|
||||
) VALUES ($1, $2, $3, $4, $5, 'active', 'unknown', 'unknown', 'healthy', NULL, $6, $7, $8::jsonb, $9, $10)
|
||||
ON CONFLICT (node_key) DO UPDATE SET
|
||||
name = EXCLUDED.name,
|
||||
ownership_type = EXCLUDED.ownership_type,
|
||||
owner_organization_id = EXCLUDED.owner_organization_id,
|
||||
registration_status = 'active',
|
||||
reported_version = EXCLUDED.reported_version,
|
||||
last_seen_at = EXCLUDED.last_seen_at,
|
||||
metadata = EXCLUDED.metadata,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
RETURNING id
|
||||
`, nodeID, payload.OwnerOrganizationID, payload.NodeKey, payload.Name, payload.OwnershipType, payload.ReportedVersion, now, []byte(payload.Metadata), now, now).Scan(&nodeID); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusOK, map[string]any{
|
||||
"node_id": nodeID,
|
||||
"status": "registered",
|
||||
"legacy": true,
|
||||
"warning": "direct node-agent registration is retained for compatibility; production enrollment must use /node-agents/enroll",
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Module) reportHealth(w http.ResponseWriter, r *http.Request) {
|
||||
var payload struct {
|
||||
HealthStatus string `json:"health_status"`
|
||||
ReportedVersion *string `json:"reported_version"`
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid node health payload")
|
||||
return
|
||||
}
|
||||
if payload.HealthStatus == "" {
|
||||
payload.HealthStatus = "unknown"
|
||||
}
|
||||
if len(payload.Metadata) == 0 {
|
||||
payload.Metadata = json.RawMessage(`{}`)
|
||||
}
|
||||
if _, err := m.db.Exec(r.Context(), `
|
||||
UPDATE nodes
|
||||
SET health_status = $2, reported_version = COALESCE($3, reported_version), last_seen_at = $4, metadata = $5::jsonb, updated_at = $4
|
||||
WHERE id = $1
|
||||
`, chi.URLParam(r, "nodeID"), payload.HealthStatus, payload.ReportedVersion, time.Now().UTC(), []byte(payload.Metadata)); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{"status": "accepted"})
|
||||
}
|
||||
|
||||
func (m *Module) reportServiceStatus(w http.ResponseWriter, r *http.Request) {
|
||||
var payload struct {
|
||||
Services []struct {
|
||||
ServiceType string `json:"service_type"`
|
||||
ReportedState string `json:"reported_state"`
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
} `json:"services"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid node service status payload")
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
for _, service := range payload.Services {
|
||||
if len(service.Metadata) == 0 {
|
||||
service.Metadata = json.RawMessage(`{}`)
|
||||
}
|
||||
if _, err := m.db.Exec(r.Context(), `
|
||||
INSERT INTO node_services (
|
||||
node_id, service_type, enabled, desired_state, reported_state, last_reported_at, metadata, updated_at
|
||||
) VALUES ($1, $2, FALSE, 'disabled', $3, $4, $5::jsonb, $4)
|
||||
ON CONFLICT (node_id, service_type) DO UPDATE SET
|
||||
reported_state = EXCLUDED.reported_state,
|
||||
last_reported_at = EXCLUDED.last_reported_at,
|
||||
metadata = EXCLUDED.metadata,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
`, chi.URLParam(r, "nodeID"), service.ServiceType, service.ReportedState, now, []byte(service.Metadata)); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{"status": "accepted"})
|
||||
}
|
||||
|
||||
func (m *Module) listVPNAssignments(w http.ResponseWriter, r *http.Request) {
|
||||
items, err := m.cluster.ListNodeVPNAssignments(r.Context(), chi.URLParam(r, "clusterID"), chi.URLParam(r, "nodeID"))
|
||||
if writeClusterServiceError(w, err) {
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusOK, map[string]any{
|
||||
"vpn_assignments": items,
|
||||
"runtime_execution_enabled": false,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Module) reportVPNAssignmentStatus(w http.ResponseWriter, r *http.Request) {
|
||||
var payload struct {
|
||||
ObservedStatus string `json:"observed_status"`
|
||||
StatusPayload json.RawMessage `json:"status_payload"`
|
||||
ObservedAt *time.Time `json:"observed_at"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid vpn assignment status payload")
|
||||
return
|
||||
}
|
||||
observedAt := time.Time{}
|
||||
if payload.ObservedAt != nil {
|
||||
observedAt = *payload.ObservedAt
|
||||
}
|
||||
item, err := m.cluster.ReportNodeVPNAssignmentStatus(r.Context(), clustermodule.ReportNodeVPNAssignmentStatusInput{
|
||||
ClusterID: chi.URLParam(r, "clusterID"),
|
||||
NodeID: chi.URLParam(r, "nodeID"),
|
||||
VPNConnectionID: chi.URLParam(r, "vpnConnectionID"),
|
||||
ObservedStatus: payload.ObservedStatus,
|
||||
StatusPayload: payload.StatusPayload,
|
||||
ObservedAt: observedAt,
|
||||
})
|
||||
if writeClusterServiceError(w, err) {
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{
|
||||
"vpn_assignment_status": item,
|
||||
"runtime_execution_enabled": false,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Module) requestUpdateManifest(w http.ResponseWriter, r *http.Request) {
|
||||
nodeID := chi.URLParam(r, "nodeID")
|
||||
var mode, channel string
|
||||
var canary, automatic bool
|
||||
var desiredVersion *string
|
||||
if err := m.db.QueryRow(r.Context(), `
|
||||
SELECT n.desired_version, p.mode, p.channel, p.canary, p.automatic_rollout
|
||||
FROM nodes n
|
||||
LEFT JOIN node_update_policies p ON p.node_id = n.id
|
||||
WHERE n.id = $1
|
||||
`, nodeID).Scan(&desiredVersion, &mode, &channel, &canary, &automatic); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusOK, map[string]any{
|
||||
"manifest": map[string]any{
|
||||
"node_id": nodeID,
|
||||
"desired_version": desiredVersion,
|
||||
"mode": mode,
|
||||
"channel": channel,
|
||||
"canary": canary,
|
||||
"automatic_rollout": automatic,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Module) acknowledgeUpdateResult(w http.ResponseWriter, r *http.Request) {
|
||||
m.recordUpdateRun(w, r, "update")
|
||||
}
|
||||
|
||||
func (m *Module) reportRollbackResult(w http.ResponseWriter, r *http.Request) {
|
||||
m.recordUpdateRun(w, r, "rollback")
|
||||
}
|
||||
|
||||
func (m *Module) recordUpdateRun(w http.ResponseWriter, r *http.Request, action string) {
|
||||
var payload struct {
|
||||
TargetVersion string `json:"target_version"`
|
||||
Status string `json:"status"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid update result payload")
|
||||
return
|
||||
}
|
||||
if payload.Status == "" {
|
||||
payload.Status = "acknowledged"
|
||||
}
|
||||
if len(payload.Payload) == 0 {
|
||||
payload.Payload = json.RawMessage(`{}`)
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
runID := uuid.NewString()
|
||||
if _, err := m.db.Exec(r.Context(), `
|
||||
INSERT INTO node_agent_update_runs (
|
||||
id, node_id, action, target_version, status, requested_at, acknowledged_at, completed_at, payload
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $6, CASE WHEN $5 IN ('succeeded', 'failed') THEN $6 ELSE NULL END, $7::jsonb)
|
||||
`, runID, chi.URLParam(r, "nodeID"), action, payload.TargetVersion, payload.Status, now, []byte(payload.Payload)); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if action == "update" && payload.Status == "succeeded" {
|
||||
_, _ = m.db.Exec(r.Context(), `
|
||||
UPDATE nodes
|
||||
SET reported_version = $2, version_state = 'current', updated_at = $3
|
||||
WHERE id = $1
|
||||
`, chi.URLParam(r, "nodeID"), payload.TargetVersion, now)
|
||||
}
|
||||
if action == "rollback" && payload.Status == "succeeded" {
|
||||
_, _ = m.db.Exec(r.Context(), `
|
||||
UPDATE nodes
|
||||
SET reported_version = $2, version_state = 'rollback', updated_at = $3
|
||||
WHERE id = $1
|
||||
`, chi.URLParam(r, "nodeID"), payload.TargetVersion, now)
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{"status": "accepted", "run_id": runID})
|
||||
}
|
||||
|
||||
func writeClusterServiceError(w http.ResponseWriter, err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
switch {
|
||||
case errors.Is(err, clustermodule.ErrVPNLeaseOwnerNotAllowed), errors.Is(err, clustermodule.ErrVPNLeaseOwnerRoleRequired):
|
||||
httpx.WriteError(w, http.StatusForbidden, err.Error())
|
||||
case errors.Is(err, clustermodule.ErrInvalidPayload):
|
||||
httpx.WriteError(w, http.StatusBadRequest, err.Error())
|
||||
default:
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package organization
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestTenantSafeTopologyExposureDoesNotExposeCoreMesh(t *testing.T) {
|
||||
value := tenantSafeTopologyExposure()
|
||||
forbidden := []string{
|
||||
"core_node_id",
|
||||
"mesh_route",
|
||||
"cluster_private_topology",
|
||||
"certificate_serial",
|
||||
}
|
||||
for _, token := range forbidden {
|
||||
if value == token {
|
||||
t.Fatalf("topology exposure leaked forbidden token %q", token)
|
||||
}
|
||||
}
|
||||
if value != "tenant_safe_no_core_mesh_topology" {
|
||||
t.Fatalf("unexpected topology exposure marker: %q", value)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,518 @@
|
||||
package organization
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/authority"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/module"
|
||||
)
|
||||
|
||||
const (
|
||||
RoleOrgOwner = "org_owner"
|
||||
RoleOrgAdmin = "org_admin"
|
||||
RoleOrgOperator = "org_operator"
|
||||
RoleOrgMember = "org_member"
|
||||
RoleOrgViewer = "org_viewer"
|
||||
)
|
||||
|
||||
type Module struct {
|
||||
db *pgxpool.Pool
|
||||
authority *authority.Verifier
|
||||
}
|
||||
|
||||
type Organization struct {
|
||||
ID string `json:"id"`
|
||||
Slug string `json:"slug"`
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type Membership struct {
|
||||
ID string `json:"id"`
|
||||
OrganizationID string `json:"organization_id"`
|
||||
UserID string `json:"user_id"`
|
||||
RoleID string `json:"role_id"`
|
||||
Status string `json:"status"`
|
||||
InvitedByUser *string `json:"invited_by_user_id,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type AdminSummary struct {
|
||||
OrganizationID string `json:"organization_id"`
|
||||
ResourceCount int64 `json:"resource_count"`
|
||||
ActiveSessionCount int64 `json:"active_session_count"`
|
||||
ServiceEndpoints []ServiceSummary `json:"service_endpoints"`
|
||||
ConnectorStatus map[string]any `json:"connector_status"`
|
||||
RecentAudit []OrgAuditEvent `json:"recent_audit"`
|
||||
TopologyExposure string `json:"topology_exposure"`
|
||||
}
|
||||
|
||||
type ServiceSummary struct {
|
||||
Protocol string `json:"protocol"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
type OrgAuditEvent struct {
|
||||
ID string `json:"id"`
|
||||
EventType string `json:"event_type"`
|
||||
TargetType string `json:"target_type"`
|
||||
TargetID string `json:"target_id"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type createOrganizationRequest struct {
|
||||
ActorUserID string `json:"actor_user_id"`
|
||||
Slug string `json:"slug"`
|
||||
Name string `json:"name"`
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
}
|
||||
|
||||
type addMembershipRequest struct {
|
||||
ActorUserID string `json:"actor_user_id"`
|
||||
UserID string `json:"user_id"`
|
||||
RoleID string `json:"role_id"`
|
||||
}
|
||||
|
||||
func NewModule(deps module.Dependencies) *Module {
|
||||
authorityVerifier, _ := authority.NewVerifier(deps.Config.Installation)
|
||||
return &Module{db: deps.Infra.DB, authority: authorityVerifier}
|
||||
}
|
||||
|
||||
func (m *Module) Name() string {
|
||||
return "organization"
|
||||
}
|
||||
|
||||
func (m *Module) RegisterRoutes(router chi.Router) {
|
||||
router.Route("/organizations", func(r chi.Router) {
|
||||
r.Get("/", m.listOrganizations)
|
||||
r.Post("/", m.createOrganization)
|
||||
r.Get("/{organizationID}", m.getOrganization)
|
||||
r.Get("/{organizationID}/admin-summary", m.getAdminSummary)
|
||||
r.Get("/{organizationID}/memberships", m.listMemberships)
|
||||
r.Post("/{organizationID}/memberships", m.addMembership)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Module) listOrganizations(w http.ResponseWriter, r *http.Request) {
|
||||
userID := r.URL.Query().Get("user_id")
|
||||
if userID == "" {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
|
||||
return
|
||||
}
|
||||
platformRole, err := m.getPlatformRole(r.Context(), userID)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
var rows pgx.Rows
|
||||
if isPlatformAdmin(platformRole) {
|
||||
rows, err = m.db.Query(r.Context(), `
|
||||
SELECT id, slug, name, status, metadata, created_at, updated_at
|
||||
FROM organizations
|
||||
ORDER BY created_at DESC
|
||||
`)
|
||||
} else {
|
||||
rows, err = m.db.Query(r.Context(), `
|
||||
SELECT o.id, o.slug, o.name, o.status, o.metadata, o.created_at, o.updated_at
|
||||
FROM organizations o
|
||||
INNER JOIN organization_memberships om ON om.organization_id = o.id
|
||||
WHERE om.user_id = $1 AND om.status = 'active'
|
||||
ORDER BY o.created_at DESC
|
||||
`, userID)
|
||||
}
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
var organizations []Organization
|
||||
for rows.Next() {
|
||||
org, err := scanOrganization(rows)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
organizations = append(organizations, org)
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusOK, map[string]any{"organizations": organizations})
|
||||
}
|
||||
|
||||
func (m *Module) getOrganization(w http.ResponseWriter, r *http.Request) {
|
||||
orgID := chi.URLParam(r, "organizationID")
|
||||
userID := r.URL.Query().Get("user_id")
|
||||
if userID == "" {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
|
||||
return
|
||||
}
|
||||
if err := m.ensureOrgAccess(r.Context(), orgID, userID, false); err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if errors.Is(err, pgx.ErrNoRows) || errors.Is(err, errForbidden) {
|
||||
status = http.StatusForbidden
|
||||
}
|
||||
httpx.WriteError(w, status, err.Error())
|
||||
return
|
||||
}
|
||||
org, err := m.getOrganizationByID(r.Context(), orgID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
httpx.WriteError(w, http.StatusNotFound, "organization not found")
|
||||
return
|
||||
}
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusOK, map[string]any{"organization": org})
|
||||
}
|
||||
|
||||
func (m *Module) getAdminSummary(w http.ResponseWriter, r *http.Request) {
|
||||
orgID := chi.URLParam(r, "organizationID")
|
||||
actorUserID := r.URL.Query().Get("actor_user_id")
|
||||
if actorUserID == "" {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "actor_user_id is required")
|
||||
return
|
||||
}
|
||||
if err := m.ensureOrgAccess(r.Context(), orgID, actorUserID, true); err != nil {
|
||||
httpx.WriteError(w, http.StatusForbidden, err.Error())
|
||||
return
|
||||
}
|
||||
summary, err := m.loadAdminSummary(r.Context(), orgID)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusOK, map[string]any{"admin_summary": summary})
|
||||
}
|
||||
|
||||
func (m *Module) loadAdminSummary(ctx context.Context, orgID string) (AdminSummary, error) {
|
||||
var resourceCount int64
|
||||
if err := m.db.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM resources
|
||||
WHERE organization_id = $1::uuid
|
||||
`, orgID).Scan(&resourceCount); err != nil {
|
||||
return AdminSummary{}, err
|
||||
}
|
||||
|
||||
var activeSessionCount int64
|
||||
if err := m.db.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM remote_sessions
|
||||
WHERE organization_id = $1::uuid
|
||||
AND state = 'active'
|
||||
`, orgID).Scan(&activeSessionCount); err != nil {
|
||||
return AdminSummary{}, err
|
||||
}
|
||||
|
||||
rows, err := m.db.Query(ctx, `
|
||||
SELECT protocol, COUNT(*)
|
||||
FROM resources
|
||||
WHERE organization_id = $1::uuid
|
||||
GROUP BY protocol
|
||||
ORDER BY protocol
|
||||
`, orgID)
|
||||
if err != nil {
|
||||
return AdminSummary{}, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var services []ServiceSummary
|
||||
for rows.Next() {
|
||||
var item ServiceSummary
|
||||
if err := rows.Scan(&item.Protocol, &item.Count); err != nil {
|
||||
return AdminSummary{}, err
|
||||
}
|
||||
services = append(services, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return AdminSummary{}, err
|
||||
}
|
||||
|
||||
auditRows, err := m.db.Query(ctx, `
|
||||
SELECT ae.id::text, ae.event_type, ae.target_type, ae.target_id, ae.payload, ae.created_at
|
||||
FROM audit_events ae
|
||||
LEFT JOIN remote_sessions rs ON rs.id = ae.remote_session_id
|
||||
WHERE rs.organization_id = $1::uuid
|
||||
ORDER BY ae.created_at DESC
|
||||
LIMIT 20
|
||||
`, orgID)
|
||||
if err != nil {
|
||||
return AdminSummary{}, err
|
||||
}
|
||||
defer auditRows.Close()
|
||||
var audit []OrgAuditEvent
|
||||
for auditRows.Next() {
|
||||
var item OrgAuditEvent
|
||||
if err := auditRows.Scan(&item.ID, &item.EventType, &item.TargetType, &item.TargetID, &item.Payload, &item.CreatedAt); err != nil {
|
||||
return AdminSummary{}, err
|
||||
}
|
||||
audit = append(audit, item)
|
||||
}
|
||||
if err := auditRows.Err(); err != nil {
|
||||
return AdminSummary{}, err
|
||||
}
|
||||
|
||||
return AdminSummary{
|
||||
OrganizationID: orgID,
|
||||
ResourceCount: resourceCount,
|
||||
ActiveSessionCount: activeSessionCount,
|
||||
ServiceEndpoints: services,
|
||||
ConnectorStatus: map[string]any{
|
||||
"vpn": "not_implemented",
|
||||
"connector": "not_implemented",
|
||||
},
|
||||
RecentAudit: audit,
|
||||
TopologyExposure: tenantSafeTopologyExposure(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func tenantSafeTopologyExposure() string {
|
||||
return "tenant_safe_no_core_mesh_topology"
|
||||
}
|
||||
|
||||
func (m *Module) createOrganization(w http.ResponseWriter, r *http.Request) {
|
||||
var req createOrganizationRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid organization payload")
|
||||
return
|
||||
}
|
||||
if req.ActorUserID == "" || req.Name == "" || req.Slug == "" {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "actor_user_id, slug, and name are required")
|
||||
return
|
||||
}
|
||||
role, err := m.getPlatformRole(r.Context(), req.ActorUserID)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if !isPlatformAdmin(role) {
|
||||
httpx.WriteError(w, http.StatusForbidden, "platform admin role is required")
|
||||
return
|
||||
}
|
||||
if len(req.Metadata) == 0 {
|
||||
req.Metadata = json.RawMessage(`{}`)
|
||||
}
|
||||
if !json.Valid(req.Metadata) {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "metadata must be valid json")
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
org := Organization{
|
||||
ID: uuid.NewString(),
|
||||
Slug: normalizeSlug(req.Slug),
|
||||
Name: req.Name,
|
||||
Status: "active",
|
||||
Metadata: req.Metadata,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
membership := Membership{
|
||||
ID: uuid.NewString(),
|
||||
OrganizationID: org.ID,
|
||||
UserID: req.ActorUserID,
|
||||
RoleID: RoleOrgOwner,
|
||||
Status: "active",
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
tx, err := m.db.Begin(r.Context())
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
defer tx.Rollback(r.Context())
|
||||
if _, err := tx.Exec(r.Context(), `
|
||||
INSERT INTO organizations (id, slug, name, status, metadata, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5::jsonb, $6, $7)
|
||||
`, org.ID, org.Slug, org.Name, org.Status, []byte(org.Metadata), org.CreatedAt, org.UpdatedAt); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if _, err := tx.Exec(r.Context(), `
|
||||
INSERT INTO organization_memberships (
|
||||
id, organization_id, user_id, role_id, status, invited_by_user_id, created_at, updated_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
`, membership.ID, membership.OrganizationID, membership.UserID, membership.RoleID, membership.Status, req.ActorUserID, membership.CreatedAt, membership.UpdatedAt); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if err := tx.Commit(r.Context()); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusCreated, map[string]any{
|
||||
"organization": org,
|
||||
"membership": membership,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Module) listMemberships(w http.ResponseWriter, r *http.Request) {
|
||||
orgID := chi.URLParam(r, "organizationID")
|
||||
userID := r.URL.Query().Get("user_id")
|
||||
if userID == "" {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
|
||||
return
|
||||
}
|
||||
if err := m.ensureOrgAccess(r.Context(), orgID, userID, true); err != nil {
|
||||
httpx.WriteError(w, http.StatusForbidden, err.Error())
|
||||
return
|
||||
}
|
||||
rows, err := m.db.Query(r.Context(), `
|
||||
SELECT id, organization_id, user_id, role_id, status, invited_by_user_id, created_at, updated_at
|
||||
FROM organization_memberships
|
||||
WHERE organization_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`, orgID)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
var memberships []Membership
|
||||
for rows.Next() {
|
||||
membership, err := scanMembership(rows)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
memberships = append(memberships, membership)
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusOK, map[string]any{"memberships": memberships})
|
||||
}
|
||||
|
||||
func (m *Module) addMembership(w http.ResponseWriter, r *http.Request) {
|
||||
orgID := chi.URLParam(r, "organizationID")
|
||||
var req addMembershipRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid membership payload")
|
||||
return
|
||||
}
|
||||
if req.ActorUserID == "" || req.UserID == "" || req.RoleID == "" {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "actor_user_id, user_id, and role_id are required")
|
||||
return
|
||||
}
|
||||
if err := m.ensureOrgAccess(r.Context(), orgID, req.ActorUserID, true); err != nil {
|
||||
httpx.WriteError(w, http.StatusForbidden, err.Error())
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
membership := Membership{
|
||||
ID: uuid.NewString(),
|
||||
OrganizationID: orgID,
|
||||
UserID: req.UserID,
|
||||
RoleID: req.RoleID,
|
||||
Status: "active",
|
||||
InvitedByUser: &req.ActorUserID,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if _, err := m.db.Exec(r.Context(), `
|
||||
INSERT INTO organization_memberships (
|
||||
id, organization_id, user_id, role_id, status, invited_by_user_id, created_at, updated_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
ON CONFLICT (organization_id, user_id) DO UPDATE SET
|
||||
role_id = EXCLUDED.role_id,
|
||||
status = 'active',
|
||||
invited_by_user_id = EXCLUDED.invited_by_user_id,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
`, membership.ID, membership.OrganizationID, membership.UserID, membership.RoleID, membership.Status, membership.InvitedByUser, membership.CreatedAt, membership.UpdatedAt); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusCreated, map[string]any{"membership": membership})
|
||||
}
|
||||
|
||||
var errForbidden = errors.New("forbidden")
|
||||
|
||||
func (m *Module) ensureOrgAccess(ctx context.Context, orgID, userID string, adminRequired bool) error {
|
||||
role, err := m.getPlatformRole(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if isPlatformAdmin(role) {
|
||||
return nil
|
||||
}
|
||||
query := `
|
||||
SELECT role_id
|
||||
FROM organization_memberships
|
||||
WHERE organization_id = $1 AND user_id = $2 AND status = 'active'
|
||||
`
|
||||
var roleID string
|
||||
if err := m.db.QueryRow(ctx, query, orgID, userID).Scan(&roleID); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return errForbidden
|
||||
}
|
||||
return err
|
||||
}
|
||||
if adminRequired && roleID != RoleOrgOwner && roleID != RoleOrgAdmin {
|
||||
return errForbidden
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Module) getPlatformRole(ctx context.Context, userID string) (string, error) {
|
||||
return authority.EffectivePlatformRole(ctx, m.db, m.authority, userID)
|
||||
}
|
||||
|
||||
func isPlatformAdmin(role string) bool {
|
||||
return role == "platform_admin" || role == "platform_recovery_admin"
|
||||
}
|
||||
|
||||
func (m *Module) getOrganizationByID(ctx context.Context, orgID string) (Organization, error) {
|
||||
row := m.db.QueryRow(ctx, `
|
||||
SELECT id, slug, name, status, metadata, created_at, updated_at
|
||||
FROM organizations
|
||||
WHERE id = $1
|
||||
`, orgID)
|
||||
return scanOrganization(row)
|
||||
}
|
||||
|
||||
type rowScanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func scanOrganization(row rowScanner) (Organization, error) {
|
||||
var org Organization
|
||||
if err := row.Scan(&org.ID, &org.Slug, &org.Name, &org.Status, &org.Metadata, &org.CreatedAt, &org.UpdatedAt); err != nil {
|
||||
return Organization{}, err
|
||||
}
|
||||
if len(org.Metadata) == 0 {
|
||||
org.Metadata = json.RawMessage(`{}`)
|
||||
}
|
||||
return org, nil
|
||||
}
|
||||
|
||||
func scanMembership(row rowScanner) (Membership, error) {
|
||||
var membership Membership
|
||||
if err := row.Scan(
|
||||
&membership.ID,
|
||||
&membership.OrganizationID,
|
||||
&membership.UserID,
|
||||
&membership.RoleID,
|
||||
&membership.Status,
|
||||
&membership.InvitedByUser,
|
||||
&membership.CreatedAt,
|
||||
&membership.UpdatedAt,
|
||||
); err != nil {
|
||||
return Membership{}, err
|
||||
}
|
||||
return membership, nil
|
||||
}
|
||||
|
||||
func normalizeSlug(in string) string {
|
||||
return strings.ToLower(strings.TrimSpace(in))
|
||||
}
|
||||
@@ -0,0 +1,639 @@
|
||||
package resource
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/authority"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/module"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
|
||||
)
|
||||
|
||||
const (
|
||||
CertificateVerificationModeStrict = "strict"
|
||||
CertificateVerificationModeIgnore = "ignore"
|
||||
RenderQualityProfileLowBandwidth = "low_bandwidth"
|
||||
RenderQualityProfileBalanced = "balanced"
|
||||
RenderQualityProfileHighQuality = "high_quality"
|
||||
RenderQualityProfileTextPriority = "text_priority"
|
||||
ClipboardModeDisabled = "disabled"
|
||||
ClipboardModeClientToServer = "client_to_server"
|
||||
ClipboardModeServerToClient = "server_to_client"
|
||||
ClipboardModeBidirectional = "bidirectional"
|
||||
FileTransferModeDisabled = "disabled"
|
||||
FileTransferModeClientToServer = "client_to_server"
|
||||
FileTransferModeServerToClient = "server_to_client"
|
||||
FileTransferModeBidirectional = "bidirectional"
|
||||
)
|
||||
|
||||
type Module struct {
|
||||
db *pgxpool.Pool
|
||||
appEnv string
|
||||
secretStore *secrets.ResourceSecretStore
|
||||
authority *authority.Verifier
|
||||
}
|
||||
|
||||
type Resource struct {
|
||||
ID string `json:"id"`
|
||||
OrganizationID string `json:"organization_id"`
|
||||
Name string `json:"name"`
|
||||
Address string `json:"address"`
|
||||
Protocol string `json:"protocol"`
|
||||
SecretRef *string `json:"secret_ref,omitempty"`
|
||||
CertificateVerificationMode string `json:"certificate_verification_mode"`
|
||||
RenderQualityProfile string `json:"render_quality_profile"`
|
||||
ClipboardMode string `json:"clipboard_mode"`
|
||||
FileTransferMode string `json:"file_transfer_mode"`
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type upsertResourceRequest struct {
|
||||
ActorUserID string `json:"actor_user_id"`
|
||||
OrganizationID string `json:"organization_id"`
|
||||
Name string `json:"name"`
|
||||
Address string `json:"address"`
|
||||
Protocol string `json:"protocol"`
|
||||
SecretRef *string `json:"secret_ref"`
|
||||
CertificateVerificationMode string `json:"certificate_verification_mode"`
|
||||
RenderQualityProfile string `json:"render_quality_profile"`
|
||||
ClipboardMode string `json:"clipboard_mode"`
|
||||
FileTransferMode string `json:"file_transfer_mode"`
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
}
|
||||
|
||||
type upsertResourceSecretRequest struct {
|
||||
ActorUserID string `json:"actor_user_id"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
}
|
||||
|
||||
func NewModule(deps module.Dependencies, secretStores ...*secrets.ResourceSecretStore) *Module {
|
||||
var secretStore *secrets.ResourceSecretStore
|
||||
if len(secretStores) > 0 {
|
||||
secretStore = secretStores[0]
|
||||
}
|
||||
authorityVerifier, _ := authority.NewVerifier(deps.Config.Installation)
|
||||
return &Module{db: deps.Infra.DB, appEnv: deps.Config.App.Env, secretStore: secretStore, authority: authorityVerifier}
|
||||
}
|
||||
|
||||
func (m *Module) Name() string {
|
||||
return "resource"
|
||||
}
|
||||
|
||||
func (m *Module) RegisterRoutes(router chi.Router) {
|
||||
router.Route("/resources", func(r chi.Router) {
|
||||
r.Get("/", m.listResources)
|
||||
r.Post("/", m.createResource)
|
||||
r.Get("/{resourceID}", m.getResource)
|
||||
r.Put("/{resourceID}", m.updateResource)
|
||||
r.Put("/{resourceID}/secret", m.upsertResourceSecret)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Module) listResources(w http.ResponseWriter, r *http.Request) {
|
||||
userID := r.URL.Query().Get("user_id")
|
||||
orgID := r.URL.Query().Get("organization_id")
|
||||
if userID == "" {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
|
||||
return
|
||||
}
|
||||
platformRole, err := m.getPlatformRole(r.Context(), userID)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
query := `
|
||||
SELECT r.id, r.organization_id, r.name, r.address, r.protocol, r.secret_ref,
|
||||
r.certificate_verification_mode, r.metadata, r.created_at, r.updated_at,
|
||||
COALESCE(rp.clipboard_mode, 'disabled') AS clipboard_mode,
|
||||
COALESCE(rp.file_transfer_mode, 'disabled') AS file_transfer_mode
|
||||
FROM resources r
|
||||
LEFT JOIN resource_policies rp ON rp.resource_id = r.id
|
||||
`
|
||||
args := make([]any, 0, 2)
|
||||
if platformRole == "platform_admin" || platformRole == "platform_recovery_admin" {
|
||||
if orgID != "" {
|
||||
query += ` WHERE r.organization_id = $1`
|
||||
args = append(args, orgID)
|
||||
}
|
||||
query += ` ORDER BY r.created_at DESC`
|
||||
} else {
|
||||
query += `
|
||||
INNER JOIN organization_memberships om ON om.organization_id = r.organization_id
|
||||
WHERE om.user_id = $1 AND om.status = 'active'
|
||||
`
|
||||
args = append(args, userID)
|
||||
if orgID != "" {
|
||||
query += ` AND r.organization_id = $2`
|
||||
args = append(args, orgID)
|
||||
}
|
||||
query += ` ORDER BY r.created_at DESC`
|
||||
}
|
||||
rows, err := m.db.Query(r.Context(), query, args...)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
resources := make([]Resource, 0)
|
||||
for rows.Next() {
|
||||
resource, err := scanResource(rows)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
resources = append(resources, resource)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
httpx.WriteJSON(w, http.StatusOK, map[string]any{"resources": resources})
|
||||
}
|
||||
|
||||
func (m *Module) getResource(w http.ResponseWriter, r *http.Request) {
|
||||
userID := r.URL.Query().Get("user_id")
|
||||
if userID == "" {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
|
||||
return
|
||||
}
|
||||
resource, err := m.getByID(r.Context(), chi.URLParam(r, "resourceID"))
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
httpx.WriteError(w, http.StatusNotFound, "resource not found")
|
||||
return
|
||||
}
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := m.ensureResourceAccess(r.Context(), resource.OrganizationID, userID, false); err != nil {
|
||||
httpx.WriteError(w, http.StatusForbidden, err.Error())
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusOK, map[string]any{"resource": resource})
|
||||
}
|
||||
|
||||
func (m *Module) createResource(w http.ResponseWriter, r *http.Request) {
|
||||
req, err := decodeUpsertRequest(r)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
if err := secrets.ValidateResourceSecretReadiness(req.Protocol, req.SecretRef, req.Metadata, m.appEnv); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
resource := Resource{
|
||||
ID: uuid.NewString(),
|
||||
OrganizationID: req.OrganizationID,
|
||||
Name: req.Name,
|
||||
Address: req.Address,
|
||||
Protocol: req.Protocol,
|
||||
SecretRef: req.SecretRef,
|
||||
CertificateVerificationMode: req.CertificateVerificationMode,
|
||||
RenderQualityProfile: req.RenderQualityProfile,
|
||||
ClipboardMode: req.ClipboardMode,
|
||||
FileTransferMode: req.FileTransferMode,
|
||||
Metadata: req.Metadata,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if err := m.ensureResourceAccess(r.Context(), req.OrganizationID, req.ActorUserID, true); err != nil {
|
||||
httpx.WriteError(w, http.StatusForbidden, err.Error())
|
||||
return
|
||||
}
|
||||
tx, err := m.db.Begin(r.Context())
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
defer tx.Rollback(r.Context())
|
||||
if _, err := tx.Exec(r.Context(), `
|
||||
INSERT INTO resources (
|
||||
id, organization_id, name, address, protocol, secret_ref, certificate_verification_mode, metadata, created_at, updated_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9, $10)
|
||||
`, resource.ID, resource.OrganizationID, resource.Name, resource.Address, resource.Protocol, resource.SecretRef, resource.CertificateVerificationMode, []byte(resource.Metadata), resource.CreatedAt, resource.UpdatedAt); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if err := upsertResourcePolicy(r.Context(), tx, resource.ID, resource.ClipboardMode, resource.FileTransferMode, now); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if err := tx.Commit(r.Context()); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
httpx.WriteJSON(w, http.StatusCreated, map[string]any{"resource": resource})
|
||||
}
|
||||
|
||||
func (m *Module) updateResource(w http.ResponseWriter, r *http.Request) {
|
||||
req, err := decodeUpsertRequest(r)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
if err := secrets.ValidateResourceSecretReadiness(req.Protocol, req.SecretRef, req.Metadata, m.appEnv); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resourceID := chi.URLParam(r, "resourceID")
|
||||
existing, err := m.getByID(r.Context(), resourceID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
httpx.WriteError(w, http.StatusNotFound, "resource not found")
|
||||
return
|
||||
}
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if err := m.ensureResourceAccess(r.Context(), existing.OrganizationID, req.ActorUserID, true); err != nil {
|
||||
httpx.WriteError(w, http.StatusForbidden, err.Error())
|
||||
return
|
||||
}
|
||||
if req.OrganizationID != existing.OrganizationID {
|
||||
if err := m.ensureResourceAccess(r.Context(), req.OrganizationID, req.ActorUserID, true); err != nil {
|
||||
httpx.WriteError(w, http.StatusForbidden, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
tx, err := m.db.Begin(r.Context())
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
defer tx.Rollback(r.Context())
|
||||
tag, err := tx.Exec(r.Context(), `
|
||||
UPDATE resources
|
||||
SET
|
||||
organization_id = $2,
|
||||
name = $3,
|
||||
address = $4,
|
||||
protocol = $5,
|
||||
secret_ref = $6,
|
||||
certificate_verification_mode = $7,
|
||||
metadata = $8::jsonb,
|
||||
updated_at = $9
|
||||
WHERE id = $1
|
||||
`, resourceID, req.OrganizationID, req.Name, req.Address, req.Protocol, req.SecretRef, req.CertificateVerificationMode, []byte(req.Metadata), now)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
httpx.WriteError(w, http.StatusNotFound, "resource not found")
|
||||
return
|
||||
}
|
||||
if err := upsertResourcePolicy(r.Context(), tx, resourceID, req.ClipboardMode, req.FileTransferMode, now); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if err := tx.Commit(r.Context()); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resource, err := m.getByID(r.Context(), resourceID)
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
httpx.WriteJSON(w, http.StatusOK, map[string]any{"resource": resource})
|
||||
}
|
||||
|
||||
func (m *Module) upsertResourceSecret(w http.ResponseWriter, r *http.Request) {
|
||||
if m.secretStore == nil {
|
||||
httpx.WriteError(w, http.StatusServiceUnavailable, "resource secret encryption is not configured")
|
||||
return
|
||||
}
|
||||
var req upsertResourceSecretRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid resource secret payload")
|
||||
return
|
||||
}
|
||||
if req.ActorUserID == "" {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "actor_user_id is required")
|
||||
return
|
||||
}
|
||||
resourceID := chi.URLParam(r, "resourceID")
|
||||
resource, err := m.getByID(r.Context(), resourceID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
httpx.WriteError(w, http.StatusNotFound, "resource not found")
|
||||
return
|
||||
}
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if err := m.ensureResourceAccess(r.Context(), resource.OrganizationID, req.ActorUserID, true); err != nil {
|
||||
httpx.WriteError(w, http.StatusForbidden, err.Error())
|
||||
return
|
||||
}
|
||||
tx, err := m.db.Begin(r.Context())
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
defer tx.Rollback(r.Context())
|
||||
secretStore := m.secretStore.WithDB(tx)
|
||||
secretRef := secrets.DefaultResourceSecretRef(resource.OrganizationID, resource.ID)
|
||||
descriptor, err := secretStore.Upsert(r.Context(), secrets.UpsertResourceSecretCommand{
|
||||
OrganizationID: resource.OrganizationID,
|
||||
ResourceID: resource.ID,
|
||||
Protocol: resource.Protocol,
|
||||
SecretRef: secretRef,
|
||||
Payload: req.Payload,
|
||||
Metadata: req.Metadata,
|
||||
ActorUserID: req.ActorUserID,
|
||||
})
|
||||
if err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
if _, err := tx.Exec(r.Context(), `
|
||||
UPDATE resources
|
||||
SET secret_ref = $2, updated_at = $3
|
||||
WHERE id = $1::uuid
|
||||
`, resource.ID, descriptor.SecretRef, time.Now().UTC()); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if err := writeAuditEvent(r.Context(), tx, "resource_secret_rotated", req.ActorUserID, "resource_secret", descriptor.SecretRef, map[string]any{
|
||||
"resource_id": resource.ID,
|
||||
"organization_id": resource.OrganizationID,
|
||||
"protocol": resource.Protocol,
|
||||
"version": descriptor.Version,
|
||||
"secret_ref": descriptor.SecretRef,
|
||||
}); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if err := tx.Commit(r.Context()); err != nil {
|
||||
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusOK, map[string]any{"secret": descriptor})
|
||||
}
|
||||
|
||||
func decodeUpsertRequest(r *http.Request) (*upsertResourceRequest, error) {
|
||||
var req upsertResourceRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.New("invalid resource payload")
|
||||
}
|
||||
if req.Name == "" {
|
||||
return nil, errors.New("name is required")
|
||||
}
|
||||
if req.ActorUserID == "" {
|
||||
return nil, errors.New("actor_user_id is required")
|
||||
}
|
||||
if req.OrganizationID == "" {
|
||||
return nil, errors.New("organization_id is required")
|
||||
}
|
||||
if req.Address == "" {
|
||||
return nil, errors.New("address is required")
|
||||
}
|
||||
if req.Protocol == "" {
|
||||
req.Protocol = "rdp"
|
||||
}
|
||||
mode, err := normalizeCertificateVerificationMode(req.CertificateVerificationMode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.CertificateVerificationMode = mode
|
||||
renderQualityProfile, err := normalizeRenderQualityProfile(req.RenderQualityProfile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.RenderQualityProfile = renderQualityProfile
|
||||
clipboardMode, err := normalizeClipboardMode(req.ClipboardMode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.ClipboardMode = clipboardMode
|
||||
fileTransferMode, err := normalizeFileTransferMode(req.FileTransferMode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.FileTransferMode = fileTransferMode
|
||||
metadata, err := normalizeMetadata(req.Metadata, req.CertificateVerificationMode, req.RenderQualityProfile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Metadata = metadata
|
||||
return &req, nil
|
||||
}
|
||||
|
||||
func normalizeCertificateVerificationMode(mode string) (string, error) {
|
||||
switch mode {
|
||||
case "", CertificateVerificationModeStrict:
|
||||
return CertificateVerificationModeStrict, nil
|
||||
case CertificateVerificationModeIgnore:
|
||||
return CertificateVerificationModeIgnore, nil
|
||||
default:
|
||||
return "", errors.New("certificate_verification_mode must be one of: strict, ignore")
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeClipboardMode(mode string) (string, error) {
|
||||
switch mode {
|
||||
case "", ClipboardModeDisabled:
|
||||
return ClipboardModeDisabled, nil
|
||||
case ClipboardModeClientToServer, ClipboardModeServerToClient, ClipboardModeBidirectional:
|
||||
return mode, nil
|
||||
default:
|
||||
return "", errors.New("clipboard_mode must be one of: disabled, client_to_server, server_to_client, bidirectional")
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeFileTransferMode(mode string) (string, error) {
|
||||
switch mode {
|
||||
case "", FileTransferModeDisabled:
|
||||
return FileTransferModeDisabled, nil
|
||||
case FileTransferModeClientToServer, FileTransferModeServerToClient, FileTransferModeBidirectional:
|
||||
return mode, nil
|
||||
default:
|
||||
return "", errors.New("file_transfer_mode must be one of: disabled, client_to_server, server_to_client, bidirectional")
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeMetadata(raw json.RawMessage, certificateVerificationMode, renderQualityProfile string) (json.RawMessage, error) {
|
||||
if len(raw) == 0 {
|
||||
raw = json.RawMessage(`{}`)
|
||||
}
|
||||
if !json.Valid(raw) {
|
||||
return nil, errors.New("metadata must be valid json")
|
||||
}
|
||||
var metadata map[string]any
|
||||
if err := json.Unmarshal(raw, &metadata); err != nil {
|
||||
return nil, errors.New("metadata must be a json object")
|
||||
}
|
||||
metadata["certificate_verification_mode"] = certificateVerificationMode
|
||||
metadata["render_quality_profile"] = renderQualityProfile
|
||||
encoded, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.RawMessage(encoded), nil
|
||||
}
|
||||
|
||||
func (m *Module) getByID(ctx context.Context, resourceID string) (Resource, error) {
|
||||
row := m.db.QueryRow(ctx, `
|
||||
SELECT r.id, r.organization_id, r.name, r.address, r.protocol, r.secret_ref,
|
||||
r.certificate_verification_mode, r.metadata, r.created_at, r.updated_at,
|
||||
COALESCE(rp.clipboard_mode, 'disabled') AS clipboard_mode,
|
||||
COALESCE(rp.file_transfer_mode, 'disabled') AS file_transfer_mode
|
||||
FROM resources r
|
||||
LEFT JOIN resource_policies rp ON rp.resource_id = r.id
|
||||
WHERE r.id = $1
|
||||
`, resourceID)
|
||||
return scanResource(row)
|
||||
}
|
||||
|
||||
func (m *Module) ensureResourceAccess(ctx context.Context, orgID, userID string, adminRequired bool) error {
|
||||
role, err := m.getPlatformRole(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if role == "platform_admin" || role == "platform_recovery_admin" {
|
||||
return nil
|
||||
}
|
||||
var membershipRole string
|
||||
if err := m.db.QueryRow(ctx, `
|
||||
SELECT role_id
|
||||
FROM organization_memberships
|
||||
WHERE organization_id = $1 AND user_id = $2 AND status = 'active'
|
||||
`, orgID, userID).Scan(&membershipRole); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return errors.New("forbidden")
|
||||
}
|
||||
return err
|
||||
}
|
||||
if adminRequired && membershipRole != "org_owner" && membershipRole != "org_admin" {
|
||||
return errors.New("forbidden")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Module) getPlatformRole(ctx context.Context, userID string) (string, error) {
|
||||
return authority.EffectivePlatformRole(ctx, m.db, m.authority, userID)
|
||||
}
|
||||
|
||||
type rowScanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func scanResource(row rowScanner) (Resource, error) {
|
||||
var resource Resource
|
||||
if err := row.Scan(
|
||||
&resource.ID,
|
||||
&resource.OrganizationID,
|
||||
&resource.Name,
|
||||
&resource.Address,
|
||||
&resource.Protocol,
|
||||
&resource.SecretRef,
|
||||
&resource.CertificateVerificationMode,
|
||||
&resource.Metadata,
|
||||
&resource.CreatedAt,
|
||||
&resource.UpdatedAt,
|
||||
&resource.ClipboardMode,
|
||||
&resource.FileTransferMode,
|
||||
); err != nil {
|
||||
return Resource{}, err
|
||||
}
|
||||
if len(resource.Metadata) == 0 {
|
||||
resource.Metadata = json.RawMessage(`{}`)
|
||||
}
|
||||
if resource.CertificateVerificationMode == "" {
|
||||
resource.CertificateVerificationMode = CertificateVerificationModeStrict
|
||||
}
|
||||
if resource.RenderQualityProfile == "" {
|
||||
resource.RenderQualityProfile = renderQualityProfileFromMetadata(resource.Metadata)
|
||||
}
|
||||
if resource.ClipboardMode == "" {
|
||||
resource.ClipboardMode = ClipboardModeDisabled
|
||||
}
|
||||
if resource.FileTransferMode == "" {
|
||||
resource.FileTransferMode = FileTransferModeDisabled
|
||||
}
|
||||
return resource, nil
|
||||
}
|
||||
|
||||
func upsertResourcePolicy(ctx context.Context, tx pgx.Tx, resourceID, clipboardMode, fileTransferMode string, now time.Time) error {
|
||||
clipboardEnabled := clipboardMode != ClipboardModeDisabled
|
||||
fileTransferEnabled := fileTransferMode == FileTransferModeClientToServer || fileTransferMode == FileTransferModeBidirectional
|
||||
_, err := tx.Exec(ctx, `
|
||||
INSERT INTO resource_policies (
|
||||
resource_id, clipboard_enabled, clipboard_mode, file_transfer_enabled, file_transfer_mode, created_at, updated_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $6)
|
||||
ON CONFLICT (resource_id) DO UPDATE SET
|
||||
clipboard_enabled = EXCLUDED.clipboard_enabled,
|
||||
clipboard_mode = EXCLUDED.clipboard_mode,
|
||||
file_transfer_enabled = EXCLUDED.file_transfer_enabled,
|
||||
file_transfer_mode = EXCLUDED.file_transfer_mode,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
`, resourceID, clipboardEnabled, clipboardMode, fileTransferEnabled, fileTransferMode, now)
|
||||
return err
|
||||
}
|
||||
|
||||
func writeAuditEvent(ctx context.Context, tx pgx.Tx, eventType, actorUserID, targetType, targetID string, payload map[string]any) error {
|
||||
encoded, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(ctx, `
|
||||
INSERT INTO audit_events (
|
||||
id, actor_user_id, event_type, target_type, target_id, payload, created_at
|
||||
) VALUES (
|
||||
$1::uuid, NULLIF($2, '')::uuid, $3, $4, $5, $6::jsonb, $7
|
||||
)
|
||||
`, uuid.NewString(), actorUserID, eventType, targetType, targetID, encoded, time.Now().UTC())
|
||||
return err
|
||||
}
|
||||
|
||||
func normalizeRenderQualityProfile(profile string) (string, error) {
|
||||
switch profile {
|
||||
case "", RenderQualityProfileBalanced:
|
||||
return RenderQualityProfileBalanced, nil
|
||||
case RenderQualityProfileLowBandwidth, RenderQualityProfileHighQuality, RenderQualityProfileTextPriority:
|
||||
return profile, nil
|
||||
default:
|
||||
return "", errors.New("render_quality_profile must be one of: low_bandwidth, balanced, high_quality, text_priority")
|
||||
}
|
||||
}
|
||||
|
||||
func renderQualityProfileFromMetadata(raw json.RawMessage) string {
|
||||
if len(raw) == 0 {
|
||||
return RenderQualityProfileBalanced
|
||||
}
|
||||
var metadata map[string]any
|
||||
if err := json.Unmarshal(raw, &metadata); err != nil {
|
||||
return RenderQualityProfileBalanced
|
||||
}
|
||||
if profile, ok := metadata["render_quality_profile"].(string); ok {
|
||||
switch profile {
|
||||
case RenderQualityProfileLowBandwidth, RenderQualityProfileBalanced, RenderQualityProfileHighQuality, RenderQualityProfileTextPriority:
|
||||
return profile
|
||||
}
|
||||
}
|
||||
return RenderQualityProfileBalanced
|
||||
}
|
||||
@@ -0,0 +1,219 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
|
||||
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
|
||||
)
|
||||
|
||||
const (
|
||||
directWorkerTLSTrustModeSmokeInsecure = "smoke_insecure"
|
||||
directWorkerTLSTrustModePublicCA = "public_ca"
|
||||
directWorkerTLSTrustModePlatformCA = "platform_ca"
|
||||
)
|
||||
|
||||
type DataPlaneTokenClaims struct {
|
||||
SessionID string `json:"session_id"`
|
||||
AttachmentID string `json:"attachment_id"`
|
||||
UserID string `json:"user_id"`
|
||||
OrganizationID string `json:"organization_id"`
|
||||
ClusterID string `json:"cluster_id,omitempty"`
|
||||
WorkerID string `json:"worker_id"`
|
||||
ResourceID string `json:"resource_id"`
|
||||
AllowedChannels []string `json:"allowed_channels"`
|
||||
ExpiresAtValue time.Time `json:"expires_at"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
func (s *Service) buildDataPlaneOffer(session RemoteSession, attachment SessionAttachment) (*sessioncontracts.DataPlaneOffer, error) {
|
||||
if s.cfg.DataPlane.TokenTTL <= 0 || s.cfg.DataPlane.TokenPrivateKeyPEM == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
now := s.now().UTC()
|
||||
expiresAt := now.Add(s.cfg.DataPlane.TokenTTL)
|
||||
allowedChannels := dataPlaneAllowedChannelsFromSession(session)
|
||||
jti := uuid.NewString()
|
||||
claims := DataPlaneTokenClaims{
|
||||
SessionID: session.ID,
|
||||
AttachmentID: attachment.ID,
|
||||
UserID: attachment.UserID,
|
||||
OrganizationID: session.OrganizationID,
|
||||
WorkerID: session.WorkerID,
|
||||
ResourceID: session.ResourceID,
|
||||
AllowedChannels: allowedChannels,
|
||||
ExpiresAtValue: expiresAt,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ID: jti,
|
||||
Issuer: s.cfg.Auth.Issuer,
|
||||
Subject: attachment.UserID,
|
||||
Audience: jwt.ClaimStrings{"rap-data-plane", "worker:" + session.WorkerID},
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
||||
},
|
||||
}
|
||||
token, err := signDataPlaneToken(claims, s.cfg.DataPlane.TokenPrivateKeyPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
candidates := s.buildDataPlaneCandidates(session)
|
||||
preferred := sessioncontracts.DataPlaneCandidateBackendGateway
|
||||
if len(candidates) > 0 {
|
||||
preferred = candidates[0].Type
|
||||
}
|
||||
|
||||
return &sessioncontracts.DataPlaneOffer{
|
||||
Preferred: preferred,
|
||||
Token: token,
|
||||
ExpiresAt: expiresAt,
|
||||
Candidates: candidates,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) buildDataPlaneCandidates(session RemoteSession) []sessioncontracts.DataPlaneCandidate {
|
||||
var candidates []sessioncontracts.DataPlaneCandidate
|
||||
if directURL := s.directWorkerWSSURL(session.WorkerID); directURL != "" && s.canAdvertiseDirectWorkerWSS() {
|
||||
metadata := map[string]any(nil)
|
||||
if s.cfg.DataPlane.DirectWorkerJSONRuntime {
|
||||
metadata = map[string]any{
|
||||
"runtime_transport": "json_v1",
|
||||
"traffic_ready": true,
|
||||
}
|
||||
s.addDirectWorkerTLSTrustMetadata(metadata)
|
||||
if s.cfg.DataPlane.DirectWorkerBinaryRender {
|
||||
metadata["render_transport"] = "binary_v1"
|
||||
metadata["binary_render"] = true
|
||||
metadata["supported_color_modes"] = []string{"full_color", "grayscale"}
|
||||
metadata["default_color_mode"] = "full_color"
|
||||
}
|
||||
}
|
||||
candidates = append(candidates, sessioncontracts.DataPlaneCandidate{
|
||||
Type: sessioncontracts.DataPlaneCandidateDirectWorkerWSS,
|
||||
URL: directURL,
|
||||
WorkerID: session.WorkerID,
|
||||
Priority: 10,
|
||||
Metadata: metadata,
|
||||
})
|
||||
}
|
||||
if s.cfg.DataPlane.BackendGatewayURL != "" {
|
||||
candidates = append(candidates, sessioncontracts.DataPlaneCandidate{
|
||||
Type: sessioncontracts.DataPlaneCandidateBackendGateway,
|
||||
URL: s.cfg.DataPlane.BackendGatewayURL,
|
||||
Priority: 100,
|
||||
})
|
||||
}
|
||||
return candidates
|
||||
}
|
||||
|
||||
func (s *Service) canAdvertiseDirectWorkerWSS() bool {
|
||||
trustMode := normalizeDirectWorkerTLSTrustMode(s.cfg.DataPlane.DirectWorkerTLSTrustMode)
|
||||
return !secrets.IsProductionEnv(s.cfg.App.Env) || directWorkerTLSTrustModeIsProductionTrusted(trustMode)
|
||||
}
|
||||
|
||||
func (s *Service) addDirectWorkerTLSTrustMetadata(metadata map[string]any) {
|
||||
trustMode := normalizeDirectWorkerTLSTrustMode(s.cfg.DataPlane.DirectWorkerTLSTrustMode)
|
||||
metadata["tls_trust_mode"] = trustMode
|
||||
metadata["production_trusted"] = directWorkerTLSTrustModeIsProductionTrusted(trustMode)
|
||||
metadata["smoke_only"] = trustMode == directWorkerTLSTrustModeSmokeInsecure
|
||||
if s.cfg.DataPlane.DirectWorkerTLSCARef != "" {
|
||||
metadata["tls_ca_ref"] = s.cfg.DataPlane.DirectWorkerTLSCARef
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeDirectWorkerTLSTrustMode(mode string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||
case directWorkerTLSTrustModePublicCA:
|
||||
return directWorkerTLSTrustModePublicCA
|
||||
case directWorkerTLSTrustModePlatformCA:
|
||||
return directWorkerTLSTrustModePlatformCA
|
||||
default:
|
||||
return directWorkerTLSTrustModeSmokeInsecure
|
||||
}
|
||||
}
|
||||
|
||||
func directWorkerTLSTrustModeIsProductionTrusted(mode string) bool {
|
||||
return mode == directWorkerTLSTrustModePublicCA || mode == directWorkerTLSTrustModePlatformCA
|
||||
}
|
||||
|
||||
func (s *Service) directWorkerWSSURL(workerID string) string {
|
||||
template := strings.TrimSpace(s.cfg.DataPlane.DirectWorkerWSSURLTemplate)
|
||||
if template == "" || workerID == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.ReplaceAll(template, "{worker_id}", workerID)
|
||||
}
|
||||
|
||||
func signDataPlaneToken(claims DataPlaneTokenClaims, privateKeyPEM string) (string, error) {
|
||||
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(privateKeyPEM))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse data-plane private key: %w", err)
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
signed, err := token.SignedString(privateKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("sign data-plane token: %w", err)
|
||||
}
|
||||
return signed, nil
|
||||
}
|
||||
|
||||
func parseDataPlaneToken(tokenValue string, publicKey *rsa.PublicKey) (*DataPlaneTokenClaims, error) {
|
||||
claims := &DataPlaneTokenClaims{}
|
||||
token, err := jwt.ParseWithClaims(tokenValue, claims, func(token *jwt.Token) (any, error) {
|
||||
if token.Method != jwt.SigningMethodRS256 {
|
||||
return nil, fmt.Errorf("unexpected data-plane signing method: %s", token.Header["alg"])
|
||||
}
|
||||
return publicKey, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !token.Valid {
|
||||
return nil, fmt.Errorf("data-plane token invalid")
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func dataPlaneAllowedChannelsFromSession(session RemoteSession) []string {
|
||||
channels := []string{
|
||||
sessioncontracts.DataPlaneChannelControl,
|
||||
sessioncontracts.DataPlaneChannelInput,
|
||||
sessioncontracts.DataPlaneChannelRender,
|
||||
sessioncontracts.DataPlaneChannelTelemetry,
|
||||
}
|
||||
metadata := decodeJSONMap(session.Metadata)
|
||||
policy, _ := metadata["policy"].(map[string]any)
|
||||
if policy != nil {
|
||||
if mode, _ := policy["clipboard_mode"].(string); mode != "" && mode != string(ResourceClipboardModeDisabled) {
|
||||
channels = append(channels, sessioncontracts.DataPlaneChannelClipboard)
|
||||
}
|
||||
if mode, _ := policy["file_transfer_mode"].(string); fileTransferAllowsClientToServer(ResourceFileTransferMode(mode)) {
|
||||
channels = append(channels, sessioncontracts.DataPlaneChannelFileUpload)
|
||||
}
|
||||
if mode, _ := policy["file_transfer_mode"].(string); fileTransferAllowsServerToClient(ResourceFileTransferMode(mode)) {
|
||||
channels = append(channels, sessioncontracts.DataPlaneChannelFileDownload)
|
||||
}
|
||||
}
|
||||
return channels
|
||||
}
|
||||
|
||||
func (s *Service) attachDataPlaneOffer(result *SessionControlResult) error {
|
||||
if result == nil || result.Attachment == nil {
|
||||
return nil
|
||||
}
|
||||
result.GatewayURL = s.cfg.DataPlane.BackendGatewayURL
|
||||
offer, err := s.buildDataPlaneOffer(result.Session, *result.Attachment)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result.DataPlane = offer
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,357 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/config"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/module"
|
||||
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
|
||||
)
|
||||
|
||||
func TestDataPlaneTokenScopeValidation(t *testing.T) {
|
||||
now := time.Now().UTC().Truncate(time.Second)
|
||||
privateKeyPEM, publicKey := testRS256Key(t)
|
||||
service := &Service{
|
||||
cfg: module.Config{
|
||||
Auth: config.AuthConfig{
|
||||
Issuer: "rap-api-test",
|
||||
},
|
||||
DataPlane: config.DataPlaneConfig{
|
||||
TokenTTL: time.Minute,
|
||||
TokenPrivateKeyPEM: privateKeyPEM,
|
||||
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
|
||||
},
|
||||
},
|
||||
now: func() time.Time { return now },
|
||||
}
|
||||
session := RemoteSession{
|
||||
ID: "session-1",
|
||||
OrganizationID: "org-1",
|
||||
ResourceID: "resource-1",
|
||||
WorkerID: "worker-1",
|
||||
Metadata: mustJSON(t, map[string]any{"policy": map[string]any{"clipboard_mode": "bidirectional", "file_transfer_mode": "client_to_server"}}),
|
||||
}
|
||||
attachment := SessionAttachment{
|
||||
ID: "attachment-1",
|
||||
UserID: "user-1",
|
||||
}
|
||||
|
||||
offer, err := service.buildDataPlaneOffer(session, attachment)
|
||||
if err != nil {
|
||||
t.Fatalf("buildDataPlaneOffer returned error: %v", err)
|
||||
}
|
||||
if offer == nil {
|
||||
t.Fatal("expected data-plane offer")
|
||||
}
|
||||
|
||||
claims, err := parseDataPlaneToken(offer.Token, publicKey)
|
||||
if err != nil {
|
||||
t.Fatalf("parseDataPlaneToken returned error: %v", err)
|
||||
}
|
||||
assertEqual(t, claims.SessionID, session.ID, "session_id")
|
||||
assertEqual(t, claims.AttachmentID, attachment.ID, "attachment_id")
|
||||
assertEqual(t, claims.UserID, attachment.UserID, "user_id")
|
||||
assertEqual(t, claims.OrganizationID, session.OrganizationID, "organization_id")
|
||||
assertEqual(t, claims.WorkerID, session.WorkerID, "worker_id")
|
||||
assertEqual(t, claims.ResourceID, session.ResourceID, "resource_id")
|
||||
if claims.ID == "" {
|
||||
t.Fatal("expected jti")
|
||||
}
|
||||
if claims.ExpiresAt == nil || !claims.ExpiresAt.Time.Equal(now.Add(time.Minute)) {
|
||||
t.Fatalf("unexpected expires_at: %v", claims.ExpiresAt)
|
||||
}
|
||||
if !claims.ExpiresAtValue.Equal(now.Add(time.Minute)) {
|
||||
t.Fatalf("unexpected expires_at claim value: %v", claims.ExpiresAtValue)
|
||||
}
|
||||
for _, channel := range []string{
|
||||
sessioncontracts.DataPlaneChannelControl,
|
||||
sessioncontracts.DataPlaneChannelInput,
|
||||
sessioncontracts.DataPlaneChannelRender,
|
||||
sessioncontracts.DataPlaneChannelTelemetry,
|
||||
sessioncontracts.DataPlaneChannelClipboard,
|
||||
sessioncontracts.DataPlaneChannelFileUpload,
|
||||
} {
|
||||
if !slices.Contains(claims.AllowedChannels, channel) {
|
||||
t.Fatalf("expected allowed channel %q in %v", channel, claims.AllowedChannels)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataPlaneOfferResponseShapeCompatibility(t *testing.T) {
|
||||
now := time.Now().UTC().Truncate(time.Second)
|
||||
privateKeyPEM, _ := testRS256Key(t)
|
||||
service := &Service{
|
||||
cfg: module.Config{
|
||||
Auth: config.AuthConfig{Issuer: "rap-api-test"},
|
||||
DataPlane: config.DataPlaneConfig{
|
||||
TokenTTL: time.Minute,
|
||||
TokenPrivateKeyPEM: privateKeyPEM,
|
||||
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
|
||||
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
|
||||
DirectWorkerJSONRuntime: true,
|
||||
DirectWorkerTLSTrustMode: "smoke_insecure",
|
||||
},
|
||||
},
|
||||
now: func() time.Time { return now },
|
||||
}
|
||||
result := &SessionControlResult{
|
||||
Session: RemoteSession{
|
||||
ID: "session-1",
|
||||
OrganizationID: "org-1",
|
||||
ResourceID: "resource-1",
|
||||
WorkerID: "worker-1",
|
||||
Metadata: mustJSON(t, map[string]any{"policy": map[string]any{"clipboard_mode": "disabled", "file_transfer_mode": "disabled"}}),
|
||||
},
|
||||
Attachment: &SessionAttachment{ID: "attachment-1", UserID: "user-1"},
|
||||
AttachToken: &sessioncontracts.AttachTokenClaims{
|
||||
Token: "existing-attach-token",
|
||||
SessionID: "session-1",
|
||||
AttachmentID: "attachment-1",
|
||||
UserID: "user-1",
|
||||
WorkerID: "worker-1",
|
||||
ExpiresAt: now.Add(2 * time.Minute),
|
||||
},
|
||||
}
|
||||
|
||||
if err := service.attachDataPlaneOffer(result); err != nil {
|
||||
t.Fatalf("attachDataPlaneOffer returned error: %v", err)
|
||||
}
|
||||
payload, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal response: %v", err)
|
||||
}
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(payload, &decoded); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if decoded["session"] == nil || decoded["attachment"] == nil || decoded["attach_token"] == nil {
|
||||
t.Fatalf("response lost existing fields: %s", payload)
|
||||
}
|
||||
if decoded["data_plane"] == nil || decoded["gateway_url"] == nil {
|
||||
t.Fatalf("response missing data-plane fields: %s", payload)
|
||||
}
|
||||
if result.DataPlane == nil {
|
||||
t.Fatal("expected data-plane offer")
|
||||
}
|
||||
if result.DataPlane.Preferred != sessioncontracts.DataPlaneCandidateDirectWorkerWSS {
|
||||
t.Fatalf("unexpected preferred candidate: %s", result.DataPlane.Preferred)
|
||||
}
|
||||
if len(result.DataPlane.Candidates) != 2 {
|
||||
t.Fatalf("expected direct and fallback candidates, got %d", len(result.DataPlane.Candidates))
|
||||
}
|
||||
if result.DataPlane.Candidates[0].URL != "wss://worker-1.worker.example.test/rap/v1/data-plane" {
|
||||
t.Fatalf("unexpected direct candidate URL: %s", result.DataPlane.Candidates[0].URL)
|
||||
}
|
||||
if result.DataPlane.Candidates[0].Metadata["runtime_transport"] != "json_v1" {
|
||||
t.Fatalf("direct candidate is missing json_v1 runtime metadata: %#v", result.DataPlane.Candidates[0].Metadata)
|
||||
}
|
||||
if result.DataPlane.Candidates[0].Metadata["traffic_ready"] != true {
|
||||
t.Fatalf("direct candidate is missing traffic_ready metadata: %#v", result.DataPlane.Candidates[0].Metadata)
|
||||
}
|
||||
if result.DataPlane.Candidates[0].Metadata["smoke_only"] != true {
|
||||
t.Fatalf("direct candidate should be marked smoke-only by default: %#v", result.DataPlane.Candidates[0].Metadata)
|
||||
}
|
||||
if result.DataPlane.Candidates[0].Metadata["production_trusted"] != false {
|
||||
t.Fatalf("smoke direct candidate must not be production-trusted: %#v", result.DataPlane.Candidates[0].Metadata)
|
||||
}
|
||||
if !strings.Contains(result.DataPlane.Candidates[1].URL, "/api/v1/gateway/ws") {
|
||||
t.Fatalf("unexpected backend candidate URL: %s", result.DataPlane.Candidates[1].URL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataPlaneDirectCandidateMetadataRequiresRuntimeFlag(t *testing.T) {
|
||||
service := &Service{
|
||||
cfg: module.Config{
|
||||
DataPlane: config.DataPlaneConfig{
|
||||
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
|
||||
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
|
||||
DirectWorkerTLSTrustMode: "smoke_insecure",
|
||||
},
|
||||
},
|
||||
}
|
||||
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
|
||||
if len(candidates) != 2 {
|
||||
t.Fatalf("expected direct and fallback candidates, got %d", len(candidates))
|
||||
}
|
||||
if candidates[0].Metadata != nil {
|
||||
t.Fatalf("direct candidate must not advertise json_v1 before runtime flag is enabled: %#v", candidates[0].Metadata)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataPlaneDirectCandidateAdvertisesBinaryRenderOnlyWhenEnabled(t *testing.T) {
|
||||
service := &Service{
|
||||
cfg: module.Config{
|
||||
DataPlane: config.DataPlaneConfig{
|
||||
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
|
||||
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
|
||||
DirectWorkerJSONRuntime: true,
|
||||
DirectWorkerBinaryRender: true,
|
||||
DirectWorkerTLSTrustMode: "platform_ca",
|
||||
DirectWorkerTLSCARef: "rap-platform-ca:v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
|
||||
if len(candidates) != 2 {
|
||||
t.Fatalf("expected direct and fallback candidates, got %d", len(candidates))
|
||||
}
|
||||
if candidates[0].Metadata["render_transport"] != "binary_v1" {
|
||||
t.Fatalf("direct candidate is missing binary render metadata: %#v", candidates[0].Metadata)
|
||||
}
|
||||
if candidates[0].Metadata["binary_render"] != true {
|
||||
t.Fatalf("direct candidate is missing binary_render metadata: %#v", candidates[0].Metadata)
|
||||
}
|
||||
if candidates[0].Metadata["default_color_mode"] != "full_color" {
|
||||
t.Fatalf("direct candidate is missing default_color_mode metadata: %#v", candidates[0].Metadata)
|
||||
}
|
||||
if candidates[0].Metadata["production_trusted"] != true || candidates[0].Metadata["tls_trust_mode"] != "platform_ca" {
|
||||
t.Fatalf("direct candidate is missing production trust metadata: %#v", candidates[0].Metadata)
|
||||
}
|
||||
if candidates[0].Metadata["tls_ca_ref"] != "rap-platform-ca:v1" {
|
||||
t.Fatalf("direct candidate is missing tls_ca_ref metadata: %#v", candidates[0].Metadata)
|
||||
}
|
||||
modes, ok := candidates[0].Metadata["supported_color_modes"].([]string)
|
||||
if !ok || !slices.Contains(modes, "full_color") || !slices.Contains(modes, "grayscale") {
|
||||
t.Fatalf("direct candidate is missing supported_color_modes metadata: %#v", candidates[0].Metadata)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataPlaneDirectCandidateOmittedInProductionWhenSmokeOnly(t *testing.T) {
|
||||
service := &Service{
|
||||
cfg: module.Config{
|
||||
App: config.AppConfig{Env: "production"},
|
||||
DataPlane: config.DataPlaneConfig{
|
||||
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
|
||||
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
|
||||
DirectWorkerJSONRuntime: true,
|
||||
DirectWorkerTLSTrustMode: "smoke_insecure",
|
||||
},
|
||||
},
|
||||
}
|
||||
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
|
||||
if len(candidates) != 1 {
|
||||
t.Fatalf("expected fallback-only candidates in production with smoke TLS, got %d", len(candidates))
|
||||
}
|
||||
if candidates[0].Type != sessioncontracts.DataPlaneCandidateBackendGateway {
|
||||
t.Fatalf("production must not advertise smoke-only direct candidate: %#v", candidates)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataPlaneDirectCandidateAdvertisedInProductionWhenTrusted(t *testing.T) {
|
||||
service := &Service{
|
||||
cfg: module.Config{
|
||||
App: config.AppConfig{Env: "production"},
|
||||
DataPlane: config.DataPlaneConfig{
|
||||
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
|
||||
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
|
||||
DirectWorkerJSONRuntime: true,
|
||||
DirectWorkerTLSTrustMode: "public_ca",
|
||||
},
|
||||
},
|
||||
}
|
||||
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
|
||||
if len(candidates) != 2 {
|
||||
t.Fatalf("expected trusted direct and fallback candidates, got %d", len(candidates))
|
||||
}
|
||||
if candidates[0].Metadata["production_trusted"] != true || candidates[0].Metadata["tls_trust_mode"] != "public_ca" {
|
||||
t.Fatalf("trusted production direct candidate metadata mismatch: %#v", candidates[0].Metadata)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataPlaneCandidatesFallbackOnlyWhenDirectTemplateMissing(t *testing.T) {
|
||||
service := &Service{
|
||||
cfg: module.Config{
|
||||
DataPlane: config.DataPlaneConfig{
|
||||
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
|
||||
},
|
||||
},
|
||||
}
|
||||
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
|
||||
if len(candidates) != 1 {
|
||||
t.Fatalf("expected fallback-only candidate list, got %d", len(candidates))
|
||||
}
|
||||
if candidates[0].Type != sessioncontracts.DataPlaneCandidateBackendGateway {
|
||||
t.Fatalf("unexpected candidate type: %s", candidates[0].Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataPlaneAllowedChannelsRespectRuntimePolicy(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
policy map[string]any
|
||||
expected []string
|
||||
blocked []string
|
||||
}{
|
||||
{
|
||||
name: "disabled policies expose only control input render telemetry",
|
||||
policy: map[string]any{"clipboard_mode": "disabled", "file_transfer_mode": "disabled"},
|
||||
expected: []string{sessioncontracts.DataPlaneChannelControl, sessioncontracts.DataPlaneChannelInput, sessioncontracts.DataPlaneChannelRender, sessioncontracts.DataPlaneChannelTelemetry},
|
||||
blocked: []string{sessioncontracts.DataPlaneChannelClipboard, sessioncontracts.DataPlaneChannelFileUpload},
|
||||
},
|
||||
{
|
||||
name: "clipboard policy adds clipboard channel",
|
||||
policy: map[string]any{"clipboard_mode": "server_to_client", "file_transfer_mode": "disabled"},
|
||||
expected: []string{sessioncontracts.DataPlaneChannelClipboard},
|
||||
blocked: []string{sessioncontracts.DataPlaneChannelFileUpload},
|
||||
},
|
||||
{
|
||||
name: "client upload policy adds file upload channel",
|
||||
policy: map[string]any{"clipboard_mode": "disabled", "file_transfer_mode": "client_to_server"},
|
||||
expected: []string{sessioncontracts.DataPlaneChannelFileUpload},
|
||||
blocked: []string{sessioncontracts.DataPlaneChannelClipboard},
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session := RemoteSession{Metadata: mustJSON(t, map[string]any{"policy": tc.policy})}
|
||||
channels := dataPlaneAllowedChannelsFromSession(session)
|
||||
for _, channel := range tc.expected {
|
||||
if !slices.Contains(channels, channel) {
|
||||
t.Fatalf("expected channel %q in %v", channel, channels)
|
||||
}
|
||||
}
|
||||
for _, channel := range tc.blocked {
|
||||
if slices.Contains(channels, channel) {
|
||||
t.Fatalf("did not expect channel %q in %v", channel, channels)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustJSON(t *testing.T, value any) []byte {
|
||||
t.Helper()
|
||||
payload, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal test metadata: %v", err)
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func testRS256Key(t *testing.T) (string, *rsa.PublicKey) {
|
||||
t.Helper()
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("generate RSA key: %v", err)
|
||||
}
|
||||
encoded := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||
})
|
||||
return string(encoded), &privateKey.PublicKey
|
||||
}
|
||||
|
||||
func assertEqual(t *testing.T, got, want, name string) {
|
||||
t.Helper()
|
||||
if got != want {
|
||||
t.Fatalf("unexpected %s: got %q want %q", name, got, want)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package sessionbroker
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrSessionNotFound = errors.New("remote session not found")
|
||||
ErrAttachmentNotFound = errors.New("session attachment not found")
|
||||
ErrActiveControllerPresent = errors.New("active controller already present")
|
||||
ErrTakeoverNotAllowed = errors.New("takeover not allowed")
|
||||
ErrTrustedDeviceRequired = errors.New("trusted device required")
|
||||
ErrAccessDenied = errors.New("access denied")
|
||||
ErrSessionNotAttachable = errors.New("session is not attachable")
|
||||
ErrSessionNotTerminable = errors.New("session is not terminable")
|
||||
ErrAttachTokenInvalid = errors.New("attach token invalid or expired")
|
||||
)
|
||||
@@ -0,0 +1,65 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
|
||||
)
|
||||
|
||||
type LiveStateStore interface {
|
||||
UpsertSession(ctx context.Context, state LiveSessionState) error
|
||||
GetSession(ctx context.Context, sessionID string) (*LiveSessionState, error)
|
||||
DeleteSession(ctx context.Context, sessionID string) error
|
||||
BindController(ctx context.Context, binding sessioncontracts.ControllerBinding, ttl time.Duration) error
|
||||
GetControllerBinding(ctx context.Context, sessionID string) (*sessioncontracts.ControllerBinding, error)
|
||||
ClearControllerBinding(ctx context.Context, sessionID string) error
|
||||
StoreAttachToken(ctx context.Context, claims sessioncontracts.AttachTokenClaims, ttl time.Duration) error
|
||||
ConsumeAttachToken(ctx context.Context, token string) (*sessioncontracts.AttachTokenClaims, error)
|
||||
TouchAttachmentHeartbeat(ctx context.Context, sessionID, attachmentID string, ttl time.Duration) error
|
||||
UpdateWorkerRoute(ctx context.Context, route WorkerRoute, ttl time.Duration) error
|
||||
GetWorkerRoute(ctx context.Context, sessionID string) (*WorkerRoute, error)
|
||||
DeleteWorkerRoute(ctx context.Context, sessionID string) error
|
||||
}
|
||||
|
||||
type LiveSessionState struct {
|
||||
SessionID string `json:"session_id"`
|
||||
ResourceID string `json:"resource_id"`
|
||||
WorkerID string `json:"worker_id"`
|
||||
State sessioncontracts.State `json:"state"`
|
||||
ControllerID string `json:"controller_id"`
|
||||
AttachmentID string `json:"attachment_id"`
|
||||
TakeoverVersion int `json:"takeover_version"`
|
||||
RenderQualityProfile string `json:"render_quality_profile,omitempty"`
|
||||
RenderState string `json:"render_state,omitempty"`
|
||||
RenderWidth int `json:"render_width,omitempty"`
|
||||
RenderHeight int `json:"render_height,omitempty"`
|
||||
RenderFrameSequence int64 `json:"render_frame_sequence,omitempty"`
|
||||
RenderFrameFormat string `json:"render_frame_format,omitempty"`
|
||||
RenderFrameData string `json:"render_frame_data,omitempty"`
|
||||
LastInputCorrelationID string `json:"last_input_correlation_id,omitempty"`
|
||||
WorkerFrameCapturedAt string `json:"worker_frame_captured_at,omitempty"`
|
||||
CursorX int `json:"cursor_x,omitempty"`
|
||||
CursorY int `json:"cursor_y,omitempty"`
|
||||
CursorVisible bool `json:"cursor_visible,omitempty"`
|
||||
DirtyRectangles int `json:"dirty_rectangles,omitempty"`
|
||||
LastRenderAt *time.Time `json:"last_render_at,omitempty"`
|
||||
ClipboardSequence int64 `json:"clipboard_sequence,omitempty"`
|
||||
ClipboardText string `json:"clipboard_text,omitempty"`
|
||||
ClipboardOrigin string `json:"clipboard_origin,omitempty"`
|
||||
ClipboardContentHash string `json:"clipboard_content_hash,omitempty"`
|
||||
ClipboardUpdatedAt *time.Time `json:"clipboard_updated_at,omitempty"`
|
||||
FileDownloadSequence int64 `json:"file_download_sequence,omitempty"`
|
||||
FileDownloadType string `json:"file_download_type,omitempty"`
|
||||
FileDownloadPayload map[string]any `json:"file_download_payload,omitempty"`
|
||||
FileDownloadUpdatedAt *time.Time `json:"file_download_updated_at,omitempty"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type WorkerRoute struct {
|
||||
SessionID string `json:"session_id"`
|
||||
WorkerID string `json:"worker_id"`
|
||||
LeaseID string `json:"lease_id"`
|
||||
ControlStream string `json:"control_stream"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
|
||||
)
|
||||
|
||||
type AttachmentRole string
|
||||
|
||||
const (
|
||||
AttachmentRoleController AttachmentRole = "controller"
|
||||
)
|
||||
|
||||
type AttachmentState string
|
||||
|
||||
const (
|
||||
AttachmentStateAttaching AttachmentState = "attaching"
|
||||
AttachmentStateActive AttachmentState = "active"
|
||||
AttachmentStateDetached AttachmentState = "detached"
|
||||
AttachmentStateSuperseded AttachmentState = "superseded"
|
||||
AttachmentStateRevoked AttachmentState = "revoked"
|
||||
AttachmentStateClosed AttachmentState = "closed"
|
||||
)
|
||||
|
||||
type ResourceTakeoverPolicy string
|
||||
|
||||
const (
|
||||
ResourceTakeoverPolicyTrustedDevice ResourceTakeoverPolicy = "trusted_device"
|
||||
ResourceTakeoverPolicySameUser ResourceTakeoverPolicy = "same_user"
|
||||
ResourceTakeoverPolicyAdminOnly ResourceTakeoverPolicy = "admin_only"
|
||||
)
|
||||
|
||||
type ResourceClipboardMode string
|
||||
|
||||
const (
|
||||
ResourceClipboardModeDisabled ResourceClipboardMode = "disabled"
|
||||
ResourceClipboardModeClientToServer ResourceClipboardMode = "client_to_server"
|
||||
ResourceClipboardModeServerToClient ResourceClipboardMode = "server_to_client"
|
||||
ResourceClipboardModeBidirectional ResourceClipboardMode = "bidirectional"
|
||||
)
|
||||
|
||||
type ResourceFileTransferMode string
|
||||
|
||||
const (
|
||||
ResourceFileTransferModeDisabled ResourceFileTransferMode = "disabled"
|
||||
ResourceFileTransferModeClientToServer ResourceFileTransferMode = "client_to_server"
|
||||
ResourceFileTransferModeServerToClient ResourceFileTransferMode = "server_to_client"
|
||||
ResourceFileTransferModeBidirectional ResourceFileTransferMode = "bidirectional"
|
||||
)
|
||||
|
||||
type RemoteSession struct {
|
||||
ID string
|
||||
OrganizationID string
|
||||
ResourceID string
|
||||
Protocol string
|
||||
State sessioncontracts.State
|
||||
WorkerID string
|
||||
ControllerUserID string
|
||||
DetachDeadlineAt *time.Time
|
||||
LastHeartbeatAt *time.Time
|
||||
TakeoverVersion int
|
||||
RenderQualityProfile string
|
||||
Metadata []byte
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type SessionAttachment struct {
|
||||
ID string
|
||||
RemoteSessionID string
|
||||
UserID string
|
||||
DeviceID string
|
||||
Role AttachmentRole
|
||||
State AttachmentState
|
||||
SupersededBy *string
|
||||
TakeoverOf *string
|
||||
AttachedAt *time.Time
|
||||
DetachedAt *time.Time
|
||||
LastInputAt *time.Time
|
||||
Metadata []byte
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type ResourcePolicy struct {
|
||||
ResourceID string
|
||||
MaxConcurrentSessions int
|
||||
TakeoverPolicy ResourceTakeoverPolicy
|
||||
RequireTrustedDevice bool
|
||||
DetachGracePeriod time.Duration
|
||||
ClipboardEnabled bool
|
||||
ClipboardMode ResourceClipboardMode
|
||||
FileTransferEnabled bool
|
||||
FileTransferMode ResourceFileTransferMode
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type ResourceRuntimeSpec struct {
|
||||
ID string
|
||||
OrganizationID string
|
||||
Name string
|
||||
Address string
|
||||
Protocol string
|
||||
SecretRef *string
|
||||
CertificateVerificationMode string
|
||||
Metadata []byte
|
||||
}
|
||||
|
||||
type AuditEvent struct {
|
||||
ID string
|
||||
ActorUserID *string
|
||||
ActorDeviceID *string
|
||||
EventType string
|
||||
TargetType string
|
||||
TargetID string
|
||||
RemoteSessionID *string
|
||||
Payload []byte
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
AuditEventSessionStarted = "session_started"
|
||||
AuditEventSessionAttached = "session_attached"
|
||||
AuditEventSessionDetached = "session_detached"
|
||||
AuditEventSessionTakenOver = "session_taken_over"
|
||||
AuditEventSessionTerminated = "session_terminated"
|
||||
AuditEventSessionFailed = "session_failed"
|
||||
AuditEventSecretAccessed = "resource_secret_accessed"
|
||||
AuditEventSecretAccessDenied = "resource_secret_access_denied"
|
||||
)
|
||||
@@ -0,0 +1,164 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
|
||||
)
|
||||
|
||||
type Module struct {
|
||||
service *Service
|
||||
}
|
||||
|
||||
func NewModule(service *Service) *Module {
|
||||
return &Module{service: service}
|
||||
}
|
||||
|
||||
func (m *Module) Name() string {
|
||||
return "session-broker"
|
||||
}
|
||||
|
||||
func (m *Module) Service() *Service {
|
||||
return m.service
|
||||
}
|
||||
|
||||
func (m *Module) RegisterRoutes(router chi.Router) {
|
||||
router.Route("/sessions", func(r chi.Router) {
|
||||
r.Get("/", m.listSessions)
|
||||
r.Post("/", m.startSession)
|
||||
r.Post("/{sessionID}/attach", m.attachSession)
|
||||
r.Post("/{sessionID}/detach", m.detachSession)
|
||||
r.Post("/{sessionID}/takeover", m.takeoverSession)
|
||||
r.Post("/{sessionID}/terminate", m.terminateSession)
|
||||
r.Post("/{sessionID}/fail", m.markFailed)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Module) listSessions(w http.ResponseWriter, r *http.Request) {
|
||||
userID := r.URL.Query().Get("user_id")
|
||||
if userID == "" {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
|
||||
return
|
||||
}
|
||||
sessions, err := m.service.ListSessions(r.Context(), userID)
|
||||
if err != nil {
|
||||
status, message := m.service.MapError(err)
|
||||
httpx.WriteError(w, status, message)
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusOK, map[string]any{"sessions": sessions})
|
||||
}
|
||||
|
||||
func (m *Module) startSession(w http.ResponseWriter, r *http.Request) {
|
||||
var cmd StartRemoteSessionCommand
|
||||
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid start session payload")
|
||||
return
|
||||
}
|
||||
result, err := m.service.StartRemoteSession(r.Context(), cmd)
|
||||
if err != nil {
|
||||
status, message := m.service.MapError(err)
|
||||
httpx.WriteError(w, status, message)
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusCreated, result)
|
||||
}
|
||||
|
||||
func (m *Module) attachSession(w http.ResponseWriter, r *http.Request) {
|
||||
var cmd AttachToSessionCommand
|
||||
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid attach session payload")
|
||||
return
|
||||
}
|
||||
cmd.SessionID = chi.URLParam(r, "sessionID")
|
||||
result, err := m.service.AttachToSession(r.Context(), cmd)
|
||||
if err != nil {
|
||||
status, message := m.service.MapError(err)
|
||||
httpx.WriteError(w, status, message)
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (m *Module) detachSession(w http.ResponseWriter, r *http.Request) {
|
||||
var cmd DetachFromSessionCommand
|
||||
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid detach session payload")
|
||||
return
|
||||
}
|
||||
cmd.SessionID = chi.URLParam(r, "sessionID")
|
||||
result, err := m.service.DetachFromSession(r.Context(), cmd)
|
||||
if err != nil {
|
||||
status, message := m.service.MapError(err)
|
||||
httpx.WriteError(w, status, message)
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusAccepted, result)
|
||||
}
|
||||
|
||||
func (m *Module) takeoverSession(w http.ResponseWriter, r *http.Request) {
|
||||
var cmd TakeoverSessionCommand
|
||||
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid takeover session payload")
|
||||
return
|
||||
}
|
||||
cmd.SessionID = chi.URLParam(r, "sessionID")
|
||||
result, err := m.service.TakeoverSession(r.Context(), cmd)
|
||||
if err != nil {
|
||||
status, message := m.service.MapError(err)
|
||||
httpx.WriteError(w, status, message)
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (m *Module) terminateSession(w http.ResponseWriter, r *http.Request) {
|
||||
var cmd TerminateSessionCommand
|
||||
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid terminate session payload")
|
||||
return
|
||||
}
|
||||
cmd.SessionID = chi.URLParam(r, "sessionID")
|
||||
if err := m.service.TerminateSession(r.Context(), cmd); err != nil {
|
||||
status, message := m.service.MapError(err)
|
||||
httpx.WriteError(w, status, message)
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{
|
||||
"status": "terminated",
|
||||
"message": httpx.NewMessage(
|
||||
"session.terminated",
|
||||
"status.session.terminated",
|
||||
"Session terminated.",
|
||||
nil,
|
||||
"",
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Module) markFailed(w http.ResponseWriter, r *http.Request) {
|
||||
var cmd MarkSessionFailedCommand
|
||||
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid fail session payload")
|
||||
return
|
||||
}
|
||||
cmd.SessionID = chi.URLParam(r, "sessionID")
|
||||
if err := m.service.MarkSessionFailed(r.Context(), cmd); err != nil {
|
||||
status, message := m.service.MapError(err)
|
||||
httpx.WriteError(w, status, message)
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{
|
||||
"status": "failed",
|
||||
"message": httpx.NewMessage(
|
||||
"session.failed",
|
||||
"status.session.failed",
|
||||
"Session marked as failed.",
|
||||
nil,
|
||||
"",
|
||||
),
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
|
||||
)
|
||||
|
||||
type WorkerOrchestrator interface {
|
||||
Reserve(ctx context.Context, request workercontracts.AttachRequest) (*workercontracts.WorkerLease, error)
|
||||
GetSessionLease(ctx context.Context, sessionID string) (*workercontracts.WorkerLease, error)
|
||||
ReleaseSessionLease(ctx context.Context, sessionID string) error
|
||||
PrepareAttachment(ctx context.Context, session RemoteSession, attachment SessionAttachment, runtimeMetadata map[string]any) error
|
||||
NotifyDetachment(ctx context.Context, session RemoteSession, attachment SessionAttachment) error
|
||||
TerminateRemoteSession(ctx context.Context, sessionID, attachmentID string) error
|
||||
ValidateSessionRuntime(ctx context.Context, sessionID, workerID string) (bool, string, error)
|
||||
}
|
||||
@@ -0,0 +1,607 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/authority"
|
||||
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
|
||||
)
|
||||
|
||||
type postgresStore struct {
|
||||
db postgresplatform.DBTX
|
||||
authority *authority.Verifier
|
||||
}
|
||||
|
||||
type PostgresTransactor struct {
|
||||
pool *pgxpool.Pool
|
||||
authority *authority.Verifier
|
||||
}
|
||||
|
||||
func NewPostgresStore(pool *pgxpool.Pool, verifiers ...*authority.Verifier) Store {
|
||||
var authorityVerifier *authority.Verifier
|
||||
if len(verifiers) > 0 {
|
||||
authorityVerifier = verifiers[0]
|
||||
}
|
||||
return &postgresStore{db: pool, authority: authorityVerifier}
|
||||
}
|
||||
|
||||
func NewPostgresTransactor(pool *pgxpool.Pool, verifiers ...*authority.Verifier) *PostgresTransactor {
|
||||
var authorityVerifier *authority.Verifier
|
||||
if len(verifiers) > 0 {
|
||||
authorityVerifier = verifiers[0]
|
||||
}
|
||||
return &PostgresTransactor{pool: pool, authority: authorityVerifier}
|
||||
}
|
||||
|
||||
func (t *PostgresTransactor) WithinTransaction(ctx context.Context, fn func(store Store) error) error {
|
||||
return postgresplatform.WithTransaction(ctx, t.pool, func(tx pgx.Tx) error {
|
||||
return fn(&postgresStore{db: tx, authority: t.authority})
|
||||
})
|
||||
}
|
||||
|
||||
func (s *postgresStore) RemoteSessions() RemoteSessionRepository {
|
||||
return &postgresRemoteSessionRepository{db: s.db}
|
||||
}
|
||||
|
||||
func (s *postgresStore) SessionAttachments() SessionAttachmentRepository {
|
||||
return &postgresSessionAttachmentRepository{db: s.db}
|
||||
}
|
||||
|
||||
func (s *postgresStore) ResourcePolicies() ResourcePolicyRepository {
|
||||
return &postgresResourcePolicyRepository{db: s.db}
|
||||
}
|
||||
|
||||
func (s *postgresStore) ResourceRuntime() ResourceRuntimeRepository {
|
||||
return &postgresResourceRuntimeRepository{db: s.db}
|
||||
}
|
||||
|
||||
func (s *postgresStore) AuditEvents() AuditEventRepository {
|
||||
return &postgresAuditEventRepository{db: s.db}
|
||||
}
|
||||
|
||||
func (s *postgresStore) Access() AccessRepository {
|
||||
return &postgresAccessRepository{db: s.db, authority: s.authority}
|
||||
}
|
||||
|
||||
type postgresRemoteSessionRepository struct {
|
||||
db postgresplatform.DBTX
|
||||
}
|
||||
|
||||
type postgresSessionAttachmentRepository struct {
|
||||
db postgresplatform.DBTX
|
||||
}
|
||||
|
||||
type postgresResourcePolicyRepository struct {
|
||||
db postgresplatform.DBTX
|
||||
}
|
||||
|
||||
type postgresResourceRuntimeRepository struct {
|
||||
db postgresplatform.DBTX
|
||||
}
|
||||
|
||||
type postgresAuditEventRepository struct {
|
||||
db postgresplatform.DBTX
|
||||
}
|
||||
|
||||
type postgresAccessRepository struct {
|
||||
db postgresplatform.DBTX
|
||||
authority *authority.Verifier
|
||||
}
|
||||
|
||||
func (r *postgresRemoteSessionRepository) Create(ctx context.Context, session RemoteSession) error {
|
||||
const query = `
|
||||
INSERT INTO remote_sessions (
|
||||
id, organization_id, resource_id, protocol, state, worker_id, controller_user_id, detach_deadline_at,
|
||||
last_heartbeat_at, takeover_version, metadata, created_at, updated_at
|
||||
) VALUES (
|
||||
$1::uuid, $2::uuid, $3::uuid, $4, $5, NULLIF($6, ''), $7::uuid, $8, $9, $10, $11::jsonb, $12, $13
|
||||
)
|
||||
`
|
||||
if _, err := r.db.Exec(ctx, query,
|
||||
session.ID,
|
||||
session.OrganizationID,
|
||||
session.ResourceID,
|
||||
session.Protocol,
|
||||
session.State,
|
||||
session.WorkerID,
|
||||
session.ControllerUserID,
|
||||
session.DetachDeadlineAt,
|
||||
session.LastHeartbeatAt,
|
||||
session.TakeoverVersion,
|
||||
jsonPayload(session.Metadata),
|
||||
session.CreatedAt,
|
||||
session.UpdatedAt,
|
||||
); err != nil {
|
||||
return fmt.Errorf("create remote session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *postgresRemoteSessionRepository) GetByID(ctx context.Context, sessionID string) (*RemoteSession, error) {
|
||||
return r.getByID(ctx, sessionID, "")
|
||||
}
|
||||
|
||||
func (r *postgresRemoteSessionRepository) GetByIDForUpdate(ctx context.Context, sessionID string) (*RemoteSession, error) {
|
||||
return r.getByID(ctx, sessionID, " FOR UPDATE")
|
||||
}
|
||||
|
||||
func (r *postgresRemoteSessionRepository) getByID(ctx context.Context, sessionID string, suffix string) (*RemoteSession, error) {
|
||||
query := `
|
||||
SELECT id::text, organization_id::text, resource_id::text, protocol, state, COALESCE(worker_id, ''), controller_user_id::text,
|
||||
detach_deadline_at, last_heartbeat_at, takeover_version, metadata, created_at, updated_at
|
||||
FROM remote_sessions
|
||||
WHERE id = $1::uuid` + suffix
|
||||
remoteSession, err := scanRemoteSession(r.db.QueryRow(ctx, query, sessionID))
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return remoteSession, err
|
||||
}
|
||||
|
||||
func (r *postgresRemoteSessionRepository) ListByController(ctx context.Context, userID string) ([]RemoteSession, error) {
|
||||
const query = `
|
||||
SELECT id::text, organization_id::text, resource_id::text, protocol, state, COALESCE(worker_id, ''), controller_user_id::text,
|
||||
detach_deadline_at, last_heartbeat_at, takeover_version, metadata, created_at, updated_at
|
||||
FROM remote_sessions
|
||||
WHERE controller_user_id = $1::uuid
|
||||
ORDER BY updated_at DESC
|
||||
`
|
||||
rows, err := r.db.Query(ctx, query, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list remote sessions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sessions []RemoteSession
|
||||
for rows.Next() {
|
||||
item, err := scanRemoteSession(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessions = append(sessions, *item)
|
||||
}
|
||||
return sessions, rows.Err()
|
||||
}
|
||||
|
||||
func (r *postgresRemoteSessionRepository) CountLiveByResource(ctx context.Context, resourceID string) (int, error) {
|
||||
const query = `
|
||||
SELECT COUNT(*)
|
||||
FROM remote_sessions
|
||||
WHERE resource_id = $1::uuid AND state IN ('starting', 'active', 'detached', 'reconnecting')
|
||||
`
|
||||
var count int
|
||||
if err := r.db.QueryRow(ctx, query, resourceID).Scan(&count); err != nil {
|
||||
return 0, fmt.Errorf("count live remote sessions: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *postgresRemoteSessionRepository) ListDetachedExpired(ctx context.Context, before time.Time, limit int) ([]RemoteSession, error) {
|
||||
const query = `
|
||||
SELECT id::text, organization_id::text, resource_id::text, protocol, state, COALESCE(worker_id, ''), controller_user_id::text,
|
||||
detach_deadline_at, last_heartbeat_at, takeover_version, metadata, created_at, updated_at
|
||||
FROM remote_sessions
|
||||
WHERE state = 'detached' AND detach_deadline_at IS NOT NULL AND detach_deadline_at <= $1
|
||||
ORDER BY detach_deadline_at ASC
|
||||
LIMIT $2
|
||||
`
|
||||
rows, err := r.db.Query(ctx, query, before, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list detached expired sessions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sessions []RemoteSession
|
||||
for rows.Next() {
|
||||
item, err := scanRemoteSession(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessions = append(sessions, *item)
|
||||
}
|
||||
return sessions, rows.Err()
|
||||
}
|
||||
|
||||
func (r *postgresRemoteSessionRepository) UpdateState(ctx context.Context, params UpdateRemoteSessionStateParams) error {
|
||||
const query = `
|
||||
UPDATE remote_sessions
|
||||
SET state = $2,
|
||||
worker_id = NULLIF($3, ''),
|
||||
detach_deadline_at = $4,
|
||||
last_heartbeat_at = $5,
|
||||
takeover_version = $6,
|
||||
updated_at = $7
|
||||
WHERE id = $1::uuid
|
||||
`
|
||||
if _, err := r.db.Exec(ctx, query,
|
||||
params.RemoteSessionID,
|
||||
params.State,
|
||||
params.WorkerID,
|
||||
params.DetachDeadlineAt,
|
||||
params.LastHeartbeatAt,
|
||||
params.TakeoverVersion,
|
||||
params.UpdatedAt,
|
||||
); err != nil {
|
||||
return fmt.Errorf("update remote session state: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *postgresSessionAttachmentRepository) Create(ctx context.Context, attachment SessionAttachment) error {
|
||||
const query = `
|
||||
INSERT INTO session_attachments (
|
||||
id, remote_session_id, user_id, device_id, role, state, superseded_by,
|
||||
takeover_of, attached_at, detached_at, last_input_at, metadata, created_at, updated_at
|
||||
) VALUES (
|
||||
$1::uuid, $2::uuid, $3::uuid, $4::uuid, $5, $6, NULLIF($7, '')::uuid,
|
||||
NULLIF($8, '')::uuid, $9, $10, $11, $12::jsonb, $13, $14
|
||||
)
|
||||
`
|
||||
if _, err := r.db.Exec(ctx, query,
|
||||
attachment.ID,
|
||||
attachment.RemoteSessionID,
|
||||
attachment.UserID,
|
||||
attachment.DeviceID,
|
||||
attachment.Role,
|
||||
attachment.State,
|
||||
stringValue(attachment.SupersededBy),
|
||||
stringValue(attachment.TakeoverOf),
|
||||
attachment.AttachedAt,
|
||||
attachment.DetachedAt,
|
||||
attachment.LastInputAt,
|
||||
jsonPayload(attachment.Metadata),
|
||||
attachment.CreatedAt,
|
||||
attachment.UpdatedAt,
|
||||
); err != nil {
|
||||
return fmt.Errorf("create session attachment: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *postgresSessionAttachmentRepository) GetByID(ctx context.Context, attachmentID string) (*SessionAttachment, error) {
|
||||
return r.getByID(ctx, attachmentID, "")
|
||||
}
|
||||
|
||||
func (r *postgresSessionAttachmentRepository) GetByIDForUpdate(ctx context.Context, attachmentID string) (*SessionAttachment, error) {
|
||||
return r.getByID(ctx, attachmentID, " FOR UPDATE")
|
||||
}
|
||||
|
||||
func (r *postgresSessionAttachmentRepository) getByID(ctx context.Context, attachmentID string, suffix string) (*SessionAttachment, error) {
|
||||
query := `
|
||||
SELECT id::text, remote_session_id::text, user_id::text, device_id::text, role, state,
|
||||
superseded_by::text, takeover_of::text, attached_at, detached_at, last_input_at, metadata, created_at, updated_at
|
||||
FROM session_attachments
|
||||
WHERE id = $1::uuid` + suffix
|
||||
attachment, err := scanSessionAttachment(r.db.QueryRow(ctx, query, attachmentID))
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return attachment, err
|
||||
}
|
||||
|
||||
func (r *postgresSessionAttachmentRepository) ListByRemoteSession(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error) {
|
||||
return r.listByRemoteSession(ctx, remoteSessionID, "")
|
||||
}
|
||||
|
||||
func (r *postgresSessionAttachmentRepository) ListActiveByRemoteSessionForUpdate(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error) {
|
||||
return r.listByRemoteSession(ctx, remoteSessionID, " AND state IN ('attaching', 'active', 'reconnecting') FOR UPDATE")
|
||||
}
|
||||
|
||||
func (r *postgresSessionAttachmentRepository) listByRemoteSession(ctx context.Context, remoteSessionID string, suffix string) ([]SessionAttachment, error) {
|
||||
query := `
|
||||
SELECT id::text, remote_session_id::text, user_id::text, device_id::text, role, state,
|
||||
superseded_by::text, takeover_of::text, attached_at, detached_at, last_input_at, metadata, created_at, updated_at
|
||||
FROM session_attachments
|
||||
WHERE remote_session_id = $1::uuid` + suffix
|
||||
rows, err := r.db.Query(ctx, query, remoteSessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list session attachments: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var attachments []SessionAttachment
|
||||
for rows.Next() {
|
||||
item, err := scanSessionAttachment(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attachments = append(attachments, *item)
|
||||
}
|
||||
return attachments, rows.Err()
|
||||
}
|
||||
|
||||
func (r *postgresSessionAttachmentRepository) UpdateState(ctx context.Context, params UpdateSessionAttachmentStateParams) error {
|
||||
const query = `
|
||||
UPDATE session_attachments
|
||||
SET state = $2,
|
||||
detached_at = $3,
|
||||
last_input_at = $4,
|
||||
updated_at = $5
|
||||
WHERE id = $1::uuid
|
||||
`
|
||||
if _, err := r.db.Exec(ctx, query,
|
||||
params.AttachmentID,
|
||||
params.State,
|
||||
params.DetachedAt,
|
||||
params.LastInputAt,
|
||||
params.UpdatedAt,
|
||||
); err != nil {
|
||||
return fmt.Errorf("update session attachment state: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *postgresSessionAttachmentRepository) Supersede(ctx context.Context, params SupersedeAttachmentParams) error {
|
||||
const query = `
|
||||
UPDATE session_attachments
|
||||
SET state = 'superseded',
|
||||
superseded_by = $2::uuid,
|
||||
detached_at = $3,
|
||||
updated_at = $4
|
||||
WHERE id = $1::uuid
|
||||
`
|
||||
if _, err := r.db.Exec(ctx, query,
|
||||
params.PreviousAttachmentID,
|
||||
params.NextAttachmentID,
|
||||
params.DetachedAt,
|
||||
params.UpdatedAt,
|
||||
); err != nil {
|
||||
return fmt.Errorf("supersede attachment: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *postgresResourcePolicyRepository) GetByResourceID(ctx context.Context, resourceID string) (*ResourcePolicy, error) {
|
||||
const query = `
|
||||
SELECT resource_id::text, max_concurrent_sessions, takeover_policy, require_trusted_device,
|
||||
detach_grace_period_seconds, clipboard_enabled, clipboard_mode, file_transfer_enabled,
|
||||
COALESCE(file_transfer_mode, 'disabled'), created_at, updated_at
|
||||
FROM resource_policies
|
||||
WHERE resource_id = $1::uuid
|
||||
`
|
||||
policy, err := scanResourcePolicy(r.db.QueryRow(ctx, query, resourceID))
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return policy, err
|
||||
}
|
||||
|
||||
func (r *postgresResourcePolicyRepository) Upsert(ctx context.Context, policy ResourcePolicy) error {
|
||||
const query = `
|
||||
INSERT INTO resource_policies (
|
||||
resource_id, max_concurrent_sessions, takeover_policy, require_trusted_device,
|
||||
detach_grace_period_seconds, clipboard_enabled, clipboard_mode, file_transfer_enabled, file_transfer_mode, created_at, updated_at
|
||||
) VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
ON CONFLICT (resource_id) DO UPDATE SET
|
||||
max_concurrent_sessions = EXCLUDED.max_concurrent_sessions,
|
||||
takeover_policy = EXCLUDED.takeover_policy,
|
||||
require_trusted_device = EXCLUDED.require_trusted_device,
|
||||
detach_grace_period_seconds = EXCLUDED.detach_grace_period_seconds,
|
||||
clipboard_enabled = EXCLUDED.clipboard_enabled,
|
||||
clipboard_mode = EXCLUDED.clipboard_mode,
|
||||
file_transfer_enabled = EXCLUDED.file_transfer_enabled,
|
||||
file_transfer_mode = EXCLUDED.file_transfer_mode,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
`
|
||||
clipboardMode := normalizeClipboardMode(policy.ClipboardMode)
|
||||
fileTransferMode := normalizeFileTransferMode(policy.FileTransferMode)
|
||||
if _, err := r.db.Exec(ctx, query,
|
||||
policy.ResourceID,
|
||||
policy.MaxConcurrentSessions,
|
||||
policy.TakeoverPolicy,
|
||||
policy.RequireTrustedDevice,
|
||||
int(policy.DetachGracePeriod.Seconds()),
|
||||
clipboardMode != ResourceClipboardModeDisabled,
|
||||
clipboardMode,
|
||||
fileTransferAllowsClientToServer(fileTransferMode),
|
||||
fileTransferMode,
|
||||
policy.CreatedAt,
|
||||
policy.UpdatedAt,
|
||||
); err != nil {
|
||||
return fmt.Errorf("upsert resource policy: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *postgresResourceRuntimeRepository) GetByID(ctx context.Context, resourceID string) (*ResourceRuntimeSpec, error) {
|
||||
const query = `
|
||||
SELECT id::text, organization_id::text, name, address, protocol, secret_ref, certificate_verification_mode, metadata
|
||||
FROM resources
|
||||
WHERE id = $1::uuid
|
||||
`
|
||||
item := &ResourceRuntimeSpec{}
|
||||
var secretRef *string
|
||||
var metadata []byte
|
||||
if err := r.db.QueryRow(ctx, query, resourceID).Scan(
|
||||
&item.ID,
|
||||
&item.OrganizationID,
|
||||
&item.Name,
|
||||
&item.Address,
|
||||
&item.Protocol,
|
||||
&secretRef,
|
||||
&item.CertificateVerificationMode,
|
||||
&metadata,
|
||||
); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("get resource runtime spec: %w", err)
|
||||
}
|
||||
item.SecretRef = secretRef
|
||||
item.Metadata = metadata
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func (r *postgresAuditEventRepository) Create(ctx context.Context, event AuditEvent) error {
|
||||
const query = `
|
||||
INSERT INTO audit_events (
|
||||
id, actor_user_id, actor_device_id, event_type, target_type, target_id,
|
||||
remote_session_id, payload, created_at
|
||||
) VALUES (
|
||||
$1::uuid, NULLIF($2, '')::uuid, NULLIF($3, '')::uuid, $4, $5, $6,
|
||||
NULLIF($7, '')::uuid, $8::jsonb, $9
|
||||
)
|
||||
`
|
||||
if _, err := r.db.Exec(ctx, query,
|
||||
event.ID,
|
||||
stringValue(event.ActorUserID),
|
||||
stringValue(event.ActorDeviceID),
|
||||
event.EventType,
|
||||
event.TargetType,
|
||||
event.TargetID,
|
||||
stringValue(event.RemoteSessionID),
|
||||
jsonPayload(event.Payload),
|
||||
event.CreatedAt,
|
||||
); err != nil {
|
||||
return fmt.Errorf("create audit event: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *postgresAccessRepository) IsTrustedDevice(ctx context.Context, userID, deviceID string) (bool, error) {
|
||||
const query = `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM devices
|
||||
WHERE id = $1::uuid AND user_id = $2::uuid AND trust_status = 'trusted' AND revoked_at IS NULL
|
||||
)
|
||||
`
|
||||
var trusted bool
|
||||
if err := r.db.QueryRow(ctx, query, deviceID, userID).Scan(&trusted); err != nil {
|
||||
return false, fmt.Errorf("check trusted device: %w", err)
|
||||
}
|
||||
return trusted, nil
|
||||
}
|
||||
|
||||
func (r *postgresAccessRepository) GetPlatformRole(ctx context.Context, userID string) (string, error) {
|
||||
return authority.EffectivePlatformRole(ctx, r.db, r.authority, userID)
|
||||
}
|
||||
|
||||
func (r *postgresAccessRepository) GetOrganizationRole(ctx context.Context, organizationID, userID string) (string, bool, error) {
|
||||
const query = `
|
||||
SELECT role_id
|
||||
FROM organization_memberships
|
||||
WHERE organization_id = $1::uuid AND user_id = $2::uuid AND status = 'active'
|
||||
`
|
||||
var role string
|
||||
if err := r.db.QueryRow(ctx, query, organizationID, userID).Scan(&role); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", false, nil
|
||||
}
|
||||
return "", false, fmt.Errorf("get organization role: %w", err)
|
||||
}
|
||||
return role, true, nil
|
||||
}
|
||||
|
||||
type scanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func scanRemoteSession(row scanner) (*RemoteSession, error) {
|
||||
item := &RemoteSession{}
|
||||
var detachDeadlineAt, lastHeartbeatAt *time.Time
|
||||
var metadata []byte
|
||||
if err := row.Scan(
|
||||
&item.ID,
|
||||
&item.OrganizationID,
|
||||
&item.ResourceID,
|
||||
&item.Protocol,
|
||||
&item.State,
|
||||
&item.WorkerID,
|
||||
&item.ControllerUserID,
|
||||
&detachDeadlineAt,
|
||||
&lastHeartbeatAt,
|
||||
&item.TakeoverVersion,
|
||||
&metadata,
|
||||
&item.CreatedAt,
|
||||
&item.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan remote session: %w", err)
|
||||
}
|
||||
item.DetachDeadlineAt = detachDeadlineAt
|
||||
item.LastHeartbeatAt = lastHeartbeatAt
|
||||
item.Metadata = metadata
|
||||
item.RenderQualityProfile = renderQualityProfileFromSessionMetadata(metadata)
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func scanSessionAttachment(row scanner) (*SessionAttachment, error) {
|
||||
item := &SessionAttachment{}
|
||||
var supersededBy, takeoverOf *string
|
||||
var attachedAt, detachedAt, lastInputAt *time.Time
|
||||
var metadata []byte
|
||||
if err := row.Scan(
|
||||
&item.ID,
|
||||
&item.RemoteSessionID,
|
||||
&item.UserID,
|
||||
&item.DeviceID,
|
||||
&item.Role,
|
||||
&item.State,
|
||||
&supersededBy,
|
||||
&takeoverOf,
|
||||
&attachedAt,
|
||||
&detachedAt,
|
||||
&lastInputAt,
|
||||
&metadata,
|
||||
&item.CreatedAt,
|
||||
&item.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan session attachment: %w", err)
|
||||
}
|
||||
item.SupersededBy = supersededBy
|
||||
item.TakeoverOf = takeoverOf
|
||||
item.AttachedAt = attachedAt
|
||||
item.DetachedAt = detachedAt
|
||||
item.LastInputAt = lastInputAt
|
||||
item.Metadata = metadata
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func scanResourcePolicy(row scanner) (*ResourcePolicy, error) {
|
||||
item := &ResourcePolicy{}
|
||||
var detachGraceSeconds int
|
||||
if err := row.Scan(
|
||||
&item.ResourceID,
|
||||
&item.MaxConcurrentSessions,
|
||||
&item.TakeoverPolicy,
|
||||
&item.RequireTrustedDevice,
|
||||
&detachGraceSeconds,
|
||||
&item.ClipboardEnabled,
|
||||
&item.ClipboardMode,
|
||||
&item.FileTransferEnabled,
|
||||
&item.FileTransferMode,
|
||||
&item.CreatedAt,
|
||||
&item.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan resource policy: %w", err)
|
||||
}
|
||||
item.DetachGracePeriod = time.Duration(detachGraceSeconds) * time.Second
|
||||
item.ClipboardMode = normalizeClipboardMode(item.ClipboardMode)
|
||||
item.ClipboardEnabled = item.ClipboardMode != ResourceClipboardModeDisabled
|
||||
item.FileTransferMode = normalizeFileTransferMode(item.FileTransferMode)
|
||||
item.FileTransferEnabled = fileTransferAllowsClientToServer(item.FileTransferMode)
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func jsonPayload(payload []byte) []byte {
|
||||
if len(payload) == 0 {
|
||||
return []byte(`{}`)
|
||||
}
|
||||
if json.Valid(payload) {
|
||||
return payload
|
||||
}
|
||||
return []byte(`{}`)
|
||||
}
|
||||
|
||||
func stringValue(value *string) string {
|
||||
if value == nil {
|
||||
return ""
|
||||
}
|
||||
return *value
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
|
||||
)
|
||||
|
||||
type RedisLiveStateStore struct {
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
func NewRedisLiveStateStore(client *redis.Client) *RedisLiveStateStore {
|
||||
return &RedisLiveStateStore{client: client}
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) UpsertSession(ctx context.Context, state LiveSessionState) error {
|
||||
return s.setJSON(ctx, liveSessionKey(state.SessionID), state, 0)
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) GetSession(ctx context.Context, sessionID string) (*LiveSessionState, error) {
|
||||
var state LiveSessionState
|
||||
ok, err := s.getJSON(ctx, liveSessionKey(sessionID), &state)
|
||||
if err != nil || !ok {
|
||||
return nil, err
|
||||
}
|
||||
return &state, nil
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) DeleteSession(ctx context.Context, sessionID string) error {
|
||||
return s.client.Del(ctx, liveSessionKey(sessionID)).Err()
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) BindController(ctx context.Context, binding sessioncontracts.ControllerBinding, ttl time.Duration) error {
|
||||
return s.setJSON(ctx, controllerBindingKey(binding.SessionID), binding, ttl)
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) GetControllerBinding(ctx context.Context, sessionID string) (*sessioncontracts.ControllerBinding, error) {
|
||||
var binding sessioncontracts.ControllerBinding
|
||||
ok, err := s.getJSON(ctx, controllerBindingKey(sessionID), &binding)
|
||||
if err != nil || !ok {
|
||||
return nil, err
|
||||
}
|
||||
return &binding, nil
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) ClearControllerBinding(ctx context.Context, sessionID string) error {
|
||||
return s.client.Del(ctx, controllerBindingKey(sessionID)).Err()
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) StoreAttachToken(ctx context.Context, claims sessioncontracts.AttachTokenClaims, ttl time.Duration) error {
|
||||
return s.setJSON(ctx, attachTokenKey(claims.Token), claims, ttl)
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) ConsumeAttachToken(ctx context.Context, token string) (*sessioncontracts.AttachTokenClaims, error) {
|
||||
key := attachTokenKey(token)
|
||||
payload, err := s.client.GetDel(ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("consume attach token: %w", err)
|
||||
}
|
||||
var claims sessioncontracts.AttachTokenClaims
|
||||
if err := json.Unmarshal([]byte(payload), &claims); err != nil {
|
||||
return nil, fmt.Errorf("decode attach token: %w", err)
|
||||
}
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) TouchAttachmentHeartbeat(ctx context.Context, sessionID, attachmentID string, ttl time.Duration) error {
|
||||
return s.client.Set(ctx, attachmentHeartbeatKey(sessionID, attachmentID), time.Now().UTC().Format(time.RFC3339Nano), ttl).Err()
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) UpdateWorkerRoute(ctx context.Context, route WorkerRoute, ttl time.Duration) error {
|
||||
return s.setJSON(ctx, workerRouteKey(route.SessionID), route, ttl)
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) GetWorkerRoute(ctx context.Context, sessionID string) (*WorkerRoute, error) {
|
||||
var route WorkerRoute
|
||||
ok, err := s.getJSON(ctx, workerRouteKey(sessionID), &route)
|
||||
if err != nil || !ok {
|
||||
return nil, err
|
||||
}
|
||||
return &route, nil
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) DeleteWorkerRoute(ctx context.Context, sessionID string) error {
|
||||
return s.client.Del(ctx, workerRouteKey(sessionID)).Err()
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) setJSON(ctx context.Context, key string, value any, ttl time.Duration) error {
|
||||
payload, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode redis payload: %w", err)
|
||||
}
|
||||
if err := s.client.Set(ctx, key, payload, ttl).Err(); err != nil {
|
||||
return fmt.Errorf("set redis key %s: %w", key, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) getJSON(ctx context.Context, key string, dest any) (bool, error) {
|
||||
payload, err := s.client.Get(ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get redis key %s: %w", key, err)
|
||||
}
|
||||
if err := json.Unmarshal([]byte(payload), dest); err != nil {
|
||||
return false, fmt.Errorf("decode redis key %s: %w", key, err)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func liveSessionKey(sessionID string) string {
|
||||
return "live:session:" + sessionID
|
||||
}
|
||||
|
||||
func controllerBindingKey(sessionID string) string {
|
||||
return "live:session:" + sessionID + ":controller"
|
||||
}
|
||||
|
||||
func attachTokenKey(token string) string {
|
||||
return "live:attach:" + token
|
||||
}
|
||||
|
||||
func attachmentHeartbeatKey(sessionID, attachmentID string) string {
|
||||
return "live:session:" + sessionID + ":attachment:" + attachmentID + ":heartbeat"
|
||||
}
|
||||
|
||||
func workerRouteKey(sessionID string) string {
|
||||
return "live:session:" + sessionID + ":worker-route"
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
|
||||
)
|
||||
|
||||
type RemoteSessionRepository interface {
|
||||
Create(ctx context.Context, session RemoteSession) error
|
||||
GetByID(ctx context.Context, sessionID string) (*RemoteSession, error)
|
||||
GetByIDForUpdate(ctx context.Context, sessionID string) (*RemoteSession, error)
|
||||
ListByController(ctx context.Context, userID string) ([]RemoteSession, error)
|
||||
CountLiveByResource(ctx context.Context, resourceID string) (int, error)
|
||||
ListDetachedExpired(ctx context.Context, before time.Time, limit int) ([]RemoteSession, error)
|
||||
UpdateState(ctx context.Context, params UpdateRemoteSessionStateParams) error
|
||||
}
|
||||
|
||||
type SessionAttachmentRepository interface {
|
||||
Create(ctx context.Context, attachment SessionAttachment) error
|
||||
GetByID(ctx context.Context, attachmentID string) (*SessionAttachment, error)
|
||||
GetByIDForUpdate(ctx context.Context, attachmentID string) (*SessionAttachment, error)
|
||||
ListByRemoteSession(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error)
|
||||
ListActiveByRemoteSessionForUpdate(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error)
|
||||
UpdateState(ctx context.Context, params UpdateSessionAttachmentStateParams) error
|
||||
Supersede(ctx context.Context, params SupersedeAttachmentParams) error
|
||||
}
|
||||
|
||||
type ResourcePolicyRepository interface {
|
||||
GetByResourceID(ctx context.Context, resourceID string) (*ResourcePolicy, error)
|
||||
Upsert(ctx context.Context, policy ResourcePolicy) error
|
||||
}
|
||||
|
||||
type AuditEventRepository interface {
|
||||
Create(ctx context.Context, event AuditEvent) error
|
||||
}
|
||||
|
||||
type Store interface {
|
||||
RemoteSessions() RemoteSessionRepository
|
||||
SessionAttachments() SessionAttachmentRepository
|
||||
ResourcePolicies() ResourcePolicyRepository
|
||||
ResourceRuntime() ResourceRuntimeRepository
|
||||
AuditEvents() AuditEventRepository
|
||||
Access() AccessRepository
|
||||
}
|
||||
|
||||
type Transactor interface {
|
||||
WithinTransaction(ctx context.Context, fn func(store Store) error) error
|
||||
}
|
||||
|
||||
type UpdateRemoteSessionStateParams struct {
|
||||
RemoteSessionID string
|
||||
State sessioncontracts.State
|
||||
WorkerID string
|
||||
DetachDeadlineAt *time.Time
|
||||
LastHeartbeatAt *time.Time
|
||||
TakeoverVersion int
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type UpdateSessionAttachmentStateParams struct {
|
||||
AttachmentID string
|
||||
State AttachmentState
|
||||
DetachedAt *time.Time
|
||||
LastInputAt *time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type SupersedeAttachmentParams struct {
|
||||
PreviousAttachmentID string
|
||||
NextAttachmentID string
|
||||
DetachedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type AccessRepository interface {
|
||||
IsTrustedDevice(ctx context.Context, userID, deviceID string) (bool, error)
|
||||
GetPlatformRole(ctx context.Context, userID string) (string, error)
|
||||
GetOrganizationRole(ctx context.Context, organizationID, userID string) (string, bool, error)
|
||||
}
|
||||
|
||||
type ResourceRuntimeRepository interface {
|
||||
GetByID(ctx context.Context, resourceID string) (*ResourceRuntimeSpec, error)
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/config"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/module"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
|
||||
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
|
||||
)
|
||||
|
||||
type fakeSecretResolver struct {
|
||||
response *secrets.ResolvedResourceSecret
|
||||
err error
|
||||
request secrets.ResolveResourceSecretRequest
|
||||
}
|
||||
|
||||
func testAppConfig(env string) config.AppConfig {
|
||||
return config.AppConfig{Name: "rap-api-test", Env: env}
|
||||
}
|
||||
|
||||
func (r *fakeSecretResolver) ResolveForSession(_ context.Context, req secrets.ResolveResourceSecretRequest) (*secrets.ResolvedResourceSecret, error) {
|
||||
r.request = req
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
return r.response, nil
|
||||
}
|
||||
|
||||
func TestRuntimeAssignmentMetadataMergesResolvedSecretWithoutMutatingSessionMetadata(t *testing.T) {
|
||||
resolver := &fakeSecretResolver{
|
||||
response: &secrets.ResolvedResourceSecret{
|
||||
Descriptor: secrets.ResourceSecretDescriptor{Version: 3},
|
||||
Payload: json.RawMessage(`{"username":"user","password":"secret","domain":"corp"}`),
|
||||
},
|
||||
}
|
||||
service := NewService(module.Dependencies{
|
||||
Config: module.Config{App: testAppConfig("production")},
|
||||
}, nil, nil, nil, nil, resolver)
|
||||
sessionMetadata := mustJSON(t, map[string]any{
|
||||
"resource": map[string]any{
|
||||
"id": "resource-1",
|
||||
"organization_id": "org-1",
|
||||
"secret_ref": "rap-secret://org/org-1/resources/resource-1/primary",
|
||||
"metadata": map[string]any{
|
||||
"rdp_host": "host",
|
||||
},
|
||||
},
|
||||
})
|
||||
session := RemoteSession{
|
||||
ID: "session-1",
|
||||
OrganizationID: "org-1",
|
||||
ResourceID: "resource-1",
|
||||
WorkerID: "worker-1",
|
||||
Metadata: sessionMetadata,
|
||||
}
|
||||
metadata, secretRef, version, err := service.runtimeAssignmentMetadata(context.Background(), session, &workercontracts.WorkerLease{LeaseID: "lease-1"})
|
||||
if err != nil {
|
||||
t.Fatalf("runtimeAssignmentMetadata returned error: %v", err)
|
||||
}
|
||||
if secretRef == "" || version != 3 {
|
||||
t.Fatalf("expected secret ref and version, got ref=%q version=%d", secretRef, version)
|
||||
}
|
||||
resource := metadata["resource"].(map[string]any)
|
||||
resourceMetadata := resource["metadata"].(map[string]any)
|
||||
if resourceMetadata["username"] != "user" || resourceMetadata["password"] != "secret" || resourceMetadata["domain"] != "corp" {
|
||||
t.Fatalf("resolved secret was not merged: %#v", resourceMetadata)
|
||||
}
|
||||
var persisted map[string]any
|
||||
if err := json.Unmarshal(session.Metadata, &persisted); err != nil {
|
||||
t.Fatalf("decode persisted metadata: %v", err)
|
||||
}
|
||||
persistedResource := persisted["resource"].(map[string]any)
|
||||
persistedMetadata := persistedResource["metadata"].(map[string]any)
|
||||
if _, ok := persistedMetadata["password"]; ok {
|
||||
t.Fatalf("session metadata was mutated with plaintext secret")
|
||||
}
|
||||
if resolver.request.LeaseID != "lease-1" || resolver.request.WorkerID != "worker-1" {
|
||||
t.Fatalf("resolver request missed lease/worker proof: %#v", resolver.request)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeAssignmentMetadataRequiresResolverInProduction(t *testing.T) {
|
||||
service := NewService(module.Dependencies{
|
||||
Config: module.Config{App: testAppConfig("production")},
|
||||
}, nil, nil, nil, nil)
|
||||
session := RemoteSession{
|
||||
ID: "session-1",
|
||||
OrganizationID: "org-1",
|
||||
ResourceID: "resource-1",
|
||||
WorkerID: "worker-1",
|
||||
Metadata: mustJSON(t, map[string]any{
|
||||
"resource": map[string]any{
|
||||
"secret_ref": "rap-secret://org/org-1/resources/resource-1/primary",
|
||||
},
|
||||
}),
|
||||
}
|
||||
_, _, _, err := service.runtimeAssignmentMetadata(context.Background(), session, &workercontracts.WorkerLease{LeaseID: "lease-1"})
|
||||
if !errors.Is(err, secrets.ErrSecretEncryptionKeyMissing) {
|
||||
t.Fatalf("expected missing resolver error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeAssignmentMetadataAllowsDevelopmentMetadataWithoutResolver(t *testing.T) {
|
||||
service := NewService(module.Dependencies{
|
||||
Config: module.Config{App: testAppConfig("development")},
|
||||
}, nil, nil, nil, nil)
|
||||
session := RemoteSession{
|
||||
ID: "session-1",
|
||||
OrganizationID: "org-1",
|
||||
ResourceID: "resource-1",
|
||||
WorkerID: "worker-1",
|
||||
Metadata: mustJSON(t, map[string]any{
|
||||
"resource": map[string]any{
|
||||
"secret_ref": "rap-secret://org/org-1/resources/resource-1/primary",
|
||||
"metadata": map[string]any{
|
||||
"username": "dev-user",
|
||||
"password": "dev-password",
|
||||
},
|
||||
},
|
||||
}),
|
||||
}
|
||||
metadata, secretRef, _, err := service.runtimeAssignmentMetadata(context.Background(), session, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("development metadata should not require resolver: %v", err)
|
||||
}
|
||||
if secretRef != "" {
|
||||
t.Fatalf("development fallback should not audit resolver use, got %q", secretRef)
|
||||
}
|
||||
resource := metadata["resource"].(map[string]any)
|
||||
resourceMetadata := resource["metadata"].(map[string]any)
|
||||
if resourceMetadata["password"] != "dev-password" {
|
||||
t.Fatalf("development metadata was not preserved")
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,391 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/module"
|
||||
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
|
||||
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
|
||||
)
|
||||
|
||||
func TestHandleWorkerConnectedIgnoresTerminalSession(t *testing.T) {
|
||||
service, store, live, _ := newStaleWorkerEventTestService()
|
||||
store.remote.sessions["session-1"] = RemoteSession{
|
||||
ID: "session-1",
|
||||
State: sessioncontracts.StateTerminated,
|
||||
WorkerID: "worker-1",
|
||||
}
|
||||
|
||||
if err := service.HandleWorkerConnected(context.Background(), "session-1"); err != nil {
|
||||
t.Fatalf("HandleWorkerConnected returned error for stale terminal event: %v", err)
|
||||
}
|
||||
if got := store.remote.sessions["session-1"].State; got != sessioncontracts.StateTerminated {
|
||||
t.Fatalf("stale connected event changed terminal state to %q", got)
|
||||
}
|
||||
if store.remote.updateCount != 0 {
|
||||
t.Fatalf("stale connected event updated authoritative session %d times", store.remote.updateCount)
|
||||
}
|
||||
if live.upsertCount != 0 {
|
||||
t.Fatalf("stale connected event recreated live state %d times", live.upsertCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateWorkerRenderTelemetryIgnoresTerminalSession(t *testing.T) {
|
||||
service, store, live, _ := newStaleWorkerEventTestService()
|
||||
store.remote.sessions["session-1"] = RemoteSession{
|
||||
ID: "session-1",
|
||||
State: sessioncontracts.StateTerminated,
|
||||
WorkerID: "worker-1",
|
||||
}
|
||||
|
||||
err := service.UpdateWorkerRenderTelemetry(context.Background(), "session-1", map[string]any{
|
||||
"render_state": "ready",
|
||||
"width": 1280,
|
||||
"height": 720,
|
||||
"frame_sequence": int64(99),
|
||||
"frame_data": "stale-frame",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateWorkerRenderTelemetry returned error for stale terminal event: %v", err)
|
||||
}
|
||||
if live.upsertCount != 0 {
|
||||
t.Fatalf("stale render event recreated live state %d times", live.upsertCount)
|
||||
}
|
||||
if live.sessions["session-1"] != nil {
|
||||
t.Fatalf("stale render event left live state behind: %#v", live.sessions["session-1"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkSessionFailedTransitionsActiveSession(t *testing.T) {
|
||||
service, store, live, orchestrator := newStaleWorkerEventTestService()
|
||||
store.remote.sessions["session-1"] = RemoteSession{
|
||||
ID: "session-1",
|
||||
State: sessioncontracts.StateActive,
|
||||
WorkerID: "worker-1",
|
||||
TakeoverVersion: 3,
|
||||
}
|
||||
store.attachments.items["attachment-1"] = SessionAttachment{
|
||||
ID: "attachment-1",
|
||||
RemoteSessionID: "session-1",
|
||||
State: AttachmentStateActive,
|
||||
}
|
||||
live.sessions["session-1"] = &LiveSessionState{SessionID: "session-1", State: sessioncontracts.StateActive}
|
||||
|
||||
if err := service.MarkSessionFailed(context.Background(), MarkSessionFailedCommand{SessionID: "session-1", Reason: "worker_lost"}); err != nil {
|
||||
t.Fatalf("MarkSessionFailed returned error: %v", err)
|
||||
}
|
||||
if got := store.remote.sessions["session-1"].State; got != sessioncontracts.StateFailed {
|
||||
t.Fatalf("expected failed state, got %q", got)
|
||||
}
|
||||
if got := store.attachments.items["attachment-1"].State; got != AttachmentStateClosed {
|
||||
t.Fatalf("expected attachment closed, got %q", got)
|
||||
}
|
||||
if store.audit.createCount != 1 {
|
||||
t.Fatalf("expected one audit event, got %d", store.audit.createCount)
|
||||
}
|
||||
if live.sessions["session-1"] != nil {
|
||||
t.Fatal("expected failed session live state to be deleted")
|
||||
}
|
||||
if orchestrator.releaseCount != 1 {
|
||||
t.Fatalf("expected session lease release, got %d", orchestrator.releaseCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkSessionFailedAlreadyFailedIsIdempotent(t *testing.T) {
|
||||
service, store, _, _ := newStaleWorkerEventTestService()
|
||||
store.remote.sessions["session-1"] = RemoteSession{
|
||||
ID: "session-1",
|
||||
State: sessioncontracts.StateFailed,
|
||||
WorkerID: "worker-1",
|
||||
TakeoverVersion: 1,
|
||||
}
|
||||
|
||||
if err := service.MarkSessionFailed(context.Background(), MarkSessionFailedCommand{SessionID: "session-1", Reason: "duplicate_worker_failure"}); err != nil {
|
||||
t.Fatalf("duplicate MarkSessionFailed returned error: %v", err)
|
||||
}
|
||||
if got := store.remote.sessions["session-1"].State; got != sessioncontracts.StateFailed {
|
||||
t.Fatalf("duplicate terminal event changed state to %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func newStaleWorkerEventTestService() (*Service, *staleWorkerEventTestStore, *staleWorkerEventLiveState, *staleWorkerEventOrchestrator) {
|
||||
store := &staleWorkerEventTestStore{
|
||||
remote: &staleWorkerEventRemoteSessions{sessions: map[string]RemoteSession{}},
|
||||
attachments: &staleWorkerEventAttachments{items: map[string]SessionAttachment{}},
|
||||
policies: &staleWorkerEventPolicies{},
|
||||
audit: &staleWorkerEventAudit{},
|
||||
}
|
||||
live := &staleWorkerEventLiveState{sessions: map[string]*LiveSessionState{}}
|
||||
orchestrator := &staleWorkerEventOrchestrator{}
|
||||
service := NewService(module.Dependencies{
|
||||
Infra: module.Infra{Logger: slog.New(slog.NewTextHandler(io.Discard, nil))},
|
||||
}, store, staleWorkerEventTransactor{store: store}, live, orchestrator)
|
||||
service.now = func() time.Time { return time.Unix(100, 0).UTC() }
|
||||
return service, store, live, orchestrator
|
||||
}
|
||||
|
||||
type staleWorkerEventTransactor struct {
|
||||
store Store
|
||||
}
|
||||
|
||||
func (t staleWorkerEventTransactor) WithinTransaction(ctx context.Context, fn func(store Store) error) error {
|
||||
return fn(t.store)
|
||||
}
|
||||
|
||||
type staleWorkerEventTestStore struct {
|
||||
remote *staleWorkerEventRemoteSessions
|
||||
attachments *staleWorkerEventAttachments
|
||||
policies *staleWorkerEventPolicies
|
||||
audit *staleWorkerEventAudit
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventTestStore) RemoteSessions() RemoteSessionRepository { return s.remote }
|
||||
func (s *staleWorkerEventTestStore) SessionAttachments() SessionAttachmentRepository {
|
||||
return s.attachments
|
||||
}
|
||||
func (s *staleWorkerEventTestStore) ResourcePolicies() ResourcePolicyRepository { return s.policies }
|
||||
func (s *staleWorkerEventTestStore) ResourceRuntime() ResourceRuntimeRepository {
|
||||
return staleWorkerEventResourceRuntime{}
|
||||
}
|
||||
func (s *staleWorkerEventTestStore) AuditEvents() AuditEventRepository { return s.audit }
|
||||
func (s *staleWorkerEventTestStore) Access() AccessRepository { return staleWorkerEventAccess{} }
|
||||
|
||||
type staleWorkerEventRemoteSessions struct {
|
||||
sessions map[string]RemoteSession
|
||||
updateCount int
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventRemoteSessions) Create(_ context.Context, session RemoteSession) error {
|
||||
r.sessions[session.ID] = session
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventRemoteSessions) GetByID(_ context.Context, sessionID string) (*RemoteSession, error) {
|
||||
session, ok := r.sessions[sessionID]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventRemoteSessions) GetByIDForUpdate(ctx context.Context, sessionID string) (*RemoteSession, error) {
|
||||
return r.GetByID(ctx, sessionID)
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventRemoteSessions) ListByController(_ context.Context, _ string) ([]RemoteSession, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventRemoteSessions) CountLiveByResource(_ context.Context, _ string) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventRemoteSessions) ListDetachedExpired(_ context.Context, _ time.Time, _ int) ([]RemoteSession, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventRemoteSessions) UpdateState(_ context.Context, params UpdateRemoteSessionStateParams) error {
|
||||
session := r.sessions[params.RemoteSessionID]
|
||||
session.State = params.State
|
||||
session.WorkerID = params.WorkerID
|
||||
session.DetachDeadlineAt = params.DetachDeadlineAt
|
||||
session.LastHeartbeatAt = params.LastHeartbeatAt
|
||||
session.TakeoverVersion = params.TakeoverVersion
|
||||
session.UpdatedAt = params.UpdatedAt
|
||||
r.sessions[params.RemoteSessionID] = session
|
||||
r.updateCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
type staleWorkerEventAttachments struct {
|
||||
items map[string]SessionAttachment
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventAttachments) Create(_ context.Context, attachment SessionAttachment) error {
|
||||
r.items[attachment.ID] = attachment
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventAttachments) GetByID(_ context.Context, attachmentID string) (*SessionAttachment, error) {
|
||||
attachment, ok := r.items[attachmentID]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
return &attachment, nil
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventAttachments) GetByIDForUpdate(ctx context.Context, attachmentID string) (*SessionAttachment, error) {
|
||||
return r.GetByID(ctx, attachmentID)
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventAttachments) ListByRemoteSession(_ context.Context, remoteSessionID string) ([]SessionAttachment, error) {
|
||||
attachments := make([]SessionAttachment, 0)
|
||||
for _, attachment := range r.items {
|
||||
if attachment.RemoteSessionID == remoteSessionID {
|
||||
attachments = append(attachments, attachment)
|
||||
}
|
||||
}
|
||||
return attachments, nil
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventAttachments) ListActiveByRemoteSessionForUpdate(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error) {
|
||||
return r.ListByRemoteSession(ctx, remoteSessionID)
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventAttachments) UpdateState(_ context.Context, params UpdateSessionAttachmentStateParams) error {
|
||||
attachment := r.items[params.AttachmentID]
|
||||
attachment.State = params.State
|
||||
attachment.DetachedAt = params.DetachedAt
|
||||
attachment.LastInputAt = params.LastInputAt
|
||||
attachment.UpdatedAt = params.UpdatedAt
|
||||
r.items[params.AttachmentID] = attachment
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventAttachments) Supersede(_ context.Context, params SupersedeAttachmentParams) error {
|
||||
attachment := r.items[params.PreviousAttachmentID]
|
||||
attachment.State = AttachmentStateSuperseded
|
||||
attachment.SupersededBy = ¶ms.NextAttachmentID
|
||||
attachment.DetachedAt = ¶ms.DetachedAt
|
||||
attachment.UpdatedAt = params.UpdatedAt
|
||||
r.items[params.PreviousAttachmentID] = attachment
|
||||
return nil
|
||||
}
|
||||
|
||||
type staleWorkerEventPolicies struct{}
|
||||
|
||||
func (r *staleWorkerEventPolicies) GetByResourceID(_ context.Context, _ string) (*ResourcePolicy, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventPolicies) Upsert(_ context.Context, _ ResourcePolicy) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type staleWorkerEventAudit struct {
|
||||
createCount int
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventAudit) Create(_ context.Context, _ AuditEvent) error {
|
||||
r.createCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
type staleWorkerEventResourceRuntime struct{}
|
||||
|
||||
func (staleWorkerEventResourceRuntime) GetByID(_ context.Context, _ string) (*ResourceRuntimeSpec, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type staleWorkerEventAccess struct{}
|
||||
|
||||
func (staleWorkerEventAccess) IsTrustedDevice(_ context.Context, _, _ string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (staleWorkerEventAccess) GetPlatformRole(_ context.Context, _ string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (staleWorkerEventAccess) GetOrganizationRole(_ context.Context, _, _ string) (string, bool, error) {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
type staleWorkerEventLiveState struct {
|
||||
sessions map[string]*LiveSessionState
|
||||
upsertCount int
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) UpsertSession(_ context.Context, state LiveSessionState) error {
|
||||
copied := state
|
||||
s.sessions[state.SessionID] = &copied
|
||||
s.upsertCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) GetSession(_ context.Context, sessionID string) (*LiveSessionState, error) {
|
||||
state := s.sessions[sessionID]
|
||||
if state == nil {
|
||||
return nil, nil
|
||||
}
|
||||
copied := *state
|
||||
return &copied, nil
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) DeleteSession(_ context.Context, sessionID string) error {
|
||||
delete(s.sessions, sessionID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) BindController(_ context.Context, _ sessioncontracts.ControllerBinding, _ time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) GetControllerBinding(_ context.Context, _ string) (*sessioncontracts.ControllerBinding, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) ClearControllerBinding(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) StoreAttachToken(_ context.Context, _ sessioncontracts.AttachTokenClaims, _ time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) ConsumeAttachToken(_ context.Context, _ string) (*sessioncontracts.AttachTokenClaims, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) TouchAttachmentHeartbeat(_ context.Context, _, _ string, _ time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) UpdateWorkerRoute(_ context.Context, _ WorkerRoute, _ time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) GetWorkerRoute(_ context.Context, _ string) (*WorkerRoute, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) DeleteWorkerRoute(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type staleWorkerEventOrchestrator struct {
|
||||
releaseCount int
|
||||
}
|
||||
|
||||
func (o *staleWorkerEventOrchestrator) Reserve(_ context.Context, _ workercontracts.AttachRequest) (*workercontracts.WorkerLease, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (o *staleWorkerEventOrchestrator) GetSessionLease(_ context.Context, _ string) (*workercontracts.WorkerLease, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (o *staleWorkerEventOrchestrator) ReleaseSessionLease(_ context.Context, _ string) error {
|
||||
o.releaseCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *staleWorkerEventOrchestrator) PrepareAttachment(_ context.Context, _ RemoteSession, _ SessionAttachment, _ map[string]any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *staleWorkerEventOrchestrator) NotifyDetachment(_ context.Context, _ RemoteSession, _ SessionAttachment) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *staleWorkerEventOrchestrator) TerminateRemoteSession(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *staleWorkerEventOrchestrator) ValidateSessionRuntime(_ context.Context, _, _ string) (bool, string, error) {
|
||||
return true, "", nil
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
|
||||
)
|
||||
|
||||
var allowedTransitions = map[sessioncontracts.State]map[sessioncontracts.State]struct{}{
|
||||
sessioncontracts.StateStarting: {
|
||||
sessioncontracts.StateActive: {},
|
||||
sessioncontracts.StateFailed: {},
|
||||
sessioncontracts.StateTerminated: {},
|
||||
},
|
||||
sessioncontracts.StateActive: {
|
||||
sessioncontracts.StateDetached: {},
|
||||
sessioncontracts.StateReconnecting: {},
|
||||
sessioncontracts.StateFailed: {},
|
||||
sessioncontracts.StateTerminated: {},
|
||||
},
|
||||
sessioncontracts.StateDetached: {
|
||||
sessioncontracts.StateReconnecting: {},
|
||||
sessioncontracts.StateTerminated: {},
|
||||
sessioncontracts.StateFailed: {},
|
||||
},
|
||||
sessioncontracts.StateReconnecting: {
|
||||
sessioncontracts.StateActive: {},
|
||||
sessioncontracts.StateDetached: {},
|
||||
sessioncontracts.StateFailed: {},
|
||||
sessioncontracts.StateTerminated: {},
|
||||
},
|
||||
}
|
||||
|
||||
func validateTransition(from, to sessioncontracts.State) error {
|
||||
if from == to {
|
||||
return nil
|
||||
}
|
||||
if allowed, ok := allowedTransitions[from]; ok {
|
||||
if _, ok := allowed[to]; ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("invalid session state transition: %s -> %s", from, to)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,29 @@
|
||||
package worker
|
||||
|
||||
type SessionEvent struct {
|
||||
Type string `json:"type"`
|
||||
SessionID string `json:"session_id"`
|
||||
WorkerID string `json:"worker_id"`
|
||||
Payload map[string]any `json:"payload,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
SessionEventConnected = "session_connected"
|
||||
SessionEventHeartbeat = "session_heartbeat"
|
||||
SessionEventFailed = "session_failed"
|
||||
SessionEventTerminated = "session_terminated"
|
||||
SessionEventDisplayReady = "session_display_ready"
|
||||
SessionEventRenderReady = "session_render_ready"
|
||||
SessionEventRenderDirty = "session_render_dirty"
|
||||
SessionEventRenderResized = "session_render_resized"
|
||||
SessionEventCursorUpdated = "session_cursor_updated"
|
||||
SessionEventFrame = "session_frame"
|
||||
SessionEventClipboardText = "session_clipboard_text"
|
||||
SessionEventFileUploaded = "session_file_upload_completed"
|
||||
SessionEventFileDownloadAvailable = "session_file_download_available"
|
||||
SessionEventFileDownloadChunk = "session_file_download_chunk"
|
||||
SessionEventFileDownloadProgress = "session_file_download_progress"
|
||||
SessionEventFileDownloadCompleted = "session_file_download_completed"
|
||||
SessionEventFileDownloadFailed = "session_file_download_failed"
|
||||
SessionEventFileDownloadBlocked = "session_file_download_blocked"
|
||||
)
|
||||
@@ -0,0 +1,52 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/sessionbroker"
|
||||
)
|
||||
|
||||
type LeaseMonitor struct {
|
||||
service *Service
|
||||
broker *sessionbroker.Service
|
||||
interval time.Duration
|
||||
}
|
||||
|
||||
func NewLeaseMonitor(service *Service, broker *sessionbroker.Service, interval time.Duration) *LeaseMonitor {
|
||||
if interval <= 0 {
|
||||
interval = 15 * time.Second
|
||||
}
|
||||
return &LeaseMonitor{
|
||||
service: service,
|
||||
broker: broker,
|
||||
interval: interval,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *LeaseMonitor) Run(ctx context.Context) error {
|
||||
ticker := time.NewTicker(m.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
stale, err := m.service.RecoverStaleLeases(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, lease := range stale {
|
||||
err := m.broker.MarkSessionFailed(ctx, sessionbroker.MarkSessionFailedCommand{
|
||||
SessionID: lease.SessionID,
|
||||
Reason: "worker_lease_stale_or_worker_missing",
|
||||
})
|
||||
if err != nil && !errors.Is(err, sessionbroker.ErrSessionNotFound) && !errors.Is(err, sessionbroker.ErrSessionNotTerminable) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/sessionbroker"
|
||||
)
|
||||
|
||||
type EventProcessor struct {
|
||||
client *redis.Client
|
||||
broker *sessionbroker.Service
|
||||
}
|
||||
|
||||
func NewEventProcessor(client *redis.Client, broker *sessionbroker.Service) *EventProcessor {
|
||||
return &EventProcessor{client: client, broker: broker}
|
||||
}
|
||||
|
||||
func (p *EventProcessor) Run(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
result, err := p.client.BLPop(ctx, 5*time.Second, "worker:events").Result()
|
||||
if err == redis.Nil {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("consume worker event: %w", err)
|
||||
}
|
||||
if len(result) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
var event SessionEvent
|
||||
if err := json.Unmarshal([]byte(result[1]), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
if err := p.handleEvent(ctx, event); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *EventProcessor) handleEvent(ctx context.Context, event SessionEvent) error {
|
||||
switch event.Type {
|
||||
case SessionEventConnected, SessionEventDisplayReady:
|
||||
if err := p.broker.HandleWorkerConnected(ctx, event.SessionID); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(event.Payload) > 0 {
|
||||
if err := p.broker.UpdateWorkerRenderTelemetry(ctx, event.SessionID, event.Payload); err != nil && !errors.Is(err, sessionbroker.ErrSessionNotFound) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case SessionEventHeartbeat:
|
||||
return p.broker.HandleWorkerHeartbeat(ctx, event.SessionID)
|
||||
case SessionEventRenderReady, SessionEventRenderDirty, SessionEventRenderResized, SessionEventCursorUpdated, SessionEventFrame:
|
||||
if len(event.Payload) == 0 {
|
||||
return nil
|
||||
}
|
||||
if correlationID, _ := event.Payload["input_correlation_id"].(string); correlationID != "" {
|
||||
slog.Info("worker frame event received",
|
||||
"session_id", event.SessionID,
|
||||
"worker_id", event.WorkerID,
|
||||
"frame_sequence", event.Payload["frame_sequence"],
|
||||
"correlation_id", correlationID,
|
||||
"worker_frame_captured_at", event.Payload["worker_frame_captured_at"],
|
||||
"trace_stage", "backend_frame_receive")
|
||||
}
|
||||
return p.updateRenderTelemetryWithRetry(ctx, event.SessionID, event.Payload)
|
||||
case SessionEventClipboardText:
|
||||
if len(event.Payload) == 0 {
|
||||
return nil
|
||||
}
|
||||
slog.Info("worker clipboard event received",
|
||||
"session_id", event.SessionID,
|
||||
"worker_id", event.WorkerID,
|
||||
"origin", event.Payload["origin"],
|
||||
"sequence_id", event.Payload["sequence_id"],
|
||||
"content_hash", event.Payload["content_hash"])
|
||||
return p.broker.UpdateWorkerClipboardText(ctx, event.SessionID, event.Payload)
|
||||
case SessionEventFileUploaded:
|
||||
slog.Info("worker file upload completed",
|
||||
"session_id", event.SessionID,
|
||||
"worker_id", event.WorkerID,
|
||||
"transfer_id", event.Payload["transfer_id"],
|
||||
"file_name", event.Payload["file_name"],
|
||||
"file_size", event.Payload["file_size"],
|
||||
"content_hash", event.Payload["content_hash"],
|
||||
"storage_path", event.Payload["storage_path"])
|
||||
return nil
|
||||
case SessionEventFileDownloadAvailable, SessionEventFileDownloadChunk, SessionEventFileDownloadProgress,
|
||||
SessionEventFileDownloadCompleted, SessionEventFileDownloadFailed, SessionEventFileDownloadBlocked:
|
||||
slog.Info("worker file download event received",
|
||||
"session_id", event.SessionID,
|
||||
"worker_id", event.WorkerID,
|
||||
"event_type", event.Type,
|
||||
"transfer_id", event.Payload["transfer_id"],
|
||||
"file_id", event.Payload["file_id"],
|
||||
"file_name", event.Payload["file_name"],
|
||||
"status", event.Payload["status"])
|
||||
return p.broker.UpdateWorkerFileDownloadEvent(ctx, event.SessionID, event.Type, event.Payload)
|
||||
case SessionEventFailed:
|
||||
reason, _ := event.Payload["reason"].(string)
|
||||
err := p.broker.MarkSessionFailed(ctx, sessionbroker.MarkSessionFailedCommand{
|
||||
SessionID: event.SessionID,
|
||||
Reason: reason,
|
||||
})
|
||||
if errors.Is(err, sessionbroker.ErrSessionNotFound) || errors.Is(err, sessionbroker.ErrSessionNotTerminable) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
case SessionEventTerminated:
|
||||
reason, _ := event.Payload["reason"].(string)
|
||||
err := p.broker.TerminateSession(ctx, sessionbroker.TerminateSessionCommand{
|
||||
SessionID: event.SessionID,
|
||||
Reason: reason,
|
||||
})
|
||||
if errors.Is(err, sessionbroker.ErrSessionNotFound) || errors.Is(err, sessionbroker.ErrSessionNotTerminable) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *EventProcessor) updateRenderTelemetryWithRetry(ctx context.Context, sessionID string, payload map[string]any) error {
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < 10; attempt++ {
|
||||
err := p.broker.UpdateWorkerRenderTelemetry(ctx, sessionID, payload)
|
||||
if err == nil || errors.Is(err, sessionbroker.ErrSessionNotFound) {
|
||||
return nil
|
||||
}
|
||||
lastErr = err
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
@@ -0,0 +1,264 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
|
||||
)
|
||||
|
||||
type RedisStore struct {
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
func NewRedisStore(client *redis.Client) *RedisStore {
|
||||
return &RedisStore{client: client}
|
||||
}
|
||||
|
||||
func (s *RedisStore) RegisterWorker(ctx context.Context, registration workercontracts.WorkerRegistration, ttl time.Duration) error {
|
||||
payload, err := json.Marshal(registration)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal worker registration: %w", err)
|
||||
}
|
||||
pipe := s.client.TxPipeline()
|
||||
pipe.Set(ctx, workerKey(registration.WorkerID), payload, ttl)
|
||||
pipe.SAdd(ctx, workerSetKey(), registration.WorkerID)
|
||||
_, err = pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("register worker: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RedisStore) TouchWorkerHeartbeat(ctx context.Context, heartbeat workercontracts.WorkerHeartbeat, ttl time.Duration) error {
|
||||
registration, err := s.GetWorker(ctx, heartbeat.WorkerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if registration == nil {
|
||||
registration = &workercontracts.WorkerRegistration{
|
||||
WorkerID: heartbeat.WorkerID,
|
||||
Protocol: workercontracts.ProtocolRDP,
|
||||
}
|
||||
}
|
||||
registration.Status = heartbeat.Status
|
||||
registration.LastHeartbeatAt = heartbeat.LastHeartbeatAt
|
||||
return s.RegisterWorker(ctx, *registration, ttl)
|
||||
}
|
||||
|
||||
func (s *RedisStore) ListWorkers(ctx context.Context) ([]workercontracts.WorkerRegistration, error) {
|
||||
ids, err := s.client.SMembers(ctx, workerSetKey()).Result()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list worker ids: %w", err)
|
||||
}
|
||||
workers := make([]workercontracts.WorkerRegistration, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
worker, err := s.GetWorker(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if worker != nil {
|
||||
workers = append(workers, *worker)
|
||||
}
|
||||
}
|
||||
return workers, nil
|
||||
}
|
||||
|
||||
func (s *RedisStore) GetWorker(ctx context.Context, workerID string) (*workercontracts.WorkerRegistration, error) {
|
||||
payload, err := s.client.Get(ctx, workerKey(workerID)).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get worker: %w", err)
|
||||
}
|
||||
var registration workercontracts.WorkerRegistration
|
||||
if err := json.Unmarshal([]byte(payload), ®istration); err != nil {
|
||||
return nil, fmt.Errorf("decode worker registration: %w", err)
|
||||
}
|
||||
return ®istration, nil
|
||||
}
|
||||
|
||||
func (s *RedisStore) AcquireLease(ctx context.Context, lease workercontracts.WorkerLease, ttl time.Duration) error {
|
||||
payload, err := json.Marshal(lease)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal lease: %w", err)
|
||||
}
|
||||
ok, err := s.client.SetNX(ctx, leaseKey(lease.LeaseID), payload, ttl).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("acquire lease: %w", err)
|
||||
}
|
||||
if !ok {
|
||||
return fmt.Errorf("lease already exists")
|
||||
}
|
||||
pipe := s.client.TxPipeline()
|
||||
pipe.SAdd(ctx, leaseSetKey(), lease.LeaseID)
|
||||
pipe.Set(ctx, sessionLeaseKey(lease.SessionID), lease.LeaseID, ttl)
|
||||
_, err = pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("index lease: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RedisStore) GetLease(ctx context.Context, leaseID string) (*workercontracts.WorkerLease, error) {
|
||||
payload, err := s.client.Get(ctx, leaseKey(leaseID)).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get lease: %w", err)
|
||||
}
|
||||
var lease workercontracts.WorkerLease
|
||||
if err := json.Unmarshal([]byte(payload), &lease); err != nil {
|
||||
return nil, fmt.Errorf("decode lease: %w", err)
|
||||
}
|
||||
return &lease, nil
|
||||
}
|
||||
|
||||
func (s *RedisStore) GetLeaseBySession(ctx context.Context, sessionID string) (*workercontracts.WorkerLease, error) {
|
||||
leaseID, err := s.client.Get(ctx, sessionLeaseKey(sessionID)).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get lease by session: %w", err)
|
||||
}
|
||||
return s.GetLease(ctx, leaseID)
|
||||
}
|
||||
|
||||
func (s *RedisStore) RenewLease(ctx context.Context, lease workercontracts.WorkerLease, ttl time.Duration) error {
|
||||
payload, err := json.Marshal(lease)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal lease renewal: %w", err)
|
||||
}
|
||||
pipe := s.client.TxPipeline()
|
||||
pipe.Set(ctx, leaseKey(lease.LeaseID), payload, ttl)
|
||||
pipe.Set(ctx, sessionLeaseKey(lease.SessionID), lease.LeaseID, ttl)
|
||||
_, err = pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("renew lease: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RedisStore) ReleaseLease(ctx context.Context, leaseID string) error {
|
||||
lease, err := s.GetLease(ctx, leaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pipe := s.client.TxPipeline()
|
||||
pipe.Del(ctx, leaseKey(leaseID))
|
||||
pipe.SRem(ctx, leaseSetKey(), leaseID)
|
||||
if lease != nil {
|
||||
pipe.Del(ctx, sessionLeaseKey(lease.SessionID))
|
||||
}
|
||||
_, err = pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("release lease: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RedisStore) ListLeases(ctx context.Context) ([]workercontracts.WorkerLease, error) {
|
||||
ids, err := s.client.SMembers(ctx, leaseSetKey()).Result()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list lease ids: %w", err)
|
||||
}
|
||||
leases := make([]workercontracts.WorkerLease, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
lease, err := s.GetLease(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lease != nil {
|
||||
leases = append(leases, *lease)
|
||||
}
|
||||
}
|
||||
return leases, nil
|
||||
}
|
||||
|
||||
func (s *RedisStore) AppendEnvelope(ctx context.Context, envelope workercontracts.RoutedEnvelope) error {
|
||||
payload, err := json.Marshal(envelope)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal routed envelope: %w", err)
|
||||
}
|
||||
key := workerQueueKey(envelope.SessionID)
|
||||
if err := s.client.RPush(ctx, key, payload).Err(); err != nil {
|
||||
return fmt.Errorf("append routed envelope: %w", err)
|
||||
}
|
||||
if envelope.Type == "input" {
|
||||
correlationID, _ := envelope.Payload["correlation_id"].(string)
|
||||
if correlationID != "" {
|
||||
if length, err := s.client.LLen(ctx, key).Result(); err == nil {
|
||||
slog.Info("worker queue length after input append",
|
||||
"session_id", envelope.SessionID,
|
||||
"attachment_id", envelope.AttachmentID,
|
||||
"correlation_id", correlationID,
|
||||
"queue_key", key,
|
||||
"queue_length", length,
|
||||
"trace_stage", "redis_queue_append")
|
||||
}
|
||||
}
|
||||
}
|
||||
return s.client.Expire(ctx, key, 10*time.Minute).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) AppendAssignment(ctx context.Context, workerID string, payload map[string]any) error {
|
||||
encoded, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal worker assignment: %w", err)
|
||||
}
|
||||
if err := s.client.RPush(ctx, workerControlQueueKey(workerID), encoded).Err(); err != nil {
|
||||
return fmt.Errorf("append worker assignment: %w", err)
|
||||
}
|
||||
return s.client.Expire(ctx, workerControlQueueKey(workerID), 10*time.Minute).Err()
|
||||
}
|
||||
|
||||
func (s *RedisStore) AppendEvent(ctx context.Context, payload map[string]any) error {
|
||||
encoded, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal worker event: %w", err)
|
||||
}
|
||||
if err := s.client.RPush(ctx, workerEventsKey(), encoded).Err(); err != nil {
|
||||
return fmt.Errorf("append worker event: %w", err)
|
||||
}
|
||||
return s.client.Expire(ctx, workerEventsKey(), 10*time.Minute).Err()
|
||||
}
|
||||
|
||||
func workerKey(workerID string) string {
|
||||
return "worker:registration:" + workerID
|
||||
}
|
||||
|
||||
func workerSetKey() string {
|
||||
return "worker:registrations"
|
||||
}
|
||||
|
||||
func leaseKey(leaseID string) string {
|
||||
return "worker:lease:" + leaseID
|
||||
}
|
||||
|
||||
func leaseSetKey() string {
|
||||
return "worker:leases"
|
||||
}
|
||||
|
||||
func sessionLeaseKey(sessionID string) string {
|
||||
return "worker:session-lease:" + sessionID
|
||||
}
|
||||
|
||||
func workerQueueKey(sessionID string) string {
|
||||
return "worker:queue:" + sessionID
|
||||
}
|
||||
|
||||
func workerControlQueueKey(workerID string) string {
|
||||
return "worker:control:" + workerID
|
||||
}
|
||||
|
||||
func workerEventsKey() string {
|
||||
return "worker:events"
|
||||
}
|
||||
@@ -0,0 +1,274 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/sessionbroker"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/module"
|
||||
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
|
||||
)
|
||||
|
||||
var ErrNoWorkerAvailable = errors.New("no worker available")
|
||||
|
||||
type Service struct {
|
||||
cfg module.Config
|
||||
store Store
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
func NewService(deps module.Dependencies, store Store) *Service {
|
||||
return &Service{
|
||||
cfg: deps.Config,
|
||||
store: store,
|
||||
now: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Register(ctx context.Context, registration workercontracts.WorkerRegistration) error {
|
||||
if registration.WorkerID == "" {
|
||||
return fmt.Errorf("worker id is required")
|
||||
}
|
||||
registration.LastHeartbeatAt = s.now().UTC()
|
||||
return s.store.RegisterWorker(ctx, registration, s.cfg.Worker.HeartbeatTTL)
|
||||
}
|
||||
|
||||
func (s *Service) Heartbeat(ctx context.Context, heartbeat workercontracts.WorkerHeartbeat) error {
|
||||
heartbeat.LastHeartbeatAt = s.now().UTC()
|
||||
return s.store.TouchWorkerHeartbeat(ctx, heartbeat, s.cfg.Worker.HeartbeatTTL)
|
||||
}
|
||||
|
||||
func (s *Service) Reserve(ctx context.Context, request workercontracts.AttachRequest) (*workercontracts.WorkerLease, error) {
|
||||
registration, err := s.reserveWorker(ctx, workercontracts.ProtocolRDP, request.RequiredCapabilities)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.AcquireLease(ctx, registration.WorkerID, request)
|
||||
}
|
||||
|
||||
func (s *Service) reserveWorker(ctx context.Context, protocol workercontracts.Protocol, capabilities []string) (*workercontracts.WorkerRegistration, error) {
|
||||
workers, err := s.store.ListWorkers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
now := s.now().UTC()
|
||||
for _, worker := range workers {
|
||||
if worker.Protocol != protocol || worker.Status != workercontracts.StatusOnline {
|
||||
continue
|
||||
}
|
||||
if now.Sub(worker.LastHeartbeatAt) > s.cfg.Worker.StaleLeaseGracePeriod+s.cfg.Worker.HeartbeatTTL {
|
||||
continue
|
||||
}
|
||||
if !hasCapabilities(worker.Capabilities, capabilities) {
|
||||
continue
|
||||
}
|
||||
return &worker, nil
|
||||
}
|
||||
return nil, ErrNoWorkerAvailable
|
||||
}
|
||||
|
||||
func (s *Service) AcquireLease(ctx context.Context, workerID string, request workercontracts.AttachRequest) (*workercontracts.WorkerLease, error) {
|
||||
if request.SessionID == "" {
|
||||
request.SessionID = uuid.NewString()
|
||||
}
|
||||
now := s.now().UTC()
|
||||
lease := workercontracts.WorkerLease{
|
||||
LeaseID: uuid.NewString(),
|
||||
WorkerID: workerID,
|
||||
Protocol: workercontracts.ProtocolRDP,
|
||||
ResourceID: request.ResourceID,
|
||||
SessionID: request.SessionID,
|
||||
Capabilities: request.RequiredCapabilities,
|
||||
ControlStream: "worker://control/" + workerID,
|
||||
ExpiresAt: now.Add(s.cfg.Worker.LeaseTTL),
|
||||
RenderQualityProfile: normalizeRenderQualityProfile(request.RenderQualityProfile),
|
||||
}
|
||||
if err := s.store.AcquireLease(ctx, lease, s.cfg.Worker.LeaseTTL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &lease, nil
|
||||
}
|
||||
|
||||
func (s *Service) GetSessionLease(ctx context.Context, sessionID string) (*workercontracts.WorkerLease, error) {
|
||||
return s.store.GetLeaseBySession(ctx, sessionID)
|
||||
}
|
||||
|
||||
func (s *Service) RenewLease(ctx context.Context, leaseID string) (*workercontracts.WorkerLease, error) {
|
||||
lease, err := s.store.GetLease(ctx, leaseID)
|
||||
if err != nil || lease == nil {
|
||||
return lease, err
|
||||
}
|
||||
lease.ExpiresAt = s.now().UTC().Add(s.cfg.Worker.LeaseTTL)
|
||||
if err := s.store.RenewLease(ctx, *lease, s.cfg.Worker.LeaseTTL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return lease, nil
|
||||
}
|
||||
|
||||
func (s *Service) ReleaseLease(ctx context.Context, leaseID string) error {
|
||||
return s.store.ReleaseLease(ctx, leaseID)
|
||||
}
|
||||
|
||||
func (s *Service) ReleaseSessionLease(ctx context.Context, sessionID string) error {
|
||||
lease, err := s.store.GetLeaseBySession(ctx, sessionID)
|
||||
if err != nil || lease == nil {
|
||||
return err
|
||||
}
|
||||
return s.store.ReleaseLease(ctx, lease.LeaseID)
|
||||
}
|
||||
|
||||
func (s *Service) RecoverStaleLeases(ctx context.Context) ([]workercontracts.WorkerLease, error) {
|
||||
leases, err := s.store.ListLeases(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var stale []workercontracts.WorkerLease
|
||||
now := s.now().UTC()
|
||||
for _, lease := range leases {
|
||||
registration, err := s.store.GetWorker(ctx, lease.WorkerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if registration == nil || now.Sub(registration.LastHeartbeatAt) > s.cfg.Worker.StaleLeaseGracePeriod+s.cfg.Worker.HeartbeatTTL {
|
||||
if err := s.store.ReleaseLease(ctx, lease.LeaseID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stale = append(stale, lease)
|
||||
}
|
||||
}
|
||||
return stale, nil
|
||||
}
|
||||
|
||||
func (s *Service) ValidateSessionRuntime(ctx context.Context, sessionID, workerID string) (bool, string, error) {
|
||||
lease, err := s.store.GetLeaseBySession(ctx, sessionID)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
if lease == nil {
|
||||
return false, "worker_lease_missing", nil
|
||||
}
|
||||
if workerID != "" && lease.WorkerID != workerID {
|
||||
_ = s.store.ReleaseLease(ctx, lease.LeaseID)
|
||||
return false, "worker_binding_mismatch", nil
|
||||
}
|
||||
now := s.now().UTC()
|
||||
if !lease.ExpiresAt.After(now) {
|
||||
_ = s.store.ReleaseLease(ctx, lease.LeaseID)
|
||||
return false, "worker_lease_expired", nil
|
||||
}
|
||||
registration, err := s.store.GetWorker(ctx, lease.WorkerID)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
if registration == nil {
|
||||
_ = s.store.ReleaseLease(ctx, lease.LeaseID)
|
||||
return false, "worker_registration_missing", nil
|
||||
}
|
||||
if registration.Status != workercontracts.StatusOnline {
|
||||
return false, "worker_not_online", nil
|
||||
}
|
||||
if now.Sub(registration.LastHeartbeatAt) > s.cfg.Worker.StaleLeaseGracePeriod+s.cfg.Worker.HeartbeatTTL {
|
||||
_ = s.store.ReleaseLease(ctx, lease.LeaseID)
|
||||
return false, "worker_heartbeat_stale", nil
|
||||
}
|
||||
return true, "", nil
|
||||
}
|
||||
|
||||
func (s *Service) PublishControl(ctx context.Context, envelope workercontracts.RoutedEnvelope) error {
|
||||
return s.store.AppendEnvelope(ctx, envelope)
|
||||
}
|
||||
|
||||
func (s *Service) PublishInput(ctx context.Context, envelope workercontracts.RoutedEnvelope) error {
|
||||
return s.store.AppendEnvelope(ctx, envelope)
|
||||
}
|
||||
|
||||
func hasCapabilities(workerCaps, required []string) bool {
|
||||
for _, capability := range required {
|
||||
if !slices.Contains(workerCaps, capability) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Service) PrepareAttachment(ctx context.Context, session sessionbroker.RemoteSession, attachment sessionbroker.SessionAttachment, runtimeMetadata map[string]any) error {
|
||||
renderQualityProfile := normalizeRenderQualityProfile(session.RenderQualityProfile)
|
||||
if renderQualityProfile == "balanced" {
|
||||
renderQualityProfile = renderQualityProfileFromMetadata(session.Metadata)
|
||||
}
|
||||
if runtimeMetadata == nil {
|
||||
runtimeMetadata = decodeMetadata(session.Metadata)
|
||||
}
|
||||
return s.store.AppendAssignment(ctx, session.WorkerID, map[string]any{
|
||||
"type": "session_assignment",
|
||||
"session_id": session.ID,
|
||||
"worker_id": session.WorkerID,
|
||||
"attachment_id": attachment.ID,
|
||||
"user_id": attachment.UserID,
|
||||
"device_id": attachment.DeviceID,
|
||||
"takeover_of": attachment.TakeoverOf,
|
||||
"state": session.State,
|
||||
"render_quality_profile": renderQualityProfile,
|
||||
"metadata": runtimeMetadata,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) NotifyDetachment(ctx context.Context, session sessionbroker.RemoteSession, attachment sessionbroker.SessionAttachment) error {
|
||||
return s.PublishControl(ctx, workercontracts.RoutedEnvelope{
|
||||
SessionID: session.ID,
|
||||
AttachmentID: attachment.ID,
|
||||
Type: "control",
|
||||
Payload: map[string]any{
|
||||
"action": "detach",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) TerminateRemoteSession(ctx context.Context, sessionID, attachmentID string) error {
|
||||
return s.PublishControl(ctx, workercontracts.RoutedEnvelope{
|
||||
SessionID: sessionID,
|
||||
AttachmentID: attachmentID,
|
||||
Type: "control",
|
||||
Payload: map[string]any{
|
||||
"action": "terminate",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func decodeMetadata(payload []byte) map[string]any {
|
||||
var out map[string]any
|
||||
if len(payload) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
if err := json.Unmarshal(payload, &out); err != nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeRenderQualityProfile(profile string) string {
|
||||
switch profile {
|
||||
case "low_bandwidth", "balanced", "high_quality", "text_priority":
|
||||
return profile
|
||||
default:
|
||||
return "balanced"
|
||||
}
|
||||
}
|
||||
|
||||
func renderQualityProfileFromMetadata(metadata []byte) string {
|
||||
decoded := decodeMetadata(metadata)
|
||||
resource, _ := decoded["resource"].(map[string]any)
|
||||
if resource == nil {
|
||||
return "balanced"
|
||||
}
|
||||
if profile, ok := resource["render_quality_profile"].(string); ok && profile != "" {
|
||||
return normalizeRenderQualityProfile(profile)
|
||||
}
|
||||
return "balanced"
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
|
||||
)
|
||||
|
||||
type Store interface {
|
||||
RegisterWorker(ctx context.Context, registration workercontracts.WorkerRegistration, ttl time.Duration) error
|
||||
TouchWorkerHeartbeat(ctx context.Context, heartbeat workercontracts.WorkerHeartbeat, ttl time.Duration) error
|
||||
ListWorkers(ctx context.Context) ([]workercontracts.WorkerRegistration, error)
|
||||
GetWorker(ctx context.Context, workerID string) (*workercontracts.WorkerRegistration, error)
|
||||
AcquireLease(ctx context.Context, lease workercontracts.WorkerLease, ttl time.Duration) error
|
||||
GetLease(ctx context.Context, leaseID string) (*workercontracts.WorkerLease, error)
|
||||
GetLeaseBySession(ctx context.Context, sessionID string) (*workercontracts.WorkerLease, error)
|
||||
RenewLease(ctx context.Context, lease workercontracts.WorkerLease, ttl time.Duration) error
|
||||
ReleaseLease(ctx context.Context, leaseID string) error
|
||||
ListLeases(ctx context.Context) ([]workercontracts.WorkerLease, error)
|
||||
AppendAssignment(ctx context.Context, workerID string, payload map[string]any) error
|
||||
AppendEnvelope(ctx context.Context, envelope workercontracts.RoutedEnvelope) error
|
||||
AppendEvent(ctx context.Context, payload map[string]any) error
|
||||
}
|
||||
@@ -0,0 +1,329 @@
|
||||
package authority
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/config"
|
||||
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
|
||||
)
|
||||
|
||||
const (
|
||||
ModeStrict = "strict"
|
||||
ModeLegacy = "legacy"
|
||||
|
||||
ActivationSchemaVersion = "rap.installation.activation.v1"
|
||||
|
||||
PlatformRoleUser = "user"
|
||||
PlatformRoleAdmin = "platform_admin"
|
||||
PlatformRoleRecoveryAdmin = "platform_recovery_admin"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidAuthorityMode = errors.New("invalid installation authority mode")
|
||||
ErrProductRootKeyNeeded = errors.New("product root public key is required")
|
||||
ErrInvalidActivation = errors.New("invalid installation activation")
|
||||
ErrInvalidGrant = errors.New("invalid platform role grant")
|
||||
)
|
||||
|
||||
type ActivationPayload struct {
|
||||
SchemaVersion string `json:"schema_version"`
|
||||
InstallID string `json:"install_id"`
|
||||
OwnerEmail string `json:"owner_email"`
|
||||
PlatformRole string `json:"platform_role"`
|
||||
IssuedAt time.Time `json:"issued_at"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
Environment string `json:"environment,omitempty"`
|
||||
}
|
||||
|
||||
type Verifier struct {
|
||||
mode string
|
||||
rootPublicKey ed25519.PublicKey
|
||||
rootFingerprint string
|
||||
allowInsecureBootstrap bool
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
func NewVerifier(cfg config.InstallationConfig) (*Verifier, error) {
|
||||
mode := strings.ToLower(strings.TrimSpace(cfg.AuthorityMode))
|
||||
if mode == "" {
|
||||
mode = ModeLegacy
|
||||
}
|
||||
verifier := &Verifier{
|
||||
mode: mode,
|
||||
allowInsecureBootstrap: cfg.AllowInsecureBootstrap,
|
||||
now: time.Now,
|
||||
}
|
||||
|
||||
switch mode {
|
||||
case ModeLegacy:
|
||||
return verifier, nil
|
||||
case ModeStrict:
|
||||
publicKey, err := decodeEd25519PublicKey(cfg.ProductRootPublicKeyBase64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
verifier.rootPublicKey = publicKey
|
||||
fingerprint := sha256.Sum256(publicKey)
|
||||
verifier.rootFingerprint = hex.EncodeToString(fingerprint[:])
|
||||
return verifier, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: %s", ErrInvalidAuthorityMode, mode)
|
||||
}
|
||||
}
|
||||
|
||||
func (v *Verifier) Mode() string {
|
||||
if v == nil || v.mode == "" {
|
||||
return ModeLegacy
|
||||
}
|
||||
return v.mode
|
||||
}
|
||||
|
||||
func (v *Verifier) Strict() bool {
|
||||
return v != nil && v.mode == ModeStrict
|
||||
}
|
||||
|
||||
func (v *Verifier) AllowInsecureBootstrap() bool {
|
||||
return v != nil && v.allowInsecureBootstrap
|
||||
}
|
||||
|
||||
func (v *Verifier) RootFingerprint() string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return v.rootFingerprint
|
||||
}
|
||||
|
||||
func (v *Verifier) VerifyActivation(payload json.RawMessage, signature string) (ActivationPayload, error) {
|
||||
if v == nil || !v.Strict() {
|
||||
return ActivationPayload{}, ErrProductRootKeyNeeded
|
||||
}
|
||||
activation, canonical, err := parseActivationPayload(payload)
|
||||
if err != nil {
|
||||
return ActivationPayload{}, err
|
||||
}
|
||||
if err := activation.validate(v.now().UTC()); err != nil {
|
||||
return ActivationPayload{}, err
|
||||
}
|
||||
if err := v.verifySignature(canonical, signature); err != nil {
|
||||
return ActivationPayload{}, fmt.Errorf("%w: %v", ErrInvalidActivation, err)
|
||||
}
|
||||
return activation, nil
|
||||
}
|
||||
|
||||
func (v *Verifier) VerifyPlatformRoleGrant(payload json.RawMessage, signature, expectedInstallID, expectedEmail, expectedRole string) (ActivationPayload, error) {
|
||||
activation, err := v.VerifyActivation(payload, signature)
|
||||
if err != nil {
|
||||
return ActivationPayload{}, fmt.Errorf("%w: %v", ErrInvalidGrant, err)
|
||||
}
|
||||
if activation.InstallID != strings.TrimSpace(expectedInstallID) {
|
||||
return ActivationPayload{}, fmt.Errorf("%w: install_id mismatch", ErrInvalidGrant)
|
||||
}
|
||||
if !strings.EqualFold(activation.OwnerEmail, strings.TrimSpace(expectedEmail)) {
|
||||
return ActivationPayload{}, fmt.Errorf("%w: owner_email mismatch", ErrInvalidGrant)
|
||||
}
|
||||
if activation.PlatformRole != strings.TrimSpace(expectedRole) {
|
||||
return ActivationPayload{}, fmt.Errorf("%w: platform_role mismatch", ErrInvalidGrant)
|
||||
}
|
||||
return activation, nil
|
||||
}
|
||||
|
||||
func CanonicalJSON(raw json.RawMessage) ([]byte, error) {
|
||||
if len(raw) == 0 {
|
||||
return nil, fmt.Errorf("%w: empty payload", ErrInvalidActivation)
|
||||
}
|
||||
var value any
|
||||
if err := json.Unmarshal(raw, &value); err != nil {
|
||||
return nil, fmt.Errorf("%w: invalid json: %v", ErrInvalidActivation, err)
|
||||
}
|
||||
canonical, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: canonical json: %v", ErrInvalidActivation, err)
|
||||
}
|
||||
return canonical, nil
|
||||
}
|
||||
|
||||
func EffectivePlatformRole(ctx context.Context, db postgresplatform.DBTX, verifier *Verifier, userID string) (string, error) {
|
||||
userID = strings.TrimSpace(userID)
|
||||
if userID == "" {
|
||||
return PlatformRoleUser, nil
|
||||
}
|
||||
if verifier == nil || !verifier.Strict() {
|
||||
return legacyPlatformRole(ctx, db, userID)
|
||||
}
|
||||
|
||||
var email string
|
||||
if err := db.QueryRow(ctx, `SELECT email FROM users WHERE id = $1::uuid`, userID).Scan(&email); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return PlatformRoleUser, nil
|
||||
}
|
||||
return "", fmt.Errorf("get user email for platform grant: %w", err)
|
||||
}
|
||||
|
||||
rows, err := db.Query(ctx, `
|
||||
SELECT prg.role, prg.install_id, prg.grant_payload, prg.grant_signature
|
||||
FROM platform_role_grants prg
|
||||
JOIN installation_authority ia
|
||||
ON ia.id = 1
|
||||
AND ia.install_id = prg.install_id
|
||||
AND ia.authority_state = 'active'
|
||||
WHERE prg.user_id = $1::uuid
|
||||
AND prg.revoked_at IS NULL
|
||||
AND (prg.expires_at IS NULL OR prg.expires_at > NOW())
|
||||
ORDER BY CASE prg.role
|
||||
WHEN 'platform_recovery_admin' THEN 0
|
||||
WHEN 'platform_admin' THEN 1
|
||||
ELSE 2
|
||||
END, prg.granted_at DESC
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("query platform role grants: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
bestRole := PlatformRoleUser
|
||||
for rows.Next() {
|
||||
var role, installID, signature string
|
||||
var payload []byte
|
||||
if err := rows.Scan(&role, &installID, &payload, &signature); err != nil {
|
||||
return "", fmt.Errorf("scan platform role grant: %w", err)
|
||||
}
|
||||
if _, err := verifier.VerifyPlatformRoleGrant(json.RawMessage(payload), signature, installID, email, role); err != nil {
|
||||
continue
|
||||
}
|
||||
if role == PlatformRoleRecoveryAdmin {
|
||||
return role, nil
|
||||
}
|
||||
if role == PlatformRoleAdmin {
|
||||
bestRole = role
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return "", fmt.Errorf("iterate platform role grants: %w", err)
|
||||
}
|
||||
return bestRole, nil
|
||||
}
|
||||
|
||||
func legacyPlatformRole(ctx context.Context, db postgresplatform.DBTX, userID string) (string, error) {
|
||||
var role string
|
||||
if err := db.QueryRow(ctx, `SELECT platform_role FROM users WHERE id = $1::uuid`, userID).Scan(&role); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return PlatformRoleUser, nil
|
||||
}
|
||||
return "", fmt.Errorf("get platform role: %w", err)
|
||||
}
|
||||
if role == "" {
|
||||
return PlatformRoleUser, nil
|
||||
}
|
||||
return role, nil
|
||||
}
|
||||
|
||||
func parseActivationPayload(raw json.RawMessage) (ActivationPayload, []byte, error) {
|
||||
canonical, err := CanonicalJSON(raw)
|
||||
if err != nil {
|
||||
return ActivationPayload{}, nil, err
|
||||
}
|
||||
var activation ActivationPayload
|
||||
if err := json.Unmarshal(canonical, &activation); err != nil {
|
||||
return ActivationPayload{}, nil, fmt.Errorf("%w: decode activation: %v", ErrInvalidActivation, err)
|
||||
}
|
||||
activation.SchemaVersion = strings.TrimSpace(activation.SchemaVersion)
|
||||
activation.InstallID = strings.TrimSpace(activation.InstallID)
|
||||
activation.OwnerEmail = strings.ToLower(strings.TrimSpace(activation.OwnerEmail))
|
||||
activation.PlatformRole = strings.TrimSpace(activation.PlatformRole)
|
||||
activation.Nonce = strings.TrimSpace(activation.Nonce)
|
||||
activation.Environment = strings.TrimSpace(activation.Environment)
|
||||
return activation, canonical, nil
|
||||
}
|
||||
|
||||
func (p ActivationPayload) validate(now time.Time) error {
|
||||
if p.SchemaVersion != ActivationSchemaVersion {
|
||||
return fmt.Errorf("%w: schema_version must be %s", ErrInvalidActivation, ActivationSchemaVersion)
|
||||
}
|
||||
if p.InstallID == "" {
|
||||
return fmt.Errorf("%w: install_id is required", ErrInvalidActivation)
|
||||
}
|
||||
if p.OwnerEmail == "" || !strings.Contains(p.OwnerEmail, "@") {
|
||||
return fmt.Errorf("%w: owner_email is required", ErrInvalidActivation)
|
||||
}
|
||||
switch p.PlatformRole {
|
||||
case PlatformRoleAdmin, PlatformRoleRecoveryAdmin:
|
||||
default:
|
||||
return fmt.Errorf("%w: platform_role must be platform_admin or platform_recovery_admin", ErrInvalidActivation)
|
||||
}
|
||||
if p.IssuedAt.IsZero() {
|
||||
return fmt.Errorf("%w: issued_at is required", ErrInvalidActivation)
|
||||
}
|
||||
if p.IssuedAt.After(now.Add(5 * time.Minute)) {
|
||||
return fmt.Errorf("%w: issued_at is too far in the future", ErrInvalidActivation)
|
||||
}
|
||||
if p.ExpiresAt != nil && !p.ExpiresAt.After(now) {
|
||||
return fmt.Errorf("%w: activation expired", ErrInvalidActivation)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *Verifier) verifySignature(payload []byte, signatureText string) error {
|
||||
signature, err := decodeBase64(strings.TrimSpace(signatureText))
|
||||
if err != nil {
|
||||
return fmt.Errorf("signature must be base64 encoded: %w", err)
|
||||
}
|
||||
if len(signature) != ed25519.SignatureSize {
|
||||
return fmt.Errorf("signature must decode to %d bytes", ed25519.SignatureSize)
|
||||
}
|
||||
if !ed25519.Verify(v.rootPublicKey, payload, signature) {
|
||||
return errors.New("signature verification failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeEd25519PublicKey(value string) (ed25519.PublicKey, error) {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return nil, ErrProductRootKeyNeeded
|
||||
}
|
||||
if block, _ := pem.Decode([]byte(value)); block != nil {
|
||||
parsed, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse product root public key PEM: %w", err)
|
||||
}
|
||||
publicKey, ok := parsed.(ed25519.PublicKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("product root public key PEM must contain an Ed25519 public key")
|
||||
}
|
||||
return publicKey, nil
|
||||
}
|
||||
decoded, err := decodeBase64(value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("product root public key must be base64 encoded: %w", err)
|
||||
}
|
||||
if len(decoded) != ed25519.PublicKeySize {
|
||||
return nil, fmt.Errorf("product root public key must decode to %d bytes", ed25519.PublicKeySize)
|
||||
}
|
||||
return ed25519.PublicKey(decoded), nil
|
||||
}
|
||||
|
||||
func decodeBase64(value string) ([]byte, error) {
|
||||
decoded, err := base64.StdEncoding.DecodeString(value)
|
||||
if err == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
decoded, rawErr := base64.RawStdEncoding.DecodeString(value)
|
||||
if rawErr == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
package authority
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/config"
|
||||
)
|
||||
|
||||
func TestVerifierAcceptsSignedActivation(t *testing.T) {
|
||||
publicKey, privateKey, err := ed25519.GenerateKey(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("generate key: %v", err)
|
||||
}
|
||||
verifier, err := NewVerifier(config.InstallationConfig{
|
||||
AuthorityMode: ModeStrict,
|
||||
ProductRootPublicKeyBase64: base64.StdEncoding.EncodeToString(publicKey),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewVerifier: %v", err)
|
||||
}
|
||||
verifier.now = func() time.Time { return time.Date(2026, 4, 28, 12, 0, 0, 0, time.UTC) }
|
||||
|
||||
payload := json.RawMessage(`{
|
||||
"platform_role":"platform_admin",
|
||||
"owner_email":"Owner@Example.test",
|
||||
"install_id":"install-1",
|
||||
"schema_version":"rap.installation.activation.v1",
|
||||
"issued_at":"2026-04-28T11:00:00Z",
|
||||
"expires_at":"2026-04-29T11:00:00Z"
|
||||
}`)
|
||||
canonical, err := CanonicalJSON(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("CanonicalJSON: %v", err)
|
||||
}
|
||||
signature := base64.StdEncoding.EncodeToString(ed25519.Sign(privateKey, canonical))
|
||||
|
||||
activation, err := verifier.VerifyActivation(payload, signature)
|
||||
if err != nil {
|
||||
t.Fatalf("VerifyActivation: %v", err)
|
||||
}
|
||||
if activation.OwnerEmail != "owner@example.test" || activation.PlatformRole != PlatformRoleAdmin {
|
||||
t.Fatalf("unexpected activation: %+v", activation)
|
||||
}
|
||||
if verifier.RootFingerprint() == "" {
|
||||
t.Fatal("expected root fingerprint")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifierRejectsTamperedActivation(t *testing.T) {
|
||||
publicKey, privateKey, err := ed25519.GenerateKey(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("generate key: %v", err)
|
||||
}
|
||||
verifier, err := NewVerifier(config.InstallationConfig{
|
||||
AuthorityMode: ModeStrict,
|
||||
ProductRootPublicKeyBase64: base64.StdEncoding.EncodeToString(publicKey),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewVerifier: %v", err)
|
||||
}
|
||||
verifier.now = func() time.Time { return time.Date(2026, 4, 28, 12, 0, 0, 0, time.UTC) }
|
||||
|
||||
payload := json.RawMessage(`{
|
||||
"schema_version":"rap.installation.activation.v1",
|
||||
"install_id":"install-1",
|
||||
"owner_email":"owner@example.test",
|
||||
"platform_role":"platform_admin",
|
||||
"issued_at":"2026-04-28T11:00:00Z"
|
||||
}`)
|
||||
canonical, err := CanonicalJSON(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("CanonicalJSON: %v", err)
|
||||
}
|
||||
signature := base64.StdEncoding.EncodeToString(ed25519.Sign(privateKey, canonical))
|
||||
tampered := json.RawMessage(strings.Replace(string(payload), "platform_admin", "platform_recovery_admin", 1))
|
||||
|
||||
if _, err := verifier.VerifyActivation(tampered, signature); err == nil {
|
||||
t.Fatal("expected tampered activation to fail")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
package clusterauth
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
AuthoritySchemaVersion = "rap.cluster_authority.v1"
|
||||
SignatureSchemaVersion = "rap.cluster_authority.signature.v1"
|
||||
AlgorithmEd25519 = "ed25519"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidKey = errors.New("invalid cluster authority key")
|
||||
ErrInvalidSignature = errors.New("invalid cluster authority signature")
|
||||
ErrInvalidPayload = errors.New("invalid cluster authority payload")
|
||||
)
|
||||
|
||||
type KeyPair struct {
|
||||
PublicKeyB64 string
|
||||
PrivateKeyB64 string
|
||||
Fingerprint string
|
||||
}
|
||||
|
||||
type Signature struct {
|
||||
SchemaVersion string `json:"schema_version"`
|
||||
Algorithm string `json:"algorithm"`
|
||||
KeyFingerprint string `json:"key_fingerprint"`
|
||||
Signature string `json:"signature"`
|
||||
SignedAt time.Time `json:"signed_at"`
|
||||
}
|
||||
|
||||
func GenerateKeyPair() (KeyPair, error) {
|
||||
publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return KeyPair{}, err
|
||||
}
|
||||
fingerprint := Fingerprint(publicKey)
|
||||
return KeyPair{
|
||||
PublicKeyB64: base64.StdEncoding.EncodeToString(publicKey),
|
||||
PrivateKeyB64: base64.StdEncoding.EncodeToString(privateKey),
|
||||
Fingerprint: fingerprint,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func Fingerprint(publicKey ed25519.PublicKey) string {
|
||||
sum := sha256.Sum256(publicKey)
|
||||
return "rap-ca-ed25519-" + hex.EncodeToString(sum[:16])
|
||||
}
|
||||
|
||||
func FingerprintFromBase64(publicKeyB64 string) (string, error) {
|
||||
publicKey, err := DecodePublicKey(publicKeyB64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return Fingerprint(publicKey), nil
|
||||
}
|
||||
|
||||
func SignRaw(privateKeyB64 string, payload json.RawMessage, signedAt time.Time) (Signature, error) {
|
||||
privateKey, err := DecodePrivateKey(privateKeyB64)
|
||||
if err != nil {
|
||||
return Signature{}, err
|
||||
}
|
||||
canonical, err := CanonicalJSON(payload)
|
||||
if err != nil {
|
||||
return Signature{}, err
|
||||
}
|
||||
publicKey, ok := privateKey.Public().(ed25519.PublicKey)
|
||||
if !ok {
|
||||
return Signature{}, ErrInvalidKey
|
||||
}
|
||||
signature := ed25519.Sign(privateKey, canonical)
|
||||
return Signature{
|
||||
SchemaVersion: SignatureSchemaVersion,
|
||||
Algorithm: AlgorithmEd25519,
|
||||
KeyFingerprint: Fingerprint(publicKey),
|
||||
Signature: base64.StdEncoding.EncodeToString(signature),
|
||||
SignedAt: signedAt.UTC(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func SignPayload(privateKeyB64 string, payload any, signedAt time.Time) (json.RawMessage, Signature, error) {
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, Signature{}, fmt.Errorf("%w: marshal: %v", ErrInvalidPayload, err)
|
||||
}
|
||||
signature, err := SignRaw(privateKeyB64, raw, signedAt)
|
||||
if err != nil {
|
||||
return nil, Signature{}, err
|
||||
}
|
||||
return json.RawMessage(raw), signature, nil
|
||||
}
|
||||
|
||||
func VerifyRaw(publicKeyB64 string, payload json.RawMessage, signature Signature) error {
|
||||
if signature.SchemaVersion != SignatureSchemaVersion {
|
||||
return fmt.Errorf("%w: schema_version must be %s", ErrInvalidSignature, SignatureSchemaVersion)
|
||||
}
|
||||
if signature.Algorithm != AlgorithmEd25519 {
|
||||
return fmt.Errorf("%w: algorithm must be %s", ErrInvalidSignature, AlgorithmEd25519)
|
||||
}
|
||||
publicKey, err := DecodePublicKey(publicKeyB64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if signature.KeyFingerprint != Fingerprint(publicKey) {
|
||||
return fmt.Errorf("%w: key fingerprint mismatch", ErrInvalidSignature)
|
||||
}
|
||||
canonical, err := CanonicalJSON(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decodedSignature, err := decodeBase64(strings.TrimSpace(signature.Signature))
|
||||
if err != nil || len(decodedSignature) != ed25519.SignatureSize {
|
||||
return fmt.Errorf("%w: signature must be base64 ed25519 signature", ErrInvalidSignature)
|
||||
}
|
||||
if !ed25519.Verify(publicKey, canonical, decodedSignature) {
|
||||
return ErrInvalidSignature
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func CanonicalJSON(raw json.RawMessage) ([]byte, error) {
|
||||
if len(raw) == 0 {
|
||||
return nil, fmt.Errorf("%w: empty payload", ErrInvalidPayload)
|
||||
}
|
||||
var value any
|
||||
if err := json.Unmarshal(raw, &value); err != nil {
|
||||
return nil, fmt.Errorf("%w: invalid json: %v", ErrInvalidPayload, err)
|
||||
}
|
||||
canonical, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: canonical json: %v", ErrInvalidPayload, err)
|
||||
}
|
||||
return canonical, nil
|
||||
}
|
||||
|
||||
func HashRaw(raw json.RawMessage) (string, error) {
|
||||
canonical, err := CanonicalJSON(raw)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sum := sha256.Sum256(canonical)
|
||||
return hex.EncodeToString(sum[:]), nil
|
||||
}
|
||||
|
||||
func DecodePublicKey(value string) (ed25519.PublicKey, error) {
|
||||
decoded, err := decodeBase64(strings.TrimSpace(value))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: public key must be base64 encoded", ErrInvalidKey)
|
||||
}
|
||||
if len(decoded) != ed25519.PublicKeySize {
|
||||
return nil, fmt.Errorf("%w: public key must decode to %d bytes", ErrInvalidKey, ed25519.PublicKeySize)
|
||||
}
|
||||
return ed25519.PublicKey(decoded), nil
|
||||
}
|
||||
|
||||
func DecodePrivateKey(value string) (ed25519.PrivateKey, error) {
|
||||
decoded, err := decodeBase64(strings.TrimSpace(value))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: private key must be base64 encoded", ErrInvalidKey)
|
||||
}
|
||||
if len(decoded) != ed25519.PrivateKeySize {
|
||||
return nil, fmt.Errorf("%w: private key must decode to %d bytes", ErrInvalidKey, ed25519.PrivateKeySize)
|
||||
}
|
||||
return ed25519.PrivateKey(decoded), nil
|
||||
}
|
||||
|
||||
func decodeBase64(value string) ([]byte, error) {
|
||||
if value == "" {
|
||||
return nil, errors.New("empty base64 value")
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(value)
|
||||
if err == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
return base64.RawStdEncoding.DecodeString(value)
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package clusterauth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSignAndVerifyRawPayload(t *testing.T) {
|
||||
keys, err := GenerateKeyPair()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKeyPair: %v", err)
|
||||
}
|
||||
payload := json.RawMessage(`{"cluster_id":"cluster-1","schema_version":"test.v1","value":1}`)
|
||||
|
||||
signature, err := SignRaw(keys.PrivateKeyB64, payload, time.Date(2026, 4, 28, 12, 0, 0, 0, time.UTC))
|
||||
if err != nil {
|
||||
t.Fatalf("SignRaw: %v", err)
|
||||
}
|
||||
if signature.KeyFingerprint != keys.Fingerprint {
|
||||
t.Fatalf("fingerprint = %q, want %q", signature.KeyFingerprint, keys.Fingerprint)
|
||||
}
|
||||
if err := VerifyRaw(keys.PublicKeyB64, payload, signature); err != nil {
|
||||
t.Fatalf("VerifyRaw: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyRawRejectsTamperedPayload(t *testing.T) {
|
||||
keys, err := GenerateKeyPair()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKeyPair: %v", err)
|
||||
}
|
||||
payload := json.RawMessage(`{"cluster_id":"cluster-1","schema_version":"test.v1","value":1}`)
|
||||
signature, err := SignRaw(keys.PrivateKeyB64, payload, time.Date(2026, 4, 28, 12, 0, 0, 0, time.UTC))
|
||||
if err != nil {
|
||||
t.Fatalf("SignRaw: %v", err)
|
||||
}
|
||||
|
||||
tampered := json.RawMessage(`{"cluster_id":"cluster-1","schema_version":"test.v1","value":2}`)
|
||||
if err := VerifyRaw(keys.PublicKeyB64, tampered, signature); !errors.Is(err, ErrInvalidSignature) {
|
||||
t.Fatalf("err = %v, want ErrInvalidSignature", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,307 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
App AppConfig
|
||||
HTTP HTTPConfig
|
||||
Postgres PostgresConfig
|
||||
Redis RedisConfig
|
||||
Auth AuthConfig
|
||||
Installation InstallationConfig
|
||||
DataPlane DataPlaneConfig
|
||||
Secret SecretConfig
|
||||
Session SessionConfig
|
||||
Worker WorkerConfig
|
||||
WebSocket WebSocketConfig
|
||||
}
|
||||
|
||||
type AppConfig struct {
|
||||
Name string
|
||||
Env string
|
||||
}
|
||||
|
||||
type HTTPConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
IdleTimeout time.Duration
|
||||
ShutdownTimeout time.Duration
|
||||
}
|
||||
|
||||
type PostgresConfig struct {
|
||||
DSN string
|
||||
MaxConns int32
|
||||
MinConns int32
|
||||
ConnectTimeout time.Duration
|
||||
}
|
||||
|
||||
type RedisConfig struct {
|
||||
Addr string
|
||||
Password string
|
||||
DB int
|
||||
DialTimeout time.Duration
|
||||
}
|
||||
|
||||
type AuthConfig struct {
|
||||
AccessTokenTTL time.Duration
|
||||
RefreshTokenTTL time.Duration
|
||||
Issuer string
|
||||
AccessTokenSecret string
|
||||
RefreshHashSecret string
|
||||
}
|
||||
|
||||
type InstallationConfig struct {
|
||||
AuthorityMode string
|
||||
ProductRootPublicKeyBase64 string
|
||||
ProductRootPublicKeyFile string
|
||||
AllowInsecureBootstrap bool
|
||||
}
|
||||
|
||||
type DataPlaneConfig struct {
|
||||
TokenTTL time.Duration
|
||||
TokenPrivateKeyPEM string
|
||||
TokenPrivateKeyFile string
|
||||
BackendGatewayURL string
|
||||
DirectWorkerWSSURLTemplate string
|
||||
DirectWorkerJSONRuntime bool
|
||||
DirectWorkerBinaryRender bool
|
||||
DirectWorkerTLSTrustMode string
|
||||
DirectWorkerTLSCARef string
|
||||
}
|
||||
|
||||
type SecretConfig struct {
|
||||
EncryptionKeyBase64 string
|
||||
EncryptionKeyFile string
|
||||
EncryptionKeyID string
|
||||
}
|
||||
|
||||
type SessionConfig struct {
|
||||
HeartbeatTTL time.Duration
|
||||
DetachGracePeriod time.Duration
|
||||
AttachTokenTTL time.Duration
|
||||
LiveStateTTL time.Duration
|
||||
RecoveryBatchSize int
|
||||
}
|
||||
|
||||
type WorkerConfig struct {
|
||||
LeaseTTL time.Duration
|
||||
HeartbeatTTL time.Duration
|
||||
StaleLeaseGracePeriod time.Duration
|
||||
}
|
||||
|
||||
type WebSocketConfig struct {
|
||||
WriteTimeout time.Duration
|
||||
PingInterval time.Duration
|
||||
PongWait time.Duration
|
||||
}
|
||||
|
||||
func Load() (Config, error) {
|
||||
cfg := Config{
|
||||
App: AppConfig{
|
||||
Name: getEnv("APP_NAME", "rap-api"),
|
||||
Env: getEnv("APP_ENV", "development"),
|
||||
},
|
||||
HTTP: HTTPConfig{
|
||||
Host: getEnv("HTTP_HOST", "0.0.0.0"),
|
||||
Port: getInt("HTTP_PORT", 8080),
|
||||
ReadTimeout: getDuration("HTTP_READ_TIMEOUT", 15*time.Second),
|
||||
WriteTimeout: getDuration("HTTP_WRITE_TIMEOUT", 15*time.Second),
|
||||
IdleTimeout: getDuration("HTTP_IDLE_TIMEOUT", 60*time.Second),
|
||||
ShutdownTimeout: getDuration("HTTP_SHUTDOWN_TIMEOUT", 10*time.Second),
|
||||
},
|
||||
Postgres: PostgresConfig{
|
||||
DSN: getEnv("POSTGRES_DSN", ""),
|
||||
MaxConns: int32(getInt("POSTGRES_MAX_CONNS", 20)),
|
||||
MinConns: int32(getInt("POSTGRES_MIN_CONNS", 2)),
|
||||
ConnectTimeout: getDuration("POSTGRES_CONNECT_TIMEOUT", 5*time.Second),
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Addr: getEnv("REDIS_ADDR", "localhost:6379"),
|
||||
Password: getEnv("REDIS_PASSWORD", ""),
|
||||
DB: getInt("REDIS_DB", 0),
|
||||
DialTimeout: getDuration("REDIS_DIAL_TIMEOUT", 5*time.Second),
|
||||
},
|
||||
Auth: AuthConfig{
|
||||
AccessTokenTTL: getDuration("AUTH_ACCESS_TOKEN_TTL", 15*time.Minute),
|
||||
RefreshTokenTTL: getDuration("AUTH_REFRESH_TOKEN_TTL", 30*24*time.Hour),
|
||||
Issuer: getEnv("AUTH_ISSUER", "rap-api"),
|
||||
AccessTokenSecret: getEnv("AUTH_ACCESS_TOKEN_SECRET", ""),
|
||||
RefreshHashSecret: getEnv("AUTH_REFRESH_HASH_SECRET", ""),
|
||||
},
|
||||
Installation: InstallationConfig{
|
||||
AuthorityMode: getEnv("INSTALLATION_AUTHORITY_MODE", ""),
|
||||
ProductRootPublicKeyBase64: getEnv("INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_B64", ""),
|
||||
ProductRootPublicKeyFile: getEnv("INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_FILE", ""),
|
||||
AllowInsecureBootstrap: getBool("INSTALLATION_INSECURE_BOOTSTRAP_ENABLED", false),
|
||||
},
|
||||
DataPlane: DataPlaneConfig{
|
||||
TokenTTL: getDuration("DATA_PLANE_TOKEN_TTL", 1*time.Minute),
|
||||
TokenPrivateKeyPEM: getEnv("DATA_PLANE_TOKEN_PRIVATE_KEY_PEM", ""),
|
||||
TokenPrivateKeyFile: getEnv("DATA_PLANE_TOKEN_PRIVATE_KEY_FILE", ""),
|
||||
BackendGatewayURL: getEnv("DATA_PLANE_BACKEND_GATEWAY_URL", "/api/v1/gateway/ws"),
|
||||
DirectWorkerWSSURLTemplate: getEnv("DATA_PLANE_DIRECT_WORKER_WSS_URL_TEMPLATE", ""),
|
||||
DirectWorkerJSONRuntime: getBool("DATA_PLANE_DIRECT_WORKER_JSON_RUNTIME", false),
|
||||
DirectWorkerBinaryRender: getBool("DATA_PLANE_DIRECT_WORKER_BINARY_RENDER", false),
|
||||
DirectWorkerTLSTrustMode: getEnv("DATA_PLANE_DIRECT_WORKER_TLS_TRUST_MODE", "smoke_insecure"),
|
||||
DirectWorkerTLSCARef: getEnv("DATA_PLANE_DIRECT_WORKER_TLS_CA_REF", ""),
|
||||
},
|
||||
Secret: SecretConfig{
|
||||
EncryptionKeyBase64: getEnv("SECRET_ENCRYPTION_KEY_B64", ""),
|
||||
EncryptionKeyFile: getEnv("SECRET_ENCRYPTION_KEY_FILE", ""),
|
||||
EncryptionKeyID: getEnv("SECRET_ENCRYPTION_KEY_ID", "local-v1"),
|
||||
},
|
||||
Session: SessionConfig{
|
||||
HeartbeatTTL: getDuration("SESSION_HEARTBEAT_TTL", 90*time.Second),
|
||||
DetachGracePeriod: getDuration("SESSION_DETACH_GRACE_PERIOD", 30*time.Minute),
|
||||
AttachTokenTTL: getDuration("SESSION_ATTACH_TOKEN_TTL", 2*time.Minute),
|
||||
LiveStateTTL: getDuration("SESSION_LIVE_STATE_TTL", 2*time.Minute),
|
||||
RecoveryBatchSize: getInt("SESSION_RECOVERY_BATCH_SIZE", 100),
|
||||
},
|
||||
Worker: WorkerConfig{
|
||||
LeaseTTL: getDuration("WORKER_LEASE_TTL", 45*time.Second),
|
||||
HeartbeatTTL: getDuration("WORKER_HEARTBEAT_TTL", 15*time.Second),
|
||||
StaleLeaseGracePeriod: getDuration("WORKER_STALE_LEASE_GRACE_PERIOD", 30*time.Second),
|
||||
},
|
||||
WebSocket: WebSocketConfig{
|
||||
WriteTimeout: getDuration("WEBSOCKET_WRITE_TIMEOUT", 10*time.Second),
|
||||
PingInterval: getDuration("WEBSOCKET_PING_INTERVAL", 20*time.Second),
|
||||
PongWait: getDuration("WEBSOCKET_PONG_WAIT", 40*time.Second),
|
||||
},
|
||||
}
|
||||
|
||||
if cfg.Postgres.DSN == "" {
|
||||
return Config{}, fmt.Errorf("POSTGRES_DSN is required")
|
||||
}
|
||||
if cfg.Auth.AccessTokenSecret == "" {
|
||||
return Config{}, fmt.Errorf("AUTH_ACCESS_TOKEN_SECRET is required")
|
||||
}
|
||||
if cfg.Auth.RefreshHashSecret == "" {
|
||||
return Config{}, fmt.Errorf("AUTH_REFRESH_HASH_SECRET is required")
|
||||
}
|
||||
if cfg.Installation.ProductRootPublicKeyBase64 == "" && cfg.Installation.ProductRootPublicKeyFile != "" {
|
||||
publicKey, err := os.ReadFile(cfg.Installation.ProductRootPublicKeyFile)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("read INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_FILE: %w", err)
|
||||
}
|
||||
cfg.Installation.ProductRootPublicKeyBase64 = strings.TrimSpace(string(publicKey))
|
||||
}
|
||||
cfg.Installation.AuthorityMode = normalizeInstallationAuthorityMode(cfg.Installation.AuthorityMode, cfg.Installation.ProductRootPublicKeyBase64)
|
||||
if isProductionEnv(cfg.App.Env) && cfg.Installation.AuthorityMode != "strict" {
|
||||
return Config{}, fmt.Errorf("INSTALLATION_AUTHORITY_MODE=strict with INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_B64 or file is required in production")
|
||||
}
|
||||
if cfg.DataPlane.TokenPrivateKeyPEM == "" && cfg.DataPlane.TokenPrivateKeyFile != "" {
|
||||
privateKey, err := os.ReadFile(cfg.DataPlane.TokenPrivateKeyFile)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("read DATA_PLANE_TOKEN_PRIVATE_KEY_FILE: %w", err)
|
||||
}
|
||||
cfg.DataPlane.TokenPrivateKeyPEM = string(privateKey)
|
||||
}
|
||||
if cfg.Secret.EncryptionKeyBase64 == "" && cfg.Secret.EncryptionKeyFile != "" {
|
||||
secretKey, err := os.ReadFile(cfg.Secret.EncryptionKeyFile)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("read SECRET_ENCRYPTION_KEY_FILE: %w", err)
|
||||
}
|
||||
cfg.Secret.EncryptionKeyBase64 = strings.TrimSpace(string(secretKey))
|
||||
}
|
||||
if cfg.Secret.EncryptionKeyBase64 != "" {
|
||||
decoded, err := base64.StdEncoding.DecodeString(cfg.Secret.EncryptionKeyBase64)
|
||||
if err != nil {
|
||||
if decodedRaw, rawErr := base64.RawStdEncoding.DecodeString(cfg.Secret.EncryptionKeyBase64); rawErr == nil {
|
||||
decoded = decodedRaw
|
||||
} else {
|
||||
return Config{}, fmt.Errorf("SECRET_ENCRYPTION_KEY_B64 must be base64 encoded: %w", err)
|
||||
}
|
||||
}
|
||||
if len(decoded) != 32 {
|
||||
return Config{}, fmt.Errorf("SECRET_ENCRYPTION_KEY_B64 must decode to 32 bytes for AES-256-GCM")
|
||||
}
|
||||
}
|
||||
if isProductionEnv(cfg.App.Env) && cfg.Secret.EncryptionKeyBase64 == "" {
|
||||
return Config{}, fmt.Errorf("SECRET_ENCRYPTION_KEY_B64 or SECRET_ENCRYPTION_KEY_FILE is required in production")
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func normalizeInstallationAuthorityMode(mode string, rootPublicKey string) string {
|
||||
mode = strings.ToLower(strings.TrimSpace(mode))
|
||||
switch mode {
|
||||
case "strict", "legacy":
|
||||
return mode
|
||||
case "":
|
||||
if strings.TrimSpace(rootPublicKey) != "" {
|
||||
return "strict"
|
||||
}
|
||||
return "legacy"
|
||||
default:
|
||||
return mode
|
||||
}
|
||||
}
|
||||
|
||||
func isProductionEnv(appEnv string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(appEnv)) {
|
||||
case "production", "prod":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func getEnv(key, fallback string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func getInt(key string, fallback int) int {
|
||||
value := os.Getenv(key)
|
||||
if value == "" {
|
||||
return fallback
|
||||
}
|
||||
|
||||
parsed, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
|
||||
return parsed
|
||||
}
|
||||
|
||||
func getBool(key string, fallback bool) bool {
|
||||
value := os.Getenv(key)
|
||||
if value == "" {
|
||||
return fallback
|
||||
}
|
||||
switch value {
|
||||
case "1", "true", "TRUE", "yes", "on":
|
||||
return true
|
||||
case "0", "false", "FALSE", "no", "off":
|
||||
return false
|
||||
default:
|
||||
return fallback
|
||||
}
|
||||
}
|
||||
|
||||
func getDuration(key string, fallback time.Duration) time.Duration {
|
||||
value := os.Getenv(key)
|
||||
if value == "" {
|
||||
return fallback
|
||||
}
|
||||
|
||||
parsed, err := time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
|
||||
return parsed
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package httpserver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/config"
|
||||
)
|
||||
|
||||
func New(cfg config.HTTPConfig, handler http.Handler) *http.Server {
|
||||
return &http.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||
Handler: handler,
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
ReadTimeout: cfg.ReadTimeout,
|
||||
WriteTimeout: cfg.WriteTimeout,
|
||||
IdleTimeout: cfg.IdleTimeout,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func WriteJSON(w http.ResponseWriter, status int, payload any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(payload)
|
||||
}
|
||||
|
||||
func WriteError(w http.ResponseWriter, status int, message string) {
|
||||
traceID := ensureTraceID(w)
|
||||
WriteJSON(w, status, ErrorResponse{
|
||||
Error: NewErrorMessage(status, message, nil, traceID),
|
||||
})
|
||||
}
|
||||
|
||||
func WriteErrorMessage(w http.ResponseWriter, status int, message any) {
|
||||
traceID := ensureTraceID(w)
|
||||
switch payload := message.(type) {
|
||||
case string:
|
||||
WriteJSON(w, status, ErrorResponse{
|
||||
Error: NewErrorMessage(status, payload, nil, traceID),
|
||||
})
|
||||
case ErrorResponse:
|
||||
payload.Error.TraceID = traceID
|
||||
WriteJSON(w, status, payload)
|
||||
case *ErrorResponse:
|
||||
if payload == nil {
|
||||
WriteJSON(w, status, ErrorResponse{
|
||||
Error: NewErrorMessage(status, "", nil, traceID),
|
||||
})
|
||||
return
|
||||
}
|
||||
payload.Error.TraceID = traceID
|
||||
WriteJSON(w, status, payload)
|
||||
default:
|
||||
WriteJSON(w, status, ErrorResponse{
|
||||
Error: NewErrorMessage(status, "Request failed.", nil, traceID),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
messagecontracts "github.com/example/remote-access-platform/backend/pkg/contracts/message"
|
||||
)
|
||||
|
||||
type ErrorResponse struct {
|
||||
Error messagecontracts.Message `json:"error"`
|
||||
}
|
||||
|
||||
func NewMessage(code, messageKey, fallbackMessage string, details map[string]any, traceID string) messagecontracts.Message {
|
||||
if traceID == "" {
|
||||
traceID = uuid.NewString()
|
||||
}
|
||||
if details == nil {
|
||||
details = map[string]any{}
|
||||
}
|
||||
return messagecontracts.Message{
|
||||
Code: code,
|
||||
MessageKey: messageKey,
|
||||
FallbackMessage: fallbackMessage,
|
||||
Details: details,
|
||||
TraceID: traceID,
|
||||
}
|
||||
}
|
||||
|
||||
func NewErrorMessage(status int, fallbackMessage string, details map[string]any, traceID string) messagecontracts.Message {
|
||||
normalizedFallback, normalizedDetails := normalizeErrorFallback(status, fallbackMessage, details)
|
||||
code := deriveErrorCode(status, normalizedFallback)
|
||||
return NewMessage(code, "errors."+code, normalizedFallback, normalizedDetails, traceID)
|
||||
}
|
||||
|
||||
func ensureTraceID(w http.ResponseWriter) string {
|
||||
traceID := w.Header().Get("X-Trace-Id")
|
||||
if traceID == "" {
|
||||
traceID = uuid.NewString()
|
||||
w.Header().Set("X-Trace-Id", traceID)
|
||||
}
|
||||
return traceID
|
||||
}
|
||||
|
||||
func normalizeErrorFallback(status int, fallbackMessage string, details map[string]any) (string, map[string]any) {
|
||||
if details == nil {
|
||||
details = map[string]any{}
|
||||
}
|
||||
details["http_status"] = status
|
||||
|
||||
if status >= http.StatusInternalServerError {
|
||||
return "An internal server error occurred.", details
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(fallbackMessage)
|
||||
switch strings.ToLower(trimmed) {
|
||||
case "forbidden", "access denied":
|
||||
return "Access denied.", details
|
||||
}
|
||||
|
||||
if field, ok := extractRequiredField(trimmed); ok {
|
||||
details["field"] = field
|
||||
}
|
||||
|
||||
return trimmed, details
|
||||
}
|
||||
|
||||
func deriveErrorCode(status int, fallbackMessage string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(fallbackMessage)) {
|
||||
case "invalid credentials":
|
||||
return "auth.invalid_credentials"
|
||||
case "session expired. please sign in again.":
|
||||
return "auth.session_expired"
|
||||
case "access denied.":
|
||||
return "common.access_denied"
|
||||
}
|
||||
|
||||
statusPrefix := map[int]string{
|
||||
http.StatusBadRequest: "bad_request",
|
||||
http.StatusUnauthorized: "unauthorized",
|
||||
http.StatusForbidden: "forbidden",
|
||||
http.StatusNotFound: "not_found",
|
||||
http.StatusConflict: "conflict",
|
||||
http.StatusUnprocessableEntity: "unprocessable_entity",
|
||||
http.StatusInternalServerError: "internal_server_error",
|
||||
}[status]
|
||||
if statusPrefix == "" {
|
||||
statusPrefix = "http_" + strings.ReplaceAll(http.StatusText(status), " ", "_")
|
||||
statusPrefix = strings.ToLower(statusPrefix)
|
||||
}
|
||||
|
||||
slug := slugifyMessage(fallbackMessage)
|
||||
if slug == "" {
|
||||
slug = "message"
|
||||
}
|
||||
if status >= http.StatusInternalServerError {
|
||||
return "common." + statusPrefix
|
||||
}
|
||||
return statusPrefix + "." + slug
|
||||
}
|
||||
|
||||
func slugifyMessage(input string) string {
|
||||
var builder strings.Builder
|
||||
lastUnderscore := false
|
||||
for _, r := range strings.ToLower(strings.TrimSpace(input)) {
|
||||
if unicode.IsLetter(r) || unicode.IsDigit(r) {
|
||||
builder.WriteRune(r)
|
||||
lastUnderscore = false
|
||||
continue
|
||||
}
|
||||
if !lastUnderscore {
|
||||
builder.WriteRune('_')
|
||||
lastUnderscore = true
|
||||
}
|
||||
}
|
||||
return strings.Trim(builder.String(), "_")
|
||||
}
|
||||
|
||||
func extractRequiredField(message string) (string, bool) {
|
||||
const suffix = " is required"
|
||||
if !strings.HasSuffix(strings.ToLower(message), suffix) {
|
||||
return "", false
|
||||
}
|
||||
field := strings.TrimSpace(message[:len(message)-len(suffix)])
|
||||
field = strings.ReplaceAll(field, " ", "_")
|
||||
field = strings.ToLower(field)
|
||||
return field, field != ""
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
)
|
||||
|
||||
func New(env string) *slog.Logger {
|
||||
level := slog.LevelInfo
|
||||
if env == "development" {
|
||||
level = slog.LevelDebug
|
||||
}
|
||||
|
||||
return slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
}))
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package module
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/config"
|
||||
)
|
||||
|
||||
type Dependencies struct {
|
||||
Config Config
|
||||
Infra Infra
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
App config.AppConfig
|
||||
Auth config.AuthConfig
|
||||
Installation config.InstallationConfig
|
||||
DataPlane config.DataPlaneConfig
|
||||
Secret config.SecretConfig
|
||||
Session config.SessionConfig
|
||||
Worker config.WorkerConfig
|
||||
WebSocket config.WebSocketConfig
|
||||
}
|
||||
|
||||
type Infra struct {
|
||||
Logger *slog.Logger
|
||||
DB *pgxpool.Pool
|
||||
Redis *redis.Client
|
||||
}
|
||||
|
||||
type Module interface {
|
||||
Name() string
|
||||
RegisterRoutes(router chi.Router)
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/config"
|
||||
)
|
||||
|
||||
func Open(ctx context.Context, cfg config.PostgresConfig) (*pgxpool.Pool, error) {
|
||||
poolConfig, err := pgxpool.ParseConfig(cfg.DSN)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse postgres dsn: %w", err)
|
||||
}
|
||||
|
||||
poolConfig.MaxConns = cfg.MaxConns
|
||||
poolConfig.MinConns = cfg.MinConns
|
||||
poolConfig.ConnConfig.ConnectTimeout = cfg.ConnectTimeout
|
||||
|
||||
pool, err := pgxpool.NewWithConfig(ctx, poolConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create postgres pool: %w", err)
|
||||
}
|
||||
|
||||
if err := pool.Ping(ctx); err != nil {
|
||||
pool.Close()
|
||||
return nil, fmt.Errorf("ping postgres: %w", err)
|
||||
}
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type DBTX interface {
|
||||
Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
|
||||
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
|
||||
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
|
||||
}
|
||||
|
||||
func WithTransaction(ctx context.Context, pool *pgxpool.Pool, fn func(tx pgx.Tx) error) error {
|
||||
tx, err := pool.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
_ = tx.Rollback(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return fmt.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
goredis "github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/config"
|
||||
)
|
||||
|
||||
func Open(ctx context.Context, cfg config.RedisConfig) (*goredis.Client, error) {
|
||||
client := goredis.NewClient(&goredis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
DialTimeout: cfg.DialTimeout,
|
||||
})
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, fmt.Errorf("ping redis: %w", err)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
@@ -0,0 +1,220 @@
|
||||
package runtime
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
chimiddleware "github.com/go-chi/chi/v5/middleware"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/auth"
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/cluster"
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/identitysource"
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/node"
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/nodeagent"
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/organization"
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/resource"
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/sessionbroker"
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/sessiongateway"
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/worker"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/authority"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/config"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/httpserver"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/logging"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/module"
|
||||
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
|
||||
redisplatform "github.com/example/remote-access-platform/backend/internal/platform/redis"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
|
||||
)
|
||||
|
||||
type App struct {
|
||||
cfg config.Config
|
||||
logger *slog.Logger
|
||||
httpServer *http.Server
|
||||
workers []backgroundRunner
|
||||
db closeFunc
|
||||
redis closeFunc
|
||||
}
|
||||
|
||||
type closeFunc func() error
|
||||
type backgroundRunner func(context.Context) error
|
||||
|
||||
func NewApp(ctx context.Context) (*App, error) {
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
logger := logging.New(cfg.App.Env)
|
||||
|
||||
db, err := postgresplatform.Open(ctx, cfg.Postgres)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
redisClient, err := redisplatform.Open(ctx, cfg.Redis)
|
||||
if err != nil {
|
||||
db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authorityVerifier, err := authority.NewVerifier(cfg.Installation)
|
||||
if err != nil {
|
||||
redisClient.Close()
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("create installation authority verifier: %w", err)
|
||||
}
|
||||
|
||||
deps := module.Dependencies{
|
||||
Config: module.Config{
|
||||
App: cfg.App,
|
||||
Auth: cfg.Auth,
|
||||
Installation: cfg.Installation,
|
||||
DataPlane: cfg.DataPlane,
|
||||
Secret: cfg.Secret,
|
||||
Session: cfg.Session,
|
||||
Worker: cfg.Worker,
|
||||
WebSocket: cfg.WebSocket,
|
||||
},
|
||||
Infra: module.Infra{
|
||||
Logger: logger,
|
||||
DB: db,
|
||||
Redis: redisClient,
|
||||
},
|
||||
}
|
||||
|
||||
workerStore := worker.NewRedisStore(redisClient)
|
||||
workerService := worker.NewService(deps, workerStore)
|
||||
authStore := auth.NewPostgresStore(db)
|
||||
authTx := auth.NewPostgresTransactor(db)
|
||||
authService := auth.NewService(deps, authStore, authTx, authorityVerifier)
|
||||
var resourceSecretStore *secrets.ResourceSecretStore
|
||||
if cfg.Secret.EncryptionKeyBase64 != "" {
|
||||
secretEncryptor, err := secrets.NewEncryptor(cfg.Secret.EncryptionKeyBase64, cfg.Secret.EncryptionKeyID)
|
||||
if err != nil {
|
||||
redisClient.Close()
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("create resource secret encryptor: %w", err)
|
||||
}
|
||||
resourceSecretStore = secrets.NewResourceSecretStore(db, secretEncryptor)
|
||||
}
|
||||
|
||||
brokerStore := sessionbroker.NewPostgresStore(db, authorityVerifier)
|
||||
brokerTx := sessionbroker.NewPostgresTransactor(db, authorityVerifier)
|
||||
liveStateStore := sessionbroker.NewRedisLiveStateStore(redisClient)
|
||||
brokerService := sessionbroker.NewService(deps, brokerStore, brokerTx, liveStateStore, workerService, resourceSecretStore)
|
||||
workerEvents := worker.NewEventProcessor(redisClient, brokerService)
|
||||
leaseMonitor := worker.NewLeaseMonitor(workerService, brokerService, cfg.Worker.StaleLeaseGracePeriod)
|
||||
|
||||
brokerModule := sessionbroker.NewModule(brokerService)
|
||||
authModule := auth.NewModule(deps, authService)
|
||||
clusterModule := cluster.NewModule(deps, authorityVerifier)
|
||||
organizationModule := organization.NewModule(deps)
|
||||
identitySourceModule := identitysource.NewModule(deps)
|
||||
resourceModule := resource.NewModule(deps, resourceSecretStore)
|
||||
nodeModule := node.NewModule(deps)
|
||||
nodeAgentModule := nodeagent.NewModule(deps)
|
||||
sessionGatewayModule := sessiongateway.NewModule(deps, brokerModule.Service(), workerService)
|
||||
|
||||
router := buildRouter(
|
||||
logger,
|
||||
authModule,
|
||||
clusterModule,
|
||||
organizationModule,
|
||||
identitySourceModule,
|
||||
resourceModule,
|
||||
brokerModule,
|
||||
nodeModule,
|
||||
nodeAgentModule,
|
||||
sessionGatewayModule,
|
||||
)
|
||||
|
||||
return &App{
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
httpServer: httpserver.New(cfg.HTTP, router),
|
||||
workers: []backgroundRunner{workerEvents.Run, leaseMonitor.Run},
|
||||
db: func() error {
|
||||
db.Close()
|
||||
return nil
|
||||
},
|
||||
redis: redisClient.Close,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *App) Run(ctx context.Context) error {
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
a.logger.Info("http server starting", "addr", a.httpServer.Addr, "service", a.cfg.App.Name)
|
||||
if err := a.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
errCh <- nil
|
||||
}()
|
||||
|
||||
for _, runner := range a.workers {
|
||||
runner := runner
|
||||
go func() {
|
||||
if err := runner(ctx); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
a.logger.Info("shutdown signal received")
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), a.cfg.HTTP.ShutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := a.httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
return fmt.Errorf("shutdown http server: %w", err)
|
||||
}
|
||||
|
||||
if err := a.redis(); err != nil {
|
||||
return fmt.Errorf("close redis: %w", err)
|
||||
}
|
||||
|
||||
if err := a.db(); err != nil {
|
||||
return fmt.Errorf("close postgres: %w", err)
|
||||
}
|
||||
|
||||
a.logger.Info("app stopped", "at", time.Now().UTC())
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildRouter(logger *slog.Logger, modules ...module.Module) http.Handler {
|
||||
router := chi.NewRouter()
|
||||
router.Use(chimiddleware.RequestID)
|
||||
router.Use(chimiddleware.RealIP)
|
||||
router.Use(chimiddleware.Recoverer)
|
||||
router.Use(chimiddleware.Timeout(60 * time.Second))
|
||||
router.Use(chimiddleware.Heartbeat("/healthz"))
|
||||
|
||||
router.Get("/readyz", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ready"))
|
||||
})
|
||||
|
||||
router.Route("/api/v1", func(r chi.Router) {
|
||||
for _, mod := range modules {
|
||||
logger.Info("register module routes", "module", mod.Name())
|
||||
mod.RegisterRoutes(r)
|
||||
}
|
||||
})
|
||||
|
||||
return router
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package secrets
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type AssignmentSecretMergeResult struct {
|
||||
Metadata map[string]any
|
||||
Keys []string
|
||||
}
|
||||
|
||||
func MergeResourceSecretIntoAssignmentMetadata(metadata map[string]any, payload json.RawMessage) (AssignmentSecretMergeResult, error) {
|
||||
if metadata == nil {
|
||||
metadata = map[string]any{}
|
||||
}
|
||||
var secretPayload map[string]any
|
||||
if err := json.Unmarshal(payload, &secretPayload); err != nil {
|
||||
return AssignmentSecretMergeResult{}, fmt.Errorf("decode resolved resource secret: %w", err)
|
||||
}
|
||||
resource, _ := metadata["resource"].(map[string]any)
|
||||
if resource == nil {
|
||||
resource = map[string]any{}
|
||||
metadata["resource"] = resource
|
||||
}
|
||||
resourceMetadata, _ := resource["metadata"].(map[string]any)
|
||||
if resourceMetadata == nil {
|
||||
resourceMetadata = map[string]any{}
|
||||
resource["metadata"] = resourceMetadata
|
||||
}
|
||||
keys := make([]string, 0, len(secretPayload))
|
||||
for key, value := range secretPayload {
|
||||
resourceMetadata[key] = value
|
||||
keys = append(keys, key)
|
||||
}
|
||||
return AssignmentSecretMergeResult{Metadata: metadata, Keys: keys}, nil
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
package secrets
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const AlgorithmAES256GCM = "AES-256-GCM"
|
||||
|
||||
var (
|
||||
ErrSecretEncryptionKeyMissing = errors.New("secret encryption key is not configured")
|
||||
ErrSecretPayloadInvalid = errors.New("secret payload must be a json object")
|
||||
)
|
||||
|
||||
type Encryptor struct {
|
||||
aead cipher.AEAD
|
||||
keyID string
|
||||
}
|
||||
|
||||
type EncryptedPayload struct {
|
||||
Algorithm string
|
||||
KeyID string
|
||||
Nonce []byte
|
||||
Ciphertext []byte
|
||||
PayloadSHA256 string
|
||||
}
|
||||
|
||||
func NewEncryptor(masterKeyBase64, keyID string) (*Encryptor, error) {
|
||||
masterKeyBase64 = strings.TrimSpace(masterKeyBase64)
|
||||
if masterKeyBase64 == "" {
|
||||
return nil, ErrSecretEncryptionKeyMissing
|
||||
}
|
||||
key, err := base64.StdEncoding.DecodeString(masterKeyBase64)
|
||||
if err != nil {
|
||||
if rawKey, rawErr := base64.RawStdEncoding.DecodeString(masterKeyBase64); rawErr == nil {
|
||||
key = rawKey
|
||||
} else {
|
||||
return nil, fmt.Errorf("decode secret encryption key: %w", err)
|
||||
}
|
||||
}
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("secret encryption key must decode to 32 bytes")
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create secret cipher: %w", err)
|
||||
}
|
||||
aead, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create secret gcm: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(keyID) == "" {
|
||||
keyID = "local-v1"
|
||||
}
|
||||
return &Encryptor{aead: aead, keyID: keyID}, nil
|
||||
}
|
||||
|
||||
func (e *Encryptor) KeyID() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return e.keyID
|
||||
}
|
||||
|
||||
func (e *Encryptor) Encrypt(plaintext, aad []byte) (EncryptedPayload, error) {
|
||||
if e == nil {
|
||||
return EncryptedPayload{}, ErrSecretEncryptionKeyMissing
|
||||
}
|
||||
nonce := make([]byte, e.aead.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return EncryptedPayload{}, fmt.Errorf("generate secret nonce: %w", err)
|
||||
}
|
||||
hash := sha256.Sum256(plaintext)
|
||||
return EncryptedPayload{
|
||||
Algorithm: AlgorithmAES256GCM,
|
||||
KeyID: e.keyID,
|
||||
Nonce: nonce,
|
||||
Ciphertext: e.aead.Seal(nil, nonce, plaintext, aad),
|
||||
PayloadSHA256: hex.EncodeToString(hash[:]),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *Encryptor) Decrypt(payload EncryptedPayload, aad []byte) ([]byte, error) {
|
||||
if e == nil {
|
||||
return nil, ErrSecretEncryptionKeyMissing
|
||||
}
|
||||
if payload.Algorithm != "" && payload.Algorithm != AlgorithmAES256GCM {
|
||||
return nil, fmt.Errorf("unsupported secret algorithm %q", payload.Algorithm)
|
||||
}
|
||||
plaintext, err := e.aead.Open(nil, payload.Nonce, payload.Ciphertext, aad)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt secret payload: %w", err)
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
func ResourceSecretAAD(organizationID, resourceID, secretRef, protocol string) []byte {
|
||||
return []byte(strings.Join([]string{
|
||||
"rap-resource-secret-v1",
|
||||
strings.TrimSpace(organizationID),
|
||||
strings.TrimSpace(resourceID),
|
||||
strings.TrimSpace(secretRef),
|
||||
strings.ToLower(strings.TrimSpace(protocol)),
|
||||
}, "|"))
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package secrets
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEncryptorRoundTrip(t *testing.T) {
|
||||
key := base64.StdEncoding.EncodeToString([]byte("0123456789abcdef0123456789abcdef"))
|
||||
encryptor, err := NewEncryptor(key, "test-key")
|
||||
if err != nil {
|
||||
t.Fatalf("NewEncryptor returned error: %v", err)
|
||||
}
|
||||
aad := ResourceSecretAAD("org-1", "resource-1", "rap-secret://test", "rdp")
|
||||
encrypted, err := encryptor.Encrypt([]byte(`{"username":"user","password":"secret"}`), aad)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt returned error: %v", err)
|
||||
}
|
||||
plaintext, err := encryptor.Decrypt(encrypted, aad)
|
||||
if err != nil {
|
||||
t.Fatalf("Decrypt returned error: %v", err)
|
||||
}
|
||||
if string(plaintext) != `{"username":"user","password":"secret"}` {
|
||||
t.Fatalf("unexpected plaintext: %s", plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptorRejectsWrongAAD(t *testing.T) {
|
||||
key := base64.StdEncoding.EncodeToString([]byte("0123456789abcdef0123456789abcdef"))
|
||||
encryptor, err := NewEncryptor(key, "test-key")
|
||||
if err != nil {
|
||||
t.Fatalf("NewEncryptor returned error: %v", err)
|
||||
}
|
||||
encrypted, err := encryptor.Encrypt([]byte(`{"password":"secret"}`), ResourceSecretAAD("org-1", "resource-1", "ref", "rdp"))
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt returned error: %v", err)
|
||||
}
|
||||
if _, err := encryptor.Decrypt(encrypted, ResourceSecretAAD("org-2", "resource-1", "ref", "rdp")); err == nil {
|
||||
t.Fatalf("expected decrypt with wrong aad to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeResourceSecretIntoAssignmentMetadata(t *testing.T) {
|
||||
metadata := map[string]any{
|
||||
"resource": map[string]any{
|
||||
"id": "resource-1",
|
||||
"metadata": map[string]any{
|
||||
"rdp_host": "host",
|
||||
},
|
||||
},
|
||||
}
|
||||
merged, err := MergeResourceSecretIntoAssignmentMetadata(metadata, json.RawMessage(`{"username":"user","password":"secret","domain":"corp"}`))
|
||||
if err != nil {
|
||||
t.Fatalf("MergeResourceSecretIntoAssignmentMetadata returned error: %v", err)
|
||||
}
|
||||
resource := merged.Metadata["resource"].(map[string]any)
|
||||
resourceMetadata := resource["metadata"].(map[string]any)
|
||||
if resourceMetadata["rdp_host"] != "host" {
|
||||
t.Fatalf("existing metadata was not preserved")
|
||||
}
|
||||
if resourceMetadata["username"] != "user" || resourceMetadata["password"] != "secret" || resourceMetadata["domain"] != "corp" {
|
||||
t.Fatalf("secret payload was not merged: %#v", resourceMetadata)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
package secrets
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPlaintextResourceCredentials = errors.New("plaintext resource credentials are not allowed in metadata in production")
|
||||
ErrMissingResourceSecretRef = errors.New("secret_ref is required for this resource protocol in production")
|
||||
)
|
||||
|
||||
var credentialKeyFragments = []string{
|
||||
"accesstoken",
|
||||
"clientsecret",
|
||||
"credential",
|
||||
"credentials",
|
||||
"domain",
|
||||
"password",
|
||||
"privatekey",
|
||||
"refreshtoken",
|
||||
"secret",
|
||||
"secrets",
|
||||
"token",
|
||||
"user",
|
||||
"username",
|
||||
}
|
||||
|
||||
var safeReferenceKeys = []string{
|
||||
"certificateverificationmode",
|
||||
"renderqualityprofile",
|
||||
"secretref",
|
||||
"secretreference",
|
||||
"vaultref",
|
||||
}
|
||||
|
||||
func ValidateResourceSecretReadiness(protocol string, secretRef *string, metadata json.RawMessage, appEnv string) error {
|
||||
if !IsProductionEnv(appEnv) {
|
||||
return nil
|
||||
}
|
||||
|
||||
paths, err := PlaintextCredentialMetadataPaths(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(paths) > 0 {
|
||||
return fmt.Errorf("%w: %s", ErrPlaintextResourceCredentials, strings.Join(paths, ", "))
|
||||
}
|
||||
if ResourceProtocolRequiresSecretRef(protocol) && (secretRef == nil || strings.TrimSpace(*secretRef) == "") {
|
||||
return ErrMissingResourceSecretRef
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func IsProductionEnv(appEnv string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(appEnv)) {
|
||||
case "prod", "production":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func ResourceProtocolRequiresSecretRef(protocol string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(protocol)) {
|
||||
case "rdp", "vnc", "ssh":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func PlaintextCredentialMetadataPaths(raw json.RawMessage) ([]string, error) {
|
||||
if len(raw) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var value any
|
||||
if err := json.Unmarshal(raw, &value); err != nil {
|
||||
return nil, errors.New("metadata must be valid json")
|
||||
}
|
||||
metadata, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("metadata must be a json object")
|
||||
}
|
||||
var paths []string
|
||||
collectCredentialPaths(metadata, "", &paths)
|
||||
sort.Strings(paths)
|
||||
return slices.Compact(paths), nil
|
||||
}
|
||||
|
||||
func collectCredentialPaths(value any, prefix string, paths *[]string) {
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
for key, child := range typed {
|
||||
path := key
|
||||
if prefix != "" {
|
||||
path = prefix + "." + key
|
||||
}
|
||||
if isCredentialMetadataKey(key) {
|
||||
*paths = append(*paths, path)
|
||||
}
|
||||
collectCredentialPaths(child, path, paths)
|
||||
}
|
||||
case []any:
|
||||
for index, child := range typed {
|
||||
collectCredentialPaths(child, fmt.Sprintf("%s[%d]", prefix, index), paths)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isCredentialMetadataKey(key string) bool {
|
||||
normalized := normalizeMetadataKey(key)
|
||||
if slices.Contains(safeReferenceKeys, normalized) {
|
||||
return false
|
||||
}
|
||||
for _, fragment := range credentialKeyFragments {
|
||||
if normalized == fragment || strings.HasSuffix(normalized, fragment) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizeMetadataKey(key string) string {
|
||||
key = strings.ToLower(strings.TrimSpace(key))
|
||||
replacer := strings.NewReplacer("_", "", "-", "", " ", "", ".", "")
|
||||
return replacer.Replace(key)
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package secrets
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidateResourceSecretReadinessAllowsPlaintextInDevelopment(t *testing.T) {
|
||||
metadata := json.RawMessage(`{"username":"m","password":"secret"}`)
|
||||
if err := ValidateResourceSecretReadiness("rdp", nil, metadata, "development"); err != nil {
|
||||
t.Fatalf("development metadata should remain allowed for smoke/dev: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateResourceSecretReadinessRejectsPlaintextCredentialsInProduction(t *testing.T) {
|
||||
metadata := json.RawMessage(`{"rdp_host":"host","credentials":{"username":"m","password":"secret"}}`)
|
||||
err := ValidateResourceSecretReadiness("rdp", stringPtr("vault://org/resource"), metadata, "production")
|
||||
if !errors.Is(err, ErrPlaintextResourceCredentials) {
|
||||
t.Fatalf("expected plaintext credential rejection, got %v", err)
|
||||
}
|
||||
|
||||
paths, err := PlaintextCredentialMetadataPaths(metadata)
|
||||
if err != nil {
|
||||
t.Fatalf("metadata paths: %v", err)
|
||||
}
|
||||
for _, expected := range []string{"credentials", "credentials.password", "credentials.username"} {
|
||||
if !slices.Contains(paths, expected) {
|
||||
t.Fatalf("expected sensitive path %q in %v", expected, paths)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateResourceSecretReadinessRequiresSecretRefForProductionRDP(t *testing.T) {
|
||||
metadata := json.RawMessage(`{"rdp_host":"host","rdp_port":3389}`)
|
||||
err := ValidateResourceSecretReadiness("rdp", nil, metadata, "production")
|
||||
if !errors.Is(err, ErrMissingResourceSecretRef) {
|
||||
t.Fatalf("expected missing secret_ref rejection, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateResourceSecretReadinessAllowsProductionSecretRef(t *testing.T) {
|
||||
metadata := json.RawMessage(`{"rdp_host":"host","rdp_port":3389,"secret_ref":"vault://org/resource"}`)
|
||||
if err := ValidateResourceSecretReadiness("rdp", stringPtr("vault://org/resource"), metadata, "production"); err != nil {
|
||||
t.Fatalf("production secret_ref metadata should be accepted: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func stringPtr(value string) *string {
|
||||
return &value
|
||||
}
|
||||
@@ -0,0 +1,259 @@
|
||||
package secrets
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrResourceSecretNotFound = errors.New("resource secret not found")
|
||||
ErrSecretAccessDenied = errors.New("resource secret access denied")
|
||||
ErrSecretLeaseRequired = errors.New("resource secret resolution requires lease proof")
|
||||
)
|
||||
|
||||
type ResourceSecretStore struct {
|
||||
db postgresplatform.DBTX
|
||||
encryptor *Encryptor
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
type ResourceSecretResolver interface {
|
||||
ResolveForSession(ctx context.Context, req ResolveResourceSecretRequest) (*ResolvedResourceSecret, error)
|
||||
}
|
||||
|
||||
type ResourceSecretDescriptor struct {
|
||||
ID string `json:"id"`
|
||||
OrganizationID string `json:"organization_id"`
|
||||
ResourceID string `json:"resource_id"`
|
||||
SecretRef string `json:"secret_ref"`
|
||||
Protocol string `json:"protocol"`
|
||||
Version int `json:"version"`
|
||||
KeyID string `json:"key_id"`
|
||||
Algorithm string `json:"algorithm"`
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
RotatedAt *time.Time `json:"rotated_at,omitempty"`
|
||||
}
|
||||
|
||||
type UpsertResourceSecretCommand struct {
|
||||
OrganizationID string
|
||||
ResourceID string
|
||||
Protocol string
|
||||
SecretRef string
|
||||
Payload json.RawMessage
|
||||
Metadata json.RawMessage
|
||||
ActorUserID string
|
||||
}
|
||||
|
||||
type ResolveResourceSecretRequest struct {
|
||||
SecretRef string
|
||||
OrganizationID string
|
||||
ResourceID string
|
||||
SessionID string
|
||||
WorkerID string
|
||||
LeaseID string
|
||||
}
|
||||
|
||||
type ResolvedResourceSecret struct {
|
||||
Descriptor ResourceSecretDescriptor
|
||||
Payload json.RawMessage
|
||||
}
|
||||
|
||||
func NewResourceSecretStore(db postgresplatform.DBTX, encryptor *Encryptor) *ResourceSecretStore {
|
||||
return &ResourceSecretStore{db: db, encryptor: encryptor, now: time.Now}
|
||||
}
|
||||
|
||||
func (s *ResourceSecretStore) WithDB(db postgresplatform.DBTX) *ResourceSecretStore {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return &ResourceSecretStore{db: db, encryptor: s.encryptor, now: s.now}
|
||||
}
|
||||
|
||||
func DefaultResourceSecretRef(organizationID, resourceID string) string {
|
||||
return "rap-secret://org/" + strings.TrimSpace(organizationID) + "/resources/" + strings.TrimSpace(resourceID) + "/primary"
|
||||
}
|
||||
|
||||
func (s *ResourceSecretStore) Upsert(ctx context.Context, cmd UpsertResourceSecretCommand) (*ResourceSecretDescriptor, error) {
|
||||
if s == nil || s.encryptor == nil {
|
||||
return nil, ErrSecretEncryptionKeyMissing
|
||||
}
|
||||
payload, err := normalizeJSONObject(cmd.Payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
metadata, err := normalizeJSONObjectAllowEmpty(cmd.Metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
secretRef := strings.TrimSpace(cmd.SecretRef)
|
||||
if secretRef == "" {
|
||||
secretRef = DefaultResourceSecretRef(cmd.OrganizationID, cmd.ResourceID)
|
||||
}
|
||||
protocol := strings.ToLower(strings.TrimSpace(cmd.Protocol))
|
||||
encrypted, err := s.encryptor.Encrypt(payload, ResourceSecretAAD(cmd.OrganizationID, cmd.ResourceID, secretRef, protocol))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
now := s.now().UTC()
|
||||
const query = `
|
||||
INSERT INTO resource_secrets (
|
||||
organization_id, resource_id, secret_ref, protocol, version, key_id,
|
||||
algorithm, nonce, ciphertext, payload_sha256, metadata, created_by_user_id,
|
||||
created_at, rotated_at
|
||||
) VALUES (
|
||||
$1::uuid, $2::uuid, $3, $4, 1, $5,
|
||||
$6, $7, $8, $9, $10::jsonb, NULLIF($11, '')::uuid,
|
||||
$12, NULL
|
||||
)
|
||||
ON CONFLICT (resource_id) DO UPDATE SET
|
||||
secret_ref = EXCLUDED.secret_ref,
|
||||
protocol = EXCLUDED.protocol,
|
||||
version = resource_secrets.version + 1,
|
||||
key_id = EXCLUDED.key_id,
|
||||
algorithm = EXCLUDED.algorithm,
|
||||
nonce = EXCLUDED.nonce,
|
||||
ciphertext = EXCLUDED.ciphertext,
|
||||
payload_sha256 = EXCLUDED.payload_sha256,
|
||||
metadata = EXCLUDED.metadata,
|
||||
created_by_user_id = EXCLUDED.created_by_user_id,
|
||||
rotated_at = EXCLUDED.created_at
|
||||
RETURNING id::text, organization_id::text, resource_id::text, secret_ref,
|
||||
protocol, version, key_id, algorithm, metadata, created_at, rotated_at
|
||||
`
|
||||
var descriptor ResourceSecretDescriptor
|
||||
if err := s.db.QueryRow(ctx, query,
|
||||
cmd.OrganizationID,
|
||||
cmd.ResourceID,
|
||||
secretRef,
|
||||
protocol,
|
||||
encrypted.KeyID,
|
||||
encrypted.Algorithm,
|
||||
encrypted.Nonce,
|
||||
encrypted.Ciphertext,
|
||||
encrypted.PayloadSHA256,
|
||||
metadata,
|
||||
cmd.ActorUserID,
|
||||
now,
|
||||
).Scan(
|
||||
&descriptor.ID,
|
||||
&descriptor.OrganizationID,
|
||||
&descriptor.ResourceID,
|
||||
&descriptor.SecretRef,
|
||||
&descriptor.Protocol,
|
||||
&descriptor.Version,
|
||||
&descriptor.KeyID,
|
||||
&descriptor.Algorithm,
|
||||
&descriptor.Metadata,
|
||||
&descriptor.CreatedAt,
|
||||
&descriptor.RotatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("upsert resource secret: %w", err)
|
||||
}
|
||||
return &descriptor, nil
|
||||
}
|
||||
|
||||
func (s *ResourceSecretStore) ResolveForSession(ctx context.Context, req ResolveResourceSecretRequest) (*ResolvedResourceSecret, error) {
|
||||
if s == nil || s.encryptor == nil {
|
||||
return nil, ErrSecretEncryptionKeyMissing
|
||||
}
|
||||
if strings.TrimSpace(req.LeaseID) == "" {
|
||||
return nil, ErrSecretLeaseRequired
|
||||
}
|
||||
const query = `
|
||||
SELECT sec.id::text, sec.organization_id::text, sec.resource_id::text, sec.secret_ref,
|
||||
sec.protocol, sec.version, sec.key_id, sec.algorithm, sec.metadata,
|
||||
sec.created_at, sec.rotated_at, sec.nonce, sec.ciphertext,
|
||||
rs.organization_id::text, rs.resource_id::text, COALESCE(rs.worker_id, ''), rs.state
|
||||
FROM resource_secrets sec
|
||||
JOIN remote_sessions rs ON rs.resource_id = sec.resource_id
|
||||
WHERE sec.secret_ref = $1 AND rs.id = $2::uuid
|
||||
`
|
||||
var descriptor ResourceSecretDescriptor
|
||||
var nonce, ciphertext []byte
|
||||
var sessionOrganizationID, sessionResourceID, sessionWorkerID, sessionState string
|
||||
if err := s.db.QueryRow(ctx, query, req.SecretRef, req.SessionID).Scan(
|
||||
&descriptor.ID,
|
||||
&descriptor.OrganizationID,
|
||||
&descriptor.ResourceID,
|
||||
&descriptor.SecretRef,
|
||||
&descriptor.Protocol,
|
||||
&descriptor.Version,
|
||||
&descriptor.KeyID,
|
||||
&descriptor.Algorithm,
|
||||
&descriptor.Metadata,
|
||||
&descriptor.CreatedAt,
|
||||
&descriptor.RotatedAt,
|
||||
&nonce,
|
||||
&ciphertext,
|
||||
&sessionOrganizationID,
|
||||
&sessionResourceID,
|
||||
&sessionWorkerID,
|
||||
&sessionState,
|
||||
); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, ErrResourceSecretNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("resolve resource secret: %w", err)
|
||||
}
|
||||
if descriptor.OrganizationID != req.OrganizationID ||
|
||||
descriptor.ResourceID != req.ResourceID ||
|
||||
sessionOrganizationID != req.OrganizationID ||
|
||||
sessionResourceID != req.ResourceID ||
|
||||
sessionWorkerID != req.WorkerID ||
|
||||
!secretResolvableSessionState(sessionState) {
|
||||
return nil, ErrSecretAccessDenied
|
||||
}
|
||||
plaintext, err := s.encryptor.Decrypt(EncryptedPayload{
|
||||
Algorithm: descriptor.Algorithm,
|
||||
KeyID: descriptor.KeyID,
|
||||
Nonce: nonce,
|
||||
Ciphertext: ciphertext,
|
||||
}, ResourceSecretAAD(descriptor.OrganizationID, descriptor.ResourceID, descriptor.SecretRef, descriptor.Protocol))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ResolvedResourceSecret{
|
||||
Descriptor: descriptor,
|
||||
Payload: json.RawMessage(plaintext),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeJSONObject(raw json.RawMessage) (json.RawMessage, error) {
|
||||
if len(raw) == 0 || !json.Valid(raw) {
|
||||
return nil, ErrSecretPayloadInvalid
|
||||
}
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(raw, &decoded); err != nil {
|
||||
return nil, ErrSecretPayloadInvalid
|
||||
}
|
||||
encoded, err := json.Marshal(decoded)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.RawMessage(encoded), nil
|
||||
}
|
||||
|
||||
func normalizeJSONObjectAllowEmpty(raw json.RawMessage) (json.RawMessage, error) {
|
||||
if len(raw) == 0 {
|
||||
return json.RawMessage(`{}`), nil
|
||||
}
|
||||
return normalizeJSONObject(raw)
|
||||
}
|
||||
|
||||
func secretResolvableSessionState(state string) bool {
|
||||
switch state {
|
||||
case "starting", "active", "reconnecting":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
DROP TABLE IF EXISTS audit_logs;
|
||||
DROP TABLE IF EXISTS secrets;
|
||||
DROP TABLE IF EXISTS sessions;
|
||||
DROP TABLE IF EXISTS resources;
|
||||
DROP TABLE IF EXISTS devices;
|
||||
DROP TABLE IF EXISTS users;
|
||||
@@ -0,0 +1,65 @@
|
||||
CREATE EXTENSION IF NOT EXISTS "pgcrypto";
|
||||
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
email TEXT NOT NULL UNIQUE,
|
||||
password_hash TEXT NOT NULL,
|
||||
mfa_enabled BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS devices (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
device_fingerprint TEXT NOT NULL,
|
||||
trusted_at TIMESTAMPTZ,
|
||||
last_seen_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS resources (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
name TEXT NOT NULL,
|
||||
address TEXT NOT NULL,
|
||||
protocol TEXT NOT NULL,
|
||||
secret_ref TEXT,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
resource_id UUID NOT NULL REFERENCES resources(id) ON DELETE RESTRICT,
|
||||
controller_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
worker_id TEXT,
|
||||
state TEXT NOT NULL,
|
||||
detached_until TIMESTAMPTZ,
|
||||
last_heartbeat_at TIMESTAMPTZ,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS secrets (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
scope TEXT NOT NULL,
|
||||
encrypted_payload BYTEA NOT NULL,
|
||||
key_version TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS audit_logs (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
actor_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
action TEXT NOT NULL,
|
||||
target_type TEXT NOT NULL,
|
||||
target_id TEXT NOT NULL,
|
||||
payload JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_devices_user_id ON devices(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_resource_state ON sessions(resource_id, state);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_logs_created_at ON audit_logs(created_at DESC);
|
||||
@@ -0,0 +1,9 @@
|
||||
DROP TABLE IF EXISTS auth_sessions;
|
||||
DROP INDEX IF EXISTS idx_devices_user_fingerprint;
|
||||
|
||||
ALTER TABLE devices
|
||||
DROP COLUMN IF EXISTS updated_at,
|
||||
DROP COLUMN IF EXISTS revoked_reason,
|
||||
DROP COLUMN IF EXISTS revoked_at,
|
||||
DROP COLUMN IF EXISTS trust_status,
|
||||
DROP COLUMN IF EXISTS device_label;
|
||||
@@ -0,0 +1,32 @@
|
||||
ALTER TABLE devices
|
||||
ADD COLUMN IF NOT EXISTS device_label TEXT,
|
||||
ADD COLUMN IF NOT EXISTS trust_status TEXT NOT NULL DEFAULT 'pending',
|
||||
ADD COLUMN IF NOT EXISTS revoked_at TIMESTAMPTZ,
|
||||
ADD COLUMN IF NOT EXISTS revoked_reason TEXT,
|
||||
ADD COLUMN IF NOT EXISTS updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW();
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_devices_user_fingerprint
|
||||
ON devices(user_id, device_fingerprint);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS auth_sessions (
|
||||
id UUID PRIMARY KEY,
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
device_id UUID NOT NULL REFERENCES devices(id) ON DELETE RESTRICT,
|
||||
refresh_token_hash TEXT NOT NULL,
|
||||
refresh_expires_at TIMESTAMPTZ NOT NULL,
|
||||
last_seen_at TIMESTAMPTZ,
|
||||
last_rotated_at TIMESTAMPTZ,
|
||||
revoked_at TIMESTAMPTZ,
|
||||
revoked_reason TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_auth_sessions_user_id
|
||||
ON auth_sessions(user_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_auth_sessions_device_id
|
||||
ON auth_sessions(device_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_auth_sessions_revoked_at
|
||||
ON auth_sessions(revoked_at);
|
||||
@@ -0,0 +1,4 @@
|
||||
DROP TABLE IF EXISTS audit_events;
|
||||
DROP TABLE IF EXISTS session_attachments;
|
||||
DROP TABLE IF EXISTS remote_sessions;
|
||||
DROP TABLE IF EXISTS resource_policies;
|
||||
@@ -0,0 +1,79 @@
|
||||
CREATE TABLE IF NOT EXISTS resource_policies (
|
||||
resource_id UUID PRIMARY KEY REFERENCES resources(id) ON DELETE CASCADE,
|
||||
max_concurrent_sessions INTEGER NOT NULL DEFAULT 1,
|
||||
takeover_policy TEXT NOT NULL DEFAULT 'trusted_device',
|
||||
require_trusted_device BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
detach_grace_period_seconds INTEGER NOT NULL DEFAULT 1800,
|
||||
clipboard_enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
file_transfer_enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS remote_sessions (
|
||||
id UUID PRIMARY KEY,
|
||||
resource_id UUID NOT NULL REFERENCES resources(id) ON DELETE RESTRICT,
|
||||
protocol TEXT NOT NULL,
|
||||
state TEXT NOT NULL,
|
||||
worker_id TEXT,
|
||||
controller_user_id UUID NOT NULL REFERENCES users(id) ON DELETE RESTRICT,
|
||||
detach_deadline_at TIMESTAMPTZ,
|
||||
last_heartbeat_at TIMESTAMPTZ,
|
||||
takeover_version INTEGER NOT NULL DEFAULT 1,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_remote_sessions_resource_id
|
||||
ON remote_sessions(resource_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_remote_sessions_controller_user_id
|
||||
ON remote_sessions(controller_user_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_remote_sessions_state
|
||||
ON remote_sessions(state);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS session_attachments (
|
||||
id UUID PRIMARY KEY,
|
||||
remote_session_id UUID NOT NULL REFERENCES remote_sessions(id) ON DELETE CASCADE,
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE RESTRICT,
|
||||
device_id UUID NOT NULL REFERENCES devices(id) ON DELETE RESTRICT,
|
||||
role TEXT NOT NULL,
|
||||
state TEXT NOT NULL,
|
||||
superseded_by UUID REFERENCES session_attachments(id) ON DELETE SET NULL,
|
||||
takeover_of UUID REFERENCES session_attachments(id) ON DELETE SET NULL,
|
||||
attached_at TIMESTAMPTZ,
|
||||
detached_at TIMESTAMPTZ,
|
||||
last_input_at TIMESTAMPTZ,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_session_attachments_remote_session_id
|
||||
ON session_attachments(remote_session_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_session_attachments_user_id
|
||||
ON session_attachments(user_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_session_attachments_state
|
||||
ON session_attachments(state);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS audit_events (
|
||||
id UUID PRIMARY KEY,
|
||||
actor_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
actor_device_id UUID REFERENCES devices(id) ON DELETE SET NULL,
|
||||
event_type TEXT NOT NULL,
|
||||
target_type TEXT NOT NULL,
|
||||
target_id TEXT NOT NULL,
|
||||
remote_session_id UUID REFERENCES remote_sessions(id) ON DELETE SET NULL,
|
||||
payload JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_events_created_at
|
||||
ON audit_events(created_at DESC);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_events_remote_session_id
|
||||
ON audit_events(remote_session_id);
|
||||
@@ -0,0 +1,5 @@
|
||||
ALTER TABLE resources
|
||||
DROP CONSTRAINT IF EXISTS resources_certificate_verification_mode_check;
|
||||
|
||||
ALTER TABLE resources
|
||||
DROP COLUMN IF EXISTS certificate_verification_mode;
|
||||
@@ -0,0 +1,6 @@
|
||||
ALTER TABLE resources
|
||||
ADD COLUMN certificate_verification_mode TEXT NOT NULL DEFAULT 'strict';
|
||||
|
||||
ALTER TABLE resources
|
||||
ADD CONSTRAINT resources_certificate_verification_mode_check
|
||||
CHECK (certificate_verification_mode IN ('strict', 'ignore'));
|
||||
@@ -0,0 +1,26 @@
|
||||
DROP TABLE IF EXISTS node_agent_update_runs;
|
||||
DROP TABLE IF EXISTS node_partition_states;
|
||||
DROP TABLE IF EXISTS node_update_policies;
|
||||
DROP TABLE IF EXISTS node_services;
|
||||
DROP TABLE IF EXISTS node_capabilities;
|
||||
DROP TABLE IF EXISTS nodes;
|
||||
DROP TABLE IF EXISTS identity_mappings;
|
||||
DROP TABLE IF EXISTS identity_sources;
|
||||
|
||||
DROP INDEX IF EXISTS idx_remote_sessions_organization_id;
|
||||
ALTER TABLE remote_sessions
|
||||
DROP COLUMN IF EXISTS organization_id;
|
||||
|
||||
DROP INDEX IF EXISTS idx_resources_organization_id;
|
||||
ALTER TABLE resources
|
||||
DROP COLUMN IF EXISTS organization_id;
|
||||
|
||||
DROP TABLE IF EXISTS organization_memberships;
|
||||
DROP TABLE IF EXISTS organization_roles;
|
||||
DROP TABLE IF EXISTS organizations;
|
||||
|
||||
ALTER TABLE users
|
||||
DROP CONSTRAINT IF EXISTS users_platform_role_check;
|
||||
|
||||
ALTER TABLE users
|
||||
DROP COLUMN IF EXISTS platform_role;
|
||||
@@ -0,0 +1,219 @@
|
||||
ALTER TABLE users
|
||||
ADD COLUMN IF NOT EXISTS platform_role TEXT NOT NULL DEFAULT 'user';
|
||||
|
||||
ALTER TABLE users
|
||||
ADD CONSTRAINT users_platform_role_check
|
||||
CHECK (platform_role IN ('user', 'platform_admin', 'platform_recovery_admin'));
|
||||
|
||||
CREATE TABLE IF NOT EXISTS organizations (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
slug TEXT NOT NULL UNIQUE,
|
||||
name TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'active',
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT organizations_status_check
|
||||
CHECK (status IN ('active', 'suspended', 'archived'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS organization_roles (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
scope TEXT NOT NULL DEFAULT 'organization',
|
||||
description TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT organization_roles_scope_check
|
||||
CHECK (scope IN ('organization'))
|
||||
);
|
||||
|
||||
INSERT INTO organization_roles (id, name, scope, description)
|
||||
VALUES
|
||||
('org_owner', 'Organization Owner', 'organization', 'Full organization control including membership administration.'),
|
||||
('org_admin', 'Organization Admin', 'organization', 'Administrative access within one organization.'),
|
||||
('org_operator', 'Organization Operator', 'organization', 'Operational access within one organization.'),
|
||||
('org_member', 'Organization Member', 'organization', 'Standard organization member access.'),
|
||||
('org_viewer', 'Organization Viewer', 'organization', 'Read-only organization visibility.')
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
name = EXCLUDED.name,
|
||||
description = EXCLUDED.description;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS organization_memberships (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
organization_id UUID NOT NULL REFERENCES organizations(id) ON DELETE CASCADE,
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
role_id TEXT NOT NULL REFERENCES organization_roles(id) ON DELETE RESTRICT,
|
||||
status TEXT NOT NULL DEFAULT 'active',
|
||||
invited_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE (organization_id, user_id),
|
||||
CONSTRAINT organization_memberships_status_check
|
||||
CHECK (status IN ('active', 'suspended', 'revoked'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_organization_memberships_user_id
|
||||
ON organization_memberships(user_id);
|
||||
|
||||
INSERT INTO organizations (slug, name, status, metadata)
|
||||
VALUES ('default', 'Default Organization', 'active', '{"bootstrap":true}'::jsonb)
|
||||
ON CONFLICT (slug) DO NOTHING;
|
||||
|
||||
ALTER TABLE resources
|
||||
ADD COLUMN IF NOT EXISTS organization_id UUID REFERENCES organizations(id) ON DELETE RESTRICT;
|
||||
|
||||
UPDATE resources
|
||||
SET organization_id = organizations.id
|
||||
FROM organizations
|
||||
WHERE organizations.slug = 'default'
|
||||
AND resources.organization_id IS NULL;
|
||||
|
||||
ALTER TABLE resources
|
||||
ALTER COLUMN organization_id SET NOT NULL;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_resources_organization_id
|
||||
ON resources(organization_id);
|
||||
|
||||
ALTER TABLE remote_sessions
|
||||
ADD COLUMN IF NOT EXISTS organization_id UUID REFERENCES organizations(id) ON DELETE RESTRICT;
|
||||
|
||||
UPDATE remote_sessions
|
||||
SET organization_id = resources.organization_id
|
||||
FROM resources
|
||||
WHERE resources.id = remote_sessions.resource_id
|
||||
AND remote_sessions.organization_id IS NULL;
|
||||
|
||||
ALTER TABLE remote_sessions
|
||||
ALTER COLUMN organization_id SET NOT NULL;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_remote_sessions_organization_id
|
||||
ON remote_sessions(organization_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS identity_sources (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
organization_id UUID NOT NULL REFERENCES organizations(id) ON DELETE CASCADE,
|
||||
kind TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'active',
|
||||
config JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT identity_sources_kind_check
|
||||
CHECK (kind IN ('local', 'ldap_ad', 'oidc')),
|
||||
CONSTRAINT identity_sources_status_check
|
||||
CHECK (status IN ('active', 'disabled'))
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_identity_sources_org_name
|
||||
ON identity_sources(organization_id, name);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS identity_mappings (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
identity_source_id UUID NOT NULL REFERENCES identity_sources(id) ON DELETE CASCADE,
|
||||
mapping_type TEXT NOT NULL,
|
||||
external_selector JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
internal_target JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT identity_mappings_type_check
|
||||
CHECK (mapping_type IN ('group_binding', 'claim_binding', 'user_binding'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_identity_mappings_source_id
|
||||
ON identity_mappings(identity_source_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS nodes (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
owner_organization_id UUID REFERENCES organizations(id) ON DELETE SET NULL,
|
||||
node_key TEXT NOT NULL UNIQUE,
|
||||
name TEXT NOT NULL,
|
||||
ownership_type TEXT NOT NULL,
|
||||
registration_status TEXT NOT NULL DEFAULT 'pending',
|
||||
health_status TEXT NOT NULL DEFAULT 'unknown',
|
||||
version_state TEXT NOT NULL DEFAULT 'unknown',
|
||||
partition_state TEXT NOT NULL DEFAULT 'healthy',
|
||||
desired_version TEXT,
|
||||
reported_version TEXT,
|
||||
last_seen_at TIMESTAMPTZ,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT nodes_ownership_type_check
|
||||
CHECK (ownership_type IN ('platform_managed', 'customer_managed')),
|
||||
CONSTRAINT nodes_registration_status_check
|
||||
CHECK (registration_status IN ('pending', 'active', 'disabled', 'revoked')),
|
||||
CONSTRAINT nodes_health_status_check
|
||||
CHECK (health_status IN ('unknown', 'healthy', 'warning', 'critical')),
|
||||
CONSTRAINT nodes_version_state_check
|
||||
CHECK (version_state IN ('unknown', 'current', 'outdated', 'updating', 'rollback', 'failed')),
|
||||
CONSTRAINT nodes_partition_state_check
|
||||
CHECK (partition_state IN ('healthy', 'degraded', 'recovery', 'isolated'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS node_capabilities (
|
||||
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
|
||||
capability TEXT NOT NULL,
|
||||
value JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
PRIMARY KEY (node_id, capability)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS node_services (
|
||||
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
|
||||
service_type TEXT NOT NULL,
|
||||
enabled BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
desired_state TEXT NOT NULL DEFAULT 'disabled',
|
||||
reported_state TEXT NOT NULL DEFAULT 'unknown',
|
||||
last_reported_at TIMESTAMPTZ,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
PRIMARY KEY (node_id, service_type),
|
||||
CONSTRAINT node_services_desired_state_check
|
||||
CHECK (desired_state IN ('enabled', 'disabled', 'drain')),
|
||||
CONSTRAINT node_services_reported_state_check
|
||||
CHECK (reported_state IN ('unknown', 'starting', 'running', 'degraded', 'stopped', 'failed'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS node_update_policies (
|
||||
node_id UUID PRIMARY KEY REFERENCES nodes(id) ON DELETE CASCADE,
|
||||
mode TEXT NOT NULL DEFAULT 'manual',
|
||||
channel TEXT NOT NULL DEFAULT 'stable',
|
||||
maintenance_window JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
canary BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
automatic_rollout BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT node_update_policies_mode_check
|
||||
CHECK (mode IN ('manual', 'automatic', 'staged', 'canary'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS node_partition_states (
|
||||
node_id UUID PRIMARY KEY REFERENCES nodes(id) ON DELETE CASCADE,
|
||||
cluster_state TEXT NOT NULL DEFAULT 'healthy',
|
||||
recovery_mode TEXT NOT NULL DEFAULT 'normal',
|
||||
notes TEXT,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT node_partition_states_cluster_state_check
|
||||
CHECK (cluster_state IN ('healthy', 'degraded', 'recovery', 'isolated')),
|
||||
CONSTRAINT node_partition_states_recovery_mode_check
|
||||
CHECK (recovery_mode IN ('normal', 'manual_recovery', 'emergency'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS node_agent_update_runs (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
|
||||
action TEXT NOT NULL,
|
||||
target_version TEXT,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
requested_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
acknowledged_at TIMESTAMPTZ,
|
||||
completed_at TIMESTAMPTZ,
|
||||
payload JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
CONSTRAINT node_agent_update_runs_action_check
|
||||
CHECK (action IN ('update', 'rollback')),
|
||||
CONSTRAINT node_agent_update_runs_status_check
|
||||
CHECK (status IN ('pending', 'acknowledged', 'succeeded', 'failed'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_node_agent_update_runs_node_id
|
||||
ON node_agent_update_runs(node_id, requested_at DESC);
|
||||
@@ -0,0 +1,7 @@
|
||||
DELETE FROM organization_memberships
|
||||
WHERE organization_id IN (
|
||||
SELECT id
|
||||
FROM organizations
|
||||
WHERE slug = 'default'
|
||||
)
|
||||
AND invited_by_user_id IS NULL;
|
||||
@@ -0,0 +1,29 @@
|
||||
INSERT INTO organization_memberships (
|
||||
id,
|
||||
organization_id,
|
||||
user_id,
|
||||
role_id,
|
||||
status,
|
||||
invited_by_user_id,
|
||||
created_at,
|
||||
updated_at
|
||||
)
|
||||
SELECT
|
||||
gen_random_uuid(),
|
||||
org.id,
|
||||
u.id,
|
||||
CASE
|
||||
WHEN u.platform_role IN ('platform_admin', 'platform_recovery_admin') THEN 'org_owner'
|
||||
ELSE 'org_member'
|
||||
END,
|
||||
'active',
|
||||
NULL,
|
||||
NOW(),
|
||||
NOW()
|
||||
FROM users u
|
||||
CROSS JOIN organizations org
|
||||
LEFT JOIN organization_memberships om
|
||||
ON om.organization_id = org.id
|
||||
AND om.user_id = u.id
|
||||
WHERE org.slug = 'default'
|
||||
AND om.id IS NULL;
|
||||
@@ -0,0 +1,5 @@
|
||||
ALTER TABLE resource_policies
|
||||
DROP CONSTRAINT IF EXISTS resource_policies_clipboard_mode_check;
|
||||
|
||||
ALTER TABLE resource_policies
|
||||
DROP COLUMN IF EXISTS clipboard_mode;
|
||||
@@ -0,0 +1,16 @@
|
||||
ALTER TABLE resource_policies
|
||||
ADD COLUMN IF NOT EXISTS clipboard_mode TEXT NOT NULL DEFAULT 'disabled';
|
||||
|
||||
UPDATE resource_policies
|
||||
SET clipboard_mode = CASE
|
||||
WHEN clipboard_enabled THEN 'bidirectional'
|
||||
ELSE 'disabled'
|
||||
END
|
||||
WHERE clipboard_mode = 'disabled';
|
||||
|
||||
ALTER TABLE resource_policies
|
||||
DROP CONSTRAINT IF EXISTS resource_policies_clipboard_mode_check;
|
||||
|
||||
ALTER TABLE resource_policies
|
||||
ADD CONSTRAINT resource_policies_clipboard_mode_check
|
||||
CHECK (clipboard_mode IN ('disabled', 'client_to_server', 'server_to_client', 'bidirectional'));
|
||||
@@ -0,0 +1,5 @@
|
||||
ALTER TABLE resource_policies
|
||||
DROP CONSTRAINT IF EXISTS resource_policies_file_transfer_mode_check;
|
||||
|
||||
ALTER TABLE resource_policies
|
||||
DROP COLUMN IF EXISTS file_transfer_mode;
|
||||
@@ -0,0 +1,16 @@
|
||||
ALTER TABLE resource_policies
|
||||
ADD COLUMN IF NOT EXISTS file_transfer_mode TEXT NOT NULL DEFAULT 'disabled';
|
||||
|
||||
UPDATE resource_policies
|
||||
SET file_transfer_mode = 'disabled',
|
||||
file_transfer_enabled = FALSE
|
||||
WHERE file_transfer_mode IS NULL
|
||||
OR file_transfer_mode = ''
|
||||
OR file_transfer_mode = 'disabled';
|
||||
|
||||
ALTER TABLE resource_policies
|
||||
DROP CONSTRAINT IF EXISTS resource_policies_file_transfer_mode_check;
|
||||
|
||||
ALTER TABLE resource_policies
|
||||
ADD CONSTRAINT resource_policies_file_transfer_mode_check
|
||||
CHECK (file_transfer_mode IN ('disabled', 'client_to_server', 'server_to_client', 'bidirectional'));
|
||||
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS resource_secrets;
|
||||
@@ -0,0 +1,27 @@
|
||||
CREATE TABLE IF NOT EXISTS resource_secrets (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
organization_id UUID NOT NULL REFERENCES organizations(id) ON DELETE CASCADE,
|
||||
resource_id UUID NOT NULL REFERENCES resources(id) ON DELETE CASCADE,
|
||||
secret_ref TEXT NOT NULL UNIQUE,
|
||||
protocol TEXT NOT NULL,
|
||||
version INTEGER NOT NULL DEFAULT 1,
|
||||
key_id TEXT NOT NULL,
|
||||
algorithm TEXT NOT NULL DEFAULT 'AES-256-GCM',
|
||||
nonce BYTEA NOT NULL,
|
||||
ciphertext BYTEA NOT NULL,
|
||||
payload_sha256 TEXT NOT NULL,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
created_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
rotated_at TIMESTAMPTZ,
|
||||
UNIQUE (resource_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_resource_secrets_organization_id
|
||||
ON resource_secrets(organization_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_resource_secrets_resource_id
|
||||
ON resource_secrets(resource_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_resource_secrets_secret_ref
|
||||
ON resource_secrets(secret_ref);
|
||||
@@ -0,0 +1,9 @@
|
||||
DROP TABLE IF EXISTS cluster_audit_events;
|
||||
DROP TABLE IF EXISTS node_latest_heartbeats;
|
||||
DROP TABLE IF EXISTS node_heartbeats;
|
||||
DROP TABLE IF EXISTS node_role_assignments;
|
||||
DROP TABLE IF EXISTS node_join_requests;
|
||||
DROP TABLE IF EXISTS node_join_tokens;
|
||||
DROP TABLE IF EXISTS node_identities;
|
||||
DROP TABLE IF EXISTS cluster_memberships;
|
||||
DROP TABLE IF EXISTS clusters;
|
||||
@@ -0,0 +1,186 @@
|
||||
CREATE TABLE IF NOT EXISTS clusters (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
slug TEXT NOT NULL UNIQUE,
|
||||
name TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'active',
|
||||
region TEXT,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT clusters_status_check
|
||||
CHECK (status IN ('active', 'disabled', 'archived', 'degraded'))
|
||||
);
|
||||
|
||||
INSERT INTO clusters (slug, name, status, region, metadata)
|
||||
VALUES ('default', 'Default Cluster', 'active', NULL, '{"bootstrap":true}'::jsonb)
|
||||
ON CONFLICT (slug) DO NOTHING;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS cluster_memberships (
|
||||
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
|
||||
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
|
||||
membership_status TEXT NOT NULL DEFAULT 'active',
|
||||
joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
last_seen_at TIMESTAMPTZ,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
PRIMARY KEY (cluster_id, node_id),
|
||||
CONSTRAINT cluster_memberships_status_check
|
||||
CHECK (membership_status IN ('active', 'draining', 'disabled', 'revoked'))
|
||||
);
|
||||
|
||||
INSERT INTO cluster_memberships (cluster_id, node_id, membership_status, joined_at, last_seen_at, metadata)
|
||||
SELECT c.id, n.id, 'active', COALESCE(n.created_at, NOW()), n.last_seen_at, '{"backfilled":true}'::jsonb
|
||||
FROM clusters c
|
||||
CROSS JOIN nodes n
|
||||
WHERE c.slug = 'default'
|
||||
ON CONFLICT (cluster_id, node_id) DO NOTHING;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS node_identities (
|
||||
node_id UUID PRIMARY KEY REFERENCES nodes(id) ON DELETE CASCADE,
|
||||
public_key TEXT NOT NULL,
|
||||
certificate_serial TEXT,
|
||||
certificate_not_before TIMESTAMPTZ,
|
||||
certificate_not_after TIMESTAMPTZ,
|
||||
identity_status TEXT NOT NULL DEFAULT 'pending',
|
||||
rotated_at TIMESTAMPTZ,
|
||||
revoked_at TIMESTAMPTZ,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT node_identities_status_check
|
||||
CHECK (identity_status IN ('pending', 'active', 'rotating', 'revoked'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS node_join_tokens (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
|
||||
token_hash TEXT NOT NULL UNIQUE,
|
||||
scope JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
max_uses INTEGER NOT NULL DEFAULT 1,
|
||||
used_count INTEGER NOT NULL DEFAULT 0,
|
||||
status TEXT NOT NULL DEFAULT 'active',
|
||||
created_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
revoked_at TIMESTAMPTZ,
|
||||
CONSTRAINT node_join_tokens_status_check
|
||||
CHECK (status IN ('active', 'revoked', 'expired')),
|
||||
CONSTRAINT node_join_tokens_max_uses_check
|
||||
CHECK (max_uses > 0),
|
||||
CONSTRAINT node_join_tokens_used_count_check
|
||||
CHECK (used_count >= 0)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_node_join_tokens_cluster_status
|
||||
ON node_join_tokens(cluster_id, status, expires_at);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS node_join_requests (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
|
||||
join_token_id UUID REFERENCES node_join_tokens(id) ON DELETE SET NULL,
|
||||
node_name TEXT NOT NULL,
|
||||
node_fingerprint TEXT NOT NULL,
|
||||
public_key TEXT NOT NULL,
|
||||
reported_capabilities JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
reported_facts JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
requested_roles JSONB NOT NULL DEFAULT '[]'::JSONB,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
reviewed_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
reviewed_at TIMESTAMPTZ,
|
||||
approved_node_id UUID REFERENCES nodes(id) ON DELETE SET NULL,
|
||||
rejection_reason TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT node_join_requests_status_check
|
||||
CHECK (status IN ('pending', 'approved', 'rejected', 'cancelled'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_node_join_requests_cluster_status
|
||||
ON node_join_requests(cluster_id, status, created_at DESC);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_node_join_requests_pending_fingerprint
|
||||
ON node_join_requests(cluster_id, node_fingerprint)
|
||||
WHERE status = 'pending';
|
||||
|
||||
CREATE TABLE IF NOT EXISTS node_role_assignments (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
|
||||
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
|
||||
organization_id UUID REFERENCES organizations(id) ON DELETE CASCADE,
|
||||
role TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'active',
|
||||
policy JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
assigned_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
assigned_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
revoked_at TIMESTAMPTZ,
|
||||
CONSTRAINT node_role_assignments_status_check
|
||||
CHECK (status IN ('active', 'disabled', 'revoked')),
|
||||
CONSTRAINT node_role_assignments_role_check
|
||||
CHECK (role IN (
|
||||
'entry-node',
|
||||
'relay-node',
|
||||
'core-mesh',
|
||||
'rdp-worker',
|
||||
'vnc-worker',
|
||||
'vpn-exit',
|
||||
'vpn-connector',
|
||||
'file-storage-cache',
|
||||
'update-cache',
|
||||
'video-relay'
|
||||
))
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_node_role_assignments_unique_active
|
||||
ON node_role_assignments(cluster_id, node_id, role, COALESCE(organization_id, '00000000-0000-0000-0000-000000000000'::uuid))
|
||||
WHERE status = 'active';
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_node_role_assignments_cluster
|
||||
ON node_role_assignments(cluster_id, role, status);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS node_heartbeats (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
|
||||
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
|
||||
health_status TEXT NOT NULL DEFAULT 'unknown',
|
||||
reported_version TEXT,
|
||||
capabilities JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
service_states JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
observed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT node_heartbeats_health_status_check
|
||||
CHECK (health_status IN ('unknown', 'healthy', 'warning', 'critical'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_node_heartbeats_cluster_node_observed
|
||||
ON node_heartbeats(cluster_id, node_id, observed_at DESC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS node_latest_heartbeats (
|
||||
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
|
||||
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
|
||||
heartbeat_id UUID REFERENCES node_heartbeats(id) ON DELETE SET NULL,
|
||||
health_status TEXT NOT NULL DEFAULT 'unknown',
|
||||
reported_version TEXT,
|
||||
capabilities JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
service_states JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
observed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
PRIMARY KEY (cluster_id, node_id),
|
||||
CONSTRAINT node_latest_heartbeats_health_status_check
|
||||
CHECK (health_status IN ('unknown', 'healthy', 'warning', 'critical'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS cluster_audit_events (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
cluster_id UUID REFERENCES clusters(id) ON DELETE SET NULL,
|
||||
actor_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
event_type TEXT NOT NULL,
|
||||
target_type TEXT NOT NULL,
|
||||
target_id TEXT,
|
||||
payload JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_cluster_audit_events_cluster_created
|
||||
ON cluster_audit_events(cluster_id, created_at DESC);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_cluster_audit_events_type_created
|
||||
ON cluster_audit_events(event_type, created_at DESC);
|
||||
@@ -0,0 +1,3 @@
|
||||
DROP TABLE IF EXISTS node_workload_latest_statuses;
|
||||
DROP TABLE IF EXISTS node_workload_status_reports;
|
||||
DROP TABLE IF EXISTS node_workload_desired_states;
|
||||
@@ -0,0 +1,57 @@
|
||||
CREATE TABLE IF NOT EXISTS node_workload_desired_states (
|
||||
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
|
||||
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
|
||||
service_type TEXT NOT NULL,
|
||||
desired_state TEXT NOT NULL DEFAULT 'disabled',
|
||||
version TEXT,
|
||||
runtime_mode TEXT NOT NULL DEFAULT 'container',
|
||||
artifact_ref TEXT,
|
||||
config JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
environment JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
updated_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
PRIMARY KEY (cluster_id, node_id, service_type),
|
||||
CONSTRAINT node_workload_desired_states_desired_state_check
|
||||
CHECK (desired_state IN ('enabled', 'disabled', 'drain')),
|
||||
CONSTRAINT node_workload_desired_states_runtime_mode_check
|
||||
CHECK (runtime_mode IN ('native', 'container'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_node_workload_desired_states_cluster
|
||||
ON node_workload_desired_states(cluster_id, service_type, desired_state);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS node_workload_status_reports (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
|
||||
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
|
||||
service_type TEXT NOT NULL,
|
||||
reported_state TEXT NOT NULL DEFAULT 'unknown',
|
||||
runtime_mode TEXT NOT NULL DEFAULT 'container',
|
||||
version TEXT,
|
||||
status_payload JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
observed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT node_workload_status_reports_reported_state_check
|
||||
CHECK (reported_state IN ('unknown', 'starting', 'running', 'degraded', 'stopped', 'failed', 'not_implemented')),
|
||||
CONSTRAINT node_workload_status_reports_runtime_mode_check
|
||||
CHECK (runtime_mode IN ('native', 'container'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_node_workload_status_reports_node_observed
|
||||
ON node_workload_status_reports(cluster_id, node_id, observed_at DESC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS node_workload_latest_statuses (
|
||||
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
|
||||
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
|
||||
service_type TEXT NOT NULL,
|
||||
status_report_id UUID REFERENCES node_workload_status_reports(id) ON DELETE SET NULL,
|
||||
reported_state TEXT NOT NULL DEFAULT 'unknown',
|
||||
runtime_mode TEXT NOT NULL DEFAULT 'container',
|
||||
version TEXT,
|
||||
status_payload JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
observed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
PRIMARY KEY (cluster_id, node_id, service_type),
|
||||
CONSTRAINT node_workload_latest_statuses_reported_state_check
|
||||
CHECK (reported_state IN ('unknown', 'starting', 'running', 'degraded', 'stopped', 'failed', 'not_implemented')),
|
||||
CONSTRAINT node_workload_latest_statuses_runtime_mode_check
|
||||
CHECK (runtime_mode IN ('native', 'container'))
|
||||
);
|
||||
@@ -0,0 +1,4 @@
|
||||
DROP TABLE IF EXISTS mesh_qos_policies;
|
||||
DROP TABLE IF EXISTS mesh_route_intents;
|
||||
DROP TABLE IF EXISTS mesh_latest_links;
|
||||
DROP TABLE IF EXISTS mesh_link_observations;
|
||||
@@ -0,0 +1,94 @@
|
||||
CREATE TABLE IF NOT EXISTS mesh_link_observations (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
|
||||
source_node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
|
||||
target_node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
|
||||
link_status TEXT NOT NULL DEFAULT 'unknown',
|
||||
latency_ms INTEGER,
|
||||
quality_score INTEGER,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
observed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT mesh_link_observations_status_check
|
||||
CHECK (link_status IN ('unknown', 'reachable', 'degraded', 'unreachable')),
|
||||
CONSTRAINT mesh_link_observations_quality_check
|
||||
CHECK (quality_score IS NULL OR (quality_score >= 0 AND quality_score <= 100))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_mesh_link_observations_cluster_observed
|
||||
ON mesh_link_observations(cluster_id, observed_at DESC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS mesh_latest_links (
|
||||
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
|
||||
source_node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
|
||||
target_node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
|
||||
observation_id UUID REFERENCES mesh_link_observations(id) ON DELETE SET NULL,
|
||||
link_status TEXT NOT NULL DEFAULT 'unknown',
|
||||
latency_ms INTEGER,
|
||||
quality_score INTEGER,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
observed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
PRIMARY KEY (cluster_id, source_node_id, target_node_id),
|
||||
CONSTRAINT mesh_latest_links_status_check
|
||||
CHECK (link_status IN ('unknown', 'reachable', 'degraded', 'unreachable')),
|
||||
CONSTRAINT mesh_latest_links_quality_check
|
||||
CHECK (quality_score IS NULL OR (quality_score >= 0 AND quality_score <= 100))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS mesh_route_intents (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
|
||||
source_selector JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
destination_selector JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
service_class TEXT NOT NULL,
|
||||
priority INTEGER NOT NULL DEFAULT 100,
|
||||
status TEXT NOT NULL DEFAULT 'active',
|
||||
policy JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
created_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT mesh_route_intents_service_class_check
|
||||
CHECK (service_class IN ('input', 'control', 'render', 'clipboard', 'file_transfer', 'vpn_packets', 'telemetry')),
|
||||
CONSTRAINT mesh_route_intents_status_check
|
||||
CHECK (status IN ('active', 'disabled'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_mesh_route_intents_cluster_class
|
||||
ON mesh_route_intents(cluster_id, service_class, status);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS mesh_qos_policies (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
|
||||
service_class TEXT NOT NULL,
|
||||
priority INTEGER NOT NULL,
|
||||
reliability_mode TEXT NOT NULL,
|
||||
drop_policy TEXT NOT NULL,
|
||||
bandwidth_policy JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE (cluster_id, service_class),
|
||||
CONSTRAINT mesh_qos_policies_service_class_check
|
||||
CHECK (service_class IN ('input', 'control', 'render', 'clipboard', 'file_transfer', 'vpn_packets', 'telemetry')),
|
||||
CONSTRAINT mesh_qos_policies_reliability_check
|
||||
CHECK (reliability_mode IN ('reliable', 'droppable', 'adaptive')),
|
||||
CONSTRAINT mesh_qos_policies_drop_policy_check
|
||||
CHECK (drop_policy IN ('never', 'latest_only', 'adaptive'))
|
||||
);
|
||||
|
||||
INSERT INTO mesh_qos_policies (
|
||||
cluster_id, service_class, priority, reliability_mode, drop_policy, bandwidth_policy, metadata
|
||||
)
|
||||
SELECT c.id, defaults.service_class, defaults.priority, defaults.reliability_mode,
|
||||
defaults.drop_policy, '{}'::jsonb, '{"default":true}'::jsonb
|
||||
FROM clusters c
|
||||
CROSS JOIN (
|
||||
VALUES
|
||||
('input', 10, 'reliable', 'never'),
|
||||
('control', 20, 'reliable', 'never'),
|
||||
('clipboard', 40, 'reliable', 'never'),
|
||||
('render', 60, 'droppable', 'latest_only'),
|
||||
('file_transfer', 80, 'reliable', 'never'),
|
||||
('telemetry', 120, 'adaptive', 'adaptive'),
|
||||
('vpn_packets', 160, 'adaptive', 'adaptive')
|
||||
) AS defaults(service_class, priority, reliability_mode, drop_policy)
|
||||
ON CONFLICT (cluster_id, service_class) DO NOTHING;
|
||||
@@ -0,0 +1,2 @@
|
||||
DROP VIEW IF EXISTS cluster_admin_summaries;
|
||||
DROP TABLE IF EXISTS cluster_authority_states;
|
||||
@@ -0,0 +1,40 @@
|
||||
CREATE TABLE IF NOT EXISTS cluster_authority_states (
|
||||
cluster_id UUID PRIMARY KEY REFERENCES clusters(id) ON DELETE CASCADE,
|
||||
authority_state TEXT NOT NULL DEFAULT 'authoritative',
|
||||
mutation_mode TEXT NOT NULL DEFAULT 'normal',
|
||||
term BIGINT NOT NULL DEFAULT 1,
|
||||
notes TEXT,
|
||||
updated_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT cluster_authority_states_authority_check
|
||||
CHECK (authority_state IN ('authoritative', 'minority', 'isolated', 'recovery')),
|
||||
CONSTRAINT cluster_authority_states_mutation_check
|
||||
CHECK (mutation_mode IN ('normal', 'read_only', 'recovery_override'))
|
||||
);
|
||||
|
||||
INSERT INTO cluster_authority_states (cluster_id, authority_state, mutation_mode, term, notes)
|
||||
SELECT id, 'authoritative', 'normal', 1, 'default authority state'
|
||||
FROM clusters
|
||||
ON CONFLICT (cluster_id) DO NOTHING;
|
||||
|
||||
CREATE OR REPLACE VIEW cluster_admin_summaries AS
|
||||
SELECT
|
||||
c.id AS cluster_id,
|
||||
c.slug,
|
||||
c.name,
|
||||
c.status,
|
||||
c.region,
|
||||
COALESCE(cas.authority_state, 'authoritative') AS authority_state,
|
||||
COALESCE(cas.mutation_mode, 'normal') AS mutation_mode,
|
||||
COUNT(DISTINCT cm.node_id) AS node_count,
|
||||
COUNT(DISTINCT CASE WHEN n.health_status = 'healthy' THEN n.id END) AS healthy_node_count,
|
||||
COUNT(DISTINCT CASE WHEN njr.status = 'pending' THEN njr.id END) AS pending_join_count,
|
||||
COUNT(DISTINCT nra.id) AS active_role_assignment_count,
|
||||
MAX(n.last_seen_at) AS last_node_seen_at
|
||||
FROM clusters c
|
||||
LEFT JOIN cluster_authority_states cas ON cas.cluster_id = c.id
|
||||
LEFT JOIN cluster_memberships cm ON cm.cluster_id = c.id
|
||||
LEFT JOIN nodes n ON n.id = cm.node_id
|
||||
LEFT JOIN node_join_requests njr ON njr.cluster_id = c.id
|
||||
LEFT JOIN node_role_assignments nra ON nra.cluster_id = c.id AND nra.status = 'active'
|
||||
GROUP BY c.id, c.slug, c.name, c.status, c.region, cas.authority_state, cas.mutation_mode;
|
||||
@@ -0,0 +1,4 @@
|
||||
DROP TABLE IF EXISTS vpn_connection_leases;
|
||||
DROP TABLE IF EXISTS vpn_connection_route_policies;
|
||||
DROP TABLE IF EXISTS vpn_connection_allowed_nodes;
|
||||
DROP TABLE IF EXISTS vpn_connections;
|
||||
@@ -0,0 +1,125 @@
|
||||
CREATE TABLE IF NOT EXISTS vpn_connections (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
|
||||
organization_id UUID NOT NULL REFERENCES organizations(id) ON DELETE CASCADE,
|
||||
name TEXT NOT NULL,
|
||||
target_endpoint JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
protocol_family TEXT NOT NULL DEFAULT 'generic',
|
||||
credential_ref TEXT,
|
||||
mode TEXT NOT NULL DEFAULT 'single_active',
|
||||
desired_state TEXT NOT NULL DEFAULT 'disabled',
|
||||
allowed_node_policy JSONB NOT NULL DEFAULT '{"mode":"explicit","node_ids":[]}'::JSONB,
|
||||
routing_usage JSONB NOT NULL DEFAULT '[]'::JSONB,
|
||||
route_policy JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
qos_policy JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
placement_policy JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
status TEXT NOT NULL DEFAULT 'disabled',
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
created_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
updated_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT vpn_connections_mode_check
|
||||
CHECK (mode IN ('single_active')),
|
||||
CONSTRAINT vpn_connections_desired_state_check
|
||||
CHECK (desired_state IN ('enabled', 'disabled')),
|
||||
CONSTRAINT vpn_connections_status_check
|
||||
CHECK (status IN ('disabled', 'enabled', 'connecting', 'active', 'degraded', 'failed')),
|
||||
CONSTRAINT vpn_connections_protocol_family_check
|
||||
CHECK (protocol_family IN ('generic', 'wireguard', 'ipsec', 'openvpn'))
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_vpn_connections_cluster_org_name
|
||||
ON vpn_connections(cluster_id, organization_id, lower(name));
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_vpn_connections_cluster_org_state
|
||||
ON vpn_connections(cluster_id, organization_id, desired_state, status);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS vpn_connection_allowed_nodes (
|
||||
vpn_connection_id UUID NOT NULL REFERENCES vpn_connections(id) ON DELETE CASCADE,
|
||||
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
|
||||
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
|
||||
role_preference TEXT NOT NULL DEFAULT 'candidate',
|
||||
status TEXT NOT NULL DEFAULT 'active',
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
created_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
PRIMARY KEY (vpn_connection_id, node_id),
|
||||
CONSTRAINT vpn_connection_allowed_nodes_membership_fk
|
||||
FOREIGN KEY (cluster_id, node_id)
|
||||
REFERENCES cluster_memberships(cluster_id, node_id)
|
||||
ON DELETE CASCADE,
|
||||
CONSTRAINT vpn_connection_allowed_nodes_preference_check
|
||||
CHECK (role_preference IN ('candidate', 'standby', 'preferred')),
|
||||
CONSTRAINT vpn_connection_allowed_nodes_status_check
|
||||
CHECK (status IN ('active', 'disabled'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_vpn_connection_allowed_nodes_cluster_node
|
||||
ON vpn_connection_allowed_nodes(cluster_id, node_id, status);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS vpn_connection_route_policies (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
vpn_connection_id UUID NOT NULL REFERENCES vpn_connections(id) ON DELETE CASCADE,
|
||||
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
|
||||
organization_id UUID NOT NULL REFERENCES organizations(id) ON DELETE CASCADE,
|
||||
route_type TEXT NOT NULL,
|
||||
destination TEXT NOT NULL,
|
||||
action TEXT NOT NULL DEFAULT 'allow',
|
||||
service_type TEXT,
|
||||
priority INTEGER NOT NULL DEFAULT 100,
|
||||
policy JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
status TEXT NOT NULL DEFAULT 'active',
|
||||
created_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT vpn_connection_route_policies_route_type_check
|
||||
CHECK (route_type IN ('cidr', 'dns_suffix', 'service', 'resource')),
|
||||
CONSTRAINT vpn_connection_route_policies_action_check
|
||||
CHECK (action IN ('allow', 'deny')),
|
||||
CONSTRAINT vpn_connection_route_policies_status_check
|
||||
CHECK (status IN ('active', 'disabled'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_vpn_connection_route_policies_connection
|
||||
ON vpn_connection_route_policies(vpn_connection_id, status, priority);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_vpn_connection_route_policies_cluster_org
|
||||
ON vpn_connection_route_policies(cluster_id, organization_id, route_type, status);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS vpn_connection_leases (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
vpn_connection_id UUID NOT NULL REFERENCES vpn_connections(id) ON DELETE CASCADE,
|
||||
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
|
||||
owner_node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
|
||||
lease_generation BIGINT NOT NULL,
|
||||
fencing_token TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'active',
|
||||
acquired_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
renewed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
released_at TIMESTAMPTZ,
|
||||
fenced_at TIMESTAMPTZ,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
CONSTRAINT vpn_connection_leases_membership_fk
|
||||
FOREIGN KEY (cluster_id, owner_node_id)
|
||||
REFERENCES cluster_memberships(cluster_id, node_id)
|
||||
ON DELETE CASCADE,
|
||||
CONSTRAINT vpn_connection_leases_status_check
|
||||
CHECK (status IN ('active', 'released', 'expired', 'fenced')),
|
||||
CONSTRAINT vpn_connection_leases_expiry_check
|
||||
CHECK (expires_at > acquired_at)
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_vpn_connection_leases_generation
|
||||
ON vpn_connection_leases(vpn_connection_id, lease_generation);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_vpn_connection_leases_single_active
|
||||
ON vpn_connection_leases(vpn_connection_id)
|
||||
WHERE status = 'active';
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_vpn_connection_leases_owner
|
||||
ON vpn_connection_leases(cluster_id, owner_node_id, status, expires_at);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_vpn_connection_leases_connection_time
|
||||
ON vpn_connection_leases(vpn_connection_id, acquired_at DESC);
|
||||
@@ -0,0 +1,2 @@
|
||||
DROP TABLE IF EXISTS vpn_connection_assignment_latest_statuses;
|
||||
DROP TABLE IF EXISTS vpn_connection_assignment_status_reports;
|
||||
@@ -0,0 +1,42 @@
|
||||
CREATE TABLE IF NOT EXISTS vpn_connection_assignment_status_reports (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
vpn_connection_id UUID NOT NULL REFERENCES vpn_connections(id) ON DELETE CASCADE,
|
||||
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
|
||||
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
|
||||
observed_status TEXT NOT NULL DEFAULT 'unknown',
|
||||
status_payload JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
observed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT vpn_connection_assignment_status_membership_fk
|
||||
FOREIGN KEY (cluster_id, node_id)
|
||||
REFERENCES cluster_memberships(cluster_id, node_id)
|
||||
ON DELETE CASCADE,
|
||||
CONSTRAINT vpn_connection_assignment_status_check
|
||||
CHECK (observed_status IN ('not_started', 'assigned', 'lease_required', 'blocked', 'unknown'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_vpn_assignment_status_reports_node
|
||||
ON vpn_connection_assignment_status_reports(cluster_id, node_id, observed_at DESC);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_vpn_assignment_status_reports_connection
|
||||
ON vpn_connection_assignment_status_reports(vpn_connection_id, observed_at DESC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS vpn_connection_assignment_latest_statuses (
|
||||
vpn_connection_id UUID NOT NULL REFERENCES vpn_connections(id) ON DELETE CASCADE,
|
||||
cluster_id UUID NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
|
||||
node_id UUID NOT NULL REFERENCES nodes(id) ON DELETE CASCADE,
|
||||
report_id UUID NOT NULL REFERENCES vpn_connection_assignment_status_reports(id) ON DELETE CASCADE,
|
||||
observed_status TEXT NOT NULL,
|
||||
status_payload JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||
observed_at TIMESTAMPTZ NOT NULL,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
PRIMARY KEY (vpn_connection_id, node_id),
|
||||
CONSTRAINT vpn_connection_assignment_latest_membership_fk
|
||||
FOREIGN KEY (cluster_id, node_id)
|
||||
REFERENCES cluster_memberships(cluster_id, node_id)
|
||||
ON DELETE CASCADE,
|
||||
CONSTRAINT vpn_connection_assignment_latest_status_check
|
||||
CHECK (observed_status IN ('not_started', 'assigned', 'lease_required', 'blocked', 'unknown'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_vpn_assignment_latest_node
|
||||
ON vpn_connection_assignment_latest_statuses(cluster_id, node_id, updated_at DESC);
|
||||
@@ -0,0 +1,5 @@
|
||||
DROP INDEX IF EXISTS idx_node_telemetry_cluster_node_observed;
|
||||
DROP TABLE IF EXISTS node_telemetry_observations;
|
||||
|
||||
DROP INDEX IF EXISTS idx_fabric_testing_flags_unique_scope;
|
||||
DROP TABLE IF EXISTS fabric_testing_flags;
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user