Files
rdp-proxy/backend/internal/modules/worker/service.go
T
2026-04-28 22:29:50 +03:00

275 lines
8.5 KiB
Go

package worker
import (
"context"
"encoding/json"
"errors"
"fmt"
"slices"
"time"
"github.com/google/uuid"
"github.com/example/remote-access-platform/backend/internal/modules/sessionbroker"
"github.com/example/remote-access-platform/backend/internal/platform/module"
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
)
var ErrNoWorkerAvailable = errors.New("no worker available")
type Service struct {
cfg module.Config
store Store
now func() time.Time
}
func NewService(deps module.Dependencies, store Store) *Service {
return &Service{
cfg: deps.Config,
store: store,
now: time.Now,
}
}
func (s *Service) Register(ctx context.Context, registration workercontracts.WorkerRegistration) error {
if registration.WorkerID == "" {
return fmt.Errorf("worker id is required")
}
registration.LastHeartbeatAt = s.now().UTC()
return s.store.RegisterWorker(ctx, registration, s.cfg.Worker.HeartbeatTTL)
}
func (s *Service) Heartbeat(ctx context.Context, heartbeat workercontracts.WorkerHeartbeat) error {
heartbeat.LastHeartbeatAt = s.now().UTC()
return s.store.TouchWorkerHeartbeat(ctx, heartbeat, s.cfg.Worker.HeartbeatTTL)
}
func (s *Service) Reserve(ctx context.Context, request workercontracts.AttachRequest) (*workercontracts.WorkerLease, error) {
registration, err := s.reserveWorker(ctx, workercontracts.ProtocolRDP, request.RequiredCapabilities)
if err != nil {
return nil, err
}
return s.AcquireLease(ctx, registration.WorkerID, request)
}
func (s *Service) reserveWorker(ctx context.Context, protocol workercontracts.Protocol, capabilities []string) (*workercontracts.WorkerRegistration, error) {
workers, err := s.store.ListWorkers(ctx)
if err != nil {
return nil, err
}
now := s.now().UTC()
for _, worker := range workers {
if worker.Protocol != protocol || worker.Status != workercontracts.StatusOnline {
continue
}
if now.Sub(worker.LastHeartbeatAt) > s.cfg.Worker.StaleLeaseGracePeriod+s.cfg.Worker.HeartbeatTTL {
continue
}
if !hasCapabilities(worker.Capabilities, capabilities) {
continue
}
return &worker, nil
}
return nil, ErrNoWorkerAvailable
}
func (s *Service) AcquireLease(ctx context.Context, workerID string, request workercontracts.AttachRequest) (*workercontracts.WorkerLease, error) {
if request.SessionID == "" {
request.SessionID = uuid.NewString()
}
now := s.now().UTC()
lease := workercontracts.WorkerLease{
LeaseID: uuid.NewString(),
WorkerID: workerID,
Protocol: workercontracts.ProtocolRDP,
ResourceID: request.ResourceID,
SessionID: request.SessionID,
Capabilities: request.RequiredCapabilities,
ControlStream: "worker://control/" + workerID,
ExpiresAt: now.Add(s.cfg.Worker.LeaseTTL),
RenderQualityProfile: normalizeRenderQualityProfile(request.RenderQualityProfile),
}
if err := s.store.AcquireLease(ctx, lease, s.cfg.Worker.LeaseTTL); err != nil {
return nil, err
}
return &lease, nil
}
func (s *Service) GetSessionLease(ctx context.Context, sessionID string) (*workercontracts.WorkerLease, error) {
return s.store.GetLeaseBySession(ctx, sessionID)
}
func (s *Service) RenewLease(ctx context.Context, leaseID string) (*workercontracts.WorkerLease, error) {
lease, err := s.store.GetLease(ctx, leaseID)
if err != nil || lease == nil {
return lease, err
}
lease.ExpiresAt = s.now().UTC().Add(s.cfg.Worker.LeaseTTL)
if err := s.store.RenewLease(ctx, *lease, s.cfg.Worker.LeaseTTL); err != nil {
return nil, err
}
return lease, nil
}
func (s *Service) ReleaseLease(ctx context.Context, leaseID string) error {
return s.store.ReleaseLease(ctx, leaseID)
}
func (s *Service) ReleaseSessionLease(ctx context.Context, sessionID string) error {
lease, err := s.store.GetLeaseBySession(ctx, sessionID)
if err != nil || lease == nil {
return err
}
return s.store.ReleaseLease(ctx, lease.LeaseID)
}
func (s *Service) RecoverStaleLeases(ctx context.Context) ([]workercontracts.WorkerLease, error) {
leases, err := s.store.ListLeases(ctx)
if err != nil {
return nil, err
}
var stale []workercontracts.WorkerLease
now := s.now().UTC()
for _, lease := range leases {
registration, err := s.store.GetWorker(ctx, lease.WorkerID)
if err != nil {
return nil, err
}
if registration == nil || now.Sub(registration.LastHeartbeatAt) > s.cfg.Worker.StaleLeaseGracePeriod+s.cfg.Worker.HeartbeatTTL {
if err := s.store.ReleaseLease(ctx, lease.LeaseID); err != nil {
return nil, err
}
stale = append(stale, lease)
}
}
return stale, nil
}
func (s *Service) ValidateSessionRuntime(ctx context.Context, sessionID, workerID string) (bool, string, error) {
lease, err := s.store.GetLeaseBySession(ctx, sessionID)
if err != nil {
return false, "", err
}
if lease == nil {
return false, "worker_lease_missing", nil
}
if workerID != "" && lease.WorkerID != workerID {
_ = s.store.ReleaseLease(ctx, lease.LeaseID)
return false, "worker_binding_mismatch", nil
}
now := s.now().UTC()
if !lease.ExpiresAt.After(now) {
_ = s.store.ReleaseLease(ctx, lease.LeaseID)
return false, "worker_lease_expired", nil
}
registration, err := s.store.GetWorker(ctx, lease.WorkerID)
if err != nil {
return false, "", err
}
if registration == nil {
_ = s.store.ReleaseLease(ctx, lease.LeaseID)
return false, "worker_registration_missing", nil
}
if registration.Status != workercontracts.StatusOnline {
return false, "worker_not_online", nil
}
if now.Sub(registration.LastHeartbeatAt) > s.cfg.Worker.StaleLeaseGracePeriod+s.cfg.Worker.HeartbeatTTL {
_ = s.store.ReleaseLease(ctx, lease.LeaseID)
return false, "worker_heartbeat_stale", nil
}
return true, "", nil
}
func (s *Service) PublishControl(ctx context.Context, envelope workercontracts.RoutedEnvelope) error {
return s.store.AppendEnvelope(ctx, envelope)
}
func (s *Service) PublishInput(ctx context.Context, envelope workercontracts.RoutedEnvelope) error {
return s.store.AppendEnvelope(ctx, envelope)
}
func hasCapabilities(workerCaps, required []string) bool {
for _, capability := range required {
if !slices.Contains(workerCaps, capability) {
return false
}
}
return true
}
func (s *Service) PrepareAttachment(ctx context.Context, session sessionbroker.RemoteSession, attachment sessionbroker.SessionAttachment, runtimeMetadata map[string]any) error {
renderQualityProfile := normalizeRenderQualityProfile(session.RenderQualityProfile)
if renderQualityProfile == "balanced" {
renderQualityProfile = renderQualityProfileFromMetadata(session.Metadata)
}
if runtimeMetadata == nil {
runtimeMetadata = decodeMetadata(session.Metadata)
}
return s.store.AppendAssignment(ctx, session.WorkerID, map[string]any{
"type": "session_assignment",
"session_id": session.ID,
"worker_id": session.WorkerID,
"attachment_id": attachment.ID,
"user_id": attachment.UserID,
"device_id": attachment.DeviceID,
"takeover_of": attachment.TakeoverOf,
"state": session.State,
"render_quality_profile": renderQualityProfile,
"metadata": runtimeMetadata,
})
}
func (s *Service) NotifyDetachment(ctx context.Context, session sessionbroker.RemoteSession, attachment sessionbroker.SessionAttachment) error {
return s.PublishControl(ctx, workercontracts.RoutedEnvelope{
SessionID: session.ID,
AttachmentID: attachment.ID,
Type: "control",
Payload: map[string]any{
"action": "detach",
},
})
}
func (s *Service) TerminateRemoteSession(ctx context.Context, sessionID, attachmentID string) error {
return s.PublishControl(ctx, workercontracts.RoutedEnvelope{
SessionID: sessionID,
AttachmentID: attachmentID,
Type: "control",
Payload: map[string]any{
"action": "terminate",
},
})
}
func decodeMetadata(payload []byte) map[string]any {
var out map[string]any
if len(payload) == 0 {
return map[string]any{}
}
if err := json.Unmarshal(payload, &out); err != nil {
return map[string]any{}
}
return out
}
func normalizeRenderQualityProfile(profile string) string {
switch profile {
case "low_bandwidth", "balanced", "high_quality", "text_priority":
return profile
default:
return "balanced"
}
}
func renderQualityProfileFromMetadata(metadata []byte) string {
decoded := decodeMetadata(metadata)
resource, _ := decoded["resource"].(map[string]any)
if resource == nil {
return "balanced"
}
if profile, ok := resource["render_quality_profile"].(string); ok && profile != "" {
return normalizeRenderQualityProfile(profile)
}
return "balanced"
}