Initial project snapshot

This commit is contained in:
2026-04-28 22:29:50 +03:00
commit 8ba0561f4f
365 changed files with 91832 additions and 0 deletions
@@ -0,0 +1,219 @@
package sessionbroker
import (
"crypto/rsa"
"fmt"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
)
const (
directWorkerTLSTrustModeSmokeInsecure = "smoke_insecure"
directWorkerTLSTrustModePublicCA = "public_ca"
directWorkerTLSTrustModePlatformCA = "platform_ca"
)
type DataPlaneTokenClaims struct {
SessionID string `json:"session_id"`
AttachmentID string `json:"attachment_id"`
UserID string `json:"user_id"`
OrganizationID string `json:"organization_id"`
ClusterID string `json:"cluster_id,omitempty"`
WorkerID string `json:"worker_id"`
ResourceID string `json:"resource_id"`
AllowedChannels []string `json:"allowed_channels"`
ExpiresAtValue time.Time `json:"expires_at"`
jwt.RegisteredClaims
}
func (s *Service) buildDataPlaneOffer(session RemoteSession, attachment SessionAttachment) (*sessioncontracts.DataPlaneOffer, error) {
if s.cfg.DataPlane.TokenTTL <= 0 || s.cfg.DataPlane.TokenPrivateKeyPEM == "" {
return nil, nil
}
now := s.now().UTC()
expiresAt := now.Add(s.cfg.DataPlane.TokenTTL)
allowedChannels := dataPlaneAllowedChannelsFromSession(session)
jti := uuid.NewString()
claims := DataPlaneTokenClaims{
SessionID: session.ID,
AttachmentID: attachment.ID,
UserID: attachment.UserID,
OrganizationID: session.OrganizationID,
WorkerID: session.WorkerID,
ResourceID: session.ResourceID,
AllowedChannels: allowedChannels,
ExpiresAtValue: expiresAt,
RegisteredClaims: jwt.RegisteredClaims{
ID: jti,
Issuer: s.cfg.Auth.Issuer,
Subject: attachment.UserID,
Audience: jwt.ClaimStrings{"rap-data-plane", "worker:" + session.WorkerID},
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(expiresAt),
},
}
token, err := signDataPlaneToken(claims, s.cfg.DataPlane.TokenPrivateKeyPEM)
if err != nil {
return nil, err
}
candidates := s.buildDataPlaneCandidates(session)
preferred := sessioncontracts.DataPlaneCandidateBackendGateway
if len(candidates) > 0 {
preferred = candidates[0].Type
}
return &sessioncontracts.DataPlaneOffer{
Preferred: preferred,
Token: token,
ExpiresAt: expiresAt,
Candidates: candidates,
}, nil
}
func (s *Service) buildDataPlaneCandidates(session RemoteSession) []sessioncontracts.DataPlaneCandidate {
var candidates []sessioncontracts.DataPlaneCandidate
if directURL := s.directWorkerWSSURL(session.WorkerID); directURL != "" && s.canAdvertiseDirectWorkerWSS() {
metadata := map[string]any(nil)
if s.cfg.DataPlane.DirectWorkerJSONRuntime {
metadata = map[string]any{
"runtime_transport": "json_v1",
"traffic_ready": true,
}
s.addDirectWorkerTLSTrustMetadata(metadata)
if s.cfg.DataPlane.DirectWorkerBinaryRender {
metadata["render_transport"] = "binary_v1"
metadata["binary_render"] = true
metadata["supported_color_modes"] = []string{"full_color", "grayscale"}
metadata["default_color_mode"] = "full_color"
}
}
candidates = append(candidates, sessioncontracts.DataPlaneCandidate{
Type: sessioncontracts.DataPlaneCandidateDirectWorkerWSS,
URL: directURL,
WorkerID: session.WorkerID,
Priority: 10,
Metadata: metadata,
})
}
if s.cfg.DataPlane.BackendGatewayURL != "" {
candidates = append(candidates, sessioncontracts.DataPlaneCandidate{
Type: sessioncontracts.DataPlaneCandidateBackendGateway,
URL: s.cfg.DataPlane.BackendGatewayURL,
Priority: 100,
})
}
return candidates
}
func (s *Service) canAdvertiseDirectWorkerWSS() bool {
trustMode := normalizeDirectWorkerTLSTrustMode(s.cfg.DataPlane.DirectWorkerTLSTrustMode)
return !secrets.IsProductionEnv(s.cfg.App.Env) || directWorkerTLSTrustModeIsProductionTrusted(trustMode)
}
func (s *Service) addDirectWorkerTLSTrustMetadata(metadata map[string]any) {
trustMode := normalizeDirectWorkerTLSTrustMode(s.cfg.DataPlane.DirectWorkerTLSTrustMode)
metadata["tls_trust_mode"] = trustMode
metadata["production_trusted"] = directWorkerTLSTrustModeIsProductionTrusted(trustMode)
metadata["smoke_only"] = trustMode == directWorkerTLSTrustModeSmokeInsecure
if s.cfg.DataPlane.DirectWorkerTLSCARef != "" {
metadata["tls_ca_ref"] = s.cfg.DataPlane.DirectWorkerTLSCARef
}
}
func normalizeDirectWorkerTLSTrustMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case directWorkerTLSTrustModePublicCA:
return directWorkerTLSTrustModePublicCA
case directWorkerTLSTrustModePlatformCA:
return directWorkerTLSTrustModePlatformCA
default:
return directWorkerTLSTrustModeSmokeInsecure
}
}
func directWorkerTLSTrustModeIsProductionTrusted(mode string) bool {
return mode == directWorkerTLSTrustModePublicCA || mode == directWorkerTLSTrustModePlatformCA
}
func (s *Service) directWorkerWSSURL(workerID string) string {
template := strings.TrimSpace(s.cfg.DataPlane.DirectWorkerWSSURLTemplate)
if template == "" || workerID == "" {
return ""
}
return strings.ReplaceAll(template, "{worker_id}", workerID)
}
func signDataPlaneToken(claims DataPlaneTokenClaims, privateKeyPEM string) (string, error) {
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(privateKeyPEM))
if err != nil {
return "", fmt.Errorf("parse data-plane private key: %w", err)
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
signed, err := token.SignedString(privateKey)
if err != nil {
return "", fmt.Errorf("sign data-plane token: %w", err)
}
return signed, nil
}
func parseDataPlaneToken(tokenValue string, publicKey *rsa.PublicKey) (*DataPlaneTokenClaims, error) {
claims := &DataPlaneTokenClaims{}
token, err := jwt.ParseWithClaims(tokenValue, claims, func(token *jwt.Token) (any, error) {
if token.Method != jwt.SigningMethodRS256 {
return nil, fmt.Errorf("unexpected data-plane signing method: %s", token.Header["alg"])
}
return publicKey, nil
})
if err != nil {
return nil, err
}
if !token.Valid {
return nil, fmt.Errorf("data-plane token invalid")
}
return claims, nil
}
func dataPlaneAllowedChannelsFromSession(session RemoteSession) []string {
channels := []string{
sessioncontracts.DataPlaneChannelControl,
sessioncontracts.DataPlaneChannelInput,
sessioncontracts.DataPlaneChannelRender,
sessioncontracts.DataPlaneChannelTelemetry,
}
metadata := decodeJSONMap(session.Metadata)
policy, _ := metadata["policy"].(map[string]any)
if policy != nil {
if mode, _ := policy["clipboard_mode"].(string); mode != "" && mode != string(ResourceClipboardModeDisabled) {
channels = append(channels, sessioncontracts.DataPlaneChannelClipboard)
}
if mode, _ := policy["file_transfer_mode"].(string); fileTransferAllowsClientToServer(ResourceFileTransferMode(mode)) {
channels = append(channels, sessioncontracts.DataPlaneChannelFileUpload)
}
if mode, _ := policy["file_transfer_mode"].(string); fileTransferAllowsServerToClient(ResourceFileTransferMode(mode)) {
channels = append(channels, sessioncontracts.DataPlaneChannelFileDownload)
}
}
return channels
}
func (s *Service) attachDataPlaneOffer(result *SessionControlResult) error {
if result == nil || result.Attachment == nil {
return nil
}
result.GatewayURL = s.cfg.DataPlane.BackendGatewayURL
offer, err := s.buildDataPlaneOffer(result.Session, *result.Attachment)
if err != nil {
return err
}
result.DataPlane = offer
return nil
}
@@ -0,0 +1,357 @@
package sessionbroker
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"slices"
"strings"
"testing"
"time"
"github.com/example/remote-access-platform/backend/internal/platform/config"
"github.com/example/remote-access-platform/backend/internal/platform/module"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
)
func TestDataPlaneTokenScopeValidation(t *testing.T) {
now := time.Now().UTC().Truncate(time.Second)
privateKeyPEM, publicKey := testRS256Key(t)
service := &Service{
cfg: module.Config{
Auth: config.AuthConfig{
Issuer: "rap-api-test",
},
DataPlane: config.DataPlaneConfig{
TokenTTL: time.Minute,
TokenPrivateKeyPEM: privateKeyPEM,
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
},
},
now: func() time.Time { return now },
}
session := RemoteSession{
ID: "session-1",
OrganizationID: "org-1",
ResourceID: "resource-1",
WorkerID: "worker-1",
Metadata: mustJSON(t, map[string]any{"policy": map[string]any{"clipboard_mode": "bidirectional", "file_transfer_mode": "client_to_server"}}),
}
attachment := SessionAttachment{
ID: "attachment-1",
UserID: "user-1",
}
offer, err := service.buildDataPlaneOffer(session, attachment)
if err != nil {
t.Fatalf("buildDataPlaneOffer returned error: %v", err)
}
if offer == nil {
t.Fatal("expected data-plane offer")
}
claims, err := parseDataPlaneToken(offer.Token, publicKey)
if err != nil {
t.Fatalf("parseDataPlaneToken returned error: %v", err)
}
assertEqual(t, claims.SessionID, session.ID, "session_id")
assertEqual(t, claims.AttachmentID, attachment.ID, "attachment_id")
assertEqual(t, claims.UserID, attachment.UserID, "user_id")
assertEqual(t, claims.OrganizationID, session.OrganizationID, "organization_id")
assertEqual(t, claims.WorkerID, session.WorkerID, "worker_id")
assertEqual(t, claims.ResourceID, session.ResourceID, "resource_id")
if claims.ID == "" {
t.Fatal("expected jti")
}
if claims.ExpiresAt == nil || !claims.ExpiresAt.Time.Equal(now.Add(time.Minute)) {
t.Fatalf("unexpected expires_at: %v", claims.ExpiresAt)
}
if !claims.ExpiresAtValue.Equal(now.Add(time.Minute)) {
t.Fatalf("unexpected expires_at claim value: %v", claims.ExpiresAtValue)
}
for _, channel := range []string{
sessioncontracts.DataPlaneChannelControl,
sessioncontracts.DataPlaneChannelInput,
sessioncontracts.DataPlaneChannelRender,
sessioncontracts.DataPlaneChannelTelemetry,
sessioncontracts.DataPlaneChannelClipboard,
sessioncontracts.DataPlaneChannelFileUpload,
} {
if !slices.Contains(claims.AllowedChannels, channel) {
t.Fatalf("expected allowed channel %q in %v", channel, claims.AllowedChannels)
}
}
}
func TestDataPlaneOfferResponseShapeCompatibility(t *testing.T) {
now := time.Now().UTC().Truncate(time.Second)
privateKeyPEM, _ := testRS256Key(t)
service := &Service{
cfg: module.Config{
Auth: config.AuthConfig{Issuer: "rap-api-test"},
DataPlane: config.DataPlaneConfig{
TokenTTL: time.Minute,
TokenPrivateKeyPEM: privateKeyPEM,
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
DirectWorkerJSONRuntime: true,
DirectWorkerTLSTrustMode: "smoke_insecure",
},
},
now: func() time.Time { return now },
}
result := &SessionControlResult{
Session: RemoteSession{
ID: "session-1",
OrganizationID: "org-1",
ResourceID: "resource-1",
WorkerID: "worker-1",
Metadata: mustJSON(t, map[string]any{"policy": map[string]any{"clipboard_mode": "disabled", "file_transfer_mode": "disabled"}}),
},
Attachment: &SessionAttachment{ID: "attachment-1", UserID: "user-1"},
AttachToken: &sessioncontracts.AttachTokenClaims{
Token: "existing-attach-token",
SessionID: "session-1",
AttachmentID: "attachment-1",
UserID: "user-1",
WorkerID: "worker-1",
ExpiresAt: now.Add(2 * time.Minute),
},
}
if err := service.attachDataPlaneOffer(result); err != nil {
t.Fatalf("attachDataPlaneOffer returned error: %v", err)
}
payload, err := json.Marshal(result)
if err != nil {
t.Fatalf("marshal response: %v", err)
}
var decoded map[string]any
if err := json.Unmarshal(payload, &decoded); err != nil {
t.Fatalf("decode response: %v", err)
}
if decoded["session"] == nil || decoded["attachment"] == nil || decoded["attach_token"] == nil {
t.Fatalf("response lost existing fields: %s", payload)
}
if decoded["data_plane"] == nil || decoded["gateway_url"] == nil {
t.Fatalf("response missing data-plane fields: %s", payload)
}
if result.DataPlane == nil {
t.Fatal("expected data-plane offer")
}
if result.DataPlane.Preferred != sessioncontracts.DataPlaneCandidateDirectWorkerWSS {
t.Fatalf("unexpected preferred candidate: %s", result.DataPlane.Preferred)
}
if len(result.DataPlane.Candidates) != 2 {
t.Fatalf("expected direct and fallback candidates, got %d", len(result.DataPlane.Candidates))
}
if result.DataPlane.Candidates[0].URL != "wss://worker-1.worker.example.test/rap/v1/data-plane" {
t.Fatalf("unexpected direct candidate URL: %s", result.DataPlane.Candidates[0].URL)
}
if result.DataPlane.Candidates[0].Metadata["runtime_transport"] != "json_v1" {
t.Fatalf("direct candidate is missing json_v1 runtime metadata: %#v", result.DataPlane.Candidates[0].Metadata)
}
if result.DataPlane.Candidates[0].Metadata["traffic_ready"] != true {
t.Fatalf("direct candidate is missing traffic_ready metadata: %#v", result.DataPlane.Candidates[0].Metadata)
}
if result.DataPlane.Candidates[0].Metadata["smoke_only"] != true {
t.Fatalf("direct candidate should be marked smoke-only by default: %#v", result.DataPlane.Candidates[0].Metadata)
}
if result.DataPlane.Candidates[0].Metadata["production_trusted"] != false {
t.Fatalf("smoke direct candidate must not be production-trusted: %#v", result.DataPlane.Candidates[0].Metadata)
}
if !strings.Contains(result.DataPlane.Candidates[1].URL, "/api/v1/gateway/ws") {
t.Fatalf("unexpected backend candidate URL: %s", result.DataPlane.Candidates[1].URL)
}
}
func TestDataPlaneDirectCandidateMetadataRequiresRuntimeFlag(t *testing.T) {
service := &Service{
cfg: module.Config{
DataPlane: config.DataPlaneConfig{
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
DirectWorkerTLSTrustMode: "smoke_insecure",
},
},
}
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
if len(candidates) != 2 {
t.Fatalf("expected direct and fallback candidates, got %d", len(candidates))
}
if candidates[0].Metadata != nil {
t.Fatalf("direct candidate must not advertise json_v1 before runtime flag is enabled: %#v", candidates[0].Metadata)
}
}
func TestDataPlaneDirectCandidateAdvertisesBinaryRenderOnlyWhenEnabled(t *testing.T) {
service := &Service{
cfg: module.Config{
DataPlane: config.DataPlaneConfig{
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
DirectWorkerJSONRuntime: true,
DirectWorkerBinaryRender: true,
DirectWorkerTLSTrustMode: "platform_ca",
DirectWorkerTLSCARef: "rap-platform-ca:v1",
},
},
}
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
if len(candidates) != 2 {
t.Fatalf("expected direct and fallback candidates, got %d", len(candidates))
}
if candidates[0].Metadata["render_transport"] != "binary_v1" {
t.Fatalf("direct candidate is missing binary render metadata: %#v", candidates[0].Metadata)
}
if candidates[0].Metadata["binary_render"] != true {
t.Fatalf("direct candidate is missing binary_render metadata: %#v", candidates[0].Metadata)
}
if candidates[0].Metadata["default_color_mode"] != "full_color" {
t.Fatalf("direct candidate is missing default_color_mode metadata: %#v", candidates[0].Metadata)
}
if candidates[0].Metadata["production_trusted"] != true || candidates[0].Metadata["tls_trust_mode"] != "platform_ca" {
t.Fatalf("direct candidate is missing production trust metadata: %#v", candidates[0].Metadata)
}
if candidates[0].Metadata["tls_ca_ref"] != "rap-platform-ca:v1" {
t.Fatalf("direct candidate is missing tls_ca_ref metadata: %#v", candidates[0].Metadata)
}
modes, ok := candidates[0].Metadata["supported_color_modes"].([]string)
if !ok || !slices.Contains(modes, "full_color") || !slices.Contains(modes, "grayscale") {
t.Fatalf("direct candidate is missing supported_color_modes metadata: %#v", candidates[0].Metadata)
}
}
func TestDataPlaneDirectCandidateOmittedInProductionWhenSmokeOnly(t *testing.T) {
service := &Service{
cfg: module.Config{
App: config.AppConfig{Env: "production"},
DataPlane: config.DataPlaneConfig{
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
DirectWorkerJSONRuntime: true,
DirectWorkerTLSTrustMode: "smoke_insecure",
},
},
}
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
if len(candidates) != 1 {
t.Fatalf("expected fallback-only candidates in production with smoke TLS, got %d", len(candidates))
}
if candidates[0].Type != sessioncontracts.DataPlaneCandidateBackendGateway {
t.Fatalf("production must not advertise smoke-only direct candidate: %#v", candidates)
}
}
func TestDataPlaneDirectCandidateAdvertisedInProductionWhenTrusted(t *testing.T) {
service := &Service{
cfg: module.Config{
App: config.AppConfig{Env: "production"},
DataPlane: config.DataPlaneConfig{
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
DirectWorkerJSONRuntime: true,
DirectWorkerTLSTrustMode: "public_ca",
},
},
}
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
if len(candidates) != 2 {
t.Fatalf("expected trusted direct and fallback candidates, got %d", len(candidates))
}
if candidates[0].Metadata["production_trusted"] != true || candidates[0].Metadata["tls_trust_mode"] != "public_ca" {
t.Fatalf("trusted production direct candidate metadata mismatch: %#v", candidates[0].Metadata)
}
}
func TestDataPlaneCandidatesFallbackOnlyWhenDirectTemplateMissing(t *testing.T) {
service := &Service{
cfg: module.Config{
DataPlane: config.DataPlaneConfig{
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
},
},
}
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
if len(candidates) != 1 {
t.Fatalf("expected fallback-only candidate list, got %d", len(candidates))
}
if candidates[0].Type != sessioncontracts.DataPlaneCandidateBackendGateway {
t.Fatalf("unexpected candidate type: %s", candidates[0].Type)
}
}
func TestDataPlaneAllowedChannelsRespectRuntimePolicy(t *testing.T) {
cases := []struct {
name string
policy map[string]any
expected []string
blocked []string
}{
{
name: "disabled policies expose only control input render telemetry",
policy: map[string]any{"clipboard_mode": "disabled", "file_transfer_mode": "disabled"},
expected: []string{sessioncontracts.DataPlaneChannelControl, sessioncontracts.DataPlaneChannelInput, sessioncontracts.DataPlaneChannelRender, sessioncontracts.DataPlaneChannelTelemetry},
blocked: []string{sessioncontracts.DataPlaneChannelClipboard, sessioncontracts.DataPlaneChannelFileUpload},
},
{
name: "clipboard policy adds clipboard channel",
policy: map[string]any{"clipboard_mode": "server_to_client", "file_transfer_mode": "disabled"},
expected: []string{sessioncontracts.DataPlaneChannelClipboard},
blocked: []string{sessioncontracts.DataPlaneChannelFileUpload},
},
{
name: "client upload policy adds file upload channel",
policy: map[string]any{"clipboard_mode": "disabled", "file_transfer_mode": "client_to_server"},
expected: []string{sessioncontracts.DataPlaneChannelFileUpload},
blocked: []string{sessioncontracts.DataPlaneChannelClipboard},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
session := RemoteSession{Metadata: mustJSON(t, map[string]any{"policy": tc.policy})}
channels := dataPlaneAllowedChannelsFromSession(session)
for _, channel := range tc.expected {
if !slices.Contains(channels, channel) {
t.Fatalf("expected channel %q in %v", channel, channels)
}
}
for _, channel := range tc.blocked {
if slices.Contains(channels, channel) {
t.Fatalf("did not expect channel %q in %v", channel, channels)
}
}
})
}
}
func mustJSON(t *testing.T, value any) []byte {
t.Helper()
payload, err := json.Marshal(value)
if err != nil {
t.Fatalf("marshal test metadata: %v", err)
}
return payload
}
func testRS256Key(t *testing.T) (string, *rsa.PublicKey) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("generate RSA key: %v", err)
}
encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})
return string(encoded), &privateKey.PublicKey
}
func assertEqual(t *testing.T, got, want, name string) {
t.Helper()
if got != want {
t.Fatalf("unexpected %s: got %q want %q", name, got, want)
}
}
@@ -0,0 +1,15 @@
package sessionbroker
import "errors"
var (
ErrSessionNotFound = errors.New("remote session not found")
ErrAttachmentNotFound = errors.New("session attachment not found")
ErrActiveControllerPresent = errors.New("active controller already present")
ErrTakeoverNotAllowed = errors.New("takeover not allowed")
ErrTrustedDeviceRequired = errors.New("trusted device required")
ErrAccessDenied = errors.New("access denied")
ErrSessionNotAttachable = errors.New("session is not attachable")
ErrSessionNotTerminable = errors.New("session is not terminable")
ErrAttachTokenInvalid = errors.New("attach token invalid or expired")
)
@@ -0,0 +1,65 @@
package sessionbroker
import (
"context"
"time"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
)
type LiveStateStore interface {
UpsertSession(ctx context.Context, state LiveSessionState) error
GetSession(ctx context.Context, sessionID string) (*LiveSessionState, error)
DeleteSession(ctx context.Context, sessionID string) error
BindController(ctx context.Context, binding sessioncontracts.ControllerBinding, ttl time.Duration) error
GetControllerBinding(ctx context.Context, sessionID string) (*sessioncontracts.ControllerBinding, error)
ClearControllerBinding(ctx context.Context, sessionID string) error
StoreAttachToken(ctx context.Context, claims sessioncontracts.AttachTokenClaims, ttl time.Duration) error
ConsumeAttachToken(ctx context.Context, token string) (*sessioncontracts.AttachTokenClaims, error)
TouchAttachmentHeartbeat(ctx context.Context, sessionID, attachmentID string, ttl time.Duration) error
UpdateWorkerRoute(ctx context.Context, route WorkerRoute, ttl time.Duration) error
GetWorkerRoute(ctx context.Context, sessionID string) (*WorkerRoute, error)
DeleteWorkerRoute(ctx context.Context, sessionID string) error
}
type LiveSessionState struct {
SessionID string `json:"session_id"`
ResourceID string `json:"resource_id"`
WorkerID string `json:"worker_id"`
State sessioncontracts.State `json:"state"`
ControllerID string `json:"controller_id"`
AttachmentID string `json:"attachment_id"`
TakeoverVersion int `json:"takeover_version"`
RenderQualityProfile string `json:"render_quality_profile,omitempty"`
RenderState string `json:"render_state,omitempty"`
RenderWidth int `json:"render_width,omitempty"`
RenderHeight int `json:"render_height,omitempty"`
RenderFrameSequence int64 `json:"render_frame_sequence,omitempty"`
RenderFrameFormat string `json:"render_frame_format,omitempty"`
RenderFrameData string `json:"render_frame_data,omitempty"`
LastInputCorrelationID string `json:"last_input_correlation_id,omitempty"`
WorkerFrameCapturedAt string `json:"worker_frame_captured_at,omitempty"`
CursorX int `json:"cursor_x,omitempty"`
CursorY int `json:"cursor_y,omitempty"`
CursorVisible bool `json:"cursor_visible,omitempty"`
DirtyRectangles int `json:"dirty_rectangles,omitempty"`
LastRenderAt *time.Time `json:"last_render_at,omitempty"`
ClipboardSequence int64 `json:"clipboard_sequence,omitempty"`
ClipboardText string `json:"clipboard_text,omitempty"`
ClipboardOrigin string `json:"clipboard_origin,omitempty"`
ClipboardContentHash string `json:"clipboard_content_hash,omitempty"`
ClipboardUpdatedAt *time.Time `json:"clipboard_updated_at,omitempty"`
FileDownloadSequence int64 `json:"file_download_sequence,omitempty"`
FileDownloadType string `json:"file_download_type,omitempty"`
FileDownloadPayload map[string]any `json:"file_download_payload,omitempty"`
FileDownloadUpdatedAt *time.Time `json:"file_download_updated_at,omitempty"`
UpdatedAt time.Time `json:"updated_at"`
}
type WorkerRoute struct {
SessionID string `json:"session_id"`
WorkerID string `json:"worker_id"`
LeaseID string `json:"lease_id"`
ControlStream string `json:"control_stream"`
UpdatedAt time.Time `json:"updated_at"`
}
@@ -0,0 +1,132 @@
package sessionbroker
import (
"time"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
)
type AttachmentRole string
const (
AttachmentRoleController AttachmentRole = "controller"
)
type AttachmentState string
const (
AttachmentStateAttaching AttachmentState = "attaching"
AttachmentStateActive AttachmentState = "active"
AttachmentStateDetached AttachmentState = "detached"
AttachmentStateSuperseded AttachmentState = "superseded"
AttachmentStateRevoked AttachmentState = "revoked"
AttachmentStateClosed AttachmentState = "closed"
)
type ResourceTakeoverPolicy string
const (
ResourceTakeoverPolicyTrustedDevice ResourceTakeoverPolicy = "trusted_device"
ResourceTakeoverPolicySameUser ResourceTakeoverPolicy = "same_user"
ResourceTakeoverPolicyAdminOnly ResourceTakeoverPolicy = "admin_only"
)
type ResourceClipboardMode string
const (
ResourceClipboardModeDisabled ResourceClipboardMode = "disabled"
ResourceClipboardModeClientToServer ResourceClipboardMode = "client_to_server"
ResourceClipboardModeServerToClient ResourceClipboardMode = "server_to_client"
ResourceClipboardModeBidirectional ResourceClipboardMode = "bidirectional"
)
type ResourceFileTransferMode string
const (
ResourceFileTransferModeDisabled ResourceFileTransferMode = "disabled"
ResourceFileTransferModeClientToServer ResourceFileTransferMode = "client_to_server"
ResourceFileTransferModeServerToClient ResourceFileTransferMode = "server_to_client"
ResourceFileTransferModeBidirectional ResourceFileTransferMode = "bidirectional"
)
type RemoteSession struct {
ID string
OrganizationID string
ResourceID string
Protocol string
State sessioncontracts.State
WorkerID string
ControllerUserID string
DetachDeadlineAt *time.Time
LastHeartbeatAt *time.Time
TakeoverVersion int
RenderQualityProfile string
Metadata []byte
CreatedAt time.Time
UpdatedAt time.Time
}
type SessionAttachment struct {
ID string
RemoteSessionID string
UserID string
DeviceID string
Role AttachmentRole
State AttachmentState
SupersededBy *string
TakeoverOf *string
AttachedAt *time.Time
DetachedAt *time.Time
LastInputAt *time.Time
Metadata []byte
CreatedAt time.Time
UpdatedAt time.Time
}
type ResourcePolicy struct {
ResourceID string
MaxConcurrentSessions int
TakeoverPolicy ResourceTakeoverPolicy
RequireTrustedDevice bool
DetachGracePeriod time.Duration
ClipboardEnabled bool
ClipboardMode ResourceClipboardMode
FileTransferEnabled bool
FileTransferMode ResourceFileTransferMode
CreatedAt time.Time
UpdatedAt time.Time
}
type ResourceRuntimeSpec struct {
ID string
OrganizationID string
Name string
Address string
Protocol string
SecretRef *string
CertificateVerificationMode string
Metadata []byte
}
type AuditEvent struct {
ID string
ActorUserID *string
ActorDeviceID *string
EventType string
TargetType string
TargetID string
RemoteSessionID *string
Payload []byte
CreatedAt time.Time
}
const (
AuditEventSessionStarted = "session_started"
AuditEventSessionAttached = "session_attached"
AuditEventSessionDetached = "session_detached"
AuditEventSessionTakenOver = "session_taken_over"
AuditEventSessionTerminated = "session_terminated"
AuditEventSessionFailed = "session_failed"
AuditEventSecretAccessed = "resource_secret_accessed"
AuditEventSecretAccessDenied = "resource_secret_access_denied"
)
@@ -0,0 +1,164 @@
package sessionbroker
import (
"encoding/json"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
)
type Module struct {
service *Service
}
func NewModule(service *Service) *Module {
return &Module{service: service}
}
func (m *Module) Name() string {
return "session-broker"
}
func (m *Module) Service() *Service {
return m.service
}
func (m *Module) RegisterRoutes(router chi.Router) {
router.Route("/sessions", func(r chi.Router) {
r.Get("/", m.listSessions)
r.Post("/", m.startSession)
r.Post("/{sessionID}/attach", m.attachSession)
r.Post("/{sessionID}/detach", m.detachSession)
r.Post("/{sessionID}/takeover", m.takeoverSession)
r.Post("/{sessionID}/terminate", m.terminateSession)
r.Post("/{sessionID}/fail", m.markFailed)
})
}
func (m *Module) listSessions(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id")
if userID == "" {
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
return
}
sessions, err := m.service.ListSessions(r.Context(), userID)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"sessions": sessions})
}
func (m *Module) startSession(w http.ResponseWriter, r *http.Request) {
var cmd StartRemoteSessionCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid start session payload")
return
}
result, err := m.service.StartRemoteSession(r.Context(), cmd)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusCreated, result)
}
func (m *Module) attachSession(w http.ResponseWriter, r *http.Request) {
var cmd AttachToSessionCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid attach session payload")
return
}
cmd.SessionID = chi.URLParam(r, "sessionID")
result, err := m.service.AttachToSession(r.Context(), cmd)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusOK, result)
}
func (m *Module) detachSession(w http.ResponseWriter, r *http.Request) {
var cmd DetachFromSessionCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid detach session payload")
return
}
cmd.SessionID = chi.URLParam(r, "sessionID")
result, err := m.service.DetachFromSession(r.Context(), cmd)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusAccepted, result)
}
func (m *Module) takeoverSession(w http.ResponseWriter, r *http.Request) {
var cmd TakeoverSessionCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid takeover session payload")
return
}
cmd.SessionID = chi.URLParam(r, "sessionID")
result, err := m.service.TakeoverSession(r.Context(), cmd)
if err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusOK, result)
}
func (m *Module) terminateSession(w http.ResponseWriter, r *http.Request) {
var cmd TerminateSessionCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid terminate session payload")
return
}
cmd.SessionID = chi.URLParam(r, "sessionID")
if err := m.service.TerminateSession(r.Context(), cmd); err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{
"status": "terminated",
"message": httpx.NewMessage(
"session.terminated",
"status.session.terminated",
"Session terminated.",
nil,
"",
),
})
}
func (m *Module) markFailed(w http.ResponseWriter, r *http.Request) {
var cmd MarkSessionFailedCommand
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid fail session payload")
return
}
cmd.SessionID = chi.URLParam(r, "sessionID")
if err := m.service.MarkSessionFailed(r.Context(), cmd); err != nil {
status, message := m.service.MapError(err)
httpx.WriteError(w, status, message)
return
}
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{
"status": "failed",
"message": httpx.NewMessage(
"session.failed",
"status.session.failed",
"Session marked as failed.",
nil,
"",
),
})
}
@@ -0,0 +1,17 @@
package sessionbroker
import (
"context"
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
)
type WorkerOrchestrator interface {
Reserve(ctx context.Context, request workercontracts.AttachRequest) (*workercontracts.WorkerLease, error)
GetSessionLease(ctx context.Context, sessionID string) (*workercontracts.WorkerLease, error)
ReleaseSessionLease(ctx context.Context, sessionID string) error
PrepareAttachment(ctx context.Context, session RemoteSession, attachment SessionAttachment, runtimeMetadata map[string]any) error
NotifyDetachment(ctx context.Context, session RemoteSession, attachment SessionAttachment) error
TerminateRemoteSession(ctx context.Context, sessionID, attachmentID string) error
ValidateSessionRuntime(ctx context.Context, sessionID, workerID string) (bool, string, error)
}
@@ -0,0 +1,607 @@
package sessionbroker
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/example/remote-access-platform/backend/internal/platform/authority"
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
)
type postgresStore struct {
db postgresplatform.DBTX
authority *authority.Verifier
}
type PostgresTransactor struct {
pool *pgxpool.Pool
authority *authority.Verifier
}
func NewPostgresStore(pool *pgxpool.Pool, verifiers ...*authority.Verifier) Store {
var authorityVerifier *authority.Verifier
if len(verifiers) > 0 {
authorityVerifier = verifiers[0]
}
return &postgresStore{db: pool, authority: authorityVerifier}
}
func NewPostgresTransactor(pool *pgxpool.Pool, verifiers ...*authority.Verifier) *PostgresTransactor {
var authorityVerifier *authority.Verifier
if len(verifiers) > 0 {
authorityVerifier = verifiers[0]
}
return &PostgresTransactor{pool: pool, authority: authorityVerifier}
}
func (t *PostgresTransactor) WithinTransaction(ctx context.Context, fn func(store Store) error) error {
return postgresplatform.WithTransaction(ctx, t.pool, func(tx pgx.Tx) error {
return fn(&postgresStore{db: tx, authority: t.authority})
})
}
func (s *postgresStore) RemoteSessions() RemoteSessionRepository {
return &postgresRemoteSessionRepository{db: s.db}
}
func (s *postgresStore) SessionAttachments() SessionAttachmentRepository {
return &postgresSessionAttachmentRepository{db: s.db}
}
func (s *postgresStore) ResourcePolicies() ResourcePolicyRepository {
return &postgresResourcePolicyRepository{db: s.db}
}
func (s *postgresStore) ResourceRuntime() ResourceRuntimeRepository {
return &postgresResourceRuntimeRepository{db: s.db}
}
func (s *postgresStore) AuditEvents() AuditEventRepository {
return &postgresAuditEventRepository{db: s.db}
}
func (s *postgresStore) Access() AccessRepository {
return &postgresAccessRepository{db: s.db, authority: s.authority}
}
type postgresRemoteSessionRepository struct {
db postgresplatform.DBTX
}
type postgresSessionAttachmentRepository struct {
db postgresplatform.DBTX
}
type postgresResourcePolicyRepository struct {
db postgresplatform.DBTX
}
type postgresResourceRuntimeRepository struct {
db postgresplatform.DBTX
}
type postgresAuditEventRepository struct {
db postgresplatform.DBTX
}
type postgresAccessRepository struct {
db postgresplatform.DBTX
authority *authority.Verifier
}
func (r *postgresRemoteSessionRepository) Create(ctx context.Context, session RemoteSession) error {
const query = `
INSERT INTO remote_sessions (
id, organization_id, resource_id, protocol, state, worker_id, controller_user_id, detach_deadline_at,
last_heartbeat_at, takeover_version, metadata, created_at, updated_at
) VALUES (
$1::uuid, $2::uuid, $3::uuid, $4, $5, NULLIF($6, ''), $7::uuid, $8, $9, $10, $11::jsonb, $12, $13
)
`
if _, err := r.db.Exec(ctx, query,
session.ID,
session.OrganizationID,
session.ResourceID,
session.Protocol,
session.State,
session.WorkerID,
session.ControllerUserID,
session.DetachDeadlineAt,
session.LastHeartbeatAt,
session.TakeoverVersion,
jsonPayload(session.Metadata),
session.CreatedAt,
session.UpdatedAt,
); err != nil {
return fmt.Errorf("create remote session: %w", err)
}
return nil
}
func (r *postgresRemoteSessionRepository) GetByID(ctx context.Context, sessionID string) (*RemoteSession, error) {
return r.getByID(ctx, sessionID, "")
}
func (r *postgresRemoteSessionRepository) GetByIDForUpdate(ctx context.Context, sessionID string) (*RemoteSession, error) {
return r.getByID(ctx, sessionID, " FOR UPDATE")
}
func (r *postgresRemoteSessionRepository) getByID(ctx context.Context, sessionID string, suffix string) (*RemoteSession, error) {
query := `
SELECT id::text, organization_id::text, resource_id::text, protocol, state, COALESCE(worker_id, ''), controller_user_id::text,
detach_deadline_at, last_heartbeat_at, takeover_version, metadata, created_at, updated_at
FROM remote_sessions
WHERE id = $1::uuid` + suffix
remoteSession, err := scanRemoteSession(r.db.QueryRow(ctx, query, sessionID))
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return remoteSession, err
}
func (r *postgresRemoteSessionRepository) ListByController(ctx context.Context, userID string) ([]RemoteSession, error) {
const query = `
SELECT id::text, organization_id::text, resource_id::text, protocol, state, COALESCE(worker_id, ''), controller_user_id::text,
detach_deadline_at, last_heartbeat_at, takeover_version, metadata, created_at, updated_at
FROM remote_sessions
WHERE controller_user_id = $1::uuid
ORDER BY updated_at DESC
`
rows, err := r.db.Query(ctx, query, userID)
if err != nil {
return nil, fmt.Errorf("list remote sessions: %w", err)
}
defer rows.Close()
var sessions []RemoteSession
for rows.Next() {
item, err := scanRemoteSession(rows)
if err != nil {
return nil, err
}
sessions = append(sessions, *item)
}
return sessions, rows.Err()
}
func (r *postgresRemoteSessionRepository) CountLiveByResource(ctx context.Context, resourceID string) (int, error) {
const query = `
SELECT COUNT(*)
FROM remote_sessions
WHERE resource_id = $1::uuid AND state IN ('starting', 'active', 'detached', 'reconnecting')
`
var count int
if err := r.db.QueryRow(ctx, query, resourceID).Scan(&count); err != nil {
return 0, fmt.Errorf("count live remote sessions: %w", err)
}
return count, nil
}
func (r *postgresRemoteSessionRepository) ListDetachedExpired(ctx context.Context, before time.Time, limit int) ([]RemoteSession, error) {
const query = `
SELECT id::text, organization_id::text, resource_id::text, protocol, state, COALESCE(worker_id, ''), controller_user_id::text,
detach_deadline_at, last_heartbeat_at, takeover_version, metadata, created_at, updated_at
FROM remote_sessions
WHERE state = 'detached' AND detach_deadline_at IS NOT NULL AND detach_deadline_at <= $1
ORDER BY detach_deadline_at ASC
LIMIT $2
`
rows, err := r.db.Query(ctx, query, before, limit)
if err != nil {
return nil, fmt.Errorf("list detached expired sessions: %w", err)
}
defer rows.Close()
var sessions []RemoteSession
for rows.Next() {
item, err := scanRemoteSession(rows)
if err != nil {
return nil, err
}
sessions = append(sessions, *item)
}
return sessions, rows.Err()
}
func (r *postgresRemoteSessionRepository) UpdateState(ctx context.Context, params UpdateRemoteSessionStateParams) error {
const query = `
UPDATE remote_sessions
SET state = $2,
worker_id = NULLIF($3, ''),
detach_deadline_at = $4,
last_heartbeat_at = $5,
takeover_version = $6,
updated_at = $7
WHERE id = $1::uuid
`
if _, err := r.db.Exec(ctx, query,
params.RemoteSessionID,
params.State,
params.WorkerID,
params.DetachDeadlineAt,
params.LastHeartbeatAt,
params.TakeoverVersion,
params.UpdatedAt,
); err != nil {
return fmt.Errorf("update remote session state: %w", err)
}
return nil
}
func (r *postgresSessionAttachmentRepository) Create(ctx context.Context, attachment SessionAttachment) error {
const query = `
INSERT INTO session_attachments (
id, remote_session_id, user_id, device_id, role, state, superseded_by,
takeover_of, attached_at, detached_at, last_input_at, metadata, created_at, updated_at
) VALUES (
$1::uuid, $2::uuid, $3::uuid, $4::uuid, $5, $6, NULLIF($7, '')::uuid,
NULLIF($8, '')::uuid, $9, $10, $11, $12::jsonb, $13, $14
)
`
if _, err := r.db.Exec(ctx, query,
attachment.ID,
attachment.RemoteSessionID,
attachment.UserID,
attachment.DeviceID,
attachment.Role,
attachment.State,
stringValue(attachment.SupersededBy),
stringValue(attachment.TakeoverOf),
attachment.AttachedAt,
attachment.DetachedAt,
attachment.LastInputAt,
jsonPayload(attachment.Metadata),
attachment.CreatedAt,
attachment.UpdatedAt,
); err != nil {
return fmt.Errorf("create session attachment: %w", err)
}
return nil
}
func (r *postgresSessionAttachmentRepository) GetByID(ctx context.Context, attachmentID string) (*SessionAttachment, error) {
return r.getByID(ctx, attachmentID, "")
}
func (r *postgresSessionAttachmentRepository) GetByIDForUpdate(ctx context.Context, attachmentID string) (*SessionAttachment, error) {
return r.getByID(ctx, attachmentID, " FOR UPDATE")
}
func (r *postgresSessionAttachmentRepository) getByID(ctx context.Context, attachmentID string, suffix string) (*SessionAttachment, error) {
query := `
SELECT id::text, remote_session_id::text, user_id::text, device_id::text, role, state,
superseded_by::text, takeover_of::text, attached_at, detached_at, last_input_at, metadata, created_at, updated_at
FROM session_attachments
WHERE id = $1::uuid` + suffix
attachment, err := scanSessionAttachment(r.db.QueryRow(ctx, query, attachmentID))
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return attachment, err
}
func (r *postgresSessionAttachmentRepository) ListByRemoteSession(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error) {
return r.listByRemoteSession(ctx, remoteSessionID, "")
}
func (r *postgresSessionAttachmentRepository) ListActiveByRemoteSessionForUpdate(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error) {
return r.listByRemoteSession(ctx, remoteSessionID, " AND state IN ('attaching', 'active', 'reconnecting') FOR UPDATE")
}
func (r *postgresSessionAttachmentRepository) listByRemoteSession(ctx context.Context, remoteSessionID string, suffix string) ([]SessionAttachment, error) {
query := `
SELECT id::text, remote_session_id::text, user_id::text, device_id::text, role, state,
superseded_by::text, takeover_of::text, attached_at, detached_at, last_input_at, metadata, created_at, updated_at
FROM session_attachments
WHERE remote_session_id = $1::uuid` + suffix
rows, err := r.db.Query(ctx, query, remoteSessionID)
if err != nil {
return nil, fmt.Errorf("list session attachments: %w", err)
}
defer rows.Close()
var attachments []SessionAttachment
for rows.Next() {
item, err := scanSessionAttachment(rows)
if err != nil {
return nil, err
}
attachments = append(attachments, *item)
}
return attachments, rows.Err()
}
func (r *postgresSessionAttachmentRepository) UpdateState(ctx context.Context, params UpdateSessionAttachmentStateParams) error {
const query = `
UPDATE session_attachments
SET state = $2,
detached_at = $3,
last_input_at = $4,
updated_at = $5
WHERE id = $1::uuid
`
if _, err := r.db.Exec(ctx, query,
params.AttachmentID,
params.State,
params.DetachedAt,
params.LastInputAt,
params.UpdatedAt,
); err != nil {
return fmt.Errorf("update session attachment state: %w", err)
}
return nil
}
func (r *postgresSessionAttachmentRepository) Supersede(ctx context.Context, params SupersedeAttachmentParams) error {
const query = `
UPDATE session_attachments
SET state = 'superseded',
superseded_by = $2::uuid,
detached_at = $3,
updated_at = $4
WHERE id = $1::uuid
`
if _, err := r.db.Exec(ctx, query,
params.PreviousAttachmentID,
params.NextAttachmentID,
params.DetachedAt,
params.UpdatedAt,
); err != nil {
return fmt.Errorf("supersede attachment: %w", err)
}
return nil
}
func (r *postgresResourcePolicyRepository) GetByResourceID(ctx context.Context, resourceID string) (*ResourcePolicy, error) {
const query = `
SELECT resource_id::text, max_concurrent_sessions, takeover_policy, require_trusted_device,
detach_grace_period_seconds, clipboard_enabled, clipboard_mode, file_transfer_enabled,
COALESCE(file_transfer_mode, 'disabled'), created_at, updated_at
FROM resource_policies
WHERE resource_id = $1::uuid
`
policy, err := scanResourcePolicy(r.db.QueryRow(ctx, query, resourceID))
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return policy, err
}
func (r *postgresResourcePolicyRepository) Upsert(ctx context.Context, policy ResourcePolicy) error {
const query = `
INSERT INTO resource_policies (
resource_id, max_concurrent_sessions, takeover_policy, require_trusted_device,
detach_grace_period_seconds, clipboard_enabled, clipboard_mode, file_transfer_enabled, file_transfer_mode, created_at, updated_at
) VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
ON CONFLICT (resource_id) DO UPDATE SET
max_concurrent_sessions = EXCLUDED.max_concurrent_sessions,
takeover_policy = EXCLUDED.takeover_policy,
require_trusted_device = EXCLUDED.require_trusted_device,
detach_grace_period_seconds = EXCLUDED.detach_grace_period_seconds,
clipboard_enabled = EXCLUDED.clipboard_enabled,
clipboard_mode = EXCLUDED.clipboard_mode,
file_transfer_enabled = EXCLUDED.file_transfer_enabled,
file_transfer_mode = EXCLUDED.file_transfer_mode,
updated_at = EXCLUDED.updated_at
`
clipboardMode := normalizeClipboardMode(policy.ClipboardMode)
fileTransferMode := normalizeFileTransferMode(policy.FileTransferMode)
if _, err := r.db.Exec(ctx, query,
policy.ResourceID,
policy.MaxConcurrentSessions,
policy.TakeoverPolicy,
policy.RequireTrustedDevice,
int(policy.DetachGracePeriod.Seconds()),
clipboardMode != ResourceClipboardModeDisabled,
clipboardMode,
fileTransferAllowsClientToServer(fileTransferMode),
fileTransferMode,
policy.CreatedAt,
policy.UpdatedAt,
); err != nil {
return fmt.Errorf("upsert resource policy: %w", err)
}
return nil
}
func (r *postgresResourceRuntimeRepository) GetByID(ctx context.Context, resourceID string) (*ResourceRuntimeSpec, error) {
const query = `
SELECT id::text, organization_id::text, name, address, protocol, secret_ref, certificate_verification_mode, metadata
FROM resources
WHERE id = $1::uuid
`
item := &ResourceRuntimeSpec{}
var secretRef *string
var metadata []byte
if err := r.db.QueryRow(ctx, query, resourceID).Scan(
&item.ID,
&item.OrganizationID,
&item.Name,
&item.Address,
&item.Protocol,
&secretRef,
&item.CertificateVerificationMode,
&metadata,
); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("get resource runtime spec: %w", err)
}
item.SecretRef = secretRef
item.Metadata = metadata
return item, nil
}
func (r *postgresAuditEventRepository) Create(ctx context.Context, event AuditEvent) error {
const query = `
INSERT INTO audit_events (
id, actor_user_id, actor_device_id, event_type, target_type, target_id,
remote_session_id, payload, created_at
) VALUES (
$1::uuid, NULLIF($2, '')::uuid, NULLIF($3, '')::uuid, $4, $5, $6,
NULLIF($7, '')::uuid, $8::jsonb, $9
)
`
if _, err := r.db.Exec(ctx, query,
event.ID,
stringValue(event.ActorUserID),
stringValue(event.ActorDeviceID),
event.EventType,
event.TargetType,
event.TargetID,
stringValue(event.RemoteSessionID),
jsonPayload(event.Payload),
event.CreatedAt,
); err != nil {
return fmt.Errorf("create audit event: %w", err)
}
return nil
}
func (r *postgresAccessRepository) IsTrustedDevice(ctx context.Context, userID, deviceID string) (bool, error) {
const query = `
SELECT EXISTS(
SELECT 1 FROM devices
WHERE id = $1::uuid AND user_id = $2::uuid AND trust_status = 'trusted' AND revoked_at IS NULL
)
`
var trusted bool
if err := r.db.QueryRow(ctx, query, deviceID, userID).Scan(&trusted); err != nil {
return false, fmt.Errorf("check trusted device: %w", err)
}
return trusted, nil
}
func (r *postgresAccessRepository) GetPlatformRole(ctx context.Context, userID string) (string, error) {
return authority.EffectivePlatformRole(ctx, r.db, r.authority, userID)
}
func (r *postgresAccessRepository) GetOrganizationRole(ctx context.Context, organizationID, userID string) (string, bool, error) {
const query = `
SELECT role_id
FROM organization_memberships
WHERE organization_id = $1::uuid AND user_id = $2::uuid AND status = 'active'
`
var role string
if err := r.db.QueryRow(ctx, query, organizationID, userID).Scan(&role); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return "", false, nil
}
return "", false, fmt.Errorf("get organization role: %w", err)
}
return role, true, nil
}
type scanner interface {
Scan(dest ...any) error
}
func scanRemoteSession(row scanner) (*RemoteSession, error) {
item := &RemoteSession{}
var detachDeadlineAt, lastHeartbeatAt *time.Time
var metadata []byte
if err := row.Scan(
&item.ID,
&item.OrganizationID,
&item.ResourceID,
&item.Protocol,
&item.State,
&item.WorkerID,
&item.ControllerUserID,
&detachDeadlineAt,
&lastHeartbeatAt,
&item.TakeoverVersion,
&metadata,
&item.CreatedAt,
&item.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan remote session: %w", err)
}
item.DetachDeadlineAt = detachDeadlineAt
item.LastHeartbeatAt = lastHeartbeatAt
item.Metadata = metadata
item.RenderQualityProfile = renderQualityProfileFromSessionMetadata(metadata)
return item, nil
}
func scanSessionAttachment(row scanner) (*SessionAttachment, error) {
item := &SessionAttachment{}
var supersededBy, takeoverOf *string
var attachedAt, detachedAt, lastInputAt *time.Time
var metadata []byte
if err := row.Scan(
&item.ID,
&item.RemoteSessionID,
&item.UserID,
&item.DeviceID,
&item.Role,
&item.State,
&supersededBy,
&takeoverOf,
&attachedAt,
&detachedAt,
&lastInputAt,
&metadata,
&item.CreatedAt,
&item.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan session attachment: %w", err)
}
item.SupersededBy = supersededBy
item.TakeoverOf = takeoverOf
item.AttachedAt = attachedAt
item.DetachedAt = detachedAt
item.LastInputAt = lastInputAt
item.Metadata = metadata
return item, nil
}
func scanResourcePolicy(row scanner) (*ResourcePolicy, error) {
item := &ResourcePolicy{}
var detachGraceSeconds int
if err := row.Scan(
&item.ResourceID,
&item.MaxConcurrentSessions,
&item.TakeoverPolicy,
&item.RequireTrustedDevice,
&detachGraceSeconds,
&item.ClipboardEnabled,
&item.ClipboardMode,
&item.FileTransferEnabled,
&item.FileTransferMode,
&item.CreatedAt,
&item.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan resource policy: %w", err)
}
item.DetachGracePeriod = time.Duration(detachGraceSeconds) * time.Second
item.ClipboardMode = normalizeClipboardMode(item.ClipboardMode)
item.ClipboardEnabled = item.ClipboardMode != ResourceClipboardModeDisabled
item.FileTransferMode = normalizeFileTransferMode(item.FileTransferMode)
item.FileTransferEnabled = fileTransferAllowsClientToServer(item.FileTransferMode)
return item, nil
}
func jsonPayload(payload []byte) []byte {
if len(payload) == 0 {
return []byte(`{}`)
}
if json.Valid(payload) {
return payload
}
return []byte(`{}`)
}
func stringValue(value *string) string {
if value == nil {
return ""
}
return *value
}
@@ -0,0 +1,140 @@
package sessionbroker
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
)
type RedisLiveStateStore struct {
client *redis.Client
}
func NewRedisLiveStateStore(client *redis.Client) *RedisLiveStateStore {
return &RedisLiveStateStore{client: client}
}
func (s *RedisLiveStateStore) UpsertSession(ctx context.Context, state LiveSessionState) error {
return s.setJSON(ctx, liveSessionKey(state.SessionID), state, 0)
}
func (s *RedisLiveStateStore) GetSession(ctx context.Context, sessionID string) (*LiveSessionState, error) {
var state LiveSessionState
ok, err := s.getJSON(ctx, liveSessionKey(sessionID), &state)
if err != nil || !ok {
return nil, err
}
return &state, nil
}
func (s *RedisLiveStateStore) DeleteSession(ctx context.Context, sessionID string) error {
return s.client.Del(ctx, liveSessionKey(sessionID)).Err()
}
func (s *RedisLiveStateStore) BindController(ctx context.Context, binding sessioncontracts.ControllerBinding, ttl time.Duration) error {
return s.setJSON(ctx, controllerBindingKey(binding.SessionID), binding, ttl)
}
func (s *RedisLiveStateStore) GetControllerBinding(ctx context.Context, sessionID string) (*sessioncontracts.ControllerBinding, error) {
var binding sessioncontracts.ControllerBinding
ok, err := s.getJSON(ctx, controllerBindingKey(sessionID), &binding)
if err != nil || !ok {
return nil, err
}
return &binding, nil
}
func (s *RedisLiveStateStore) ClearControllerBinding(ctx context.Context, sessionID string) error {
return s.client.Del(ctx, controllerBindingKey(sessionID)).Err()
}
func (s *RedisLiveStateStore) StoreAttachToken(ctx context.Context, claims sessioncontracts.AttachTokenClaims, ttl time.Duration) error {
return s.setJSON(ctx, attachTokenKey(claims.Token), claims, ttl)
}
func (s *RedisLiveStateStore) ConsumeAttachToken(ctx context.Context, token string) (*sessioncontracts.AttachTokenClaims, error) {
key := attachTokenKey(token)
payload, err := s.client.GetDel(ctx, key).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("consume attach token: %w", err)
}
var claims sessioncontracts.AttachTokenClaims
if err := json.Unmarshal([]byte(payload), &claims); err != nil {
return nil, fmt.Errorf("decode attach token: %w", err)
}
return &claims, nil
}
func (s *RedisLiveStateStore) TouchAttachmentHeartbeat(ctx context.Context, sessionID, attachmentID string, ttl time.Duration) error {
return s.client.Set(ctx, attachmentHeartbeatKey(sessionID, attachmentID), time.Now().UTC().Format(time.RFC3339Nano), ttl).Err()
}
func (s *RedisLiveStateStore) UpdateWorkerRoute(ctx context.Context, route WorkerRoute, ttl time.Duration) error {
return s.setJSON(ctx, workerRouteKey(route.SessionID), route, ttl)
}
func (s *RedisLiveStateStore) GetWorkerRoute(ctx context.Context, sessionID string) (*WorkerRoute, error) {
var route WorkerRoute
ok, err := s.getJSON(ctx, workerRouteKey(sessionID), &route)
if err != nil || !ok {
return nil, err
}
return &route, nil
}
func (s *RedisLiveStateStore) DeleteWorkerRoute(ctx context.Context, sessionID string) error {
return s.client.Del(ctx, workerRouteKey(sessionID)).Err()
}
func (s *RedisLiveStateStore) setJSON(ctx context.Context, key string, value any, ttl time.Duration) error {
payload, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("encode redis payload: %w", err)
}
if err := s.client.Set(ctx, key, payload, ttl).Err(); err != nil {
return fmt.Errorf("set redis key %s: %w", key, err)
}
return nil
}
func (s *RedisLiveStateStore) getJSON(ctx context.Context, key string, dest any) (bool, error) {
payload, err := s.client.Get(ctx, key).Result()
if err == redis.Nil {
return false, nil
}
if err != nil {
return false, fmt.Errorf("get redis key %s: %w", key, err)
}
if err := json.Unmarshal([]byte(payload), dest); err != nil {
return false, fmt.Errorf("decode redis key %s: %w", key, err)
}
return true, nil
}
func liveSessionKey(sessionID string) string {
return "live:session:" + sessionID
}
func controllerBindingKey(sessionID string) string {
return "live:session:" + sessionID + ":controller"
}
func attachTokenKey(token string) string {
return "live:attach:" + token
}
func attachmentHeartbeatKey(sessionID, attachmentID string) string {
return "live:session:" + sessionID + ":attachment:" + attachmentID + ":heartbeat"
}
func workerRouteKey(sessionID string) string {
return "live:session:" + sessionID + ":worker-route"
}
@@ -0,0 +1,85 @@
package sessionbroker
import (
"context"
"time"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
)
type RemoteSessionRepository interface {
Create(ctx context.Context, session RemoteSession) error
GetByID(ctx context.Context, sessionID string) (*RemoteSession, error)
GetByIDForUpdate(ctx context.Context, sessionID string) (*RemoteSession, error)
ListByController(ctx context.Context, userID string) ([]RemoteSession, error)
CountLiveByResource(ctx context.Context, resourceID string) (int, error)
ListDetachedExpired(ctx context.Context, before time.Time, limit int) ([]RemoteSession, error)
UpdateState(ctx context.Context, params UpdateRemoteSessionStateParams) error
}
type SessionAttachmentRepository interface {
Create(ctx context.Context, attachment SessionAttachment) error
GetByID(ctx context.Context, attachmentID string) (*SessionAttachment, error)
GetByIDForUpdate(ctx context.Context, attachmentID string) (*SessionAttachment, error)
ListByRemoteSession(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error)
ListActiveByRemoteSessionForUpdate(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error)
UpdateState(ctx context.Context, params UpdateSessionAttachmentStateParams) error
Supersede(ctx context.Context, params SupersedeAttachmentParams) error
}
type ResourcePolicyRepository interface {
GetByResourceID(ctx context.Context, resourceID string) (*ResourcePolicy, error)
Upsert(ctx context.Context, policy ResourcePolicy) error
}
type AuditEventRepository interface {
Create(ctx context.Context, event AuditEvent) error
}
type Store interface {
RemoteSessions() RemoteSessionRepository
SessionAttachments() SessionAttachmentRepository
ResourcePolicies() ResourcePolicyRepository
ResourceRuntime() ResourceRuntimeRepository
AuditEvents() AuditEventRepository
Access() AccessRepository
}
type Transactor interface {
WithinTransaction(ctx context.Context, fn func(store Store) error) error
}
type UpdateRemoteSessionStateParams struct {
RemoteSessionID string
State sessioncontracts.State
WorkerID string
DetachDeadlineAt *time.Time
LastHeartbeatAt *time.Time
TakeoverVersion int
UpdatedAt time.Time
}
type UpdateSessionAttachmentStateParams struct {
AttachmentID string
State AttachmentState
DetachedAt *time.Time
LastInputAt *time.Time
UpdatedAt time.Time
}
type SupersedeAttachmentParams struct {
PreviousAttachmentID string
NextAttachmentID string
DetachedAt time.Time
UpdatedAt time.Time
}
type AccessRepository interface {
IsTrustedDevice(ctx context.Context, userID, deviceID string) (bool, error)
GetPlatformRole(ctx context.Context, userID string) (string, error)
GetOrganizationRole(ctx context.Context, organizationID, userID string) (string, bool, error)
}
type ResourceRuntimeRepository interface {
GetByID(ctx context.Context, resourceID string) (*ResourceRuntimeSpec, error)
}
@@ -0,0 +1,138 @@
package sessionbroker
import (
"context"
"encoding/json"
"errors"
"testing"
"github.com/example/remote-access-platform/backend/internal/platform/config"
"github.com/example/remote-access-platform/backend/internal/platform/module"
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
)
type fakeSecretResolver struct {
response *secrets.ResolvedResourceSecret
err error
request secrets.ResolveResourceSecretRequest
}
func testAppConfig(env string) config.AppConfig {
return config.AppConfig{Name: "rap-api-test", Env: env}
}
func (r *fakeSecretResolver) ResolveForSession(_ context.Context, req secrets.ResolveResourceSecretRequest) (*secrets.ResolvedResourceSecret, error) {
r.request = req
if r.err != nil {
return nil, r.err
}
return r.response, nil
}
func TestRuntimeAssignmentMetadataMergesResolvedSecretWithoutMutatingSessionMetadata(t *testing.T) {
resolver := &fakeSecretResolver{
response: &secrets.ResolvedResourceSecret{
Descriptor: secrets.ResourceSecretDescriptor{Version: 3},
Payload: json.RawMessage(`{"username":"user","password":"secret","domain":"corp"}`),
},
}
service := NewService(module.Dependencies{
Config: module.Config{App: testAppConfig("production")},
}, nil, nil, nil, nil, resolver)
sessionMetadata := mustJSON(t, map[string]any{
"resource": map[string]any{
"id": "resource-1",
"organization_id": "org-1",
"secret_ref": "rap-secret://org/org-1/resources/resource-1/primary",
"metadata": map[string]any{
"rdp_host": "host",
},
},
})
session := RemoteSession{
ID: "session-1",
OrganizationID: "org-1",
ResourceID: "resource-1",
WorkerID: "worker-1",
Metadata: sessionMetadata,
}
metadata, secretRef, version, err := service.runtimeAssignmentMetadata(context.Background(), session, &workercontracts.WorkerLease{LeaseID: "lease-1"})
if err != nil {
t.Fatalf("runtimeAssignmentMetadata returned error: %v", err)
}
if secretRef == "" || version != 3 {
t.Fatalf("expected secret ref and version, got ref=%q version=%d", secretRef, version)
}
resource := metadata["resource"].(map[string]any)
resourceMetadata := resource["metadata"].(map[string]any)
if resourceMetadata["username"] != "user" || resourceMetadata["password"] != "secret" || resourceMetadata["domain"] != "corp" {
t.Fatalf("resolved secret was not merged: %#v", resourceMetadata)
}
var persisted map[string]any
if err := json.Unmarshal(session.Metadata, &persisted); err != nil {
t.Fatalf("decode persisted metadata: %v", err)
}
persistedResource := persisted["resource"].(map[string]any)
persistedMetadata := persistedResource["metadata"].(map[string]any)
if _, ok := persistedMetadata["password"]; ok {
t.Fatalf("session metadata was mutated with plaintext secret")
}
if resolver.request.LeaseID != "lease-1" || resolver.request.WorkerID != "worker-1" {
t.Fatalf("resolver request missed lease/worker proof: %#v", resolver.request)
}
}
func TestRuntimeAssignmentMetadataRequiresResolverInProduction(t *testing.T) {
service := NewService(module.Dependencies{
Config: module.Config{App: testAppConfig("production")},
}, nil, nil, nil, nil)
session := RemoteSession{
ID: "session-1",
OrganizationID: "org-1",
ResourceID: "resource-1",
WorkerID: "worker-1",
Metadata: mustJSON(t, map[string]any{
"resource": map[string]any{
"secret_ref": "rap-secret://org/org-1/resources/resource-1/primary",
},
}),
}
_, _, _, err := service.runtimeAssignmentMetadata(context.Background(), session, &workercontracts.WorkerLease{LeaseID: "lease-1"})
if !errors.Is(err, secrets.ErrSecretEncryptionKeyMissing) {
t.Fatalf("expected missing resolver error, got %v", err)
}
}
func TestRuntimeAssignmentMetadataAllowsDevelopmentMetadataWithoutResolver(t *testing.T) {
service := NewService(module.Dependencies{
Config: module.Config{App: testAppConfig("development")},
}, nil, nil, nil, nil)
session := RemoteSession{
ID: "session-1",
OrganizationID: "org-1",
ResourceID: "resource-1",
WorkerID: "worker-1",
Metadata: mustJSON(t, map[string]any{
"resource": map[string]any{
"secret_ref": "rap-secret://org/org-1/resources/resource-1/primary",
"metadata": map[string]any{
"username": "dev-user",
"password": "dev-password",
},
},
}),
}
metadata, secretRef, _, err := service.runtimeAssignmentMetadata(context.Background(), session, nil)
if err != nil {
t.Fatalf("development metadata should not require resolver: %v", err)
}
if secretRef != "" {
t.Fatalf("development fallback should not audit resolver use, got %q", secretRef)
}
resource := metadata["resource"].(map[string]any)
resourceMetadata := resource["metadata"].(map[string]any)
if resourceMetadata["password"] != "dev-password" {
t.Fatalf("development metadata was not preserved")
}
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,391 @@
package sessionbroker
import (
"context"
"io"
"log/slog"
"testing"
"time"
"github.com/example/remote-access-platform/backend/internal/platform/module"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
)
func TestHandleWorkerConnectedIgnoresTerminalSession(t *testing.T) {
service, store, live, _ := newStaleWorkerEventTestService()
store.remote.sessions["session-1"] = RemoteSession{
ID: "session-1",
State: sessioncontracts.StateTerminated,
WorkerID: "worker-1",
}
if err := service.HandleWorkerConnected(context.Background(), "session-1"); err != nil {
t.Fatalf("HandleWorkerConnected returned error for stale terminal event: %v", err)
}
if got := store.remote.sessions["session-1"].State; got != sessioncontracts.StateTerminated {
t.Fatalf("stale connected event changed terminal state to %q", got)
}
if store.remote.updateCount != 0 {
t.Fatalf("stale connected event updated authoritative session %d times", store.remote.updateCount)
}
if live.upsertCount != 0 {
t.Fatalf("stale connected event recreated live state %d times", live.upsertCount)
}
}
func TestUpdateWorkerRenderTelemetryIgnoresTerminalSession(t *testing.T) {
service, store, live, _ := newStaleWorkerEventTestService()
store.remote.sessions["session-1"] = RemoteSession{
ID: "session-1",
State: sessioncontracts.StateTerminated,
WorkerID: "worker-1",
}
err := service.UpdateWorkerRenderTelemetry(context.Background(), "session-1", map[string]any{
"render_state": "ready",
"width": 1280,
"height": 720,
"frame_sequence": int64(99),
"frame_data": "stale-frame",
})
if err != nil {
t.Fatalf("UpdateWorkerRenderTelemetry returned error for stale terminal event: %v", err)
}
if live.upsertCount != 0 {
t.Fatalf("stale render event recreated live state %d times", live.upsertCount)
}
if live.sessions["session-1"] != nil {
t.Fatalf("stale render event left live state behind: %#v", live.sessions["session-1"])
}
}
func TestMarkSessionFailedTransitionsActiveSession(t *testing.T) {
service, store, live, orchestrator := newStaleWorkerEventTestService()
store.remote.sessions["session-1"] = RemoteSession{
ID: "session-1",
State: sessioncontracts.StateActive,
WorkerID: "worker-1",
TakeoverVersion: 3,
}
store.attachments.items["attachment-1"] = SessionAttachment{
ID: "attachment-1",
RemoteSessionID: "session-1",
State: AttachmentStateActive,
}
live.sessions["session-1"] = &LiveSessionState{SessionID: "session-1", State: sessioncontracts.StateActive}
if err := service.MarkSessionFailed(context.Background(), MarkSessionFailedCommand{SessionID: "session-1", Reason: "worker_lost"}); err != nil {
t.Fatalf("MarkSessionFailed returned error: %v", err)
}
if got := store.remote.sessions["session-1"].State; got != sessioncontracts.StateFailed {
t.Fatalf("expected failed state, got %q", got)
}
if got := store.attachments.items["attachment-1"].State; got != AttachmentStateClosed {
t.Fatalf("expected attachment closed, got %q", got)
}
if store.audit.createCount != 1 {
t.Fatalf("expected one audit event, got %d", store.audit.createCount)
}
if live.sessions["session-1"] != nil {
t.Fatal("expected failed session live state to be deleted")
}
if orchestrator.releaseCount != 1 {
t.Fatalf("expected session lease release, got %d", orchestrator.releaseCount)
}
}
func TestMarkSessionFailedAlreadyFailedIsIdempotent(t *testing.T) {
service, store, _, _ := newStaleWorkerEventTestService()
store.remote.sessions["session-1"] = RemoteSession{
ID: "session-1",
State: sessioncontracts.StateFailed,
WorkerID: "worker-1",
TakeoverVersion: 1,
}
if err := service.MarkSessionFailed(context.Background(), MarkSessionFailedCommand{SessionID: "session-1", Reason: "duplicate_worker_failure"}); err != nil {
t.Fatalf("duplicate MarkSessionFailed returned error: %v", err)
}
if got := store.remote.sessions["session-1"].State; got != sessioncontracts.StateFailed {
t.Fatalf("duplicate terminal event changed state to %q", got)
}
}
func newStaleWorkerEventTestService() (*Service, *staleWorkerEventTestStore, *staleWorkerEventLiveState, *staleWorkerEventOrchestrator) {
store := &staleWorkerEventTestStore{
remote: &staleWorkerEventRemoteSessions{sessions: map[string]RemoteSession{}},
attachments: &staleWorkerEventAttachments{items: map[string]SessionAttachment{}},
policies: &staleWorkerEventPolicies{},
audit: &staleWorkerEventAudit{},
}
live := &staleWorkerEventLiveState{sessions: map[string]*LiveSessionState{}}
orchestrator := &staleWorkerEventOrchestrator{}
service := NewService(module.Dependencies{
Infra: module.Infra{Logger: slog.New(slog.NewTextHandler(io.Discard, nil))},
}, store, staleWorkerEventTransactor{store: store}, live, orchestrator)
service.now = func() time.Time { return time.Unix(100, 0).UTC() }
return service, store, live, orchestrator
}
type staleWorkerEventTransactor struct {
store Store
}
func (t staleWorkerEventTransactor) WithinTransaction(ctx context.Context, fn func(store Store) error) error {
return fn(t.store)
}
type staleWorkerEventTestStore struct {
remote *staleWorkerEventRemoteSessions
attachments *staleWorkerEventAttachments
policies *staleWorkerEventPolicies
audit *staleWorkerEventAudit
}
func (s *staleWorkerEventTestStore) RemoteSessions() RemoteSessionRepository { return s.remote }
func (s *staleWorkerEventTestStore) SessionAttachments() SessionAttachmentRepository {
return s.attachments
}
func (s *staleWorkerEventTestStore) ResourcePolicies() ResourcePolicyRepository { return s.policies }
func (s *staleWorkerEventTestStore) ResourceRuntime() ResourceRuntimeRepository {
return staleWorkerEventResourceRuntime{}
}
func (s *staleWorkerEventTestStore) AuditEvents() AuditEventRepository { return s.audit }
func (s *staleWorkerEventTestStore) Access() AccessRepository { return staleWorkerEventAccess{} }
type staleWorkerEventRemoteSessions struct {
sessions map[string]RemoteSession
updateCount int
}
func (r *staleWorkerEventRemoteSessions) Create(_ context.Context, session RemoteSession) error {
r.sessions[session.ID] = session
return nil
}
func (r *staleWorkerEventRemoteSessions) GetByID(_ context.Context, sessionID string) (*RemoteSession, error) {
session, ok := r.sessions[sessionID]
if !ok {
return nil, nil
}
return &session, nil
}
func (r *staleWorkerEventRemoteSessions) GetByIDForUpdate(ctx context.Context, sessionID string) (*RemoteSession, error) {
return r.GetByID(ctx, sessionID)
}
func (r *staleWorkerEventRemoteSessions) ListByController(_ context.Context, _ string) ([]RemoteSession, error) {
return nil, nil
}
func (r *staleWorkerEventRemoteSessions) CountLiveByResource(_ context.Context, _ string) (int, error) {
return 0, nil
}
func (r *staleWorkerEventRemoteSessions) ListDetachedExpired(_ context.Context, _ time.Time, _ int) ([]RemoteSession, error) {
return nil, nil
}
func (r *staleWorkerEventRemoteSessions) UpdateState(_ context.Context, params UpdateRemoteSessionStateParams) error {
session := r.sessions[params.RemoteSessionID]
session.State = params.State
session.WorkerID = params.WorkerID
session.DetachDeadlineAt = params.DetachDeadlineAt
session.LastHeartbeatAt = params.LastHeartbeatAt
session.TakeoverVersion = params.TakeoverVersion
session.UpdatedAt = params.UpdatedAt
r.sessions[params.RemoteSessionID] = session
r.updateCount++
return nil
}
type staleWorkerEventAttachments struct {
items map[string]SessionAttachment
}
func (r *staleWorkerEventAttachments) Create(_ context.Context, attachment SessionAttachment) error {
r.items[attachment.ID] = attachment
return nil
}
func (r *staleWorkerEventAttachments) GetByID(_ context.Context, attachmentID string) (*SessionAttachment, error) {
attachment, ok := r.items[attachmentID]
if !ok {
return nil, nil
}
return &attachment, nil
}
func (r *staleWorkerEventAttachments) GetByIDForUpdate(ctx context.Context, attachmentID string) (*SessionAttachment, error) {
return r.GetByID(ctx, attachmentID)
}
func (r *staleWorkerEventAttachments) ListByRemoteSession(_ context.Context, remoteSessionID string) ([]SessionAttachment, error) {
attachments := make([]SessionAttachment, 0)
for _, attachment := range r.items {
if attachment.RemoteSessionID == remoteSessionID {
attachments = append(attachments, attachment)
}
}
return attachments, nil
}
func (r *staleWorkerEventAttachments) ListActiveByRemoteSessionForUpdate(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error) {
return r.ListByRemoteSession(ctx, remoteSessionID)
}
func (r *staleWorkerEventAttachments) UpdateState(_ context.Context, params UpdateSessionAttachmentStateParams) error {
attachment := r.items[params.AttachmentID]
attachment.State = params.State
attachment.DetachedAt = params.DetachedAt
attachment.LastInputAt = params.LastInputAt
attachment.UpdatedAt = params.UpdatedAt
r.items[params.AttachmentID] = attachment
return nil
}
func (r *staleWorkerEventAttachments) Supersede(_ context.Context, params SupersedeAttachmentParams) error {
attachment := r.items[params.PreviousAttachmentID]
attachment.State = AttachmentStateSuperseded
attachment.SupersededBy = &params.NextAttachmentID
attachment.DetachedAt = &params.DetachedAt
attachment.UpdatedAt = params.UpdatedAt
r.items[params.PreviousAttachmentID] = attachment
return nil
}
type staleWorkerEventPolicies struct{}
func (r *staleWorkerEventPolicies) GetByResourceID(_ context.Context, _ string) (*ResourcePolicy, error) {
return nil, nil
}
func (r *staleWorkerEventPolicies) Upsert(_ context.Context, _ ResourcePolicy) error {
return nil
}
type staleWorkerEventAudit struct {
createCount int
}
func (r *staleWorkerEventAudit) Create(_ context.Context, _ AuditEvent) error {
r.createCount++
return nil
}
type staleWorkerEventResourceRuntime struct{}
func (staleWorkerEventResourceRuntime) GetByID(_ context.Context, _ string) (*ResourceRuntimeSpec, error) {
return nil, nil
}
type staleWorkerEventAccess struct{}
func (staleWorkerEventAccess) IsTrustedDevice(_ context.Context, _, _ string) (bool, error) {
return false, nil
}
func (staleWorkerEventAccess) GetPlatformRole(_ context.Context, _ string) (string, error) {
return "", nil
}
func (staleWorkerEventAccess) GetOrganizationRole(_ context.Context, _, _ string) (string, bool, error) {
return "", false, nil
}
type staleWorkerEventLiveState struct {
sessions map[string]*LiveSessionState
upsertCount int
}
func (s *staleWorkerEventLiveState) UpsertSession(_ context.Context, state LiveSessionState) error {
copied := state
s.sessions[state.SessionID] = &copied
s.upsertCount++
return nil
}
func (s *staleWorkerEventLiveState) GetSession(_ context.Context, sessionID string) (*LiveSessionState, error) {
state := s.sessions[sessionID]
if state == nil {
return nil, nil
}
copied := *state
return &copied, nil
}
func (s *staleWorkerEventLiveState) DeleteSession(_ context.Context, sessionID string) error {
delete(s.sessions, sessionID)
return nil
}
func (s *staleWorkerEventLiveState) BindController(_ context.Context, _ sessioncontracts.ControllerBinding, _ time.Duration) error {
return nil
}
func (s *staleWorkerEventLiveState) GetControllerBinding(_ context.Context, _ string) (*sessioncontracts.ControllerBinding, error) {
return nil, nil
}
func (s *staleWorkerEventLiveState) ClearControllerBinding(_ context.Context, _ string) error {
return nil
}
func (s *staleWorkerEventLiveState) StoreAttachToken(_ context.Context, _ sessioncontracts.AttachTokenClaims, _ time.Duration) error {
return nil
}
func (s *staleWorkerEventLiveState) ConsumeAttachToken(_ context.Context, _ string) (*sessioncontracts.AttachTokenClaims, error) {
return nil, nil
}
func (s *staleWorkerEventLiveState) TouchAttachmentHeartbeat(_ context.Context, _, _ string, _ time.Duration) error {
return nil
}
func (s *staleWorkerEventLiveState) UpdateWorkerRoute(_ context.Context, _ WorkerRoute, _ time.Duration) error {
return nil
}
func (s *staleWorkerEventLiveState) GetWorkerRoute(_ context.Context, _ string) (*WorkerRoute, error) {
return nil, nil
}
func (s *staleWorkerEventLiveState) DeleteWorkerRoute(_ context.Context, _ string) error {
return nil
}
type staleWorkerEventOrchestrator struct {
releaseCount int
}
func (o *staleWorkerEventOrchestrator) Reserve(_ context.Context, _ workercontracts.AttachRequest) (*workercontracts.WorkerLease, error) {
return nil, nil
}
func (o *staleWorkerEventOrchestrator) GetSessionLease(_ context.Context, _ string) (*workercontracts.WorkerLease, error) {
return nil, nil
}
func (o *staleWorkerEventOrchestrator) ReleaseSessionLease(_ context.Context, _ string) error {
o.releaseCount++
return nil
}
func (o *staleWorkerEventOrchestrator) PrepareAttachment(_ context.Context, _ RemoteSession, _ SessionAttachment, _ map[string]any) error {
return nil
}
func (o *staleWorkerEventOrchestrator) NotifyDetachment(_ context.Context, _ RemoteSession, _ SessionAttachment) error {
return nil
}
func (o *staleWorkerEventOrchestrator) TerminateRemoteSession(_ context.Context, _, _ string) error {
return nil
}
func (o *staleWorkerEventOrchestrator) ValidateSessionRuntime(_ context.Context, _, _ string) (bool, string, error) {
return true, "", nil
}
@@ -0,0 +1,44 @@
package sessionbroker
import (
"fmt"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
)
var allowedTransitions = map[sessioncontracts.State]map[sessioncontracts.State]struct{}{
sessioncontracts.StateStarting: {
sessioncontracts.StateActive: {},
sessioncontracts.StateFailed: {},
sessioncontracts.StateTerminated: {},
},
sessioncontracts.StateActive: {
sessioncontracts.StateDetached: {},
sessioncontracts.StateReconnecting: {},
sessioncontracts.StateFailed: {},
sessioncontracts.StateTerminated: {},
},
sessioncontracts.StateDetached: {
sessioncontracts.StateReconnecting: {},
sessioncontracts.StateTerminated: {},
sessioncontracts.StateFailed: {},
},
sessioncontracts.StateReconnecting: {
sessioncontracts.StateActive: {},
sessioncontracts.StateDetached: {},
sessioncontracts.StateFailed: {},
sessioncontracts.StateTerminated: {},
},
}
func validateTransition(from, to sessioncontracts.State) error {
if from == to {
return nil
}
if allowed, ok := allowedTransitions[from]; ok {
if _, ok := allowed[to]; ok {
return nil
}
}
return fmt.Errorf("invalid session state transition: %s -> %s", from, to)
}