Files
rdp-proxy/backend/internal/modules/sessionbroker/stale_worker_event_test.go
T
2026-04-28 22:29:50 +03:00

392 lines
13 KiB
Go

package sessionbroker
import (
"context"
"io"
"log/slog"
"testing"
"time"
"github.com/example/remote-access-platform/backend/internal/platform/module"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
)
func TestHandleWorkerConnectedIgnoresTerminalSession(t *testing.T) {
service, store, live, _ := newStaleWorkerEventTestService()
store.remote.sessions["session-1"] = RemoteSession{
ID: "session-1",
State: sessioncontracts.StateTerminated,
WorkerID: "worker-1",
}
if err := service.HandleWorkerConnected(context.Background(), "session-1"); err != nil {
t.Fatalf("HandleWorkerConnected returned error for stale terminal event: %v", err)
}
if got := store.remote.sessions["session-1"].State; got != sessioncontracts.StateTerminated {
t.Fatalf("stale connected event changed terminal state to %q", got)
}
if store.remote.updateCount != 0 {
t.Fatalf("stale connected event updated authoritative session %d times", store.remote.updateCount)
}
if live.upsertCount != 0 {
t.Fatalf("stale connected event recreated live state %d times", live.upsertCount)
}
}
func TestUpdateWorkerRenderTelemetryIgnoresTerminalSession(t *testing.T) {
service, store, live, _ := newStaleWorkerEventTestService()
store.remote.sessions["session-1"] = RemoteSession{
ID: "session-1",
State: sessioncontracts.StateTerminated,
WorkerID: "worker-1",
}
err := service.UpdateWorkerRenderTelemetry(context.Background(), "session-1", map[string]any{
"render_state": "ready",
"width": 1280,
"height": 720,
"frame_sequence": int64(99),
"frame_data": "stale-frame",
})
if err != nil {
t.Fatalf("UpdateWorkerRenderTelemetry returned error for stale terminal event: %v", err)
}
if live.upsertCount != 0 {
t.Fatalf("stale render event recreated live state %d times", live.upsertCount)
}
if live.sessions["session-1"] != nil {
t.Fatalf("stale render event left live state behind: %#v", live.sessions["session-1"])
}
}
func TestMarkSessionFailedTransitionsActiveSession(t *testing.T) {
service, store, live, orchestrator := newStaleWorkerEventTestService()
store.remote.sessions["session-1"] = RemoteSession{
ID: "session-1",
State: sessioncontracts.StateActive,
WorkerID: "worker-1",
TakeoverVersion: 3,
}
store.attachments.items["attachment-1"] = SessionAttachment{
ID: "attachment-1",
RemoteSessionID: "session-1",
State: AttachmentStateActive,
}
live.sessions["session-1"] = &LiveSessionState{SessionID: "session-1", State: sessioncontracts.StateActive}
if err := service.MarkSessionFailed(context.Background(), MarkSessionFailedCommand{SessionID: "session-1", Reason: "worker_lost"}); err != nil {
t.Fatalf("MarkSessionFailed returned error: %v", err)
}
if got := store.remote.sessions["session-1"].State; got != sessioncontracts.StateFailed {
t.Fatalf("expected failed state, got %q", got)
}
if got := store.attachments.items["attachment-1"].State; got != AttachmentStateClosed {
t.Fatalf("expected attachment closed, got %q", got)
}
if store.audit.createCount != 1 {
t.Fatalf("expected one audit event, got %d", store.audit.createCount)
}
if live.sessions["session-1"] != nil {
t.Fatal("expected failed session live state to be deleted")
}
if orchestrator.releaseCount != 1 {
t.Fatalf("expected session lease release, got %d", orchestrator.releaseCount)
}
}
func TestMarkSessionFailedAlreadyFailedIsIdempotent(t *testing.T) {
service, store, _, _ := newStaleWorkerEventTestService()
store.remote.sessions["session-1"] = RemoteSession{
ID: "session-1",
State: sessioncontracts.StateFailed,
WorkerID: "worker-1",
TakeoverVersion: 1,
}
if err := service.MarkSessionFailed(context.Background(), MarkSessionFailedCommand{SessionID: "session-1", Reason: "duplicate_worker_failure"}); err != nil {
t.Fatalf("duplicate MarkSessionFailed returned error: %v", err)
}
if got := store.remote.sessions["session-1"].State; got != sessioncontracts.StateFailed {
t.Fatalf("duplicate terminal event changed state to %q", got)
}
}
func newStaleWorkerEventTestService() (*Service, *staleWorkerEventTestStore, *staleWorkerEventLiveState, *staleWorkerEventOrchestrator) {
store := &staleWorkerEventTestStore{
remote: &staleWorkerEventRemoteSessions{sessions: map[string]RemoteSession{}},
attachments: &staleWorkerEventAttachments{items: map[string]SessionAttachment{}},
policies: &staleWorkerEventPolicies{},
audit: &staleWorkerEventAudit{},
}
live := &staleWorkerEventLiveState{sessions: map[string]*LiveSessionState{}}
orchestrator := &staleWorkerEventOrchestrator{}
service := NewService(module.Dependencies{
Infra: module.Infra{Logger: slog.New(slog.NewTextHandler(io.Discard, nil))},
}, store, staleWorkerEventTransactor{store: store}, live, orchestrator)
service.now = func() time.Time { return time.Unix(100, 0).UTC() }
return service, store, live, orchestrator
}
type staleWorkerEventTransactor struct {
store Store
}
func (t staleWorkerEventTransactor) WithinTransaction(ctx context.Context, fn func(store Store) error) error {
return fn(t.store)
}
type staleWorkerEventTestStore struct {
remote *staleWorkerEventRemoteSessions
attachments *staleWorkerEventAttachments
policies *staleWorkerEventPolicies
audit *staleWorkerEventAudit
}
func (s *staleWorkerEventTestStore) RemoteSessions() RemoteSessionRepository { return s.remote }
func (s *staleWorkerEventTestStore) SessionAttachments() SessionAttachmentRepository {
return s.attachments
}
func (s *staleWorkerEventTestStore) ResourcePolicies() ResourcePolicyRepository { return s.policies }
func (s *staleWorkerEventTestStore) ResourceRuntime() ResourceRuntimeRepository {
return staleWorkerEventResourceRuntime{}
}
func (s *staleWorkerEventTestStore) AuditEvents() AuditEventRepository { return s.audit }
func (s *staleWorkerEventTestStore) Access() AccessRepository { return staleWorkerEventAccess{} }
type staleWorkerEventRemoteSessions struct {
sessions map[string]RemoteSession
updateCount int
}
func (r *staleWorkerEventRemoteSessions) Create(_ context.Context, session RemoteSession) error {
r.sessions[session.ID] = session
return nil
}
func (r *staleWorkerEventRemoteSessions) GetByID(_ context.Context, sessionID string) (*RemoteSession, error) {
session, ok := r.sessions[sessionID]
if !ok {
return nil, nil
}
return &session, nil
}
func (r *staleWorkerEventRemoteSessions) GetByIDForUpdate(ctx context.Context, sessionID string) (*RemoteSession, error) {
return r.GetByID(ctx, sessionID)
}
func (r *staleWorkerEventRemoteSessions) ListByController(_ context.Context, _ string) ([]RemoteSession, error) {
return nil, nil
}
func (r *staleWorkerEventRemoteSessions) CountLiveByResource(_ context.Context, _ string) (int, error) {
return 0, nil
}
func (r *staleWorkerEventRemoteSessions) ListDetachedExpired(_ context.Context, _ time.Time, _ int) ([]RemoteSession, error) {
return nil, nil
}
func (r *staleWorkerEventRemoteSessions) UpdateState(_ context.Context, params UpdateRemoteSessionStateParams) error {
session := r.sessions[params.RemoteSessionID]
session.State = params.State
session.WorkerID = params.WorkerID
session.DetachDeadlineAt = params.DetachDeadlineAt
session.LastHeartbeatAt = params.LastHeartbeatAt
session.TakeoverVersion = params.TakeoverVersion
session.UpdatedAt = params.UpdatedAt
r.sessions[params.RemoteSessionID] = session
r.updateCount++
return nil
}
type staleWorkerEventAttachments struct {
items map[string]SessionAttachment
}
func (r *staleWorkerEventAttachments) Create(_ context.Context, attachment SessionAttachment) error {
r.items[attachment.ID] = attachment
return nil
}
func (r *staleWorkerEventAttachments) GetByID(_ context.Context, attachmentID string) (*SessionAttachment, error) {
attachment, ok := r.items[attachmentID]
if !ok {
return nil, nil
}
return &attachment, nil
}
func (r *staleWorkerEventAttachments) GetByIDForUpdate(ctx context.Context, attachmentID string) (*SessionAttachment, error) {
return r.GetByID(ctx, attachmentID)
}
func (r *staleWorkerEventAttachments) ListByRemoteSession(_ context.Context, remoteSessionID string) ([]SessionAttachment, error) {
attachments := make([]SessionAttachment, 0)
for _, attachment := range r.items {
if attachment.RemoteSessionID == remoteSessionID {
attachments = append(attachments, attachment)
}
}
return attachments, nil
}
func (r *staleWorkerEventAttachments) ListActiveByRemoteSessionForUpdate(ctx context.Context, remoteSessionID string) ([]SessionAttachment, error) {
return r.ListByRemoteSession(ctx, remoteSessionID)
}
func (r *staleWorkerEventAttachments) UpdateState(_ context.Context, params UpdateSessionAttachmentStateParams) error {
attachment := r.items[params.AttachmentID]
attachment.State = params.State
attachment.DetachedAt = params.DetachedAt
attachment.LastInputAt = params.LastInputAt
attachment.UpdatedAt = params.UpdatedAt
r.items[params.AttachmentID] = attachment
return nil
}
func (r *staleWorkerEventAttachments) Supersede(_ context.Context, params SupersedeAttachmentParams) error {
attachment := r.items[params.PreviousAttachmentID]
attachment.State = AttachmentStateSuperseded
attachment.SupersededBy = &params.NextAttachmentID
attachment.DetachedAt = &params.DetachedAt
attachment.UpdatedAt = params.UpdatedAt
r.items[params.PreviousAttachmentID] = attachment
return nil
}
type staleWorkerEventPolicies struct{}
func (r *staleWorkerEventPolicies) GetByResourceID(_ context.Context, _ string) (*ResourcePolicy, error) {
return nil, nil
}
func (r *staleWorkerEventPolicies) Upsert(_ context.Context, _ ResourcePolicy) error {
return nil
}
type staleWorkerEventAudit struct {
createCount int
}
func (r *staleWorkerEventAudit) Create(_ context.Context, _ AuditEvent) error {
r.createCount++
return nil
}
type staleWorkerEventResourceRuntime struct{}
func (staleWorkerEventResourceRuntime) GetByID(_ context.Context, _ string) (*ResourceRuntimeSpec, error) {
return nil, nil
}
type staleWorkerEventAccess struct{}
func (staleWorkerEventAccess) IsTrustedDevice(_ context.Context, _, _ string) (bool, error) {
return false, nil
}
func (staleWorkerEventAccess) GetPlatformRole(_ context.Context, _ string) (string, error) {
return "", nil
}
func (staleWorkerEventAccess) GetOrganizationRole(_ context.Context, _, _ string) (string, bool, error) {
return "", false, nil
}
type staleWorkerEventLiveState struct {
sessions map[string]*LiveSessionState
upsertCount int
}
func (s *staleWorkerEventLiveState) UpsertSession(_ context.Context, state LiveSessionState) error {
copied := state
s.sessions[state.SessionID] = &copied
s.upsertCount++
return nil
}
func (s *staleWorkerEventLiveState) GetSession(_ context.Context, sessionID string) (*LiveSessionState, error) {
state := s.sessions[sessionID]
if state == nil {
return nil, nil
}
copied := *state
return &copied, nil
}
func (s *staleWorkerEventLiveState) DeleteSession(_ context.Context, sessionID string) error {
delete(s.sessions, sessionID)
return nil
}
func (s *staleWorkerEventLiveState) BindController(_ context.Context, _ sessioncontracts.ControllerBinding, _ time.Duration) error {
return nil
}
func (s *staleWorkerEventLiveState) GetControllerBinding(_ context.Context, _ string) (*sessioncontracts.ControllerBinding, error) {
return nil, nil
}
func (s *staleWorkerEventLiveState) ClearControllerBinding(_ context.Context, _ string) error {
return nil
}
func (s *staleWorkerEventLiveState) StoreAttachToken(_ context.Context, _ sessioncontracts.AttachTokenClaims, _ time.Duration) error {
return nil
}
func (s *staleWorkerEventLiveState) ConsumeAttachToken(_ context.Context, _ string) (*sessioncontracts.AttachTokenClaims, error) {
return nil, nil
}
func (s *staleWorkerEventLiveState) TouchAttachmentHeartbeat(_ context.Context, _, _ string, _ time.Duration) error {
return nil
}
func (s *staleWorkerEventLiveState) UpdateWorkerRoute(_ context.Context, _ WorkerRoute, _ time.Duration) error {
return nil
}
func (s *staleWorkerEventLiveState) GetWorkerRoute(_ context.Context, _ string) (*WorkerRoute, error) {
return nil, nil
}
func (s *staleWorkerEventLiveState) DeleteWorkerRoute(_ context.Context, _ string) error {
return nil
}
type staleWorkerEventOrchestrator struct {
releaseCount int
}
func (o *staleWorkerEventOrchestrator) Reserve(_ context.Context, _ workercontracts.AttachRequest) (*workercontracts.WorkerLease, error) {
return nil, nil
}
func (o *staleWorkerEventOrchestrator) GetSessionLease(_ context.Context, _ string) (*workercontracts.WorkerLease, error) {
return nil, nil
}
func (o *staleWorkerEventOrchestrator) ReleaseSessionLease(_ context.Context, _ string) error {
o.releaseCount++
return nil
}
func (o *staleWorkerEventOrchestrator) PrepareAttachment(_ context.Context, _ RemoteSession, _ SessionAttachment, _ map[string]any) error {
return nil
}
func (o *staleWorkerEventOrchestrator) NotifyDetachment(_ context.Context, _ RemoteSession, _ SessionAttachment) error {
return nil
}
func (o *staleWorkerEventOrchestrator) TerminateRemoteSession(_ context.Context, _, _ string) error {
return nil
}
func (o *staleWorkerEventOrchestrator) ValidateSessionRuntime(_ context.Context, _, _ string) (bool, string, error) {
return true, "", nil
}