package sessionbroker import ( "context" "io" "log/slog" "testing" "time" "github.com/example/remote-access-platform/backend/internal/platform/module" sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session" workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker" ) func TestHandleWorkerConnectedIgnoresTerminalSession(t *testing.T) { service, store, live, _ := newStaleWorkerEventTestService() store.remote.sessions["session-1"] = RemoteSession{ ID: "session-1", State: sessioncontracts.StateTerminated, WorkerID: "worker-1", } if err := service.HandleWorkerConnected(context.Background(), "session-1"); err != nil { t.Fatalf("HandleWorkerConnected returned error for stale terminal event: %v", err) } if got := store.remote.sessions["session-1"].State; got != sessioncontracts.StateTerminated { t.Fatalf("stale connected event changed terminal state to %q", got) } if store.remote.updateCount != 0 { t.Fatalf("stale connected event updated authoritative session %d times", store.remote.updateCount) } if live.upsertCount != 0 { t.Fatalf("stale connected event recreated live state %d times", live.upsertCount) } } func TestUpdateWorkerRenderTelemetryIgnoresTerminalSession(t *testing.T) { service, store, live, _ := newStaleWorkerEventTestService() store.remote.sessions["session-1"] = RemoteSession{ ID: "session-1", State: sessioncontracts.StateTerminated, WorkerID: "worker-1", } err := service.UpdateWorkerRenderTelemetry(context.Background(), "session-1", map[string]any{ "render_state": "ready", "width": 1280, "height": 720, "frame_sequence": int64(99), "frame_data": "stale-frame", }) if err != nil { t.Fatalf("UpdateWorkerRenderTelemetry returned error for stale terminal event: %v", err) } if live.upsertCount != 0 { t.Fatalf("stale render event recreated live state %d times", live.upsertCount) } if live.sessions["session-1"] != nil { t.Fatalf("stale render event left live state behind: %#v", live.sessions["session-1"]) } } func TestMarkSessionFailedTransitionsActiveSession(t *testing.T) { service, store, live, orchestrator := newStaleWorkerEventTestService() store.remote.sessions["session-1"] = RemoteSession{ ID: "session-1", State: sessioncontracts.StateActive, WorkerID: "worker-1", TakeoverVersion: 3, } store.attachments.items["attachment-1"] = SessionAttachment{ ID: "attachment-1", RemoteSessionID: "session-1", State: AttachmentStateActive, } live.sessions["session-1"] = &LiveSessionState{SessionID: "session-1", State: sessioncontracts.StateActive} if err := service.MarkSessionFailed(context.Background(), MarkSessionFailedCommand{SessionID: "session-1", Reason: "worker_lost"}); err != nil { t.Fatalf("MarkSessionFailed returned error: %v", err) } if got := store.remote.sessions["session-1"].State; got != sessioncontracts.StateFailed { t.Fatalf("expected failed state, got %q", got) } if got := store.attachments.items["attachment-1"].State; got != AttachmentStateClosed { t.Fatalf("expected attachment closed, got %q", got) } if store.audit.createCount != 1 { t.Fatalf("expected one audit event, got %d", store.audit.createCount) } if live.sessions["session-1"] != nil { t.Fatal("expected failed session live state to be deleted") } if orchestrator.releaseCount != 1 { t.Fatalf("expected session lease release, got %d", orchestrator.releaseCount) } } func TestMarkSessionFailedAlreadyFailedIsIdempotent(t *testing.T) { service, store, _, _ := newStaleWorkerEventTestService() store.remote.sessions["session-1"] = RemoteSession{ ID: "session-1", State: sessioncontracts.StateFailed, WorkerID: "worker-1", TakeoverVersion: 1, } if err := service.MarkSessionFailed(context.Background(), MarkSessionFailedCommand{SessionID: "session-1", Reason: "duplicate_worker_failure"}); err != nil { t.Fatalf("duplicate MarkSessionFailed returned error: %v", err) } if got := store.remote.sessions["session-1"].State; got != sessioncontracts.StateFailed { t.Fatalf("duplicate terminal event changed state to %q", got) } } func newStaleWorkerEventTestService() (*Service, *staleWorkerEventTestStore, *staleWorkerEventLiveState, *staleWorkerEventOrchestrator) { store := &staleWorkerEventTestStore{ remote: &staleWorkerEventRemoteSessions{sessions: map[string]RemoteSession{}}, attachments: &staleWorkerEventAttachments{items: map[string]SessionAttachment{}}, policies: &staleWorkerEventPolicies{}, audit: &staleWorkerEventAudit{}, } live := &staleWorkerEventLiveState{sessions: map[string]*LiveSessionState{}} orchestrator := &staleWorkerEventOrchestrator{} service := NewService(module.Dependencies{ Infra: module.Infra{Logger: slog.New(slog.NewTextHandler(io.Discard, nil))}, }, store, staleWorkerEventTransactor{store: store}, live, orchestrator) service.now = func() time.Time { return time.Unix(100, 0).UTC() } return service, store, live, orchestrator } type staleWorkerEventTransactor struct { store Store } func (t staleWorkerEventTransactor) WithinTransaction(ctx context.Context, fn func(store Store) error) error { return fn(t.store) } type staleWorkerEventTestStore struct { remote *staleWorkerEventRemoteSessions attachments *staleWorkerEventAttachments policies *staleWorkerEventPolicies audit *staleWorkerEventAudit } func (s *staleWorkerEventTestStore) RemoteSessions() RemoteSessionRepository { return s.remote } func (s *staleWorkerEventTestStore) SessionAttachments() SessionAttachmentRepository { return s.attachments } func (s *staleWorkerEventTestStore) ResourcePolicies() ResourcePolicyRepository { return s.policies } func (s *staleWorkerEventTestStore) ResourceRuntime() ResourceRuntimeRepository { return staleWorkerEventResourceRuntime{} } func (s *staleWorkerEventTestStore) AuditEvents() AuditEventRepository { return s.audit } func (s *staleWorkerEventTestStore) Access() AccessRepository { return staleWorkerEventAccess{} } type staleWorkerEventRemoteSessions struct { sessions map[string]RemoteSession updateCount int } func (r *staleWorkerEventRemoteSessions) Create(_ context.Context, session RemoteSession) error { r.sessions[session.ID] = session return nil } func (r *staleWorkerEventRemoteSessions) GetByID(_ context.Context, sessionID string) (*RemoteSession, error) { session, ok := r.sessions[sessionID] if !ok { return nil, nil } return &session, nil } func (r *staleWorkerEventRemoteSessions) GetByIDForUpdate(ctx context.Context, sessionID string) (*RemoteSession, error) { return r.GetByID(ctx, sessionID) } func (r *staleWorkerEventRemoteSessions) ListByController(_ context.Context, _ string) ([]RemoteSession, error) { return nil, nil } func (r *staleWorkerEventRemoteSessions) CountLiveByResource(_ context.Context, _ string) (int, error) { return 0, nil } func (r *staleWorkerEventRemoteSessions) ListDetachedExpired(_ context.Context, _ time.Time, _ int) ([]RemoteSession, error) { return nil, nil } func (r *staleWorkerEventRemoteSessions) UpdateState(_ context.Context, params UpdateRemoteSessionStateParams) error { session := r.sessions[params.RemoteSessionID] session.State = params.State session.WorkerID = params.WorkerID session.DetachDeadlineAt = params.DetachDeadlineAt session.LastHeartbeatAt = params.LastHeartbeatAt session.TakeoverVersion = params.TakeoverVersion session.UpdatedAt = params.UpdatedAt r.sessions[params.RemoteSessionID] = session r.updateCount++ return nil } type staleWorkerEventAttachments struct { items map[string]SessionAttachment } func (r *staleWorkerEventAttachments) Create(_ context.Context, attachment SessionAttachment) error { r.items[attachment.ID] = attachment return nil } func (r *staleWorkerEventAttachments) GetByID(_ context.Context, attachmentID string) (*SessionAttachment, error) { attachment, ok := r.items[attachmentID] if !ok { return nil, nil } return &attachment, nil } func (r *staleWorkerEventAttachments) GetByIDForUpdate(ctx context.Context, attachmentID string) (*SessionAttachment, error) { return r.GetByID(ctx, attachmentID) } func (r *staleWorkerEventAttachments) ListByRemoteSession(_ context.Context, remoteSessionID string) ([]SessionAttachment, error) { attachments := make([]SessionAttachment, 0) for _, attachment := range r.items { if attachment.RemoteSessionID == remoteSessionID { attachments = append(attachments, attachment) } } return attachments, nil } func (r *staleWorkerEventAttachments) ListActiveByRemoteSessionForUpdate(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error) { return r.ListByRemoteSession(ctx, remoteSessionID) } func (r *staleWorkerEventAttachments) UpdateState(_ context.Context, params UpdateSessionAttachmentStateParams) error { attachment := r.items[params.AttachmentID] attachment.State = params.State attachment.DetachedAt = params.DetachedAt attachment.LastInputAt = params.LastInputAt attachment.UpdatedAt = params.UpdatedAt r.items[params.AttachmentID] = attachment return nil } func (r *staleWorkerEventAttachments) Supersede(_ context.Context, params SupersedeAttachmentParams) error { attachment := r.items[params.PreviousAttachmentID] attachment.State = AttachmentStateSuperseded attachment.SupersededBy = ¶ms.NextAttachmentID attachment.DetachedAt = ¶ms.DetachedAt attachment.UpdatedAt = params.UpdatedAt r.items[params.PreviousAttachmentID] = attachment return nil } type staleWorkerEventPolicies struct{} func (r *staleWorkerEventPolicies) GetByResourceID(_ context.Context, _ string) (*ResourcePolicy, error) { return nil, nil } func (r *staleWorkerEventPolicies) Upsert(_ context.Context, _ ResourcePolicy) error { return nil } type staleWorkerEventAudit struct { createCount int } func (r *staleWorkerEventAudit) Create(_ context.Context, _ AuditEvent) error { r.createCount++ return nil } type staleWorkerEventResourceRuntime struct{} func (staleWorkerEventResourceRuntime) GetByID(_ context.Context, _ string) (*ResourceRuntimeSpec, error) { return nil, nil } type staleWorkerEventAccess struct{} func (staleWorkerEventAccess) IsTrustedDevice(_ context.Context, _, _ string) (bool, error) { return false, nil } func (staleWorkerEventAccess) GetPlatformRole(_ context.Context, _ string) (string, error) { return "", nil } func (staleWorkerEventAccess) GetOrganizationRole(_ context.Context, _, _ string) (string, bool, error) { return "", false, nil } type staleWorkerEventLiveState struct { sessions map[string]*LiveSessionState upsertCount int } func (s *staleWorkerEventLiveState) UpsertSession(_ context.Context, state LiveSessionState) error { copied := state s.sessions[state.SessionID] = &copied s.upsertCount++ return nil } func (s *staleWorkerEventLiveState) GetSession(_ context.Context, sessionID string) (*LiveSessionState, error) { state := s.sessions[sessionID] if state == nil { return nil, nil } copied := *state return &copied, nil } func (s *staleWorkerEventLiveState) DeleteSession(_ context.Context, sessionID string) error { delete(s.sessions, sessionID) return nil } func (s *staleWorkerEventLiveState) BindController(_ context.Context, _ sessioncontracts.ControllerBinding, _ time.Duration) error { return nil } func (s *staleWorkerEventLiveState) GetControllerBinding(_ context.Context, _ string) (*sessioncontracts.ControllerBinding, error) { return nil, nil } func (s *staleWorkerEventLiveState) ClearControllerBinding(_ context.Context, _ string) error { return nil } func (s *staleWorkerEventLiveState) StoreAttachToken(_ context.Context, _ sessioncontracts.AttachTokenClaims, _ time.Duration) error { return nil } func (s *staleWorkerEventLiveState) ConsumeAttachToken(_ context.Context, _ string) (*sessioncontracts.AttachTokenClaims, error) { return nil, nil } func (s *staleWorkerEventLiveState) TouchAttachmentHeartbeat(_ context.Context, _, _ string, _ time.Duration) error { return nil } func (s *staleWorkerEventLiveState) UpdateWorkerRoute(_ context.Context, _ WorkerRoute, _ time.Duration) error { return nil } func (s *staleWorkerEventLiveState) GetWorkerRoute(_ context.Context, _ string) (*WorkerRoute, error) { return nil, nil } func (s *staleWorkerEventLiveState) DeleteWorkerRoute(_ context.Context, _ string) error { return nil } type staleWorkerEventOrchestrator struct { releaseCount int } func (o *staleWorkerEventOrchestrator) Reserve(_ context.Context, _ workercontracts.AttachRequest) (*workercontracts.WorkerLease, error) { return nil, nil } func (o *staleWorkerEventOrchestrator) GetSessionLease(_ context.Context, _ string) (*workercontracts.WorkerLease, error) { return nil, nil } func (o *staleWorkerEventOrchestrator) ReleaseSessionLease(_ context.Context, _ string) error { o.releaseCount++ return nil } func (o *staleWorkerEventOrchestrator) PrepareAttachment(_ context.Context, _ RemoteSession, _ SessionAttachment, _ map[string]any) error { return nil } func (o *staleWorkerEventOrchestrator) NotifyDetachment(_ context.Context, _ RemoteSession, _ SessionAttachment) error { return nil } func (o *staleWorkerEventOrchestrator) TerminateRemoteSession(_ context.Context, _, _ string) error { return nil } func (o *staleWorkerEventOrchestrator) ValidateSessionRuntime(_ context.Context, _, _ string) (bool, string, error) { return true, "", nil }