Files
rdp-proxy/backend/internal/modules/sessionbroker/service.go
T
2026-04-28 22:29:50 +03:00

1961 lines
61 KiB
Go

package sessionbroker
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log/slog"
"time"
"github.com/google/uuid"
"github.com/example/remote-access-platform/backend/internal/platform/module"
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
)
type Service struct {
cfg module.Config
logger *slog.Logger
store Store
transactor Transactor
liveState LiveStateStore
orchestrator WorkerOrchestrator
secretResolver secrets.ResourceSecretResolver
now func() time.Time
}
type StartRemoteSessionCommand struct {
ResourceID string `json:"resource_id"`
UserID string `json:"user_id"`
DeviceID string `json:"device_id"`
}
type AttachToSessionCommand struct {
SessionID string `json:"session_id"`
UserID string `json:"user_id"`
DeviceID string `json:"device_id"`
}
type DetachFromSessionCommand struct {
SessionID string `json:"session_id"`
AttachmentID string `json:"attachment_id"`
UserID string `json:"user_id"`
Reason string `json:"reason"`
}
type TakeoverSessionCommand struct {
SessionID string `json:"session_id"`
UserID string `json:"user_id"`
DeviceID string `json:"device_id"`
Reason string `json:"reason"`
}
type TerminateSessionCommand struct {
SessionID string `json:"session_id"`
UserID string `json:"user_id"`
Reason string `json:"reason"`
}
type MarkSessionFailedCommand struct {
SessionID string `json:"session_id"`
Reason string `json:"reason"`
}
type SessionControlResult struct {
Session RemoteSession `json:"session"`
Attachment *SessionAttachment `json:"attachment,omitempty"`
AttachToken *sessioncontracts.AttachTokenClaims `json:"attach_token,omitempty"`
GatewayURL string `json:"gateway_url,omitempty"`
DataPlane *sessioncontracts.DataPlaneOffer `json:"data_plane,omitempty"`
}
func NewService(deps module.Dependencies, store Store, transactor Transactor, liveState LiveStateStore, orchestrator WorkerOrchestrator, secretResolvers ...secrets.ResourceSecretResolver) *Service {
var secretResolver secrets.ResourceSecretResolver
if len(secretResolvers) > 0 {
secretResolver = secretResolvers[0]
}
return &Service{
cfg: deps.Config,
logger: deps.Infra.Logger,
store: store,
transactor: transactor,
liveState: liveState,
orchestrator: orchestrator,
secretResolver: secretResolver,
now: time.Now,
}
}
func (s *Service) StartRemoteSession(ctx context.Context, cmd StartRemoteSessionCommand) (*SessionControlResult, error) {
resource, err := s.store.ResourceRuntime().GetByID(ctx, cmd.ResourceID)
if err != nil {
return nil, err
}
if resource == nil {
return nil, ErrSessionNotFound
}
if err := s.ensureOrganizationMemberAccess(ctx, resource.OrganizationID, cmd.UserID); err != nil {
return nil, err
}
if err := secrets.ValidateResourceSecretReadiness(resource.Protocol, resource.SecretRef, resource.Metadata, s.cfg.App.Env); err != nil {
return nil, err
}
if reusable, err := s.findReusableSession(ctx, cmd.ResourceID, cmd.UserID); err != nil {
return nil, err
} else if reusable != nil {
return s.AttachToSession(ctx, AttachToSessionCommand{
SessionID: reusable.ID,
UserID: cmd.UserID,
DeviceID: cmd.DeviceID,
})
}
policy, err := s.loadPolicy(ctx, cmd.ResourceID)
if err != nil {
return nil, err
}
if err := s.ensureTrustedDevice(ctx, policy, cmd.UserID, cmd.DeviceID); err != nil {
return nil, err
}
if err := s.ensureSessionCapacity(ctx, policy, cmd.ResourceID); err != nil {
return nil, err
}
resourceMetadata := decodeJSONMap(resource.Metadata)
resourceCertificateVerificationMode := resource.CertificateVerificationMode
if resourceCertificateVerificationMode == "" {
if mode, ok := resourceMetadata["certificate_verification_mode"].(string); ok && mode != "" {
resourceCertificateVerificationMode = mode
} else {
resourceCertificateVerificationMode = "strict"
}
}
renderQualityProfile := renderQualityProfileFromMetadata(resourceMetadata)
sessionID := uuid.NewString()
lease, err := s.orchestrator.Reserve(ctx, workercontracts.AttachRequest{
ResourceID: cmd.ResourceID,
ActorID: cmd.UserID,
SessionID: sessionID,
RequiredCapabilities: requiredCapabilities(policy),
RenderQualityProfile: renderQualityProfile,
})
if err != nil {
return nil, err
}
now := s.now().UTC()
attachmentID := uuid.NewString()
var result SessionControlResult
err = s.transactor.WithinTransaction(ctx, func(store Store) error {
sessionMetadata, err := json.Marshal(map[string]any{
"resource": map[string]any{
"id": resource.ID,
"organization_id": resource.OrganizationID,
"name": resource.Name,
"address": resource.Address,
"protocol": resource.Protocol,
"secret_ref": resource.SecretRef,
"certificate_verification_mode": resourceCertificateVerificationMode,
"render_quality_profile": renderQualityProfile,
"metadata": resourceMetadata,
},
"policy": map[string]any{
"detach_grace_period_seconds": int(policy.DetachGracePeriod.Seconds()),
"clipboard_enabled": policy.ClipboardMode != ResourceClipboardModeDisabled,
"clipboard_mode": policy.ClipboardMode,
"file_transfer_enabled": fileTransferAllowsClientToServer(policy.FileTransferMode),
"file_transfer_mode": policy.FileTransferMode,
},
})
if err != nil {
return err
}
session := RemoteSession{
ID: sessionID,
OrganizationID: resource.OrganizationID,
ResourceID: cmd.ResourceID,
Protocol: string(lease.Protocol),
State: sessioncontracts.StateStarting,
WorkerID: lease.WorkerID,
ControllerUserID: cmd.UserID,
TakeoverVersion: 1,
RenderQualityProfile: renderQualityProfile,
Metadata: sessionMetadata,
LastHeartbeatAt: &now,
CreatedAt: now,
UpdatedAt: now,
}
attachment := SessionAttachment{
ID: attachmentID,
RemoteSessionID: sessionID,
UserID: cmd.UserID,
DeviceID: cmd.DeviceID,
Role: AttachmentRoleController,
State: AttachmentStateAttaching,
AttachedAt: &now,
LastInputAt: &now,
Metadata: []byte(`{}`),
CreatedAt: now,
UpdatedAt: now,
}
if err := store.RemoteSessions().Create(ctx, session); err != nil {
return err
}
if err := store.SessionAttachments().Create(ctx, attachment); err != nil {
return err
}
if err := s.writeAuditEvent(ctx, store, AuditEventSessionStarted, cmd.UserID, cmd.DeviceID, session.ID, session.ID, map[string]any{
"resource_id": cmd.ResourceID,
"worker_id": lease.WorkerID,
}); err != nil {
return err
}
if err := s.writeAuditEvent(ctx, store, AuditEventSessionAttached, cmd.UserID, cmd.DeviceID, attachment.ID, session.ID, map[string]any{
"takeover": false,
}); err != nil {
return err
}
result.Session = session
result.Attachment = &attachment
return nil
})
if err != nil {
_ = s.orchestrator.ReleaseSessionLease(ctx, sessionID)
return nil, err
}
if err := s.prepareAttachment(ctx, result.Session, *result.Attachment, lease); err != nil {
return nil, err
}
attachToken, err := s.publishLiveState(ctx, result.Session, *result.Attachment, lease)
if err != nil {
return nil, err
}
result.AttachToken = attachToken
if err := s.attachDataPlaneOffer(&result); err != nil {
return nil, err
}
return &result, nil
}
func (s *Service) AttachToSession(ctx context.Context, cmd AttachToSessionCommand) (*SessionControlResult, error) {
if _, err := s.ensureSessionRuntimeConsistencyByID(ctx, cmd.SessionID); err != nil {
return nil, err
}
now := s.now().UTC()
var result SessionControlResult
var existingAttachment *SessionAttachment
err := s.transactor.WithinTransaction(ctx, func(store Store) error {
session, err := store.RemoteSessions().GetByIDForUpdate(ctx, cmd.SessionID)
if err != nil {
return err
}
if session == nil {
return ErrSessionNotFound
}
if err := s.ensureOrganizationMemberAccessWithStore(ctx, store, session.OrganizationID, cmd.UserID); err != nil {
return err
}
policy, err := s.loadPolicyFromStore(ctx, store, session.ResourceID)
if err != nil {
return err
}
if err := s.ensureTrustedDeviceWithStore(ctx, store, policy, cmd.UserID, cmd.DeviceID); err != nil {
return err
}
activeAttachments, err := store.SessionAttachments().ListActiveByRemoteSessionForUpdate(ctx, cmd.SessionID)
if err != nil {
return err
}
var prior *SessionAttachment
for _, attachment := range activeAttachments {
if attachment.UserID != cmd.UserID || attachment.DeviceID != cmd.DeviceID {
return ErrActiveControllerPresent
}
if attachment.State == AttachmentStateActive && session.State == sessioncontracts.StateActive {
copy := attachment
existingAttachment = &copy
session.LastHeartbeatAt = &now
result.Session = *session
result.Attachment = existingAttachment
return nil
}
copy := attachment
prior = &copy
}
targetState := sessioncontracts.StateActive
if session.State == sessioncontracts.StateDetached {
targetState = sessioncontracts.StateReconnecting
}
if err := validateTransition(session.State, targetState); err != nil && session.State != sessioncontracts.StateActive {
return err
}
if targetState == sessioncontracts.StateReconnecting {
if err := store.RemoteSessions().UpdateState(ctx, UpdateRemoteSessionStateParams{
RemoteSessionID: session.ID,
State: sessioncontracts.StateReconnecting,
WorkerID: session.WorkerID,
TakeoverVersion: session.TakeoverVersion,
UpdatedAt: now,
}); err != nil {
return err
}
session.State = sessioncontracts.StateReconnecting
}
attachment := SessionAttachment{
ID: uuid.NewString(),
RemoteSessionID: session.ID,
UserID: cmd.UserID,
DeviceID: cmd.DeviceID,
Role: AttachmentRoleController,
State: AttachmentStateActive,
AttachedAt: &now,
LastInputAt: &now,
CreatedAt: now,
UpdatedAt: now,
Metadata: []byte(`{}`),
}
if err := store.SessionAttachments().Create(ctx, attachment); err != nil {
return err
}
if prior != nil {
attachment.TakeoverOf = &prior.ID
if err := store.SessionAttachments().Supersede(ctx, SupersedeAttachmentParams{
PreviousAttachmentID: prior.ID,
NextAttachmentID: attachment.ID,
DetachedAt: now,
UpdatedAt: now,
}); err != nil {
return err
}
}
if err := store.RemoteSessions().UpdateState(ctx, UpdateRemoteSessionStateParams{
RemoteSessionID: session.ID,
State: sessioncontracts.StateActive,
WorkerID: session.WorkerID,
DetachDeadlineAt: nil,
LastHeartbeatAt: &now,
TakeoverVersion: session.TakeoverVersion,
UpdatedAt: now,
}); err != nil {
return err
}
if err := s.writeAuditEvent(ctx, store, AuditEventSessionAttached, cmd.UserID, cmd.DeviceID, attachment.ID, session.ID, map[string]any{
"reconnect": prior != nil,
}); err != nil {
return err
}
session.State = sessioncontracts.StateActive
session.DetachDeadlineAt = nil
session.LastHeartbeatAt = &now
session.UpdatedAt = now
result.Session = *session
result.Attachment = &attachment
return nil
})
if err != nil {
return nil, err
}
lease := &workercontracts.WorkerLease{
WorkerID: result.Session.WorkerID,
SessionID: result.Session.ID,
Protocol: workercontracts.Protocol(result.Session.Protocol),
ExpiresAt: now.Add(s.cfg.Worker.LeaseTTL),
ControlStream: "worker://control/" + result.Session.WorkerID,
RenderQualityProfile: result.Session.RenderQualityProfile,
}
if existingAttachment == nil {
actualLease, err := s.orchestrator.GetSessionLease(ctx, result.Session.ID)
if err != nil {
return nil, err
}
if actualLease != nil {
lease = actualLease
}
if err := s.prepareAttachment(ctx, result.Session, *result.Attachment, lease); err != nil {
return nil, err
}
} else {
lease = nil
}
currentLive, err := s.liveState.GetSession(ctx, result.Session.ID)
if err != nil {
return nil, err
}
attachToken, err := s.publishLiveState(ctx, result.Session, *result.Attachment, lease)
if err != nil {
return nil, err
}
if currentLive != nil {
live, liveErr := s.liveState.GetSession(ctx, result.Session.ID)
if liveErr != nil {
return nil, liveErr
}
if live != nil {
mergeLiveRenderTelemetry(live, currentLive)
if err := s.liveState.UpsertSession(ctx, *live); err != nil {
return nil, err
}
}
}
result.AttachToken = attachToken
if err := s.attachDataPlaneOffer(&result); err != nil {
return nil, err
}
return &result, nil
}
func (s *Service) DetachFromSession(ctx context.Context, cmd DetachFromSessionCommand) (*SessionControlResult, error) {
now := s.now().UTC()
var result SessionControlResult
err := s.transactor.WithinTransaction(ctx, func(store Store) error {
session, err := store.RemoteSessions().GetByIDForUpdate(ctx, cmd.SessionID)
if err != nil {
return err
}
if session == nil {
return ErrSessionNotFound
}
if err := s.ensureSessionControllerAccessWithStore(ctx, store, session, cmd.UserID); err != nil {
return err
}
if session.State == sessioncontracts.StateDetached {
attachment, err := store.SessionAttachments().GetByID(ctx, cmd.AttachmentID)
if err != nil {
return err
}
if attachment == nil {
return ErrAttachmentNotFound
}
result.Session = *session
result.Attachment = attachment
return nil
}
if err := validateTransition(session.State, sessioncontracts.StateDetached); err != nil {
return ErrSessionNotAttachable
}
attachment, err := store.SessionAttachments().GetByIDForUpdate(ctx, cmd.AttachmentID)
if err != nil {
return err
}
if attachment == nil || attachment.RemoteSessionID != cmd.SessionID || attachment.UserID != cmd.UserID {
return ErrAttachmentNotFound
}
if attachment.State == AttachmentStateDetached {
result.Session = *session
result.Attachment = attachment
return nil
}
detachedAt := now
if err := store.SessionAttachments().UpdateState(ctx, UpdateSessionAttachmentStateParams{
AttachmentID: attachment.ID,
State: AttachmentStateDetached,
DetachedAt: &detachedAt,
LastInputAt: attachment.LastInputAt,
UpdatedAt: now,
}); err != nil {
return err
}
deadline := now.Add(s.cfg.Session.DetachGracePeriod)
if err := store.RemoteSessions().UpdateState(ctx, UpdateRemoteSessionStateParams{
RemoteSessionID: session.ID,
State: sessioncontracts.StateDetached,
WorkerID: session.WorkerID,
DetachDeadlineAt: &deadline,
LastHeartbeatAt: session.LastHeartbeatAt,
TakeoverVersion: session.TakeoverVersion,
UpdatedAt: now,
}); err != nil {
return err
}
if err := s.writeAuditEvent(ctx, store, AuditEventSessionDetached, cmd.UserID, attachment.DeviceID, attachment.ID, session.ID, map[string]any{
"reason": cmd.Reason,
}); err != nil {
return err
}
attachment.State = AttachmentStateDetached
attachment.DetachedAt = &detachedAt
session.State = sessioncontracts.StateDetached
session.DetachDeadlineAt = &deadline
session.UpdatedAt = now
result.Session = *session
result.Attachment = attachment
return nil
})
if err != nil {
return nil, err
}
_ = s.orchestrator.NotifyDetachment(ctx, result.Session, *result.Attachment)
if err := s.liveState.ClearControllerBinding(ctx, result.Session.ID); err != nil {
return nil, err
}
currentLive, err := s.liveState.GetSession(ctx, result.Session.ID)
if err != nil {
return nil, err
}
detachedLive := LiveSessionState{
SessionID: result.Session.ID,
ResourceID: result.Session.ResourceID,
WorkerID: result.Session.WorkerID,
State: result.Session.State,
ControllerID: result.Session.ControllerUserID,
AttachmentID: result.Attachment.ID,
TakeoverVersion: result.Session.TakeoverVersion,
UpdatedAt: now,
}
mergeLiveRenderTelemetry(&detachedLive, currentLive)
if err := s.liveState.UpsertSession(ctx, detachedLive); err != nil {
return nil, err
}
return &result, nil
}
func (s *Service) TakeoverSession(ctx context.Context, cmd TakeoverSessionCommand) (*SessionControlResult, error) {
if _, err := s.ensureSessionRuntimeConsistencyByID(ctx, cmd.SessionID); err != nil {
return nil, err
}
now := s.now().UTC()
var result SessionControlResult
var existingAttachment *SessionAttachment
err := s.transactor.WithinTransaction(ctx, func(store Store) error {
session, err := store.RemoteSessions().GetByIDForUpdate(ctx, cmd.SessionID)
if err != nil {
return err
}
if session == nil {
return ErrSessionNotFound
}
if err := s.ensureOrganizationMemberAccessWithStore(ctx, store, session.OrganizationID, cmd.UserID); err != nil {
return err
}
policy, err := s.loadPolicyFromStore(ctx, store, session.ResourceID)
if err != nil {
return err
}
if err := s.ensureTakeoverAllowed(ctx, store, policy, session, cmd.UserID, cmd.DeviceID); err != nil {
return err
}
activeAttachments, err := store.SessionAttachments().ListActiveByRemoteSessionForUpdate(ctx, cmd.SessionID)
if err != nil {
return err
}
if len(activeAttachments) == 0 {
return ErrSessionNotAttachable
}
for _, attachment := range activeAttachments {
if attachment.UserID == cmd.UserID && attachment.DeviceID == cmd.DeviceID && attachment.State == AttachmentStateActive {
copy := attachment
existingAttachment = &copy
result.Session = *session
result.Attachment = existingAttachment
return nil
}
}
attachment := SessionAttachment{
ID: uuid.NewString(),
RemoteSessionID: session.ID,
UserID: cmd.UserID,
DeviceID: cmd.DeviceID,
Role: AttachmentRoleController,
State: AttachmentStateActive,
AttachedAt: &now,
LastInputAt: &now,
CreatedAt: now,
UpdatedAt: now,
Metadata: []byte(`{}`),
}
if err := store.SessionAttachments().Create(ctx, attachment); err != nil {
return err
}
firstPrevious := activeAttachments[0].ID
attachment.TakeoverOf = &firstPrevious
for _, prior := range activeAttachments {
if err := store.SessionAttachments().Supersede(ctx, SupersedeAttachmentParams{
PreviousAttachmentID: prior.ID,
NextAttachmentID: attachment.ID,
DetachedAt: now,
UpdatedAt: now,
}); err != nil {
return err
}
}
nextVersion := session.TakeoverVersion + 1
if err := store.RemoteSessions().UpdateState(ctx, UpdateRemoteSessionStateParams{
RemoteSessionID: session.ID,
State: sessioncontracts.StateActive,
WorkerID: session.WorkerID,
DetachDeadlineAt: nil,
LastHeartbeatAt: &now,
TakeoverVersion: nextVersion,
UpdatedAt: now,
}); err != nil {
return err
}
session.ControllerUserID = cmd.UserID
session.TakeoverVersion = nextVersion
session.State = sessioncontracts.StateActive
session.DetachDeadlineAt = nil
session.LastHeartbeatAt = &now
session.UpdatedAt = now
if err := s.writeAuditEvent(ctx, store, AuditEventSessionTakenOver, cmd.UserID, cmd.DeviceID, attachment.ID, session.ID, map[string]any{
"reason": cmd.Reason,
}); err != nil {
return err
}
result.Session = *session
result.Attachment = &attachment
return nil
})
if err != nil {
return nil, err
}
lease := &workercontracts.WorkerLease{
WorkerID: result.Session.WorkerID,
SessionID: result.Session.ID,
Protocol: workercontracts.Protocol(result.Session.Protocol),
ExpiresAt: now.Add(s.cfg.Worker.LeaseTTL),
ControlStream: "worker://control/" + result.Session.WorkerID,
}
if existingAttachment == nil {
actualLease, err := s.orchestrator.GetSessionLease(ctx, result.Session.ID)
if err != nil {
return nil, err
}
if actualLease != nil {
lease = actualLease
}
if err := s.prepareAttachment(ctx, result.Session, *result.Attachment, lease); err != nil {
return nil, err
}
} else {
lease = nil
}
currentLive, err := s.liveState.GetSession(ctx, result.Session.ID)
if err != nil {
return nil, err
}
attachToken, err := s.publishLiveState(ctx, result.Session, *result.Attachment, lease)
if err != nil {
return nil, err
}
if currentLive != nil {
live, liveErr := s.liveState.GetSession(ctx, result.Session.ID)
if liveErr != nil {
return nil, liveErr
}
if live != nil {
mergeLiveRenderTelemetry(live, currentLive)
if err := s.liveState.UpsertSession(ctx, *live); err != nil {
return nil, err
}
}
}
result.AttachToken = attachToken
if err := s.attachDataPlaneOffer(&result); err != nil {
return nil, err
}
return &result, nil
}
func (s *Service) TerminateSession(ctx context.Context, cmd TerminateSessionCommand) error {
if err := s.finalizeSession(ctx, cmd.SessionID, cmd.UserID, "", AuditEventSessionTerminated, cmd.Reason, sessioncontracts.StateTerminated, true); err != nil {
return err
}
return s.orchestrator.ReleaseSessionLease(ctx, cmd.SessionID)
}
func (s *Service) MarkSessionFailed(ctx context.Context, cmd MarkSessionFailedCommand) error {
if err := s.finalizeSession(ctx, cmd.SessionID, "", "", AuditEventSessionFailed, cmd.Reason, sessioncontracts.StateFailed, false); err != nil {
return err
}
return s.orchestrator.ReleaseSessionLease(ctx, cmd.SessionID)
}
func (s *Service) RecoverDetachedSessions(ctx context.Context) error {
sessions, err := s.store.RemoteSessions().ListDetachedExpired(ctx, s.now().UTC(), s.cfg.Session.RecoveryBatchSize)
if err != nil {
return err
}
for _, session := range sessions {
if err := s.TerminateSession(ctx, TerminateSessionCommand{
SessionID: session.ID,
Reason: "detach_grace_period_expired",
}); err != nil {
return err
}
}
return nil
}
func (s *Service) ConsumeAttachToken(ctx context.Context, token string) (*sessioncontracts.AttachTokenClaims, *LiveSessionState, error) {
claims, err := s.liveState.ConsumeAttachToken(ctx, token)
if err != nil {
return nil, nil, err
}
if claims == nil || claims.ExpiresAt.Before(s.now().UTC()) {
return nil, nil, ErrAttachTokenInvalid
}
state, err := s.GetLiveSession(ctx, claims.SessionID)
if err != nil {
return nil, nil, err
}
if state == nil {
return nil, nil, ErrSessionNotFound
}
return claims, state, nil
}
func (s *Service) GetControllerBinding(ctx context.Context, sessionID string) (*sessioncontracts.ControllerBinding, error) {
return s.liveState.GetControllerBinding(ctx, sessionID)
}
func (s *Service) GetSessionSnapshot(ctx context.Context, sessionID string) (*RemoteSession, error) {
return s.ensureSessionRuntimeConsistencyByID(ctx, sessionID)
}
func (s *Service) TouchAttachmentHeartbeat(ctx context.Context, sessionID, attachmentID string) error {
if err := s.liveState.TouchAttachmentHeartbeat(ctx, sessionID, attachmentID, s.cfg.Session.HeartbeatTTL); err != nil {
return err
}
binding, err := s.liveState.GetControllerBinding(ctx, sessionID)
if err != nil {
return err
}
if binding != nil && binding.AttachmentID == attachmentID {
if err := s.liveState.BindController(ctx, *binding, s.cfg.Session.LiveStateTTL); err != nil {
return err
}
}
route, err := s.liveState.GetWorkerRoute(ctx, sessionID)
if err != nil {
return err
}
if route != nil {
return s.liveState.UpdateWorkerRoute(ctx, *route, s.cfg.Session.LiveStateTTL)
}
return nil
}
func (s *Service) GetLiveSession(ctx context.Context, sessionID string) (*LiveSessionState, error) {
session, err := s.ensureSessionRuntimeConsistencyByID(ctx, sessionID)
if err != nil {
return nil, err
}
live, err := s.liveState.GetSession(ctx, sessionID)
if err != nil || live != nil {
return live, err
}
if session == nil {
return nil, nil
}
return s.RebuildLiveStateFromStore(ctx, sessionID)
}
func (s *Service) ListSessions(ctx context.Context, userID string) ([]RemoteSession, error) {
sessions, err := s.store.RemoteSessions().ListByController(ctx, userID)
if err != nil {
return nil, err
}
filtered := make([]RemoteSession, 0, len(sessions))
for _, session := range sessions {
consistent, err := s.ensureSessionRuntimeConsistency(ctx, &session)
if err != nil {
return nil, err
}
if consistent == nil {
continue
}
session = *consistent
allowed, err := s.hasOrganizationMemberAccess(ctx, session.OrganizationID, userID)
if err != nil {
return nil, err
}
if allowed {
filtered = append(filtered, session)
}
}
return filtered, nil
}
func (s *Service) finalizeSession(ctx context.Context, sessionID, actorUserID, actorDeviceID, eventType, reason string, targetState sessioncontracts.State, enforceActor bool) error {
now := s.now().UTC()
lastAttachmentID := ""
if err := s.transactor.WithinTransaction(ctx, func(store Store) error {
session, err := store.RemoteSessions().GetByIDForUpdate(ctx, sessionID)
if err != nil {
return err
}
if session == nil {
return ErrSessionNotFound
}
if enforceActor {
if err := s.ensureSessionAdminOrControllerAccessWithStore(ctx, store, session, actorUserID); err != nil {
return err
}
}
if err := validateTransition(session.State, targetState); err != nil {
return ErrSessionNotTerminable
}
activeAttachments, err := store.SessionAttachments().ListByRemoteSession(ctx, sessionID)
if err != nil {
return err
}
for _, attachment := range activeAttachments {
if attachment.State == AttachmentStateClosed || attachment.State == AttachmentStateSuperseded {
continue
}
lastAttachmentID = attachment.ID
detachedAt := now
if err := store.SessionAttachments().UpdateState(ctx, UpdateSessionAttachmentStateParams{
AttachmentID: attachment.ID,
State: AttachmentStateClosed,
DetachedAt: &detachedAt,
LastInputAt: attachment.LastInputAt,
UpdatedAt: now,
}); err != nil {
return err
}
}
if err := store.RemoteSessions().UpdateState(ctx, UpdateRemoteSessionStateParams{
RemoteSessionID: session.ID,
State: targetState,
WorkerID: session.WorkerID,
TakeoverVersion: session.TakeoverVersion,
UpdatedAt: now,
}); err != nil {
return err
}
return s.writeAuditEvent(ctx, store, eventType, actorUserID, actorDeviceID, session.ID, session.ID, map[string]any{
"reason": reason,
})
}); err != nil {
return err
}
if lastAttachmentID != "" {
_ = s.orchestrator.TerminateRemoteSession(ctx, sessionID, lastAttachmentID)
}
if err := s.liveState.ClearControllerBinding(ctx, sessionID); err != nil {
return err
}
if err := s.liveState.DeleteWorkerRoute(ctx, sessionID); err != nil {
return err
}
return s.liveState.DeleteSession(ctx, sessionID)
}
func (s *Service) HandleWorkerConnected(ctx context.Context, sessionID string) error {
now := s.now().UTC()
var sessionSnapshot *RemoteSession
var attachmentSnapshot *SessionAttachment
var staleTerminalState sessioncontracts.State
if err := s.transactor.WithinTransaction(ctx, func(store Store) error {
session, err := store.RemoteSessions().GetByIDForUpdate(ctx, sessionID)
if err != nil {
return err
}
if session == nil {
return ErrSessionNotFound
}
if isTerminalSessionState(session.State) {
staleTerminalState = session.State
return nil
}
if err := validateTransition(session.State, sessioncontracts.StateActive); err != nil && session.State != sessioncontracts.StateActive {
return err
}
attachments, err := store.SessionAttachments().ListByRemoteSession(ctx, sessionID)
if err != nil {
return err
}
for _, attachment := range attachments {
if attachment.State == AttachmentStateAttaching || attachment.State == AttachmentStateActive {
if err := store.SessionAttachments().UpdateState(ctx, UpdateSessionAttachmentStateParams{
AttachmentID: attachment.ID,
State: AttachmentStateActive,
DetachedAt: attachment.DetachedAt,
LastInputAt: attachment.LastInputAt,
UpdatedAt: now,
}); err != nil {
return err
}
attachment.State = AttachmentStateActive
attachment.UpdatedAt = now
attachmentSnapshot = &attachment
break
}
}
if err := store.RemoteSessions().UpdateState(ctx, UpdateRemoteSessionStateParams{
RemoteSessionID: session.ID,
State: sessioncontracts.StateActive,
WorkerID: session.WorkerID,
DetachDeadlineAt: nil,
LastHeartbeatAt: &now,
TakeoverVersion: session.TakeoverVersion,
UpdatedAt: now,
}); err != nil {
return err
}
session.State = sessioncontracts.StateActive
session.DetachDeadlineAt = nil
session.LastHeartbeatAt = &now
session.UpdatedAt = now
sessionSnapshot = session
return nil
}); err != nil {
return err
}
if staleTerminalState != "" {
s.logStaleWorkerEvent("connected", sessionID, staleTerminalState)
return nil
}
live, err := s.liveState.GetSession(ctx, sessionID)
if err != nil {
return err
}
if live != nil {
live.State = sessioncontracts.StateActive
live.UpdatedAt = now
return s.liveState.UpsertSession(ctx, *live)
}
if sessionSnapshot != nil && attachmentSnapshot != nil {
_, err = s.publishLiveState(ctx, *sessionSnapshot, *attachmentSnapshot, nil)
}
return err
}
func (s *Service) HandleWorkerHeartbeat(ctx context.Context, sessionID string) error {
now := s.now().UTC()
var staleTerminalState sessioncontracts.State
if err := s.transactor.WithinTransaction(ctx, func(store Store) error {
session, err := store.RemoteSessions().GetByIDForUpdate(ctx, sessionID)
if err != nil {
return err
}
if session == nil {
return ErrSessionNotFound
}
if isTerminalSessionState(session.State) {
staleTerminalState = session.State
return nil
}
return store.RemoteSessions().UpdateState(ctx, UpdateRemoteSessionStateParams{
RemoteSessionID: session.ID,
State: session.State,
WorkerID: session.WorkerID,
DetachDeadlineAt: session.DetachDeadlineAt,
LastHeartbeatAt: &now,
TakeoverVersion: session.TakeoverVersion,
UpdatedAt: now,
})
}); err != nil {
return err
}
if staleTerminalState != "" {
s.logStaleWorkerEvent("heartbeat", sessionID, staleTerminalState)
return nil
}
live, err := s.liveState.GetSession(ctx, sessionID)
if err != nil {
return err
}
if live != nil {
live.UpdatedAt = now
return s.liveState.UpsertSession(ctx, *live)
}
return nil
}
func (s *Service) UpdateWorkerRenderTelemetry(ctx context.Context, sessionID string, payload map[string]any) error {
if stale, state, err := s.isStaleTerminalWorkerEvent(ctx, sessionID); err != nil {
return err
} else if stale {
s.logStaleWorkerEvent("render_telemetry", sessionID, state)
return nil
}
live, err := s.GetLiveSession(ctx, sessionID)
if err != nil {
return err
}
if live == nil {
live, err = s.RebuildLiveStateFromStore(ctx, sessionID)
if err != nil {
return err
}
if live == nil {
return ErrSessionNotFound
}
}
now := s.now().UTC()
if profile, ok := payload["render_quality_profile"].(string); ok && profile != "" {
live.RenderQualityProfile = normalizeRenderQualityProfile(profile)
}
if state, ok := payload["render_state"].(string); ok && state != "" {
live.RenderState = state
} else if state, ok := payload["state"].(string); ok && state != "" {
live.RenderState = state
}
if width, ok := toInt(payload["width"]); ok {
live.RenderWidth = width
}
if height, ok := toInt(payload["height"]); ok {
live.RenderHeight = height
}
if sequence, ok := toInt64(payload["frame_sequence"]); ok {
live.RenderFrameSequence = sequence
}
if frameFormat, ok := payload["frame_format"].(string); ok && frameFormat != "" {
live.RenderFrameFormat = frameFormat
}
if desktopWidth, ok := toInt(payload["desktop_width"]); ok && desktopWidth > 0 {
live.RenderWidth = desktopWidth
}
if desktopHeight, ok := toInt(payload["desktop_height"]); ok && desktopHeight > 0 {
live.RenderHeight = desktopHeight
}
s.applyRenderFrameData(live, payload)
if correlationID, ok := payload["input_correlation_id"].(string); ok && correlationID != "" {
live.LastInputCorrelationID = correlationID
}
if capturedAt, ok := payload["worker_frame_captured_at"].(string); ok && capturedAt != "" {
live.WorkerFrameCapturedAt = capturedAt
}
if cursorX, ok := toInt(payload["cursor_x"]); ok {
live.CursorX = cursorX
}
if cursorY, ok := toInt(payload["cursor_y"]); ok {
live.CursorY = cursorY
}
if visible, ok := payload["cursor_visible"].(bool); ok {
live.CursorVisible = visible
}
if dirty, ok := toInt(payload["dirty_rectangles"]); ok {
live.DirtyRectangles = dirty
}
if cursorVisible, ok := payload["cursor_visible"].(bool); ok {
live.CursorVisible = cursorVisible
}
live.LastRenderAt = &now
live.UpdatedAt = now
return s.liveState.UpsertSession(ctx, *live)
}
func (s *Service) applyRenderFrameData(live *LiveSessionState, payload map[string]any) {
frameData, ok := payload["frame_data"].(string)
if !ok || frameData == "" {
return
}
updateKind, _ := payload["frame_update_kind"].(string)
if updateKind != "region" {
live.RenderFrameData = frameData
return
}
desktopWidth := live.RenderWidth
desktopHeight := live.RenderHeight
if width, ok := toInt(payload["desktop_width"]); ok && width > 0 {
desktopWidth = width
}
if height, ok := toInt(payload["desktop_height"]); ok && height > 0 {
desktopHeight = height
}
frameWidth, frameWidthOK := toInt(payload["frame_width"])
frameHeight, frameHeightOK := toInt(payload["frame_height"])
regionX, regionXOK := toInt(payload["region_x"])
regionY, regionYOK := toInt(payload["region_y"])
regionWidth, regionWidthOK := toInt(payload["region_width"])
regionHeight, regionHeightOK := toInt(payload["region_height"])
if !frameWidthOK || !frameHeightOK {
frameWidth = regionWidth
frameHeight = regionHeight
}
if !regionXOK || !regionYOK || !regionWidthOK || !regionHeightOK ||
desktopWidth <= 0 || desktopHeight <= 0 ||
frameWidth <= 0 || frameHeight <= 0 ||
regionWidth <= 0 || regionHeight <= 0 ||
regionX < 0 || regionY < 0 ||
regionX+regionWidth > desktopWidth ||
regionY+regionHeight > desktopHeight {
s.logger.Warn("render live-state region frame ignored because metadata is invalid",
"session_id", live.SessionID,
"desktop_width", desktopWidth,
"desktop_height", desktopHeight,
"frame_width", frameWidth,
"frame_height", frameHeight,
"region_x", regionX,
"region_y", regionY,
"region_width", regionWidth,
"region_height", regionHeight)
return
}
if live.RenderFrameData == "" {
s.logger.Warn("render live-state region frame ignored because no full baseline exists",
"session_id", live.SessionID,
"frame_sequence", live.RenderFrameSequence)
return
}
fullBytes, err := base64.StdEncoding.DecodeString(live.RenderFrameData)
if err != nil {
s.logger.Warn("render live-state region frame ignored because baseline decode failed",
"session_id", live.SessionID,
"error", err)
return
}
regionBytes, err := base64.StdEncoding.DecodeString(frameData)
if err != nil {
s.logger.Warn("render live-state region frame ignored because region decode failed",
"session_id", live.SessionID,
"error", err)
return
}
fullStride := desktopWidth * 4
regionStride := frameWidth * 4
if stride, ok := toInt(payload["frame_stride"]); ok && stride > 0 {
regionStride = stride
}
requiredFullBytes := fullStride * desktopHeight
requiredRegionBytes := regionStride * regionHeight
if len(fullBytes) < requiredFullBytes || len(regionBytes) < requiredRegionBytes || regionStride < regionWidth*4 {
s.logger.Warn("render live-state region frame ignored because byte lengths are invalid",
"session_id", live.SessionID,
"full_bytes", len(fullBytes),
"required_full_bytes", requiredFullBytes,
"region_bytes", len(regionBytes),
"required_region_bytes", requiredRegionBytes)
return
}
for row := 0; row < regionHeight; row++ {
srcOffset := row * regionStride
dstOffset := (regionY+row)*fullStride + regionX*4
copy(fullBytes[dstOffset:dstOffset+regionWidth*4], regionBytes[srcOffset:srcOffset+regionWidth*4])
}
live.RenderFrameData = base64.StdEncoding.EncodeToString(fullBytes[:requiredFullBytes])
live.RenderWidth = desktopWidth
live.RenderHeight = desktopHeight
live.RenderFrameFormat = "bgra32"
}
func (s *Service) UpdateWorkerClipboardText(ctx context.Context, sessionID string, payload map[string]any) error {
if stale, state, err := s.isStaleTerminalWorkerEvent(ctx, sessionID); err != nil {
return err
} else if stale {
s.logStaleWorkerEvent("clipboard_text", sessionID, state)
return nil
}
mode, state, err := s.GetSessionClipboardPolicy(ctx, sessionID)
if err != nil {
return err
}
if state != sessioncontracts.StateActive || !clipboardAllowsServerToClient(mode) {
s.logger.Info("worker clipboard text ignored by policy or state",
"session_id", sessionID,
"state", state,
"clipboard_mode", mode)
return nil
}
text, _ := payload["text"].(string)
if text == "" {
s.logger.Info("worker clipboard text ignored because payload text is empty",
"session_id", sessionID)
return nil
}
origin, _ := payload["origin"].(string)
contentHash, _ := payload["content_hash"].(string)
live, err := s.GetLiveSession(ctx, sessionID)
if err != nil {
return err
}
if live == nil {
live, err = s.RebuildLiveStateFromStore(ctx, sessionID)
if err != nil {
return err
}
if live == nil {
return ErrSessionNotFound
}
}
now := s.now().UTC()
live.ClipboardSequence++
if sequence, ok := toInt64(payload["sequence_id"]); ok && sequence > live.ClipboardSequence {
live.ClipboardSequence = sequence
}
live.ClipboardText = text
live.ClipboardOrigin = origin
live.ClipboardContentHash = contentHash
live.ClipboardUpdatedAt = &now
live.UpdatedAt = now
s.logger.Info("worker clipboard text persisted to live state",
"session_id", sessionID,
"origin", origin,
"sequence_id", live.ClipboardSequence,
"content_hash", contentHash,
"text_bytes", len(text))
return s.liveState.UpsertSession(ctx, *live)
}
func (s *Service) UpdateWorkerFileDownloadEvent(ctx context.Context, sessionID, eventType string, payload map[string]any) error {
if stale, state, err := s.isStaleTerminalWorkerEvent(ctx, sessionID); err != nil {
return err
} else if stale {
s.logStaleWorkerEvent("file_download", sessionID, state)
return nil
}
mode, state, err := s.GetSessionFileTransferPolicy(ctx, sessionID)
if err != nil {
return err
}
if state != sessioncontracts.StateActive || !fileTransferAllowsServerToClient(mode) {
s.logger.Info("worker file download event ignored by policy or state",
"session_id", sessionID,
"state", state,
"file_transfer_mode", mode,
"event_type", eventType)
return nil
}
if payload == nil {
payload = map[string]any{}
}
payload["direction"] = "server_to_client"
live, err := s.GetLiveSession(ctx, sessionID)
if err != nil {
return err
}
if live == nil {
live, err = s.RebuildLiveStateFromStore(ctx, sessionID)
if err != nil {
return err
}
if live == nil {
return ErrSessionNotFound
}
}
now := s.now().UTC()
live.FileDownloadSequence++
if sequence, ok := toInt64(payload["sequence"]); ok && sequence > live.FileDownloadSequence {
live.FileDownloadSequence = sequence
}
live.FileDownloadType = eventType
live.FileDownloadPayload = payload
live.FileDownloadUpdatedAt = &now
live.UpdatedAt = now
s.logger.Info("worker file download event persisted to live state",
"session_id", sessionID,
"event_type", eventType,
"sequence", live.FileDownloadSequence,
"transfer_id", payload["transfer_id"],
"file_id", payload["file_id"],
"file_name", payload["file_name"],
"status", payload["status"])
return s.liveState.UpsertSession(ctx, *live)
}
func (s *Service) prepareAttachment(ctx context.Context, session RemoteSession, attachment SessionAttachment, lease *workercontracts.WorkerLease) error {
runtimeMetadata, secretRef, secretVersion, err := s.runtimeAssignmentMetadata(ctx, session, lease)
if err != nil {
if secretRef != "" {
_ = s.writeAuditEvent(ctx, s.store, AuditEventSecretAccessDenied, attachment.UserID, attachment.DeviceID, secretRef, session.ID, map[string]any{
"resource_id": session.ResourceID,
"organization_id": session.OrganizationID,
"worker_id": session.WorkerID,
"reason": err.Error(),
})
}
return err
}
if err := s.orchestrator.PrepareAttachment(ctx, session, attachment, runtimeMetadata); err != nil {
return err
}
if secretRef != "" {
_ = s.writeAuditEvent(ctx, s.store, AuditEventSecretAccessed, attachment.UserID, attachment.DeviceID, secretRef, session.ID, map[string]any{
"resource_id": session.ResourceID,
"organization_id": session.OrganizationID,
"worker_id": session.WorkerID,
"secret_ref": secretRef,
"version": secretVersion,
})
}
return nil
}
func (s *Service) runtimeAssignmentMetadata(ctx context.Context, session RemoteSession, lease *workercontracts.WorkerLease) (map[string]any, string, int, error) {
metadata := decodeJSONMap(session.Metadata)
secretRef := secretRefFromAssignmentMetadata(metadata)
if secretRef == "" {
return metadata, "", 0, nil
}
if s.secretResolver == nil {
if secrets.IsProductionEnv(s.cfg.App.Env) {
return nil, secretRef, 0, secrets.ErrSecretEncryptionKeyMissing
}
return metadata, "", 0, nil
}
leaseID := ""
if lease != nil {
leaseID = lease.LeaseID
}
resolved, err := s.secretResolver.ResolveForSession(ctx, secrets.ResolveResourceSecretRequest{
SecretRef: secretRef,
OrganizationID: session.OrganizationID,
ResourceID: session.ResourceID,
SessionID: session.ID,
WorkerID: session.WorkerID,
LeaseID: leaseID,
})
if err != nil {
return nil, secretRef, 0, err
}
merged, err := secrets.MergeResourceSecretIntoAssignmentMetadata(metadata, resolved.Payload)
if err != nil {
return nil, secretRef, resolved.Descriptor.Version, err
}
return merged.Metadata, secretRef, resolved.Descriptor.Version, nil
}
func (s *Service) publishLiveState(ctx context.Context, session RemoteSession, attachment SessionAttachment, lease *workercontracts.WorkerLease) (*sessioncontracts.AttachTokenClaims, error) {
now := s.now().UTC()
renderQualityProfile := session.RenderQualityProfile
if renderQualityProfile == "" {
renderQualityProfile = renderQualityProfileFromSessionMetadata(session.Metadata)
}
if err := s.liveState.UpsertSession(ctx, LiveSessionState{
SessionID: session.ID,
ResourceID: session.ResourceID,
WorkerID: session.WorkerID,
State: session.State,
ControllerID: session.ControllerUserID,
AttachmentID: attachment.ID,
TakeoverVersion: session.TakeoverVersion,
RenderQualityProfile: renderQualityProfile,
RenderState: "connecting",
UpdatedAt: now,
}); err != nil {
return nil, err
}
if err := s.liveState.BindController(ctx, sessioncontracts.ControllerBinding{
SessionID: session.ID,
AttachmentID: attachment.ID,
UserID: attachment.UserID,
DeviceID: attachment.DeviceID,
TakeoverVersion: session.TakeoverVersion,
BoundAt: now,
}, s.cfg.Session.LiveStateTTL); err != nil {
return nil, err
}
if lease != nil {
if err := s.liveState.UpdateWorkerRoute(ctx, WorkerRoute{
SessionID: session.ID,
WorkerID: lease.WorkerID,
LeaseID: lease.LeaseID,
ControlStream: lease.ControlStream,
UpdatedAt: now,
}, s.cfg.Session.LiveStateTTL); err != nil {
return nil, err
}
}
token := sessioncontracts.AttachTokenClaims{
Token: uuid.NewString(),
SessionID: session.ID,
AttachmentID: attachment.ID,
UserID: attachment.UserID,
DeviceID: attachment.DeviceID,
WorkerID: session.WorkerID,
TakeoverVersion: session.TakeoverVersion,
ExpiresAt: now.Add(s.cfg.Session.AttachTokenTTL),
Reconnectable: true,
}
if err := s.liveState.StoreAttachToken(ctx, token, s.cfg.Session.AttachTokenTTL); err != nil {
return nil, err
}
return &token, s.liveState.TouchAttachmentHeartbeat(ctx, session.ID, attachment.ID, s.cfg.Session.HeartbeatTTL)
}
func (s *Service) loadPolicy(ctx context.Context, resourceID string) (*ResourcePolicy, error) {
return s.loadPolicyFromStore(ctx, s.store, resourceID)
}
func (s *Service) loadPolicyFromStore(ctx context.Context, store Store, resourceID string) (*ResourcePolicy, error) {
policy, err := store.ResourcePolicies().GetByResourceID(ctx, resourceID)
if err != nil {
return nil, err
}
if policy == nil {
policy = &ResourcePolicy{
ResourceID: resourceID,
MaxConcurrentSessions: 1,
TakeoverPolicy: ResourceTakeoverPolicyTrustedDevice,
RequireTrustedDevice: true,
DetachGracePeriod: s.cfg.Session.DetachGracePeriod,
ClipboardEnabled: false,
ClipboardMode: ResourceClipboardModeDisabled,
FileTransferEnabled: false,
FileTransferMode: ResourceFileTransferModeDisabled,
}
}
policy.ClipboardMode = normalizeClipboardMode(policy.ClipboardMode)
policy.ClipboardEnabled = policy.ClipboardMode != ResourceClipboardModeDisabled
policy.FileTransferMode = normalizeFileTransferMode(policy.FileTransferMode)
policy.FileTransferEnabled = fileTransferAllowsClientToServer(policy.FileTransferMode)
return policy, nil
}
func (s *Service) GetSessionClipboardPolicy(ctx context.Context, sessionID string) (ResourceClipboardMode, sessioncontracts.State, error) {
session, err := s.store.RemoteSessions().GetByID(ctx, sessionID)
if err != nil {
return ResourceClipboardModeDisabled, "", err
}
if session == nil {
return ResourceClipboardModeDisabled, "", ErrSessionNotFound
}
policy, err := s.loadPolicy(ctx, session.ResourceID)
if err != nil {
return ResourceClipboardModeDisabled, "", err
}
return policy.ClipboardMode, session.State, nil
}
func normalizeClipboardMode(mode ResourceClipboardMode) ResourceClipboardMode {
switch mode {
case ResourceClipboardModeClientToServer, ResourceClipboardModeServerToClient, ResourceClipboardModeBidirectional:
return mode
default:
return ResourceClipboardModeDisabled
}
}
func clipboardAllowsClientToServer(mode ResourceClipboardMode) bool {
return mode == ResourceClipboardModeClientToServer || mode == ResourceClipboardModeBidirectional
}
func clipboardAllowsServerToClient(mode ResourceClipboardMode) bool {
return mode == ResourceClipboardModeServerToClient || mode == ResourceClipboardModeBidirectional
}
func (s *Service) GetSessionFileTransferPolicy(ctx context.Context, sessionID string) (ResourceFileTransferMode, sessioncontracts.State, error) {
session, err := s.store.RemoteSessions().GetByID(ctx, sessionID)
if err != nil {
return ResourceFileTransferModeDisabled, "", err
}
if session == nil {
return ResourceFileTransferModeDisabled, "", ErrSessionNotFound
}
policy, err := s.loadPolicy(ctx, session.ResourceID)
if err != nil {
return ResourceFileTransferModeDisabled, "", err
}
return policy.FileTransferMode, session.State, nil
}
func normalizeFileTransferMode(mode ResourceFileTransferMode) ResourceFileTransferMode {
switch mode {
case ResourceFileTransferModeClientToServer, ResourceFileTransferModeServerToClient, ResourceFileTransferModeBidirectional:
return mode
default:
return ResourceFileTransferModeDisabled
}
}
func fileTransferAllowsClientToServer(mode ResourceFileTransferMode) bool {
return mode == ResourceFileTransferModeClientToServer || mode == ResourceFileTransferModeBidirectional
}
func fileTransferAllowsServerToClient(mode ResourceFileTransferMode) bool {
return mode == ResourceFileTransferModeServerToClient || mode == ResourceFileTransferModeBidirectional
}
func (s *Service) ensureTrustedDevice(ctx context.Context, policy *ResourcePolicy, userID, deviceID string) error {
return s.ensureTrustedDeviceWithStore(ctx, s.store, policy, userID, deviceID)
}
func (s *Service) ensureTrustedDeviceWithStore(ctx context.Context, store Store, policy *ResourcePolicy, userID, deviceID string) error {
if !policy.RequireTrustedDevice {
return nil
}
ok, err := store.Access().IsTrustedDevice(ctx, userID, deviceID)
if err != nil {
return err
}
if !ok {
return ErrTrustedDeviceRequired
}
return nil
}
func (s *Service) ensureSessionCapacity(ctx context.Context, policy *ResourcePolicy, resourceID string) error {
count, err := s.store.RemoteSessions().CountLiveByResource(ctx, resourceID)
if err != nil {
return err
}
if count >= policy.MaxConcurrentSessions {
return ErrActiveControllerPresent
}
return nil
}
func (s *Service) ensureTakeoverAllowed(ctx context.Context, store Store, policy *ResourcePolicy, session *RemoteSession, userID, deviceID string) error {
role, err := store.Access().GetPlatformRole(ctx, userID)
if err != nil {
return err
}
if isPlatformAdminRole(role) {
return nil
}
orgRole, ok, err := store.Access().GetOrganizationRole(ctx, session.OrganizationID, userID)
if err != nil {
return err
}
if !ok {
return ErrAccessDenied
}
if err := s.ensureTrustedDeviceWithStore(ctx, store, policy, userID, deviceID); err != nil {
return err
}
if isOrganizationAdminRole(orgRole) {
return nil
}
switch policy.TakeoverPolicy {
case ResourceTakeoverPolicyTrustedDevice:
return nil
case ResourceTakeoverPolicySameUser:
if session.ControllerUserID != userID {
return ErrTakeoverNotAllowed
}
return nil
case ResourceTakeoverPolicyAdminOnly:
return ErrTakeoverNotAllowed
default:
return ErrTakeoverNotAllowed
}
}
func (s *Service) writeAuditEvent(ctx context.Context, store Store, eventType, actorUserID, actorDeviceID, targetID, sessionID string, payload map[string]any) error {
encoded, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal audit payload: %w", err)
}
var actorUser *string
if actorUserID != "" {
actorUser = &actorUserID
}
var actorDevice *string
if actorDeviceID != "" {
actorDevice = &actorDeviceID
}
sessionRef := sessionID
return store.AuditEvents().Create(ctx, AuditEvent{
ID: uuid.NewString(),
ActorUserID: actorUser,
ActorDeviceID: actorDevice,
EventType: eventType,
TargetType: "remote_session",
TargetID: targetID,
RemoteSessionID: &sessionRef,
Payload: encoded,
CreatedAt: s.now().UTC(),
})
}
func requiredCapabilities(policy *ResourcePolicy) []string {
caps := []string{"adaptive-quality", "dirty-rects"}
if policy.ClipboardEnabled {
caps = append(caps, "clipboard")
}
if fileTransferAllowsClientToServer(policy.FileTransferMode) {
caps = append(caps, "file-transfer")
}
return caps
}
func renderQualityProfileFromMetadata(metadata map[string]any) string {
if profile, ok := metadata["render_quality_profile"].(string); ok && profile != "" {
switch profile {
case "low_bandwidth", "balanced", "high_quality", "text_priority":
return profile
}
}
return "balanced"
}
func renderQualityProfileFromSessionMetadata(metadata []byte) string {
if len(metadata) == 0 {
return "balanced"
}
var decoded map[string]any
if err := json.Unmarshal(metadata, &decoded); err != nil {
return "balanced"
}
resource, _ := decoded["resource"].(map[string]any)
if resource == nil {
return "balanced"
}
return renderQualityProfileFromMetadata(resource)
}
func toInt(value any) (int, bool) {
switch v := value.(type) {
case int:
return v, true
case int32:
return int(v), true
case int64:
return int(v), true
case float64:
return int(v), true
case float32:
return int(v), true
default:
return 0, false
}
}
func toInt64(value any) (int64, bool) {
switch v := value.(type) {
case int:
return int64(v), true
case int32:
return int64(v), true
case int64:
return v, true
case float64:
return int64(v), true
case float32:
return int64(v), true
default:
return 0, false
}
}
func decodeJSONMap(payload []byte) map[string]any {
var out map[string]any
if len(payload) == 0 {
return map[string]any{}
}
if err := json.Unmarshal(payload, &out); err != nil {
return map[string]any{}
}
return out
}
func secretRefFromAssignmentMetadata(metadata map[string]any) string {
resource, _ := metadata["resource"].(map[string]any)
if resource == nil {
return ""
}
secretRef, _ := resource["secret_ref"].(string)
return secretRef
}
func (s *Service) RebuildLiveStateFromStore(ctx context.Context, sessionID string) (*LiveSessionState, error) {
session, err := s.store.RemoteSessions().GetByID(ctx, sessionID)
if err != nil || session == nil {
return sessionToLiveState(session), err
}
session, err = s.ensureSessionRuntimeConsistency(ctx, session)
if err != nil || session == nil {
return sessionToLiveState(session), err
}
attachments, err := s.store.SessionAttachments().ListByRemoteSession(ctx, sessionID)
if err != nil {
return nil, err
}
var controller *SessionAttachment
for i := len(attachments) - 1; i >= 0; i-- {
attachment := attachments[i]
if attachment.State == AttachmentStateActive || attachment.State == AttachmentStateAttaching || attachment.State == AttachmentStateDetached {
controller = &attachment
break
}
}
if controller == nil {
return sessionToLiveState(session), nil
}
live := &LiveSessionState{
SessionID: session.ID,
ResourceID: session.ResourceID,
WorkerID: session.WorkerID,
State: session.State,
ControllerID: session.ControllerUserID,
AttachmentID: controller.ID,
TakeoverVersion: session.TakeoverVersion,
UpdatedAt: s.now().UTC(),
}
if err := s.liveState.UpsertSession(ctx, *live); err != nil {
return nil, err
}
if err := s.liveState.BindController(ctx, sessioncontracts.ControllerBinding{
SessionID: session.ID,
AttachmentID: controller.ID,
UserID: controller.UserID,
DeviceID: controller.DeviceID,
TakeoverVersion: session.TakeoverVersion,
BoundAt: s.now().UTC(),
}, s.cfg.Session.LiveStateTTL); err != nil {
return nil, err
}
return live, nil
}
func (s *Service) ensureSessionRuntimeConsistencyByID(ctx context.Context, sessionID string) (*RemoteSession, error) {
session, err := s.store.RemoteSessions().GetByID(ctx, sessionID)
if err != nil || session == nil {
return session, err
}
return s.ensureSessionRuntimeConsistency(ctx, session)
}
func (s *Service) ensureSessionRuntimeConsistency(ctx context.Context, session *RemoteSession) (*RemoteSession, error) {
if session == nil || !requiresLiveRuntime(session.State) {
return session, nil
}
healthy, reason, err := s.orchestrator.ValidateSessionRuntime(ctx, session.ID, session.WorkerID)
if err != nil {
return nil, err
}
if healthy {
return session, nil
}
s.logger.Warn("session runtime consistency check failed; marking session failed",
"session_id", session.ID,
"worker_id", session.WorkerID,
"state", session.State,
"reason", reason)
if err := s.MarkSessionFailed(ctx, MarkSessionFailedCommand{
SessionID: session.ID,
Reason: reason,
}); err != nil && !errors.Is(err, ErrSessionNotFound) && !errors.Is(err, ErrSessionNotTerminable) {
return nil, err
}
return s.store.RemoteSessions().GetByID(ctx, session.ID)
}
func requiresLiveRuntime(state sessioncontracts.State) bool {
switch state {
case sessioncontracts.StateStarting, sessioncontracts.StateActive, sessioncontracts.StateDetached, sessioncontracts.StateReconnecting:
return true
default:
return false
}
}
func isTerminalSessionState(state sessioncontracts.State) bool {
switch state {
case sessioncontracts.StateTerminated, sessioncontracts.StateFailed:
return true
default:
return false
}
}
func (s *Service) isStaleTerminalWorkerEvent(ctx context.Context, sessionID string) (bool, sessioncontracts.State, error) {
session, err := s.store.RemoteSessions().GetByID(ctx, sessionID)
if err != nil || session == nil {
return false, "", err
}
if isTerminalSessionState(session.State) {
return true, session.State, nil
}
return false, session.State, nil
}
func (s *Service) logStaleWorkerEvent(eventType, sessionID string, state sessioncontracts.State) {
if s.logger == nil {
return
}
args := []any{
"event_type", eventType,
"session_id", sessionID,
"state", state,
}
if eventType == "render_telemetry" {
s.logger.Debug("stale worker event ignored because authoritative session state is terminal", args...)
return
}
s.logger.Info("stale worker event ignored because authoritative session state is terminal", args...)
}
func sessionToLiveState(session *RemoteSession) *LiveSessionState {
if session == nil {
return nil
}
return &LiveSessionState{
SessionID: session.ID,
ResourceID: session.ResourceID,
WorkerID: session.WorkerID,
State: session.State,
ControllerID: session.ControllerUserID,
TakeoverVersion: session.TakeoverVersion,
RenderQualityProfile: normalizeRenderQualityProfile(session.RenderQualityProfile),
RenderState: "connecting",
}
}
func mergeLiveRenderTelemetry(dst *LiveSessionState, src *LiveSessionState) {
if dst == nil || src == nil {
return
}
if src.RenderQualityProfile != "" {
dst.RenderQualityProfile = src.RenderQualityProfile
}
if src.RenderState != "" {
dst.RenderState = src.RenderState
}
if src.RenderWidth > 0 {
dst.RenderWidth = src.RenderWidth
}
if src.RenderHeight > 0 {
dst.RenderHeight = src.RenderHeight
}
if src.RenderFrameSequence > 0 {
dst.RenderFrameSequence = src.RenderFrameSequence
}
if src.RenderFrameFormat != "" {
dst.RenderFrameFormat = src.RenderFrameFormat
}
if src.RenderFrameData != "" {
dst.RenderFrameData = src.RenderFrameData
}
if src.CursorX != 0 {
dst.CursorX = src.CursorX
}
if src.CursorY != 0 {
dst.CursorY = src.CursorY
}
dst.CursorVisible = src.CursorVisible
if src.DirtyRectangles > 0 {
dst.DirtyRectangles = src.DirtyRectangles
}
if src.LastRenderAt != nil {
dst.LastRenderAt = src.LastRenderAt
}
}
func normalizeRenderQualityProfile(profile string) string {
switch profile {
case "low_bandwidth", "balanced", "high_quality", "text_priority":
return profile
default:
return "balanced"
}
}
func (s *Service) findReusableSession(ctx context.Context, resourceID, userID string) (*RemoteSession, error) {
sessions, err := s.store.RemoteSessions().ListByController(ctx, userID)
if err != nil {
return nil, err
}
for _, session := range sessions {
if session.ResourceID != resourceID {
continue
}
if session.State == sessioncontracts.StateStarting || session.State == sessioncontracts.StateActive || session.State == sessioncontracts.StateDetached || session.State == sessioncontracts.StateReconnecting {
copy := session
return &copy, nil
}
}
return nil, nil
}
func (s *Service) MapError(err error) (int, string) {
switch {
case err == nil:
return 0, ""
case errors.Is(err, ErrSessionNotFound):
return 404, err.Error()
case errors.Is(err, ErrAttachmentNotFound):
return 404, err.Error()
case errors.Is(err, ErrAccessDenied), errors.Is(err, ErrTrustedDeviceRequired):
return 403, err.Error()
case errors.Is(err, ErrActiveControllerPresent), errors.Is(err, ErrTakeoverNotAllowed), errors.Is(err, ErrSessionNotAttachable), errors.Is(err, ErrSessionNotTerminable):
return 409, err.Error()
case errors.Is(err, ErrAttachTokenInvalid):
return 401, err.Error()
default:
return 500, err.Error()
}
}
func (s *Service) ensureOrganizationMemberAccess(ctx context.Context, organizationID, userID string) error {
return s.ensureOrganizationMemberAccessWithStore(ctx, s.store, organizationID, userID)
}
func (s *Service) ensureOrganizationMemberAccessWithStore(ctx context.Context, store Store, organizationID, userID string) error {
if userID == "" {
return ErrAccessDenied
}
allowed, err := s.hasOrganizationMemberAccessWithStore(ctx, store, organizationID, userID)
if err != nil {
return err
}
if !allowed {
return ErrAccessDenied
}
return nil
}
func (s *Service) hasOrganizationMemberAccess(ctx context.Context, organizationID, userID string) (bool, error) {
return s.hasOrganizationMemberAccessWithStore(ctx, s.store, organizationID, userID)
}
func (s *Service) hasOrganizationMemberAccessWithStore(ctx context.Context, store Store, organizationID, userID string) (bool, error) {
role, err := store.Access().GetPlatformRole(ctx, userID)
if err != nil {
return false, err
}
if isPlatformAdminRole(role) {
return true, nil
}
_, ok, err := store.Access().GetOrganizationRole(ctx, organizationID, userID)
if err != nil {
return false, err
}
return ok, nil
}
func (s *Service) ensureSessionControllerAccessWithStore(ctx context.Context, store Store, session *RemoteSession, userID string) error {
if err := s.ensureOrganizationMemberAccessWithStore(ctx, store, session.OrganizationID, userID); err != nil {
return err
}
if session.ControllerUserID == userID {
return nil
}
return ErrAccessDenied
}
func (s *Service) ensureSessionAdminOrControllerAccessWithStore(ctx context.Context, store Store, session *RemoteSession, userID string) error {
role, err := store.Access().GetPlatformRole(ctx, userID)
if err != nil {
return err
}
if isPlatformAdminRole(role) {
return nil
}
orgRole, ok, err := store.Access().GetOrganizationRole(ctx, session.OrganizationID, userID)
if err != nil {
return err
}
if !ok {
return ErrAccessDenied
}
if isOrganizationAdminRole(orgRole) || session.ControllerUserID == userID {
return nil
}
return ErrAccessDenied
}
func isPlatformAdminRole(role string) bool {
return role == "platform_admin" || role == "platform_recovery_admin"
}
func isOrganizationAdminRole(role string) bool {
return role == "org_owner" || role == "org_admin"
}