275 lines
8.5 KiB
Go
275 lines
8.5 KiB
Go
package worker
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"slices"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
|
|
"github.com/example/remote-access-platform/backend/internal/modules/sessionbroker"
|
|
"github.com/example/remote-access-platform/backend/internal/platform/module"
|
|
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
|
|
)
|
|
|
|
var ErrNoWorkerAvailable = errors.New("no worker available")
|
|
|
|
type Service struct {
|
|
cfg module.Config
|
|
store Store
|
|
now func() time.Time
|
|
}
|
|
|
|
func NewService(deps module.Dependencies, store Store) *Service {
|
|
return &Service{
|
|
cfg: deps.Config,
|
|
store: store,
|
|
now: time.Now,
|
|
}
|
|
}
|
|
|
|
func (s *Service) Register(ctx context.Context, registration workercontracts.WorkerRegistration) error {
|
|
if registration.WorkerID == "" {
|
|
return fmt.Errorf("worker id is required")
|
|
}
|
|
registration.LastHeartbeatAt = s.now().UTC()
|
|
return s.store.RegisterWorker(ctx, registration, s.cfg.Worker.HeartbeatTTL)
|
|
}
|
|
|
|
func (s *Service) Heartbeat(ctx context.Context, heartbeat workercontracts.WorkerHeartbeat) error {
|
|
heartbeat.LastHeartbeatAt = s.now().UTC()
|
|
return s.store.TouchWorkerHeartbeat(ctx, heartbeat, s.cfg.Worker.HeartbeatTTL)
|
|
}
|
|
|
|
func (s *Service) Reserve(ctx context.Context, request workercontracts.AttachRequest) (*workercontracts.WorkerLease, error) {
|
|
registration, err := s.reserveWorker(ctx, workercontracts.ProtocolRDP, request.RequiredCapabilities)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return s.AcquireLease(ctx, registration.WorkerID, request)
|
|
}
|
|
|
|
func (s *Service) reserveWorker(ctx context.Context, protocol workercontracts.Protocol, capabilities []string) (*workercontracts.WorkerRegistration, error) {
|
|
workers, err := s.store.ListWorkers(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
now := s.now().UTC()
|
|
for _, worker := range workers {
|
|
if worker.Protocol != protocol || worker.Status != workercontracts.StatusOnline {
|
|
continue
|
|
}
|
|
if now.Sub(worker.LastHeartbeatAt) > s.cfg.Worker.StaleLeaseGracePeriod+s.cfg.Worker.HeartbeatTTL {
|
|
continue
|
|
}
|
|
if !hasCapabilities(worker.Capabilities, capabilities) {
|
|
continue
|
|
}
|
|
return &worker, nil
|
|
}
|
|
return nil, ErrNoWorkerAvailable
|
|
}
|
|
|
|
func (s *Service) AcquireLease(ctx context.Context, workerID string, request workercontracts.AttachRequest) (*workercontracts.WorkerLease, error) {
|
|
if request.SessionID == "" {
|
|
request.SessionID = uuid.NewString()
|
|
}
|
|
now := s.now().UTC()
|
|
lease := workercontracts.WorkerLease{
|
|
LeaseID: uuid.NewString(),
|
|
WorkerID: workerID,
|
|
Protocol: workercontracts.ProtocolRDP,
|
|
ResourceID: request.ResourceID,
|
|
SessionID: request.SessionID,
|
|
Capabilities: request.RequiredCapabilities,
|
|
ControlStream: "worker://control/" + workerID,
|
|
ExpiresAt: now.Add(s.cfg.Worker.LeaseTTL),
|
|
RenderQualityProfile: normalizeRenderQualityProfile(request.RenderQualityProfile),
|
|
}
|
|
if err := s.store.AcquireLease(ctx, lease, s.cfg.Worker.LeaseTTL); err != nil {
|
|
return nil, err
|
|
}
|
|
return &lease, nil
|
|
}
|
|
|
|
func (s *Service) GetSessionLease(ctx context.Context, sessionID string) (*workercontracts.WorkerLease, error) {
|
|
return s.store.GetLeaseBySession(ctx, sessionID)
|
|
}
|
|
|
|
func (s *Service) RenewLease(ctx context.Context, leaseID string) (*workercontracts.WorkerLease, error) {
|
|
lease, err := s.store.GetLease(ctx, leaseID)
|
|
if err != nil || lease == nil {
|
|
return lease, err
|
|
}
|
|
lease.ExpiresAt = s.now().UTC().Add(s.cfg.Worker.LeaseTTL)
|
|
if err := s.store.RenewLease(ctx, *lease, s.cfg.Worker.LeaseTTL); err != nil {
|
|
return nil, err
|
|
}
|
|
return lease, nil
|
|
}
|
|
|
|
func (s *Service) ReleaseLease(ctx context.Context, leaseID string) error {
|
|
return s.store.ReleaseLease(ctx, leaseID)
|
|
}
|
|
|
|
func (s *Service) ReleaseSessionLease(ctx context.Context, sessionID string) error {
|
|
lease, err := s.store.GetLeaseBySession(ctx, sessionID)
|
|
if err != nil || lease == nil {
|
|
return err
|
|
}
|
|
return s.store.ReleaseLease(ctx, lease.LeaseID)
|
|
}
|
|
|
|
func (s *Service) RecoverStaleLeases(ctx context.Context) ([]workercontracts.WorkerLease, error) {
|
|
leases, err := s.store.ListLeases(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var stale []workercontracts.WorkerLease
|
|
now := s.now().UTC()
|
|
for _, lease := range leases {
|
|
registration, err := s.store.GetWorker(ctx, lease.WorkerID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if registration == nil || now.Sub(registration.LastHeartbeatAt) > s.cfg.Worker.StaleLeaseGracePeriod+s.cfg.Worker.HeartbeatTTL {
|
|
if err := s.store.ReleaseLease(ctx, lease.LeaseID); err != nil {
|
|
return nil, err
|
|
}
|
|
stale = append(stale, lease)
|
|
}
|
|
}
|
|
return stale, nil
|
|
}
|
|
|
|
func (s *Service) ValidateSessionRuntime(ctx context.Context, sessionID, workerID string) (bool, string, error) {
|
|
lease, err := s.store.GetLeaseBySession(ctx, sessionID)
|
|
if err != nil {
|
|
return false, "", err
|
|
}
|
|
if lease == nil {
|
|
return false, "worker_lease_missing", nil
|
|
}
|
|
if workerID != "" && lease.WorkerID != workerID {
|
|
_ = s.store.ReleaseLease(ctx, lease.LeaseID)
|
|
return false, "worker_binding_mismatch", nil
|
|
}
|
|
now := s.now().UTC()
|
|
if !lease.ExpiresAt.After(now) {
|
|
_ = s.store.ReleaseLease(ctx, lease.LeaseID)
|
|
return false, "worker_lease_expired", nil
|
|
}
|
|
registration, err := s.store.GetWorker(ctx, lease.WorkerID)
|
|
if err != nil {
|
|
return false, "", err
|
|
}
|
|
if registration == nil {
|
|
_ = s.store.ReleaseLease(ctx, lease.LeaseID)
|
|
return false, "worker_registration_missing", nil
|
|
}
|
|
if registration.Status != workercontracts.StatusOnline {
|
|
return false, "worker_not_online", nil
|
|
}
|
|
if now.Sub(registration.LastHeartbeatAt) > s.cfg.Worker.StaleLeaseGracePeriod+s.cfg.Worker.HeartbeatTTL {
|
|
_ = s.store.ReleaseLease(ctx, lease.LeaseID)
|
|
return false, "worker_heartbeat_stale", nil
|
|
}
|
|
return true, "", nil
|
|
}
|
|
|
|
func (s *Service) PublishControl(ctx context.Context, envelope workercontracts.RoutedEnvelope) error {
|
|
return s.store.AppendEnvelope(ctx, envelope)
|
|
}
|
|
|
|
func (s *Service) PublishInput(ctx context.Context, envelope workercontracts.RoutedEnvelope) error {
|
|
return s.store.AppendEnvelope(ctx, envelope)
|
|
}
|
|
|
|
func hasCapabilities(workerCaps, required []string) bool {
|
|
for _, capability := range required {
|
|
if !slices.Contains(workerCaps, capability) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (s *Service) PrepareAttachment(ctx context.Context, session sessionbroker.RemoteSession, attachment sessionbroker.SessionAttachment, runtimeMetadata map[string]any) error {
|
|
renderQualityProfile := normalizeRenderQualityProfile(session.RenderQualityProfile)
|
|
if renderQualityProfile == "balanced" {
|
|
renderQualityProfile = renderQualityProfileFromMetadata(session.Metadata)
|
|
}
|
|
if runtimeMetadata == nil {
|
|
runtimeMetadata = decodeMetadata(session.Metadata)
|
|
}
|
|
return s.store.AppendAssignment(ctx, session.WorkerID, map[string]any{
|
|
"type": "session_assignment",
|
|
"session_id": session.ID,
|
|
"worker_id": session.WorkerID,
|
|
"attachment_id": attachment.ID,
|
|
"user_id": attachment.UserID,
|
|
"device_id": attachment.DeviceID,
|
|
"takeover_of": attachment.TakeoverOf,
|
|
"state": session.State,
|
|
"render_quality_profile": renderQualityProfile,
|
|
"metadata": runtimeMetadata,
|
|
})
|
|
}
|
|
|
|
func (s *Service) NotifyDetachment(ctx context.Context, session sessionbroker.RemoteSession, attachment sessionbroker.SessionAttachment) error {
|
|
return s.PublishControl(ctx, workercontracts.RoutedEnvelope{
|
|
SessionID: session.ID,
|
|
AttachmentID: attachment.ID,
|
|
Type: "control",
|
|
Payload: map[string]any{
|
|
"action": "detach",
|
|
},
|
|
})
|
|
}
|
|
|
|
func (s *Service) TerminateRemoteSession(ctx context.Context, sessionID, attachmentID string) error {
|
|
return s.PublishControl(ctx, workercontracts.RoutedEnvelope{
|
|
SessionID: sessionID,
|
|
AttachmentID: attachmentID,
|
|
Type: "control",
|
|
Payload: map[string]any{
|
|
"action": "terminate",
|
|
},
|
|
})
|
|
}
|
|
|
|
func decodeMetadata(payload []byte) map[string]any {
|
|
var out map[string]any
|
|
if len(payload) == 0 {
|
|
return map[string]any{}
|
|
}
|
|
if err := json.Unmarshal(payload, &out); err != nil {
|
|
return map[string]any{}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func normalizeRenderQualityProfile(profile string) string {
|
|
switch profile {
|
|
case "low_bandwidth", "balanced", "high_quality", "text_priority":
|
|
return profile
|
|
default:
|
|
return "balanced"
|
|
}
|
|
}
|
|
|
|
func renderQualityProfileFromMetadata(metadata []byte) string {
|
|
decoded := decodeMetadata(metadata)
|
|
resource, _ := decoded["resource"].(map[string]any)
|
|
if resource == nil {
|
|
return "balanced"
|
|
}
|
|
if profile, ok := resource["render_quality_profile"].(string); ok && profile != "" {
|
|
return normalizeRenderQualityProfile(profile)
|
|
}
|
|
return "balanced"
|
|
}
|