Initial project snapshot
This commit is contained in:
@@ -0,0 +1,219 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
|
||||
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
|
||||
)
|
||||
|
||||
const (
|
||||
directWorkerTLSTrustModeSmokeInsecure = "smoke_insecure"
|
||||
directWorkerTLSTrustModePublicCA = "public_ca"
|
||||
directWorkerTLSTrustModePlatformCA = "platform_ca"
|
||||
)
|
||||
|
||||
type DataPlaneTokenClaims struct {
|
||||
SessionID string `json:"session_id"`
|
||||
AttachmentID string `json:"attachment_id"`
|
||||
UserID string `json:"user_id"`
|
||||
OrganizationID string `json:"organization_id"`
|
||||
ClusterID string `json:"cluster_id,omitempty"`
|
||||
WorkerID string `json:"worker_id"`
|
||||
ResourceID string `json:"resource_id"`
|
||||
AllowedChannels []string `json:"allowed_channels"`
|
||||
ExpiresAtValue time.Time `json:"expires_at"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
func (s *Service) buildDataPlaneOffer(session RemoteSession, attachment SessionAttachment) (*sessioncontracts.DataPlaneOffer, error) {
|
||||
if s.cfg.DataPlane.TokenTTL <= 0 || s.cfg.DataPlane.TokenPrivateKeyPEM == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
now := s.now().UTC()
|
||||
expiresAt := now.Add(s.cfg.DataPlane.TokenTTL)
|
||||
allowedChannels := dataPlaneAllowedChannelsFromSession(session)
|
||||
jti := uuid.NewString()
|
||||
claims := DataPlaneTokenClaims{
|
||||
SessionID: session.ID,
|
||||
AttachmentID: attachment.ID,
|
||||
UserID: attachment.UserID,
|
||||
OrganizationID: session.OrganizationID,
|
||||
WorkerID: session.WorkerID,
|
||||
ResourceID: session.ResourceID,
|
||||
AllowedChannels: allowedChannels,
|
||||
ExpiresAtValue: expiresAt,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ID: jti,
|
||||
Issuer: s.cfg.Auth.Issuer,
|
||||
Subject: attachment.UserID,
|
||||
Audience: jwt.ClaimStrings{"rap-data-plane", "worker:" + session.WorkerID},
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
||||
},
|
||||
}
|
||||
token, err := signDataPlaneToken(claims, s.cfg.DataPlane.TokenPrivateKeyPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
candidates := s.buildDataPlaneCandidates(session)
|
||||
preferred := sessioncontracts.DataPlaneCandidateBackendGateway
|
||||
if len(candidates) > 0 {
|
||||
preferred = candidates[0].Type
|
||||
}
|
||||
|
||||
return &sessioncontracts.DataPlaneOffer{
|
||||
Preferred: preferred,
|
||||
Token: token,
|
||||
ExpiresAt: expiresAt,
|
||||
Candidates: candidates,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) buildDataPlaneCandidates(session RemoteSession) []sessioncontracts.DataPlaneCandidate {
|
||||
var candidates []sessioncontracts.DataPlaneCandidate
|
||||
if directURL := s.directWorkerWSSURL(session.WorkerID); directURL != "" && s.canAdvertiseDirectWorkerWSS() {
|
||||
metadata := map[string]any(nil)
|
||||
if s.cfg.DataPlane.DirectWorkerJSONRuntime {
|
||||
metadata = map[string]any{
|
||||
"runtime_transport": "json_v1",
|
||||
"traffic_ready": true,
|
||||
}
|
||||
s.addDirectWorkerTLSTrustMetadata(metadata)
|
||||
if s.cfg.DataPlane.DirectWorkerBinaryRender {
|
||||
metadata["render_transport"] = "binary_v1"
|
||||
metadata["binary_render"] = true
|
||||
metadata["supported_color_modes"] = []string{"full_color", "grayscale"}
|
||||
metadata["default_color_mode"] = "full_color"
|
||||
}
|
||||
}
|
||||
candidates = append(candidates, sessioncontracts.DataPlaneCandidate{
|
||||
Type: sessioncontracts.DataPlaneCandidateDirectWorkerWSS,
|
||||
URL: directURL,
|
||||
WorkerID: session.WorkerID,
|
||||
Priority: 10,
|
||||
Metadata: metadata,
|
||||
})
|
||||
}
|
||||
if s.cfg.DataPlane.BackendGatewayURL != "" {
|
||||
candidates = append(candidates, sessioncontracts.DataPlaneCandidate{
|
||||
Type: sessioncontracts.DataPlaneCandidateBackendGateway,
|
||||
URL: s.cfg.DataPlane.BackendGatewayURL,
|
||||
Priority: 100,
|
||||
})
|
||||
}
|
||||
return candidates
|
||||
}
|
||||
|
||||
func (s *Service) canAdvertiseDirectWorkerWSS() bool {
|
||||
trustMode := normalizeDirectWorkerTLSTrustMode(s.cfg.DataPlane.DirectWorkerTLSTrustMode)
|
||||
return !secrets.IsProductionEnv(s.cfg.App.Env) || directWorkerTLSTrustModeIsProductionTrusted(trustMode)
|
||||
}
|
||||
|
||||
func (s *Service) addDirectWorkerTLSTrustMetadata(metadata map[string]any) {
|
||||
trustMode := normalizeDirectWorkerTLSTrustMode(s.cfg.DataPlane.DirectWorkerTLSTrustMode)
|
||||
metadata["tls_trust_mode"] = trustMode
|
||||
metadata["production_trusted"] = directWorkerTLSTrustModeIsProductionTrusted(trustMode)
|
||||
metadata["smoke_only"] = trustMode == directWorkerTLSTrustModeSmokeInsecure
|
||||
if s.cfg.DataPlane.DirectWorkerTLSCARef != "" {
|
||||
metadata["tls_ca_ref"] = s.cfg.DataPlane.DirectWorkerTLSCARef
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeDirectWorkerTLSTrustMode(mode string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||
case directWorkerTLSTrustModePublicCA:
|
||||
return directWorkerTLSTrustModePublicCA
|
||||
case directWorkerTLSTrustModePlatformCA:
|
||||
return directWorkerTLSTrustModePlatformCA
|
||||
default:
|
||||
return directWorkerTLSTrustModeSmokeInsecure
|
||||
}
|
||||
}
|
||||
|
||||
func directWorkerTLSTrustModeIsProductionTrusted(mode string) bool {
|
||||
return mode == directWorkerTLSTrustModePublicCA || mode == directWorkerTLSTrustModePlatformCA
|
||||
}
|
||||
|
||||
func (s *Service) directWorkerWSSURL(workerID string) string {
|
||||
template := strings.TrimSpace(s.cfg.DataPlane.DirectWorkerWSSURLTemplate)
|
||||
if template == "" || workerID == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.ReplaceAll(template, "{worker_id}", workerID)
|
||||
}
|
||||
|
||||
func signDataPlaneToken(claims DataPlaneTokenClaims, privateKeyPEM string) (string, error) {
|
||||
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(privateKeyPEM))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse data-plane private key: %w", err)
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
signed, err := token.SignedString(privateKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("sign data-plane token: %w", err)
|
||||
}
|
||||
return signed, nil
|
||||
}
|
||||
|
||||
func parseDataPlaneToken(tokenValue string, publicKey *rsa.PublicKey) (*DataPlaneTokenClaims, error) {
|
||||
claims := &DataPlaneTokenClaims{}
|
||||
token, err := jwt.ParseWithClaims(tokenValue, claims, func(token *jwt.Token) (any, error) {
|
||||
if token.Method != jwt.SigningMethodRS256 {
|
||||
return nil, fmt.Errorf("unexpected data-plane signing method: %s", token.Header["alg"])
|
||||
}
|
||||
return publicKey, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !token.Valid {
|
||||
return nil, fmt.Errorf("data-plane token invalid")
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func dataPlaneAllowedChannelsFromSession(session RemoteSession) []string {
|
||||
channels := []string{
|
||||
sessioncontracts.DataPlaneChannelControl,
|
||||
sessioncontracts.DataPlaneChannelInput,
|
||||
sessioncontracts.DataPlaneChannelRender,
|
||||
sessioncontracts.DataPlaneChannelTelemetry,
|
||||
}
|
||||
metadata := decodeJSONMap(session.Metadata)
|
||||
policy, _ := metadata["policy"].(map[string]any)
|
||||
if policy != nil {
|
||||
if mode, _ := policy["clipboard_mode"].(string); mode != "" && mode != string(ResourceClipboardModeDisabled) {
|
||||
channels = append(channels, sessioncontracts.DataPlaneChannelClipboard)
|
||||
}
|
||||
if mode, _ := policy["file_transfer_mode"].(string); fileTransferAllowsClientToServer(ResourceFileTransferMode(mode)) {
|
||||
channels = append(channels, sessioncontracts.DataPlaneChannelFileUpload)
|
||||
}
|
||||
if mode, _ := policy["file_transfer_mode"].(string); fileTransferAllowsServerToClient(ResourceFileTransferMode(mode)) {
|
||||
channels = append(channels, sessioncontracts.DataPlaneChannelFileDownload)
|
||||
}
|
||||
}
|
||||
return channels
|
||||
}
|
||||
|
||||
func (s *Service) attachDataPlaneOffer(result *SessionControlResult) error {
|
||||
if result == nil || result.Attachment == nil {
|
||||
return nil
|
||||
}
|
||||
result.GatewayURL = s.cfg.DataPlane.BackendGatewayURL
|
||||
offer, err := s.buildDataPlaneOffer(result.Session, *result.Attachment)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result.DataPlane = offer
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,357 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/config"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/module"
|
||||
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
|
||||
)
|
||||
|
||||
func TestDataPlaneTokenScopeValidation(t *testing.T) {
|
||||
now := time.Now().UTC().Truncate(time.Second)
|
||||
privateKeyPEM, publicKey := testRS256Key(t)
|
||||
service := &Service{
|
||||
cfg: module.Config{
|
||||
Auth: config.AuthConfig{
|
||||
Issuer: "rap-api-test",
|
||||
},
|
||||
DataPlane: config.DataPlaneConfig{
|
||||
TokenTTL: time.Minute,
|
||||
TokenPrivateKeyPEM: privateKeyPEM,
|
||||
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
|
||||
},
|
||||
},
|
||||
now: func() time.Time { return now },
|
||||
}
|
||||
session := RemoteSession{
|
||||
ID: "session-1",
|
||||
OrganizationID: "org-1",
|
||||
ResourceID: "resource-1",
|
||||
WorkerID: "worker-1",
|
||||
Metadata: mustJSON(t, map[string]any{"policy": map[string]any{"clipboard_mode": "bidirectional", "file_transfer_mode": "client_to_server"}}),
|
||||
}
|
||||
attachment := SessionAttachment{
|
||||
ID: "attachment-1",
|
||||
UserID: "user-1",
|
||||
}
|
||||
|
||||
offer, err := service.buildDataPlaneOffer(session, attachment)
|
||||
if err != nil {
|
||||
t.Fatalf("buildDataPlaneOffer returned error: %v", err)
|
||||
}
|
||||
if offer == nil {
|
||||
t.Fatal("expected data-plane offer")
|
||||
}
|
||||
|
||||
claims, err := parseDataPlaneToken(offer.Token, publicKey)
|
||||
if err != nil {
|
||||
t.Fatalf("parseDataPlaneToken returned error: %v", err)
|
||||
}
|
||||
assertEqual(t, claims.SessionID, session.ID, "session_id")
|
||||
assertEqual(t, claims.AttachmentID, attachment.ID, "attachment_id")
|
||||
assertEqual(t, claims.UserID, attachment.UserID, "user_id")
|
||||
assertEqual(t, claims.OrganizationID, session.OrganizationID, "organization_id")
|
||||
assertEqual(t, claims.WorkerID, session.WorkerID, "worker_id")
|
||||
assertEqual(t, claims.ResourceID, session.ResourceID, "resource_id")
|
||||
if claims.ID == "" {
|
||||
t.Fatal("expected jti")
|
||||
}
|
||||
if claims.ExpiresAt == nil || !claims.ExpiresAt.Time.Equal(now.Add(time.Minute)) {
|
||||
t.Fatalf("unexpected expires_at: %v", claims.ExpiresAt)
|
||||
}
|
||||
if !claims.ExpiresAtValue.Equal(now.Add(time.Minute)) {
|
||||
t.Fatalf("unexpected expires_at claim value: %v", claims.ExpiresAtValue)
|
||||
}
|
||||
for _, channel := range []string{
|
||||
sessioncontracts.DataPlaneChannelControl,
|
||||
sessioncontracts.DataPlaneChannelInput,
|
||||
sessioncontracts.DataPlaneChannelRender,
|
||||
sessioncontracts.DataPlaneChannelTelemetry,
|
||||
sessioncontracts.DataPlaneChannelClipboard,
|
||||
sessioncontracts.DataPlaneChannelFileUpload,
|
||||
} {
|
||||
if !slices.Contains(claims.AllowedChannels, channel) {
|
||||
t.Fatalf("expected allowed channel %q in %v", channel, claims.AllowedChannels)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataPlaneOfferResponseShapeCompatibility(t *testing.T) {
|
||||
now := time.Now().UTC().Truncate(time.Second)
|
||||
privateKeyPEM, _ := testRS256Key(t)
|
||||
service := &Service{
|
||||
cfg: module.Config{
|
||||
Auth: config.AuthConfig{Issuer: "rap-api-test"},
|
||||
DataPlane: config.DataPlaneConfig{
|
||||
TokenTTL: time.Minute,
|
||||
TokenPrivateKeyPEM: privateKeyPEM,
|
||||
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
|
||||
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
|
||||
DirectWorkerJSONRuntime: true,
|
||||
DirectWorkerTLSTrustMode: "smoke_insecure",
|
||||
},
|
||||
},
|
||||
now: func() time.Time { return now },
|
||||
}
|
||||
result := &SessionControlResult{
|
||||
Session: RemoteSession{
|
||||
ID: "session-1",
|
||||
OrganizationID: "org-1",
|
||||
ResourceID: "resource-1",
|
||||
WorkerID: "worker-1",
|
||||
Metadata: mustJSON(t, map[string]any{"policy": map[string]any{"clipboard_mode": "disabled", "file_transfer_mode": "disabled"}}),
|
||||
},
|
||||
Attachment: &SessionAttachment{ID: "attachment-1", UserID: "user-1"},
|
||||
AttachToken: &sessioncontracts.AttachTokenClaims{
|
||||
Token: "existing-attach-token",
|
||||
SessionID: "session-1",
|
||||
AttachmentID: "attachment-1",
|
||||
UserID: "user-1",
|
||||
WorkerID: "worker-1",
|
||||
ExpiresAt: now.Add(2 * time.Minute),
|
||||
},
|
||||
}
|
||||
|
||||
if err := service.attachDataPlaneOffer(result); err != nil {
|
||||
t.Fatalf("attachDataPlaneOffer returned error: %v", err)
|
||||
}
|
||||
payload, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal response: %v", err)
|
||||
}
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(payload, &decoded); err != nil {
|
||||
t.Fatalf("decode response: %v", err)
|
||||
}
|
||||
if decoded["session"] == nil || decoded["attachment"] == nil || decoded["attach_token"] == nil {
|
||||
t.Fatalf("response lost existing fields: %s", payload)
|
||||
}
|
||||
if decoded["data_plane"] == nil || decoded["gateway_url"] == nil {
|
||||
t.Fatalf("response missing data-plane fields: %s", payload)
|
||||
}
|
||||
if result.DataPlane == nil {
|
||||
t.Fatal("expected data-plane offer")
|
||||
}
|
||||
if result.DataPlane.Preferred != sessioncontracts.DataPlaneCandidateDirectWorkerWSS {
|
||||
t.Fatalf("unexpected preferred candidate: %s", result.DataPlane.Preferred)
|
||||
}
|
||||
if len(result.DataPlane.Candidates) != 2 {
|
||||
t.Fatalf("expected direct and fallback candidates, got %d", len(result.DataPlane.Candidates))
|
||||
}
|
||||
if result.DataPlane.Candidates[0].URL != "wss://worker-1.worker.example.test/rap/v1/data-plane" {
|
||||
t.Fatalf("unexpected direct candidate URL: %s", result.DataPlane.Candidates[0].URL)
|
||||
}
|
||||
if result.DataPlane.Candidates[0].Metadata["runtime_transport"] != "json_v1" {
|
||||
t.Fatalf("direct candidate is missing json_v1 runtime metadata: %#v", result.DataPlane.Candidates[0].Metadata)
|
||||
}
|
||||
if result.DataPlane.Candidates[0].Metadata["traffic_ready"] != true {
|
||||
t.Fatalf("direct candidate is missing traffic_ready metadata: %#v", result.DataPlane.Candidates[0].Metadata)
|
||||
}
|
||||
if result.DataPlane.Candidates[0].Metadata["smoke_only"] != true {
|
||||
t.Fatalf("direct candidate should be marked smoke-only by default: %#v", result.DataPlane.Candidates[0].Metadata)
|
||||
}
|
||||
if result.DataPlane.Candidates[0].Metadata["production_trusted"] != false {
|
||||
t.Fatalf("smoke direct candidate must not be production-trusted: %#v", result.DataPlane.Candidates[0].Metadata)
|
||||
}
|
||||
if !strings.Contains(result.DataPlane.Candidates[1].URL, "/api/v1/gateway/ws") {
|
||||
t.Fatalf("unexpected backend candidate URL: %s", result.DataPlane.Candidates[1].URL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataPlaneDirectCandidateMetadataRequiresRuntimeFlag(t *testing.T) {
|
||||
service := &Service{
|
||||
cfg: module.Config{
|
||||
DataPlane: config.DataPlaneConfig{
|
||||
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
|
||||
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
|
||||
DirectWorkerTLSTrustMode: "smoke_insecure",
|
||||
},
|
||||
},
|
||||
}
|
||||
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
|
||||
if len(candidates) != 2 {
|
||||
t.Fatalf("expected direct and fallback candidates, got %d", len(candidates))
|
||||
}
|
||||
if candidates[0].Metadata != nil {
|
||||
t.Fatalf("direct candidate must not advertise json_v1 before runtime flag is enabled: %#v", candidates[0].Metadata)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataPlaneDirectCandidateAdvertisesBinaryRenderOnlyWhenEnabled(t *testing.T) {
|
||||
service := &Service{
|
||||
cfg: module.Config{
|
||||
DataPlane: config.DataPlaneConfig{
|
||||
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
|
||||
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
|
||||
DirectWorkerJSONRuntime: true,
|
||||
DirectWorkerBinaryRender: true,
|
||||
DirectWorkerTLSTrustMode: "platform_ca",
|
||||
DirectWorkerTLSCARef: "rap-platform-ca:v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
|
||||
if len(candidates) != 2 {
|
||||
t.Fatalf("expected direct and fallback candidates, got %d", len(candidates))
|
||||
}
|
||||
if candidates[0].Metadata["render_transport"] != "binary_v1" {
|
||||
t.Fatalf("direct candidate is missing binary render metadata: %#v", candidates[0].Metadata)
|
||||
}
|
||||
if candidates[0].Metadata["binary_render"] != true {
|
||||
t.Fatalf("direct candidate is missing binary_render metadata: %#v", candidates[0].Metadata)
|
||||
}
|
||||
if candidates[0].Metadata["default_color_mode"] != "full_color" {
|
||||
t.Fatalf("direct candidate is missing default_color_mode metadata: %#v", candidates[0].Metadata)
|
||||
}
|
||||
if candidates[0].Metadata["production_trusted"] != true || candidates[0].Metadata["tls_trust_mode"] != "platform_ca" {
|
||||
t.Fatalf("direct candidate is missing production trust metadata: %#v", candidates[0].Metadata)
|
||||
}
|
||||
if candidates[0].Metadata["tls_ca_ref"] != "rap-platform-ca:v1" {
|
||||
t.Fatalf("direct candidate is missing tls_ca_ref metadata: %#v", candidates[0].Metadata)
|
||||
}
|
||||
modes, ok := candidates[0].Metadata["supported_color_modes"].([]string)
|
||||
if !ok || !slices.Contains(modes, "full_color") || !slices.Contains(modes, "grayscale") {
|
||||
t.Fatalf("direct candidate is missing supported_color_modes metadata: %#v", candidates[0].Metadata)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataPlaneDirectCandidateOmittedInProductionWhenSmokeOnly(t *testing.T) {
|
||||
service := &Service{
|
||||
cfg: module.Config{
|
||||
App: config.AppConfig{Env: "production"},
|
||||
DataPlane: config.DataPlaneConfig{
|
||||
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
|
||||
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
|
||||
DirectWorkerJSONRuntime: true,
|
||||
DirectWorkerTLSTrustMode: "smoke_insecure",
|
||||
},
|
||||
},
|
||||
}
|
||||
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
|
||||
if len(candidates) != 1 {
|
||||
t.Fatalf("expected fallback-only candidates in production with smoke TLS, got %d", len(candidates))
|
||||
}
|
||||
if candidates[0].Type != sessioncontracts.DataPlaneCandidateBackendGateway {
|
||||
t.Fatalf("production must not advertise smoke-only direct candidate: %#v", candidates)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataPlaneDirectCandidateAdvertisedInProductionWhenTrusted(t *testing.T) {
|
||||
service := &Service{
|
||||
cfg: module.Config{
|
||||
App: config.AppConfig{Env: "production"},
|
||||
DataPlane: config.DataPlaneConfig{
|
||||
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
|
||||
DirectWorkerWSSURLTemplate: "wss://{worker_id}.worker.example.test/rap/v1/data-plane",
|
||||
DirectWorkerJSONRuntime: true,
|
||||
DirectWorkerTLSTrustMode: "public_ca",
|
||||
},
|
||||
},
|
||||
}
|
||||
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
|
||||
if len(candidates) != 2 {
|
||||
t.Fatalf("expected trusted direct and fallback candidates, got %d", len(candidates))
|
||||
}
|
||||
if candidates[0].Metadata["production_trusted"] != true || candidates[0].Metadata["tls_trust_mode"] != "public_ca" {
|
||||
t.Fatalf("trusted production direct candidate metadata mismatch: %#v", candidates[0].Metadata)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataPlaneCandidatesFallbackOnlyWhenDirectTemplateMissing(t *testing.T) {
|
||||
service := &Service{
|
||||
cfg: module.Config{
|
||||
DataPlane: config.DataPlaneConfig{
|
||||
BackendGatewayURL: "wss://backend.example.test/api/v1/gateway/ws",
|
||||
},
|
||||
},
|
||||
}
|
||||
candidates := service.buildDataPlaneCandidates(RemoteSession{WorkerID: "worker-1"})
|
||||
if len(candidates) != 1 {
|
||||
t.Fatalf("expected fallback-only candidate list, got %d", len(candidates))
|
||||
}
|
||||
if candidates[0].Type != sessioncontracts.DataPlaneCandidateBackendGateway {
|
||||
t.Fatalf("unexpected candidate type: %s", candidates[0].Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataPlaneAllowedChannelsRespectRuntimePolicy(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
policy map[string]any
|
||||
expected []string
|
||||
blocked []string
|
||||
}{
|
||||
{
|
||||
name: "disabled policies expose only control input render telemetry",
|
||||
policy: map[string]any{"clipboard_mode": "disabled", "file_transfer_mode": "disabled"},
|
||||
expected: []string{sessioncontracts.DataPlaneChannelControl, sessioncontracts.DataPlaneChannelInput, sessioncontracts.DataPlaneChannelRender, sessioncontracts.DataPlaneChannelTelemetry},
|
||||
blocked: []string{sessioncontracts.DataPlaneChannelClipboard, sessioncontracts.DataPlaneChannelFileUpload},
|
||||
},
|
||||
{
|
||||
name: "clipboard policy adds clipboard channel",
|
||||
policy: map[string]any{"clipboard_mode": "server_to_client", "file_transfer_mode": "disabled"},
|
||||
expected: []string{sessioncontracts.DataPlaneChannelClipboard},
|
||||
blocked: []string{sessioncontracts.DataPlaneChannelFileUpload},
|
||||
},
|
||||
{
|
||||
name: "client upload policy adds file upload channel",
|
||||
policy: map[string]any{"clipboard_mode": "disabled", "file_transfer_mode": "client_to_server"},
|
||||
expected: []string{sessioncontracts.DataPlaneChannelFileUpload},
|
||||
blocked: []string{sessioncontracts.DataPlaneChannelClipboard},
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session := RemoteSession{Metadata: mustJSON(t, map[string]any{"policy": tc.policy})}
|
||||
channels := dataPlaneAllowedChannelsFromSession(session)
|
||||
for _, channel := range tc.expected {
|
||||
if !slices.Contains(channels, channel) {
|
||||
t.Fatalf("expected channel %q in %v", channel, channels)
|
||||
}
|
||||
}
|
||||
for _, channel := range tc.blocked {
|
||||
if slices.Contains(channels, channel) {
|
||||
t.Fatalf("did not expect channel %q in %v", channel, channels)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustJSON(t *testing.T, value any) []byte {
|
||||
t.Helper()
|
||||
payload, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal test metadata: %v", err)
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func testRS256Key(t *testing.T) (string, *rsa.PublicKey) {
|
||||
t.Helper()
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("generate RSA key: %v", err)
|
||||
}
|
||||
encoded := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||
})
|
||||
return string(encoded), &privateKey.PublicKey
|
||||
}
|
||||
|
||||
func assertEqual(t *testing.T, got, want, name string) {
|
||||
t.Helper()
|
||||
if got != want {
|
||||
t.Fatalf("unexpected %s: got %q want %q", name, got, want)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package sessionbroker
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrSessionNotFound = errors.New("remote session not found")
|
||||
ErrAttachmentNotFound = errors.New("session attachment not found")
|
||||
ErrActiveControllerPresent = errors.New("active controller already present")
|
||||
ErrTakeoverNotAllowed = errors.New("takeover not allowed")
|
||||
ErrTrustedDeviceRequired = errors.New("trusted device required")
|
||||
ErrAccessDenied = errors.New("access denied")
|
||||
ErrSessionNotAttachable = errors.New("session is not attachable")
|
||||
ErrSessionNotTerminable = errors.New("session is not terminable")
|
||||
ErrAttachTokenInvalid = errors.New("attach token invalid or expired")
|
||||
)
|
||||
@@ -0,0 +1,65 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
|
||||
)
|
||||
|
||||
type LiveStateStore interface {
|
||||
UpsertSession(ctx context.Context, state LiveSessionState) error
|
||||
GetSession(ctx context.Context, sessionID string) (*LiveSessionState, error)
|
||||
DeleteSession(ctx context.Context, sessionID string) error
|
||||
BindController(ctx context.Context, binding sessioncontracts.ControllerBinding, ttl time.Duration) error
|
||||
GetControllerBinding(ctx context.Context, sessionID string) (*sessioncontracts.ControllerBinding, error)
|
||||
ClearControllerBinding(ctx context.Context, sessionID string) error
|
||||
StoreAttachToken(ctx context.Context, claims sessioncontracts.AttachTokenClaims, ttl time.Duration) error
|
||||
ConsumeAttachToken(ctx context.Context, token string) (*sessioncontracts.AttachTokenClaims, error)
|
||||
TouchAttachmentHeartbeat(ctx context.Context, sessionID, attachmentID string, ttl time.Duration) error
|
||||
UpdateWorkerRoute(ctx context.Context, route WorkerRoute, ttl time.Duration) error
|
||||
GetWorkerRoute(ctx context.Context, sessionID string) (*WorkerRoute, error)
|
||||
DeleteWorkerRoute(ctx context.Context, sessionID string) error
|
||||
}
|
||||
|
||||
type LiveSessionState struct {
|
||||
SessionID string `json:"session_id"`
|
||||
ResourceID string `json:"resource_id"`
|
||||
WorkerID string `json:"worker_id"`
|
||||
State sessioncontracts.State `json:"state"`
|
||||
ControllerID string `json:"controller_id"`
|
||||
AttachmentID string `json:"attachment_id"`
|
||||
TakeoverVersion int `json:"takeover_version"`
|
||||
RenderQualityProfile string `json:"render_quality_profile,omitempty"`
|
||||
RenderState string `json:"render_state,omitempty"`
|
||||
RenderWidth int `json:"render_width,omitempty"`
|
||||
RenderHeight int `json:"render_height,omitempty"`
|
||||
RenderFrameSequence int64 `json:"render_frame_sequence,omitempty"`
|
||||
RenderFrameFormat string `json:"render_frame_format,omitempty"`
|
||||
RenderFrameData string `json:"render_frame_data,omitempty"`
|
||||
LastInputCorrelationID string `json:"last_input_correlation_id,omitempty"`
|
||||
WorkerFrameCapturedAt string `json:"worker_frame_captured_at,omitempty"`
|
||||
CursorX int `json:"cursor_x,omitempty"`
|
||||
CursorY int `json:"cursor_y,omitempty"`
|
||||
CursorVisible bool `json:"cursor_visible,omitempty"`
|
||||
DirtyRectangles int `json:"dirty_rectangles,omitempty"`
|
||||
LastRenderAt *time.Time `json:"last_render_at,omitempty"`
|
||||
ClipboardSequence int64 `json:"clipboard_sequence,omitempty"`
|
||||
ClipboardText string `json:"clipboard_text,omitempty"`
|
||||
ClipboardOrigin string `json:"clipboard_origin,omitempty"`
|
||||
ClipboardContentHash string `json:"clipboard_content_hash,omitempty"`
|
||||
ClipboardUpdatedAt *time.Time `json:"clipboard_updated_at,omitempty"`
|
||||
FileDownloadSequence int64 `json:"file_download_sequence,omitempty"`
|
||||
FileDownloadType string `json:"file_download_type,omitempty"`
|
||||
FileDownloadPayload map[string]any `json:"file_download_payload,omitempty"`
|
||||
FileDownloadUpdatedAt *time.Time `json:"file_download_updated_at,omitempty"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type WorkerRoute struct {
|
||||
SessionID string `json:"session_id"`
|
||||
WorkerID string `json:"worker_id"`
|
||||
LeaseID string `json:"lease_id"`
|
||||
ControlStream string `json:"control_stream"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
|
||||
)
|
||||
|
||||
type AttachmentRole string
|
||||
|
||||
const (
|
||||
AttachmentRoleController AttachmentRole = "controller"
|
||||
)
|
||||
|
||||
type AttachmentState string
|
||||
|
||||
const (
|
||||
AttachmentStateAttaching AttachmentState = "attaching"
|
||||
AttachmentStateActive AttachmentState = "active"
|
||||
AttachmentStateDetached AttachmentState = "detached"
|
||||
AttachmentStateSuperseded AttachmentState = "superseded"
|
||||
AttachmentStateRevoked AttachmentState = "revoked"
|
||||
AttachmentStateClosed AttachmentState = "closed"
|
||||
)
|
||||
|
||||
type ResourceTakeoverPolicy string
|
||||
|
||||
const (
|
||||
ResourceTakeoverPolicyTrustedDevice ResourceTakeoverPolicy = "trusted_device"
|
||||
ResourceTakeoverPolicySameUser ResourceTakeoverPolicy = "same_user"
|
||||
ResourceTakeoverPolicyAdminOnly ResourceTakeoverPolicy = "admin_only"
|
||||
)
|
||||
|
||||
type ResourceClipboardMode string
|
||||
|
||||
const (
|
||||
ResourceClipboardModeDisabled ResourceClipboardMode = "disabled"
|
||||
ResourceClipboardModeClientToServer ResourceClipboardMode = "client_to_server"
|
||||
ResourceClipboardModeServerToClient ResourceClipboardMode = "server_to_client"
|
||||
ResourceClipboardModeBidirectional ResourceClipboardMode = "bidirectional"
|
||||
)
|
||||
|
||||
type ResourceFileTransferMode string
|
||||
|
||||
const (
|
||||
ResourceFileTransferModeDisabled ResourceFileTransferMode = "disabled"
|
||||
ResourceFileTransferModeClientToServer ResourceFileTransferMode = "client_to_server"
|
||||
ResourceFileTransferModeServerToClient ResourceFileTransferMode = "server_to_client"
|
||||
ResourceFileTransferModeBidirectional ResourceFileTransferMode = "bidirectional"
|
||||
)
|
||||
|
||||
type RemoteSession struct {
|
||||
ID string
|
||||
OrganizationID string
|
||||
ResourceID string
|
||||
Protocol string
|
||||
State sessioncontracts.State
|
||||
WorkerID string
|
||||
ControllerUserID string
|
||||
DetachDeadlineAt *time.Time
|
||||
LastHeartbeatAt *time.Time
|
||||
TakeoverVersion int
|
||||
RenderQualityProfile string
|
||||
Metadata []byte
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type SessionAttachment struct {
|
||||
ID string
|
||||
RemoteSessionID string
|
||||
UserID string
|
||||
DeviceID string
|
||||
Role AttachmentRole
|
||||
State AttachmentState
|
||||
SupersededBy *string
|
||||
TakeoverOf *string
|
||||
AttachedAt *time.Time
|
||||
DetachedAt *time.Time
|
||||
LastInputAt *time.Time
|
||||
Metadata []byte
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type ResourcePolicy struct {
|
||||
ResourceID string
|
||||
MaxConcurrentSessions int
|
||||
TakeoverPolicy ResourceTakeoverPolicy
|
||||
RequireTrustedDevice bool
|
||||
DetachGracePeriod time.Duration
|
||||
ClipboardEnabled bool
|
||||
ClipboardMode ResourceClipboardMode
|
||||
FileTransferEnabled bool
|
||||
FileTransferMode ResourceFileTransferMode
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type ResourceRuntimeSpec struct {
|
||||
ID string
|
||||
OrganizationID string
|
||||
Name string
|
||||
Address string
|
||||
Protocol string
|
||||
SecretRef *string
|
||||
CertificateVerificationMode string
|
||||
Metadata []byte
|
||||
}
|
||||
|
||||
type AuditEvent struct {
|
||||
ID string
|
||||
ActorUserID *string
|
||||
ActorDeviceID *string
|
||||
EventType string
|
||||
TargetType string
|
||||
TargetID string
|
||||
RemoteSessionID *string
|
||||
Payload []byte
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
AuditEventSessionStarted = "session_started"
|
||||
AuditEventSessionAttached = "session_attached"
|
||||
AuditEventSessionDetached = "session_detached"
|
||||
AuditEventSessionTakenOver = "session_taken_over"
|
||||
AuditEventSessionTerminated = "session_terminated"
|
||||
AuditEventSessionFailed = "session_failed"
|
||||
AuditEventSecretAccessed = "resource_secret_accessed"
|
||||
AuditEventSecretAccessDenied = "resource_secret_access_denied"
|
||||
)
|
||||
@@ -0,0 +1,164 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
|
||||
)
|
||||
|
||||
type Module struct {
|
||||
service *Service
|
||||
}
|
||||
|
||||
func NewModule(service *Service) *Module {
|
||||
return &Module{service: service}
|
||||
}
|
||||
|
||||
func (m *Module) Name() string {
|
||||
return "session-broker"
|
||||
}
|
||||
|
||||
func (m *Module) Service() *Service {
|
||||
return m.service
|
||||
}
|
||||
|
||||
func (m *Module) RegisterRoutes(router chi.Router) {
|
||||
router.Route("/sessions", func(r chi.Router) {
|
||||
r.Get("/", m.listSessions)
|
||||
r.Post("/", m.startSession)
|
||||
r.Post("/{sessionID}/attach", m.attachSession)
|
||||
r.Post("/{sessionID}/detach", m.detachSession)
|
||||
r.Post("/{sessionID}/takeover", m.takeoverSession)
|
||||
r.Post("/{sessionID}/terminate", m.terminateSession)
|
||||
r.Post("/{sessionID}/fail", m.markFailed)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Module) listSessions(w http.ResponseWriter, r *http.Request) {
|
||||
userID := r.URL.Query().Get("user_id")
|
||||
if userID == "" {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
|
||||
return
|
||||
}
|
||||
sessions, err := m.service.ListSessions(r.Context(), userID)
|
||||
if err != nil {
|
||||
status, message := m.service.MapError(err)
|
||||
httpx.WriteError(w, status, message)
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusOK, map[string]any{"sessions": sessions})
|
||||
}
|
||||
|
||||
func (m *Module) startSession(w http.ResponseWriter, r *http.Request) {
|
||||
var cmd StartRemoteSessionCommand
|
||||
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid start session payload")
|
||||
return
|
||||
}
|
||||
result, err := m.service.StartRemoteSession(r.Context(), cmd)
|
||||
if err != nil {
|
||||
status, message := m.service.MapError(err)
|
||||
httpx.WriteError(w, status, message)
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusCreated, result)
|
||||
}
|
||||
|
||||
func (m *Module) attachSession(w http.ResponseWriter, r *http.Request) {
|
||||
var cmd AttachToSessionCommand
|
||||
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid attach session payload")
|
||||
return
|
||||
}
|
||||
cmd.SessionID = chi.URLParam(r, "sessionID")
|
||||
result, err := m.service.AttachToSession(r.Context(), cmd)
|
||||
if err != nil {
|
||||
status, message := m.service.MapError(err)
|
||||
httpx.WriteError(w, status, message)
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (m *Module) detachSession(w http.ResponseWriter, r *http.Request) {
|
||||
var cmd DetachFromSessionCommand
|
||||
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid detach session payload")
|
||||
return
|
||||
}
|
||||
cmd.SessionID = chi.URLParam(r, "sessionID")
|
||||
result, err := m.service.DetachFromSession(r.Context(), cmd)
|
||||
if err != nil {
|
||||
status, message := m.service.MapError(err)
|
||||
httpx.WriteError(w, status, message)
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusAccepted, result)
|
||||
}
|
||||
|
||||
func (m *Module) takeoverSession(w http.ResponseWriter, r *http.Request) {
|
||||
var cmd TakeoverSessionCommand
|
||||
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid takeover session payload")
|
||||
return
|
||||
}
|
||||
cmd.SessionID = chi.URLParam(r, "sessionID")
|
||||
result, err := m.service.TakeoverSession(r.Context(), cmd)
|
||||
if err != nil {
|
||||
status, message := m.service.MapError(err)
|
||||
httpx.WriteError(w, status, message)
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (m *Module) terminateSession(w http.ResponseWriter, r *http.Request) {
|
||||
var cmd TerminateSessionCommand
|
||||
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid terminate session payload")
|
||||
return
|
||||
}
|
||||
cmd.SessionID = chi.URLParam(r, "sessionID")
|
||||
if err := m.service.TerminateSession(r.Context(), cmd); err != nil {
|
||||
status, message := m.service.MapError(err)
|
||||
httpx.WriteError(w, status, message)
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{
|
||||
"status": "terminated",
|
||||
"message": httpx.NewMessage(
|
||||
"session.terminated",
|
||||
"status.session.terminated",
|
||||
"Session terminated.",
|
||||
nil,
|
||||
"",
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Module) markFailed(w http.ResponseWriter, r *http.Request) {
|
||||
var cmd MarkSessionFailedCommand
|
||||
if err := json.NewDecoder(r.Body).Decode(&cmd); err != nil {
|
||||
httpx.WriteError(w, http.StatusBadRequest, "invalid fail session payload")
|
||||
return
|
||||
}
|
||||
cmd.SessionID = chi.URLParam(r, "sessionID")
|
||||
if err := m.service.MarkSessionFailed(r.Context(), cmd); err != nil {
|
||||
status, message := m.service.MapError(err)
|
||||
httpx.WriteError(w, status, message)
|
||||
return
|
||||
}
|
||||
httpx.WriteJSON(w, http.StatusAccepted, map[string]any{
|
||||
"status": "failed",
|
||||
"message": httpx.NewMessage(
|
||||
"session.failed",
|
||||
"status.session.failed",
|
||||
"Session marked as failed.",
|
||||
nil,
|
||||
"",
|
||||
),
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
|
||||
)
|
||||
|
||||
type WorkerOrchestrator interface {
|
||||
Reserve(ctx context.Context, request workercontracts.AttachRequest) (*workercontracts.WorkerLease, error)
|
||||
GetSessionLease(ctx context.Context, sessionID string) (*workercontracts.WorkerLease, error)
|
||||
ReleaseSessionLease(ctx context.Context, sessionID string) error
|
||||
PrepareAttachment(ctx context.Context, session RemoteSession, attachment SessionAttachment, runtimeMetadata map[string]any) error
|
||||
NotifyDetachment(ctx context.Context, session RemoteSession, attachment SessionAttachment) error
|
||||
TerminateRemoteSession(ctx context.Context, sessionID, attachmentID string) error
|
||||
ValidateSessionRuntime(ctx context.Context, sessionID, workerID string) (bool, string, error)
|
||||
}
|
||||
@@ -0,0 +1,607 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/authority"
|
||||
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
|
||||
)
|
||||
|
||||
type postgresStore struct {
|
||||
db postgresplatform.DBTX
|
||||
authority *authority.Verifier
|
||||
}
|
||||
|
||||
type PostgresTransactor struct {
|
||||
pool *pgxpool.Pool
|
||||
authority *authority.Verifier
|
||||
}
|
||||
|
||||
func NewPostgresStore(pool *pgxpool.Pool, verifiers ...*authority.Verifier) Store {
|
||||
var authorityVerifier *authority.Verifier
|
||||
if len(verifiers) > 0 {
|
||||
authorityVerifier = verifiers[0]
|
||||
}
|
||||
return &postgresStore{db: pool, authority: authorityVerifier}
|
||||
}
|
||||
|
||||
func NewPostgresTransactor(pool *pgxpool.Pool, verifiers ...*authority.Verifier) *PostgresTransactor {
|
||||
var authorityVerifier *authority.Verifier
|
||||
if len(verifiers) > 0 {
|
||||
authorityVerifier = verifiers[0]
|
||||
}
|
||||
return &PostgresTransactor{pool: pool, authority: authorityVerifier}
|
||||
}
|
||||
|
||||
func (t *PostgresTransactor) WithinTransaction(ctx context.Context, fn func(store Store) error) error {
|
||||
return postgresplatform.WithTransaction(ctx, t.pool, func(tx pgx.Tx) error {
|
||||
return fn(&postgresStore{db: tx, authority: t.authority})
|
||||
})
|
||||
}
|
||||
|
||||
func (s *postgresStore) RemoteSessions() RemoteSessionRepository {
|
||||
return &postgresRemoteSessionRepository{db: s.db}
|
||||
}
|
||||
|
||||
func (s *postgresStore) SessionAttachments() SessionAttachmentRepository {
|
||||
return &postgresSessionAttachmentRepository{db: s.db}
|
||||
}
|
||||
|
||||
func (s *postgresStore) ResourcePolicies() ResourcePolicyRepository {
|
||||
return &postgresResourcePolicyRepository{db: s.db}
|
||||
}
|
||||
|
||||
func (s *postgresStore) ResourceRuntime() ResourceRuntimeRepository {
|
||||
return &postgresResourceRuntimeRepository{db: s.db}
|
||||
}
|
||||
|
||||
func (s *postgresStore) AuditEvents() AuditEventRepository {
|
||||
return &postgresAuditEventRepository{db: s.db}
|
||||
}
|
||||
|
||||
func (s *postgresStore) Access() AccessRepository {
|
||||
return &postgresAccessRepository{db: s.db, authority: s.authority}
|
||||
}
|
||||
|
||||
type postgresRemoteSessionRepository struct {
|
||||
db postgresplatform.DBTX
|
||||
}
|
||||
|
||||
type postgresSessionAttachmentRepository struct {
|
||||
db postgresplatform.DBTX
|
||||
}
|
||||
|
||||
type postgresResourcePolicyRepository struct {
|
||||
db postgresplatform.DBTX
|
||||
}
|
||||
|
||||
type postgresResourceRuntimeRepository struct {
|
||||
db postgresplatform.DBTX
|
||||
}
|
||||
|
||||
type postgresAuditEventRepository struct {
|
||||
db postgresplatform.DBTX
|
||||
}
|
||||
|
||||
type postgresAccessRepository struct {
|
||||
db postgresplatform.DBTX
|
||||
authority *authority.Verifier
|
||||
}
|
||||
|
||||
func (r *postgresRemoteSessionRepository) Create(ctx context.Context, session RemoteSession) error {
|
||||
const query = `
|
||||
INSERT INTO remote_sessions (
|
||||
id, organization_id, resource_id, protocol, state, worker_id, controller_user_id, detach_deadline_at,
|
||||
last_heartbeat_at, takeover_version, metadata, created_at, updated_at
|
||||
) VALUES (
|
||||
$1::uuid, $2::uuid, $3::uuid, $4, $5, NULLIF($6, ''), $7::uuid, $8, $9, $10, $11::jsonb, $12, $13
|
||||
)
|
||||
`
|
||||
if _, err := r.db.Exec(ctx, query,
|
||||
session.ID,
|
||||
session.OrganizationID,
|
||||
session.ResourceID,
|
||||
session.Protocol,
|
||||
session.State,
|
||||
session.WorkerID,
|
||||
session.ControllerUserID,
|
||||
session.DetachDeadlineAt,
|
||||
session.LastHeartbeatAt,
|
||||
session.TakeoverVersion,
|
||||
jsonPayload(session.Metadata),
|
||||
session.CreatedAt,
|
||||
session.UpdatedAt,
|
||||
); err != nil {
|
||||
return fmt.Errorf("create remote session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *postgresRemoteSessionRepository) GetByID(ctx context.Context, sessionID string) (*RemoteSession, error) {
|
||||
return r.getByID(ctx, sessionID, "")
|
||||
}
|
||||
|
||||
func (r *postgresRemoteSessionRepository) GetByIDForUpdate(ctx context.Context, sessionID string) (*RemoteSession, error) {
|
||||
return r.getByID(ctx, sessionID, " FOR UPDATE")
|
||||
}
|
||||
|
||||
func (r *postgresRemoteSessionRepository) getByID(ctx context.Context, sessionID string, suffix string) (*RemoteSession, error) {
|
||||
query := `
|
||||
SELECT id::text, organization_id::text, resource_id::text, protocol, state, COALESCE(worker_id, ''), controller_user_id::text,
|
||||
detach_deadline_at, last_heartbeat_at, takeover_version, metadata, created_at, updated_at
|
||||
FROM remote_sessions
|
||||
WHERE id = $1::uuid` + suffix
|
||||
remoteSession, err := scanRemoteSession(r.db.QueryRow(ctx, query, sessionID))
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return remoteSession, err
|
||||
}
|
||||
|
||||
func (r *postgresRemoteSessionRepository) ListByController(ctx context.Context, userID string) ([]RemoteSession, error) {
|
||||
const query = `
|
||||
SELECT id::text, organization_id::text, resource_id::text, protocol, state, COALESCE(worker_id, ''), controller_user_id::text,
|
||||
detach_deadline_at, last_heartbeat_at, takeover_version, metadata, created_at, updated_at
|
||||
FROM remote_sessions
|
||||
WHERE controller_user_id = $1::uuid
|
||||
ORDER BY updated_at DESC
|
||||
`
|
||||
rows, err := r.db.Query(ctx, query, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list remote sessions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sessions []RemoteSession
|
||||
for rows.Next() {
|
||||
item, err := scanRemoteSession(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessions = append(sessions, *item)
|
||||
}
|
||||
return sessions, rows.Err()
|
||||
}
|
||||
|
||||
func (r *postgresRemoteSessionRepository) CountLiveByResource(ctx context.Context, resourceID string) (int, error) {
|
||||
const query = `
|
||||
SELECT COUNT(*)
|
||||
FROM remote_sessions
|
||||
WHERE resource_id = $1::uuid AND state IN ('starting', 'active', 'detached', 'reconnecting')
|
||||
`
|
||||
var count int
|
||||
if err := r.db.QueryRow(ctx, query, resourceID).Scan(&count); err != nil {
|
||||
return 0, fmt.Errorf("count live remote sessions: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *postgresRemoteSessionRepository) ListDetachedExpired(ctx context.Context, before time.Time, limit int) ([]RemoteSession, error) {
|
||||
const query = `
|
||||
SELECT id::text, organization_id::text, resource_id::text, protocol, state, COALESCE(worker_id, ''), controller_user_id::text,
|
||||
detach_deadline_at, last_heartbeat_at, takeover_version, metadata, created_at, updated_at
|
||||
FROM remote_sessions
|
||||
WHERE state = 'detached' AND detach_deadline_at IS NOT NULL AND detach_deadline_at <= $1
|
||||
ORDER BY detach_deadline_at ASC
|
||||
LIMIT $2
|
||||
`
|
||||
rows, err := r.db.Query(ctx, query, before, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list detached expired sessions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sessions []RemoteSession
|
||||
for rows.Next() {
|
||||
item, err := scanRemoteSession(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessions = append(sessions, *item)
|
||||
}
|
||||
return sessions, rows.Err()
|
||||
}
|
||||
|
||||
func (r *postgresRemoteSessionRepository) UpdateState(ctx context.Context, params UpdateRemoteSessionStateParams) error {
|
||||
const query = `
|
||||
UPDATE remote_sessions
|
||||
SET state = $2,
|
||||
worker_id = NULLIF($3, ''),
|
||||
detach_deadline_at = $4,
|
||||
last_heartbeat_at = $5,
|
||||
takeover_version = $6,
|
||||
updated_at = $7
|
||||
WHERE id = $1::uuid
|
||||
`
|
||||
if _, err := r.db.Exec(ctx, query,
|
||||
params.RemoteSessionID,
|
||||
params.State,
|
||||
params.WorkerID,
|
||||
params.DetachDeadlineAt,
|
||||
params.LastHeartbeatAt,
|
||||
params.TakeoverVersion,
|
||||
params.UpdatedAt,
|
||||
); err != nil {
|
||||
return fmt.Errorf("update remote session state: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *postgresSessionAttachmentRepository) Create(ctx context.Context, attachment SessionAttachment) error {
|
||||
const query = `
|
||||
INSERT INTO session_attachments (
|
||||
id, remote_session_id, user_id, device_id, role, state, superseded_by,
|
||||
takeover_of, attached_at, detached_at, last_input_at, metadata, created_at, updated_at
|
||||
) VALUES (
|
||||
$1::uuid, $2::uuid, $3::uuid, $4::uuid, $5, $6, NULLIF($7, '')::uuid,
|
||||
NULLIF($8, '')::uuid, $9, $10, $11, $12::jsonb, $13, $14
|
||||
)
|
||||
`
|
||||
if _, err := r.db.Exec(ctx, query,
|
||||
attachment.ID,
|
||||
attachment.RemoteSessionID,
|
||||
attachment.UserID,
|
||||
attachment.DeviceID,
|
||||
attachment.Role,
|
||||
attachment.State,
|
||||
stringValue(attachment.SupersededBy),
|
||||
stringValue(attachment.TakeoverOf),
|
||||
attachment.AttachedAt,
|
||||
attachment.DetachedAt,
|
||||
attachment.LastInputAt,
|
||||
jsonPayload(attachment.Metadata),
|
||||
attachment.CreatedAt,
|
||||
attachment.UpdatedAt,
|
||||
); err != nil {
|
||||
return fmt.Errorf("create session attachment: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *postgresSessionAttachmentRepository) GetByID(ctx context.Context, attachmentID string) (*SessionAttachment, error) {
|
||||
return r.getByID(ctx, attachmentID, "")
|
||||
}
|
||||
|
||||
func (r *postgresSessionAttachmentRepository) GetByIDForUpdate(ctx context.Context, attachmentID string) (*SessionAttachment, error) {
|
||||
return r.getByID(ctx, attachmentID, " FOR UPDATE")
|
||||
}
|
||||
|
||||
func (r *postgresSessionAttachmentRepository) getByID(ctx context.Context, attachmentID string, suffix string) (*SessionAttachment, error) {
|
||||
query := `
|
||||
SELECT id::text, remote_session_id::text, user_id::text, device_id::text, role, state,
|
||||
superseded_by::text, takeover_of::text, attached_at, detached_at, last_input_at, metadata, created_at, updated_at
|
||||
FROM session_attachments
|
||||
WHERE id = $1::uuid` + suffix
|
||||
attachment, err := scanSessionAttachment(r.db.QueryRow(ctx, query, attachmentID))
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return attachment, err
|
||||
}
|
||||
|
||||
func (r *postgresSessionAttachmentRepository) ListByRemoteSession(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error) {
|
||||
return r.listByRemoteSession(ctx, remoteSessionID, "")
|
||||
}
|
||||
|
||||
func (r *postgresSessionAttachmentRepository) ListActiveByRemoteSessionForUpdate(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error) {
|
||||
return r.listByRemoteSession(ctx, remoteSessionID, " AND state IN ('attaching', 'active', 'reconnecting') FOR UPDATE")
|
||||
}
|
||||
|
||||
func (r *postgresSessionAttachmentRepository) listByRemoteSession(ctx context.Context, remoteSessionID string, suffix string) ([]SessionAttachment, error) {
|
||||
query := `
|
||||
SELECT id::text, remote_session_id::text, user_id::text, device_id::text, role, state,
|
||||
superseded_by::text, takeover_of::text, attached_at, detached_at, last_input_at, metadata, created_at, updated_at
|
||||
FROM session_attachments
|
||||
WHERE remote_session_id = $1::uuid` + suffix
|
||||
rows, err := r.db.Query(ctx, query, remoteSessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list session attachments: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var attachments []SessionAttachment
|
||||
for rows.Next() {
|
||||
item, err := scanSessionAttachment(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attachments = append(attachments, *item)
|
||||
}
|
||||
return attachments, rows.Err()
|
||||
}
|
||||
|
||||
func (r *postgresSessionAttachmentRepository) UpdateState(ctx context.Context, params UpdateSessionAttachmentStateParams) error {
|
||||
const query = `
|
||||
UPDATE session_attachments
|
||||
SET state = $2,
|
||||
detached_at = $3,
|
||||
last_input_at = $4,
|
||||
updated_at = $5
|
||||
WHERE id = $1::uuid
|
||||
`
|
||||
if _, err := r.db.Exec(ctx, query,
|
||||
params.AttachmentID,
|
||||
params.State,
|
||||
params.DetachedAt,
|
||||
params.LastInputAt,
|
||||
params.UpdatedAt,
|
||||
); err != nil {
|
||||
return fmt.Errorf("update session attachment state: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *postgresSessionAttachmentRepository) Supersede(ctx context.Context, params SupersedeAttachmentParams) error {
|
||||
const query = `
|
||||
UPDATE session_attachments
|
||||
SET state = 'superseded',
|
||||
superseded_by = $2::uuid,
|
||||
detached_at = $3,
|
||||
updated_at = $4
|
||||
WHERE id = $1::uuid
|
||||
`
|
||||
if _, err := r.db.Exec(ctx, query,
|
||||
params.PreviousAttachmentID,
|
||||
params.NextAttachmentID,
|
||||
params.DetachedAt,
|
||||
params.UpdatedAt,
|
||||
); err != nil {
|
||||
return fmt.Errorf("supersede attachment: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *postgresResourcePolicyRepository) GetByResourceID(ctx context.Context, resourceID string) (*ResourcePolicy, error) {
|
||||
const query = `
|
||||
SELECT resource_id::text, max_concurrent_sessions, takeover_policy, require_trusted_device,
|
||||
detach_grace_period_seconds, clipboard_enabled, clipboard_mode, file_transfer_enabled,
|
||||
COALESCE(file_transfer_mode, 'disabled'), created_at, updated_at
|
||||
FROM resource_policies
|
||||
WHERE resource_id = $1::uuid
|
||||
`
|
||||
policy, err := scanResourcePolicy(r.db.QueryRow(ctx, query, resourceID))
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return policy, err
|
||||
}
|
||||
|
||||
func (r *postgresResourcePolicyRepository) Upsert(ctx context.Context, policy ResourcePolicy) error {
|
||||
const query = `
|
||||
INSERT INTO resource_policies (
|
||||
resource_id, max_concurrent_sessions, takeover_policy, require_trusted_device,
|
||||
detach_grace_period_seconds, clipboard_enabled, clipboard_mode, file_transfer_enabled, file_transfer_mode, created_at, updated_at
|
||||
) VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
ON CONFLICT (resource_id) DO UPDATE SET
|
||||
max_concurrent_sessions = EXCLUDED.max_concurrent_sessions,
|
||||
takeover_policy = EXCLUDED.takeover_policy,
|
||||
require_trusted_device = EXCLUDED.require_trusted_device,
|
||||
detach_grace_period_seconds = EXCLUDED.detach_grace_period_seconds,
|
||||
clipboard_enabled = EXCLUDED.clipboard_enabled,
|
||||
clipboard_mode = EXCLUDED.clipboard_mode,
|
||||
file_transfer_enabled = EXCLUDED.file_transfer_enabled,
|
||||
file_transfer_mode = EXCLUDED.file_transfer_mode,
|
||||
updated_at = EXCLUDED.updated_at
|
||||
`
|
||||
clipboardMode := normalizeClipboardMode(policy.ClipboardMode)
|
||||
fileTransferMode := normalizeFileTransferMode(policy.FileTransferMode)
|
||||
if _, err := r.db.Exec(ctx, query,
|
||||
policy.ResourceID,
|
||||
policy.MaxConcurrentSessions,
|
||||
policy.TakeoverPolicy,
|
||||
policy.RequireTrustedDevice,
|
||||
int(policy.DetachGracePeriod.Seconds()),
|
||||
clipboardMode != ResourceClipboardModeDisabled,
|
||||
clipboardMode,
|
||||
fileTransferAllowsClientToServer(fileTransferMode),
|
||||
fileTransferMode,
|
||||
policy.CreatedAt,
|
||||
policy.UpdatedAt,
|
||||
); err != nil {
|
||||
return fmt.Errorf("upsert resource policy: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *postgresResourceRuntimeRepository) GetByID(ctx context.Context, resourceID string) (*ResourceRuntimeSpec, error) {
|
||||
const query = `
|
||||
SELECT id::text, organization_id::text, name, address, protocol, secret_ref, certificate_verification_mode, metadata
|
||||
FROM resources
|
||||
WHERE id = $1::uuid
|
||||
`
|
||||
item := &ResourceRuntimeSpec{}
|
||||
var secretRef *string
|
||||
var metadata []byte
|
||||
if err := r.db.QueryRow(ctx, query, resourceID).Scan(
|
||||
&item.ID,
|
||||
&item.OrganizationID,
|
||||
&item.Name,
|
||||
&item.Address,
|
||||
&item.Protocol,
|
||||
&secretRef,
|
||||
&item.CertificateVerificationMode,
|
||||
&metadata,
|
||||
); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("get resource runtime spec: %w", err)
|
||||
}
|
||||
item.SecretRef = secretRef
|
||||
item.Metadata = metadata
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func (r *postgresAuditEventRepository) Create(ctx context.Context, event AuditEvent) error {
|
||||
const query = `
|
||||
INSERT INTO audit_events (
|
||||
id, actor_user_id, actor_device_id, event_type, target_type, target_id,
|
||||
remote_session_id, payload, created_at
|
||||
) VALUES (
|
||||
$1::uuid, NULLIF($2, '')::uuid, NULLIF($3, '')::uuid, $4, $5, $6,
|
||||
NULLIF($7, '')::uuid, $8::jsonb, $9
|
||||
)
|
||||
`
|
||||
if _, err := r.db.Exec(ctx, query,
|
||||
event.ID,
|
||||
stringValue(event.ActorUserID),
|
||||
stringValue(event.ActorDeviceID),
|
||||
event.EventType,
|
||||
event.TargetType,
|
||||
event.TargetID,
|
||||
stringValue(event.RemoteSessionID),
|
||||
jsonPayload(event.Payload),
|
||||
event.CreatedAt,
|
||||
); err != nil {
|
||||
return fmt.Errorf("create audit event: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *postgresAccessRepository) IsTrustedDevice(ctx context.Context, userID, deviceID string) (bool, error) {
|
||||
const query = `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM devices
|
||||
WHERE id = $1::uuid AND user_id = $2::uuid AND trust_status = 'trusted' AND revoked_at IS NULL
|
||||
)
|
||||
`
|
||||
var trusted bool
|
||||
if err := r.db.QueryRow(ctx, query, deviceID, userID).Scan(&trusted); err != nil {
|
||||
return false, fmt.Errorf("check trusted device: %w", err)
|
||||
}
|
||||
return trusted, nil
|
||||
}
|
||||
|
||||
func (r *postgresAccessRepository) GetPlatformRole(ctx context.Context, userID string) (string, error) {
|
||||
return authority.EffectivePlatformRole(ctx, r.db, r.authority, userID)
|
||||
}
|
||||
|
||||
func (r *postgresAccessRepository) GetOrganizationRole(ctx context.Context, organizationID, userID string) (string, bool, error) {
|
||||
const query = `
|
||||
SELECT role_id
|
||||
FROM organization_memberships
|
||||
WHERE organization_id = $1::uuid AND user_id = $2::uuid AND status = 'active'
|
||||
`
|
||||
var role string
|
||||
if err := r.db.QueryRow(ctx, query, organizationID, userID).Scan(&role); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", false, nil
|
||||
}
|
||||
return "", false, fmt.Errorf("get organization role: %w", err)
|
||||
}
|
||||
return role, true, nil
|
||||
}
|
||||
|
||||
type scanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func scanRemoteSession(row scanner) (*RemoteSession, error) {
|
||||
item := &RemoteSession{}
|
||||
var detachDeadlineAt, lastHeartbeatAt *time.Time
|
||||
var metadata []byte
|
||||
if err := row.Scan(
|
||||
&item.ID,
|
||||
&item.OrganizationID,
|
||||
&item.ResourceID,
|
||||
&item.Protocol,
|
||||
&item.State,
|
||||
&item.WorkerID,
|
||||
&item.ControllerUserID,
|
||||
&detachDeadlineAt,
|
||||
&lastHeartbeatAt,
|
||||
&item.TakeoverVersion,
|
||||
&metadata,
|
||||
&item.CreatedAt,
|
||||
&item.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan remote session: %w", err)
|
||||
}
|
||||
item.DetachDeadlineAt = detachDeadlineAt
|
||||
item.LastHeartbeatAt = lastHeartbeatAt
|
||||
item.Metadata = metadata
|
||||
item.RenderQualityProfile = renderQualityProfileFromSessionMetadata(metadata)
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func scanSessionAttachment(row scanner) (*SessionAttachment, error) {
|
||||
item := &SessionAttachment{}
|
||||
var supersededBy, takeoverOf *string
|
||||
var attachedAt, detachedAt, lastInputAt *time.Time
|
||||
var metadata []byte
|
||||
if err := row.Scan(
|
||||
&item.ID,
|
||||
&item.RemoteSessionID,
|
||||
&item.UserID,
|
||||
&item.DeviceID,
|
||||
&item.Role,
|
||||
&item.State,
|
||||
&supersededBy,
|
||||
&takeoverOf,
|
||||
&attachedAt,
|
||||
&detachedAt,
|
||||
&lastInputAt,
|
||||
&metadata,
|
||||
&item.CreatedAt,
|
||||
&item.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan session attachment: %w", err)
|
||||
}
|
||||
item.SupersededBy = supersededBy
|
||||
item.TakeoverOf = takeoverOf
|
||||
item.AttachedAt = attachedAt
|
||||
item.DetachedAt = detachedAt
|
||||
item.LastInputAt = lastInputAt
|
||||
item.Metadata = metadata
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func scanResourcePolicy(row scanner) (*ResourcePolicy, error) {
|
||||
item := &ResourcePolicy{}
|
||||
var detachGraceSeconds int
|
||||
if err := row.Scan(
|
||||
&item.ResourceID,
|
||||
&item.MaxConcurrentSessions,
|
||||
&item.TakeoverPolicy,
|
||||
&item.RequireTrustedDevice,
|
||||
&detachGraceSeconds,
|
||||
&item.ClipboardEnabled,
|
||||
&item.ClipboardMode,
|
||||
&item.FileTransferEnabled,
|
||||
&item.FileTransferMode,
|
||||
&item.CreatedAt,
|
||||
&item.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan resource policy: %w", err)
|
||||
}
|
||||
item.DetachGracePeriod = time.Duration(detachGraceSeconds) * time.Second
|
||||
item.ClipboardMode = normalizeClipboardMode(item.ClipboardMode)
|
||||
item.ClipboardEnabled = item.ClipboardMode != ResourceClipboardModeDisabled
|
||||
item.FileTransferMode = normalizeFileTransferMode(item.FileTransferMode)
|
||||
item.FileTransferEnabled = fileTransferAllowsClientToServer(item.FileTransferMode)
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func jsonPayload(payload []byte) []byte {
|
||||
if len(payload) == 0 {
|
||||
return []byte(`{}`)
|
||||
}
|
||||
if json.Valid(payload) {
|
||||
return payload
|
||||
}
|
||||
return []byte(`{}`)
|
||||
}
|
||||
|
||||
func stringValue(value *string) string {
|
||||
if value == nil {
|
||||
return ""
|
||||
}
|
||||
return *value
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
|
||||
)
|
||||
|
||||
type RedisLiveStateStore struct {
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
func NewRedisLiveStateStore(client *redis.Client) *RedisLiveStateStore {
|
||||
return &RedisLiveStateStore{client: client}
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) UpsertSession(ctx context.Context, state LiveSessionState) error {
|
||||
return s.setJSON(ctx, liveSessionKey(state.SessionID), state, 0)
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) GetSession(ctx context.Context, sessionID string) (*LiveSessionState, error) {
|
||||
var state LiveSessionState
|
||||
ok, err := s.getJSON(ctx, liveSessionKey(sessionID), &state)
|
||||
if err != nil || !ok {
|
||||
return nil, err
|
||||
}
|
||||
return &state, nil
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) DeleteSession(ctx context.Context, sessionID string) error {
|
||||
return s.client.Del(ctx, liveSessionKey(sessionID)).Err()
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) BindController(ctx context.Context, binding sessioncontracts.ControllerBinding, ttl time.Duration) error {
|
||||
return s.setJSON(ctx, controllerBindingKey(binding.SessionID), binding, ttl)
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) GetControllerBinding(ctx context.Context, sessionID string) (*sessioncontracts.ControllerBinding, error) {
|
||||
var binding sessioncontracts.ControllerBinding
|
||||
ok, err := s.getJSON(ctx, controllerBindingKey(sessionID), &binding)
|
||||
if err != nil || !ok {
|
||||
return nil, err
|
||||
}
|
||||
return &binding, nil
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) ClearControllerBinding(ctx context.Context, sessionID string) error {
|
||||
return s.client.Del(ctx, controllerBindingKey(sessionID)).Err()
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) StoreAttachToken(ctx context.Context, claims sessioncontracts.AttachTokenClaims, ttl time.Duration) error {
|
||||
return s.setJSON(ctx, attachTokenKey(claims.Token), claims, ttl)
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) ConsumeAttachToken(ctx context.Context, token string) (*sessioncontracts.AttachTokenClaims, error) {
|
||||
key := attachTokenKey(token)
|
||||
payload, err := s.client.GetDel(ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("consume attach token: %w", err)
|
||||
}
|
||||
var claims sessioncontracts.AttachTokenClaims
|
||||
if err := json.Unmarshal([]byte(payload), &claims); err != nil {
|
||||
return nil, fmt.Errorf("decode attach token: %w", err)
|
||||
}
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) TouchAttachmentHeartbeat(ctx context.Context, sessionID, attachmentID string, ttl time.Duration) error {
|
||||
return s.client.Set(ctx, attachmentHeartbeatKey(sessionID, attachmentID), time.Now().UTC().Format(time.RFC3339Nano), ttl).Err()
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) UpdateWorkerRoute(ctx context.Context, route WorkerRoute, ttl time.Duration) error {
|
||||
return s.setJSON(ctx, workerRouteKey(route.SessionID), route, ttl)
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) GetWorkerRoute(ctx context.Context, sessionID string) (*WorkerRoute, error) {
|
||||
var route WorkerRoute
|
||||
ok, err := s.getJSON(ctx, workerRouteKey(sessionID), &route)
|
||||
if err != nil || !ok {
|
||||
return nil, err
|
||||
}
|
||||
return &route, nil
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) DeleteWorkerRoute(ctx context.Context, sessionID string) error {
|
||||
return s.client.Del(ctx, workerRouteKey(sessionID)).Err()
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) setJSON(ctx context.Context, key string, value any, ttl time.Duration) error {
|
||||
payload, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode redis payload: %w", err)
|
||||
}
|
||||
if err := s.client.Set(ctx, key, payload, ttl).Err(); err != nil {
|
||||
return fmt.Errorf("set redis key %s: %w", key, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RedisLiveStateStore) getJSON(ctx context.Context, key string, dest any) (bool, error) {
|
||||
payload, err := s.client.Get(ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get redis key %s: %w", key, err)
|
||||
}
|
||||
if err := json.Unmarshal([]byte(payload), dest); err != nil {
|
||||
return false, fmt.Errorf("decode redis key %s: %w", key, err)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func liveSessionKey(sessionID string) string {
|
||||
return "live:session:" + sessionID
|
||||
}
|
||||
|
||||
func controllerBindingKey(sessionID string) string {
|
||||
return "live:session:" + sessionID + ":controller"
|
||||
}
|
||||
|
||||
func attachTokenKey(token string) string {
|
||||
return "live:attach:" + token
|
||||
}
|
||||
|
||||
func attachmentHeartbeatKey(sessionID, attachmentID string) string {
|
||||
return "live:session:" + sessionID + ":attachment:" + attachmentID + ":heartbeat"
|
||||
}
|
||||
|
||||
func workerRouteKey(sessionID string) string {
|
||||
return "live:session:" + sessionID + ":worker-route"
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
|
||||
)
|
||||
|
||||
type RemoteSessionRepository interface {
|
||||
Create(ctx context.Context, session RemoteSession) error
|
||||
GetByID(ctx context.Context, sessionID string) (*RemoteSession, error)
|
||||
GetByIDForUpdate(ctx context.Context, sessionID string) (*RemoteSession, error)
|
||||
ListByController(ctx context.Context, userID string) ([]RemoteSession, error)
|
||||
CountLiveByResource(ctx context.Context, resourceID string) (int, error)
|
||||
ListDetachedExpired(ctx context.Context, before time.Time, limit int) ([]RemoteSession, error)
|
||||
UpdateState(ctx context.Context, params UpdateRemoteSessionStateParams) error
|
||||
}
|
||||
|
||||
type SessionAttachmentRepository interface {
|
||||
Create(ctx context.Context, attachment SessionAttachment) error
|
||||
GetByID(ctx context.Context, attachmentID string) (*SessionAttachment, error)
|
||||
GetByIDForUpdate(ctx context.Context, attachmentID string) (*SessionAttachment, error)
|
||||
ListByRemoteSession(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error)
|
||||
ListActiveByRemoteSessionForUpdate(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error)
|
||||
UpdateState(ctx context.Context, params UpdateSessionAttachmentStateParams) error
|
||||
Supersede(ctx context.Context, params SupersedeAttachmentParams) error
|
||||
}
|
||||
|
||||
type ResourcePolicyRepository interface {
|
||||
GetByResourceID(ctx context.Context, resourceID string) (*ResourcePolicy, error)
|
||||
Upsert(ctx context.Context, policy ResourcePolicy) error
|
||||
}
|
||||
|
||||
type AuditEventRepository interface {
|
||||
Create(ctx context.Context, event AuditEvent) error
|
||||
}
|
||||
|
||||
type Store interface {
|
||||
RemoteSessions() RemoteSessionRepository
|
||||
SessionAttachments() SessionAttachmentRepository
|
||||
ResourcePolicies() ResourcePolicyRepository
|
||||
ResourceRuntime() ResourceRuntimeRepository
|
||||
AuditEvents() AuditEventRepository
|
||||
Access() AccessRepository
|
||||
}
|
||||
|
||||
type Transactor interface {
|
||||
WithinTransaction(ctx context.Context, fn func(store Store) error) error
|
||||
}
|
||||
|
||||
type UpdateRemoteSessionStateParams struct {
|
||||
RemoteSessionID string
|
||||
State sessioncontracts.State
|
||||
WorkerID string
|
||||
DetachDeadlineAt *time.Time
|
||||
LastHeartbeatAt *time.Time
|
||||
TakeoverVersion int
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type UpdateSessionAttachmentStateParams struct {
|
||||
AttachmentID string
|
||||
State AttachmentState
|
||||
DetachedAt *time.Time
|
||||
LastInputAt *time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type SupersedeAttachmentParams struct {
|
||||
PreviousAttachmentID string
|
||||
NextAttachmentID string
|
||||
DetachedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
type AccessRepository interface {
|
||||
IsTrustedDevice(ctx context.Context, userID, deviceID string) (bool, error)
|
||||
GetPlatformRole(ctx context.Context, userID string) (string, error)
|
||||
GetOrganizationRole(ctx context.Context, organizationID, userID string) (string, bool, error)
|
||||
}
|
||||
|
||||
type ResourceRuntimeRepository interface {
|
||||
GetByID(ctx context.Context, resourceID string) (*ResourceRuntimeSpec, error)
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/config"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/module"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
|
||||
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
|
||||
)
|
||||
|
||||
type fakeSecretResolver struct {
|
||||
response *secrets.ResolvedResourceSecret
|
||||
err error
|
||||
request secrets.ResolveResourceSecretRequest
|
||||
}
|
||||
|
||||
func testAppConfig(env string) config.AppConfig {
|
||||
return config.AppConfig{Name: "rap-api-test", Env: env}
|
||||
}
|
||||
|
||||
func (r *fakeSecretResolver) ResolveForSession(_ context.Context, req secrets.ResolveResourceSecretRequest) (*secrets.ResolvedResourceSecret, error) {
|
||||
r.request = req
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
return r.response, nil
|
||||
}
|
||||
|
||||
func TestRuntimeAssignmentMetadataMergesResolvedSecretWithoutMutatingSessionMetadata(t *testing.T) {
|
||||
resolver := &fakeSecretResolver{
|
||||
response: &secrets.ResolvedResourceSecret{
|
||||
Descriptor: secrets.ResourceSecretDescriptor{Version: 3},
|
||||
Payload: json.RawMessage(`{"username":"user","password":"secret","domain":"corp"}`),
|
||||
},
|
||||
}
|
||||
service := NewService(module.Dependencies{
|
||||
Config: module.Config{App: testAppConfig("production")},
|
||||
}, nil, nil, nil, nil, resolver)
|
||||
sessionMetadata := mustJSON(t, map[string]any{
|
||||
"resource": map[string]any{
|
||||
"id": "resource-1",
|
||||
"organization_id": "org-1",
|
||||
"secret_ref": "rap-secret://org/org-1/resources/resource-1/primary",
|
||||
"metadata": map[string]any{
|
||||
"rdp_host": "host",
|
||||
},
|
||||
},
|
||||
})
|
||||
session := RemoteSession{
|
||||
ID: "session-1",
|
||||
OrganizationID: "org-1",
|
||||
ResourceID: "resource-1",
|
||||
WorkerID: "worker-1",
|
||||
Metadata: sessionMetadata,
|
||||
}
|
||||
metadata, secretRef, version, err := service.runtimeAssignmentMetadata(context.Background(), session, &workercontracts.WorkerLease{LeaseID: "lease-1"})
|
||||
if err != nil {
|
||||
t.Fatalf("runtimeAssignmentMetadata returned error: %v", err)
|
||||
}
|
||||
if secretRef == "" || version != 3 {
|
||||
t.Fatalf("expected secret ref and version, got ref=%q version=%d", secretRef, version)
|
||||
}
|
||||
resource := metadata["resource"].(map[string]any)
|
||||
resourceMetadata := resource["metadata"].(map[string]any)
|
||||
if resourceMetadata["username"] != "user" || resourceMetadata["password"] != "secret" || resourceMetadata["domain"] != "corp" {
|
||||
t.Fatalf("resolved secret was not merged: %#v", resourceMetadata)
|
||||
}
|
||||
var persisted map[string]any
|
||||
if err := json.Unmarshal(session.Metadata, &persisted); err != nil {
|
||||
t.Fatalf("decode persisted metadata: %v", err)
|
||||
}
|
||||
persistedResource := persisted["resource"].(map[string]any)
|
||||
persistedMetadata := persistedResource["metadata"].(map[string]any)
|
||||
if _, ok := persistedMetadata["password"]; ok {
|
||||
t.Fatalf("session metadata was mutated with plaintext secret")
|
||||
}
|
||||
if resolver.request.LeaseID != "lease-1" || resolver.request.WorkerID != "worker-1" {
|
||||
t.Fatalf("resolver request missed lease/worker proof: %#v", resolver.request)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeAssignmentMetadataRequiresResolverInProduction(t *testing.T) {
|
||||
service := NewService(module.Dependencies{
|
||||
Config: module.Config{App: testAppConfig("production")},
|
||||
}, nil, nil, nil, nil)
|
||||
session := RemoteSession{
|
||||
ID: "session-1",
|
||||
OrganizationID: "org-1",
|
||||
ResourceID: "resource-1",
|
||||
WorkerID: "worker-1",
|
||||
Metadata: mustJSON(t, map[string]any{
|
||||
"resource": map[string]any{
|
||||
"secret_ref": "rap-secret://org/org-1/resources/resource-1/primary",
|
||||
},
|
||||
}),
|
||||
}
|
||||
_, _, _, err := service.runtimeAssignmentMetadata(context.Background(), session, &workercontracts.WorkerLease{LeaseID: "lease-1"})
|
||||
if !errors.Is(err, secrets.ErrSecretEncryptionKeyMissing) {
|
||||
t.Fatalf("expected missing resolver error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeAssignmentMetadataAllowsDevelopmentMetadataWithoutResolver(t *testing.T) {
|
||||
service := NewService(module.Dependencies{
|
||||
Config: module.Config{App: testAppConfig("development")},
|
||||
}, nil, nil, nil, nil)
|
||||
session := RemoteSession{
|
||||
ID: "session-1",
|
||||
OrganizationID: "org-1",
|
||||
ResourceID: "resource-1",
|
||||
WorkerID: "worker-1",
|
||||
Metadata: mustJSON(t, map[string]any{
|
||||
"resource": map[string]any{
|
||||
"secret_ref": "rap-secret://org/org-1/resources/resource-1/primary",
|
||||
"metadata": map[string]any{
|
||||
"username": "dev-user",
|
||||
"password": "dev-password",
|
||||
},
|
||||
},
|
||||
}),
|
||||
}
|
||||
metadata, secretRef, _, err := service.runtimeAssignmentMetadata(context.Background(), session, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("development metadata should not require resolver: %v", err)
|
||||
}
|
||||
if secretRef != "" {
|
||||
t.Fatalf("development fallback should not audit resolver use, got %q", secretRef)
|
||||
}
|
||||
resource := metadata["resource"].(map[string]any)
|
||||
resourceMetadata := resource["metadata"].(map[string]any)
|
||||
if resourceMetadata["password"] != "dev-password" {
|
||||
t.Fatalf("development metadata was not preserved")
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,391 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/module"
|
||||
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
|
||||
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
|
||||
)
|
||||
|
||||
func TestHandleWorkerConnectedIgnoresTerminalSession(t *testing.T) {
|
||||
service, store, live, _ := newStaleWorkerEventTestService()
|
||||
store.remote.sessions["session-1"] = RemoteSession{
|
||||
ID: "session-1",
|
||||
State: sessioncontracts.StateTerminated,
|
||||
WorkerID: "worker-1",
|
||||
}
|
||||
|
||||
if err := service.HandleWorkerConnected(context.Background(), "session-1"); err != nil {
|
||||
t.Fatalf("HandleWorkerConnected returned error for stale terminal event: %v", err)
|
||||
}
|
||||
if got := store.remote.sessions["session-1"].State; got != sessioncontracts.StateTerminated {
|
||||
t.Fatalf("stale connected event changed terminal state to %q", got)
|
||||
}
|
||||
if store.remote.updateCount != 0 {
|
||||
t.Fatalf("stale connected event updated authoritative session %d times", store.remote.updateCount)
|
||||
}
|
||||
if live.upsertCount != 0 {
|
||||
t.Fatalf("stale connected event recreated live state %d times", live.upsertCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateWorkerRenderTelemetryIgnoresTerminalSession(t *testing.T) {
|
||||
service, store, live, _ := newStaleWorkerEventTestService()
|
||||
store.remote.sessions["session-1"] = RemoteSession{
|
||||
ID: "session-1",
|
||||
State: sessioncontracts.StateTerminated,
|
||||
WorkerID: "worker-1",
|
||||
}
|
||||
|
||||
err := service.UpdateWorkerRenderTelemetry(context.Background(), "session-1", map[string]any{
|
||||
"render_state": "ready",
|
||||
"width": 1280,
|
||||
"height": 720,
|
||||
"frame_sequence": int64(99),
|
||||
"frame_data": "stale-frame",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateWorkerRenderTelemetry returned error for stale terminal event: %v", err)
|
||||
}
|
||||
if live.upsertCount != 0 {
|
||||
t.Fatalf("stale render event recreated live state %d times", live.upsertCount)
|
||||
}
|
||||
if live.sessions["session-1"] != nil {
|
||||
t.Fatalf("stale render event left live state behind: %#v", live.sessions["session-1"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkSessionFailedTransitionsActiveSession(t *testing.T) {
|
||||
service, store, live, orchestrator := newStaleWorkerEventTestService()
|
||||
store.remote.sessions["session-1"] = RemoteSession{
|
||||
ID: "session-1",
|
||||
State: sessioncontracts.StateActive,
|
||||
WorkerID: "worker-1",
|
||||
TakeoverVersion: 3,
|
||||
}
|
||||
store.attachments.items["attachment-1"] = SessionAttachment{
|
||||
ID: "attachment-1",
|
||||
RemoteSessionID: "session-1",
|
||||
State: AttachmentStateActive,
|
||||
}
|
||||
live.sessions["session-1"] = &LiveSessionState{SessionID: "session-1", State: sessioncontracts.StateActive}
|
||||
|
||||
if err := service.MarkSessionFailed(context.Background(), MarkSessionFailedCommand{SessionID: "session-1", Reason: "worker_lost"}); err != nil {
|
||||
t.Fatalf("MarkSessionFailed returned error: %v", err)
|
||||
}
|
||||
if got := store.remote.sessions["session-1"].State; got != sessioncontracts.StateFailed {
|
||||
t.Fatalf("expected failed state, got %q", got)
|
||||
}
|
||||
if got := store.attachments.items["attachment-1"].State; got != AttachmentStateClosed {
|
||||
t.Fatalf("expected attachment closed, got %q", got)
|
||||
}
|
||||
if store.audit.createCount != 1 {
|
||||
t.Fatalf("expected one audit event, got %d", store.audit.createCount)
|
||||
}
|
||||
if live.sessions["session-1"] != nil {
|
||||
t.Fatal("expected failed session live state to be deleted")
|
||||
}
|
||||
if orchestrator.releaseCount != 1 {
|
||||
t.Fatalf("expected session lease release, got %d", orchestrator.releaseCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkSessionFailedAlreadyFailedIsIdempotent(t *testing.T) {
|
||||
service, store, _, _ := newStaleWorkerEventTestService()
|
||||
store.remote.sessions["session-1"] = RemoteSession{
|
||||
ID: "session-1",
|
||||
State: sessioncontracts.StateFailed,
|
||||
WorkerID: "worker-1",
|
||||
TakeoverVersion: 1,
|
||||
}
|
||||
|
||||
if err := service.MarkSessionFailed(context.Background(), MarkSessionFailedCommand{SessionID: "session-1", Reason: "duplicate_worker_failure"}); err != nil {
|
||||
t.Fatalf("duplicate MarkSessionFailed returned error: %v", err)
|
||||
}
|
||||
if got := store.remote.sessions["session-1"].State; got != sessioncontracts.StateFailed {
|
||||
t.Fatalf("duplicate terminal event changed state to %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func newStaleWorkerEventTestService() (*Service, *staleWorkerEventTestStore, *staleWorkerEventLiveState, *staleWorkerEventOrchestrator) {
|
||||
store := &staleWorkerEventTestStore{
|
||||
remote: &staleWorkerEventRemoteSessions{sessions: map[string]RemoteSession{}},
|
||||
attachments: &staleWorkerEventAttachments{items: map[string]SessionAttachment{}},
|
||||
policies: &staleWorkerEventPolicies{},
|
||||
audit: &staleWorkerEventAudit{},
|
||||
}
|
||||
live := &staleWorkerEventLiveState{sessions: map[string]*LiveSessionState{}}
|
||||
orchestrator := &staleWorkerEventOrchestrator{}
|
||||
service := NewService(module.Dependencies{
|
||||
Infra: module.Infra{Logger: slog.New(slog.NewTextHandler(io.Discard, nil))},
|
||||
}, store, staleWorkerEventTransactor{store: store}, live, orchestrator)
|
||||
service.now = func() time.Time { return time.Unix(100, 0).UTC() }
|
||||
return service, store, live, orchestrator
|
||||
}
|
||||
|
||||
type staleWorkerEventTransactor struct {
|
||||
store Store
|
||||
}
|
||||
|
||||
func (t staleWorkerEventTransactor) WithinTransaction(ctx context.Context, fn func(store Store) error) error {
|
||||
return fn(t.store)
|
||||
}
|
||||
|
||||
type staleWorkerEventTestStore struct {
|
||||
remote *staleWorkerEventRemoteSessions
|
||||
attachments *staleWorkerEventAttachments
|
||||
policies *staleWorkerEventPolicies
|
||||
audit *staleWorkerEventAudit
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventTestStore) RemoteSessions() RemoteSessionRepository { return s.remote }
|
||||
func (s *staleWorkerEventTestStore) SessionAttachments() SessionAttachmentRepository {
|
||||
return s.attachments
|
||||
}
|
||||
func (s *staleWorkerEventTestStore) ResourcePolicies() ResourcePolicyRepository { return s.policies }
|
||||
func (s *staleWorkerEventTestStore) ResourceRuntime() ResourceRuntimeRepository {
|
||||
return staleWorkerEventResourceRuntime{}
|
||||
}
|
||||
func (s *staleWorkerEventTestStore) AuditEvents() AuditEventRepository { return s.audit }
|
||||
func (s *staleWorkerEventTestStore) Access() AccessRepository { return staleWorkerEventAccess{} }
|
||||
|
||||
type staleWorkerEventRemoteSessions struct {
|
||||
sessions map[string]RemoteSession
|
||||
updateCount int
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventRemoteSessions) Create(_ context.Context, session RemoteSession) error {
|
||||
r.sessions[session.ID] = session
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventRemoteSessions) GetByID(_ context.Context, sessionID string) (*RemoteSession, error) {
|
||||
session, ok := r.sessions[sessionID]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventRemoteSessions) GetByIDForUpdate(ctx context.Context, sessionID string) (*RemoteSession, error) {
|
||||
return r.GetByID(ctx, sessionID)
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventRemoteSessions) ListByController(_ context.Context, _ string) ([]RemoteSession, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventRemoteSessions) CountLiveByResource(_ context.Context, _ string) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventRemoteSessions) ListDetachedExpired(_ context.Context, _ time.Time, _ int) ([]RemoteSession, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventRemoteSessions) UpdateState(_ context.Context, params UpdateRemoteSessionStateParams) error {
|
||||
session := r.sessions[params.RemoteSessionID]
|
||||
session.State = params.State
|
||||
session.WorkerID = params.WorkerID
|
||||
session.DetachDeadlineAt = params.DetachDeadlineAt
|
||||
session.LastHeartbeatAt = params.LastHeartbeatAt
|
||||
session.TakeoverVersion = params.TakeoverVersion
|
||||
session.UpdatedAt = params.UpdatedAt
|
||||
r.sessions[params.RemoteSessionID] = session
|
||||
r.updateCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
type staleWorkerEventAttachments struct {
|
||||
items map[string]SessionAttachment
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventAttachments) Create(_ context.Context, attachment SessionAttachment) error {
|
||||
r.items[attachment.ID] = attachment
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventAttachments) GetByID(_ context.Context, attachmentID string) (*SessionAttachment, error) {
|
||||
attachment, ok := r.items[attachmentID]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
return &attachment, nil
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventAttachments) GetByIDForUpdate(ctx context.Context, attachmentID string) (*SessionAttachment, error) {
|
||||
return r.GetByID(ctx, attachmentID)
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventAttachments) ListByRemoteSession(_ context.Context, remoteSessionID string) ([]SessionAttachment, error) {
|
||||
attachments := make([]SessionAttachment, 0)
|
||||
for _, attachment := range r.items {
|
||||
if attachment.RemoteSessionID == remoteSessionID {
|
||||
attachments = append(attachments, attachment)
|
||||
}
|
||||
}
|
||||
return attachments, nil
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventAttachments) ListActiveByRemoteSessionForUpdate(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error) {
|
||||
return r.ListByRemoteSession(ctx, remoteSessionID)
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventAttachments) UpdateState(_ context.Context, params UpdateSessionAttachmentStateParams) error {
|
||||
attachment := r.items[params.AttachmentID]
|
||||
attachment.State = params.State
|
||||
attachment.DetachedAt = params.DetachedAt
|
||||
attachment.LastInputAt = params.LastInputAt
|
||||
attachment.UpdatedAt = params.UpdatedAt
|
||||
r.items[params.AttachmentID] = attachment
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventAttachments) Supersede(_ context.Context, params SupersedeAttachmentParams) error {
|
||||
attachment := r.items[params.PreviousAttachmentID]
|
||||
attachment.State = AttachmentStateSuperseded
|
||||
attachment.SupersededBy = ¶ms.NextAttachmentID
|
||||
attachment.DetachedAt = ¶ms.DetachedAt
|
||||
attachment.UpdatedAt = params.UpdatedAt
|
||||
r.items[params.PreviousAttachmentID] = attachment
|
||||
return nil
|
||||
}
|
||||
|
||||
type staleWorkerEventPolicies struct{}
|
||||
|
||||
func (r *staleWorkerEventPolicies) GetByResourceID(_ context.Context, _ string) (*ResourcePolicy, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventPolicies) Upsert(_ context.Context, _ ResourcePolicy) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type staleWorkerEventAudit struct {
|
||||
createCount int
|
||||
}
|
||||
|
||||
func (r *staleWorkerEventAudit) Create(_ context.Context, _ AuditEvent) error {
|
||||
r.createCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
type staleWorkerEventResourceRuntime struct{}
|
||||
|
||||
func (staleWorkerEventResourceRuntime) GetByID(_ context.Context, _ string) (*ResourceRuntimeSpec, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type staleWorkerEventAccess struct{}
|
||||
|
||||
func (staleWorkerEventAccess) IsTrustedDevice(_ context.Context, _, _ string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (staleWorkerEventAccess) GetPlatformRole(_ context.Context, _ string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (staleWorkerEventAccess) GetOrganizationRole(_ context.Context, _, _ string) (string, bool, error) {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
type staleWorkerEventLiveState struct {
|
||||
sessions map[string]*LiveSessionState
|
||||
upsertCount int
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) UpsertSession(_ context.Context, state LiveSessionState) error {
|
||||
copied := state
|
||||
s.sessions[state.SessionID] = &copied
|
||||
s.upsertCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) GetSession(_ context.Context, sessionID string) (*LiveSessionState, error) {
|
||||
state := s.sessions[sessionID]
|
||||
if state == nil {
|
||||
return nil, nil
|
||||
}
|
||||
copied := *state
|
||||
return &copied, nil
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) DeleteSession(_ context.Context, sessionID string) error {
|
||||
delete(s.sessions, sessionID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) BindController(_ context.Context, _ sessioncontracts.ControllerBinding, _ time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) GetControllerBinding(_ context.Context, _ string) (*sessioncontracts.ControllerBinding, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) ClearControllerBinding(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) StoreAttachToken(_ context.Context, _ sessioncontracts.AttachTokenClaims, _ time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) ConsumeAttachToken(_ context.Context, _ string) (*sessioncontracts.AttachTokenClaims, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) TouchAttachmentHeartbeat(_ context.Context, _, _ string, _ time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) UpdateWorkerRoute(_ context.Context, _ WorkerRoute, _ time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) GetWorkerRoute(_ context.Context, _ string) (*WorkerRoute, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *staleWorkerEventLiveState) DeleteWorkerRoute(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type staleWorkerEventOrchestrator struct {
|
||||
releaseCount int
|
||||
}
|
||||
|
||||
func (o *staleWorkerEventOrchestrator) Reserve(_ context.Context, _ workercontracts.AttachRequest) (*workercontracts.WorkerLease, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (o *staleWorkerEventOrchestrator) GetSessionLease(_ context.Context, _ string) (*workercontracts.WorkerLease, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (o *staleWorkerEventOrchestrator) ReleaseSessionLease(_ context.Context, _ string) error {
|
||||
o.releaseCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *staleWorkerEventOrchestrator) PrepareAttachment(_ context.Context, _ RemoteSession, _ SessionAttachment, _ map[string]any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *staleWorkerEventOrchestrator) NotifyDetachment(_ context.Context, _ RemoteSession, _ SessionAttachment) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *staleWorkerEventOrchestrator) TerminateRemoteSession(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *staleWorkerEventOrchestrator) ValidateSessionRuntime(_ context.Context, _, _ string) (bool, string, error) {
|
||||
return true, "", nil
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package sessionbroker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
|
||||
)
|
||||
|
||||
var allowedTransitions = map[sessioncontracts.State]map[sessioncontracts.State]struct{}{
|
||||
sessioncontracts.StateStarting: {
|
||||
sessioncontracts.StateActive: {},
|
||||
sessioncontracts.StateFailed: {},
|
||||
sessioncontracts.StateTerminated: {},
|
||||
},
|
||||
sessioncontracts.StateActive: {
|
||||
sessioncontracts.StateDetached: {},
|
||||
sessioncontracts.StateReconnecting: {},
|
||||
sessioncontracts.StateFailed: {},
|
||||
sessioncontracts.StateTerminated: {},
|
||||
},
|
||||
sessioncontracts.StateDetached: {
|
||||
sessioncontracts.StateReconnecting: {},
|
||||
sessioncontracts.StateTerminated: {},
|
||||
sessioncontracts.StateFailed: {},
|
||||
},
|
||||
sessioncontracts.StateReconnecting: {
|
||||
sessioncontracts.StateActive: {},
|
||||
sessioncontracts.StateDetached: {},
|
||||
sessioncontracts.StateFailed: {},
|
||||
sessioncontracts.StateTerminated: {},
|
||||
},
|
||||
}
|
||||
|
||||
func validateTransition(from, to sessioncontracts.State) error {
|
||||
if from == to {
|
||||
return nil
|
||||
}
|
||||
if allowed, ok := allowedTransitions[from]; ok {
|
||||
if _, ok := allowed[to]; ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("invalid session state transition: %s -> %s", from, to)
|
||||
}
|
||||
Reference in New Issue
Block a user