Files
rdp-proxy/backend/internal/modules/worker/redis_store.go
T
2026-04-28 22:29:50 +03:00

265 lines
7.5 KiB
Go

package worker
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"time"
"github.com/redis/go-redis/v9"
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
)
type RedisStore struct {
client *redis.Client
}
func NewRedisStore(client *redis.Client) *RedisStore {
return &RedisStore{client: client}
}
func (s *RedisStore) RegisterWorker(ctx context.Context, registration workercontracts.WorkerRegistration, ttl time.Duration) error {
payload, err := json.Marshal(registration)
if err != nil {
return fmt.Errorf("marshal worker registration: %w", err)
}
pipe := s.client.TxPipeline()
pipe.Set(ctx, workerKey(registration.WorkerID), payload, ttl)
pipe.SAdd(ctx, workerSetKey(), registration.WorkerID)
_, err = pipe.Exec(ctx)
if err != nil {
return fmt.Errorf("register worker: %w", err)
}
return nil
}
func (s *RedisStore) TouchWorkerHeartbeat(ctx context.Context, heartbeat workercontracts.WorkerHeartbeat, ttl time.Duration) error {
registration, err := s.GetWorker(ctx, heartbeat.WorkerID)
if err != nil {
return err
}
if registration == nil {
registration = &workercontracts.WorkerRegistration{
WorkerID: heartbeat.WorkerID,
Protocol: workercontracts.ProtocolRDP,
}
}
registration.Status = heartbeat.Status
registration.LastHeartbeatAt = heartbeat.LastHeartbeatAt
return s.RegisterWorker(ctx, *registration, ttl)
}
func (s *RedisStore) ListWorkers(ctx context.Context) ([]workercontracts.WorkerRegistration, error) {
ids, err := s.client.SMembers(ctx, workerSetKey()).Result()
if err != nil {
return nil, fmt.Errorf("list worker ids: %w", err)
}
workers := make([]workercontracts.WorkerRegistration, 0, len(ids))
for _, id := range ids {
worker, err := s.GetWorker(ctx, id)
if err != nil {
return nil, err
}
if worker != nil {
workers = append(workers, *worker)
}
}
return workers, nil
}
func (s *RedisStore) GetWorker(ctx context.Context, workerID string) (*workercontracts.WorkerRegistration, error) {
payload, err := s.client.Get(ctx, workerKey(workerID)).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("get worker: %w", err)
}
var registration workercontracts.WorkerRegistration
if err := json.Unmarshal([]byte(payload), &registration); err != nil {
return nil, fmt.Errorf("decode worker registration: %w", err)
}
return &registration, nil
}
func (s *RedisStore) AcquireLease(ctx context.Context, lease workercontracts.WorkerLease, ttl time.Duration) error {
payload, err := json.Marshal(lease)
if err != nil {
return fmt.Errorf("marshal lease: %w", err)
}
ok, err := s.client.SetNX(ctx, leaseKey(lease.LeaseID), payload, ttl).Result()
if err != nil {
return fmt.Errorf("acquire lease: %w", err)
}
if !ok {
return fmt.Errorf("lease already exists")
}
pipe := s.client.TxPipeline()
pipe.SAdd(ctx, leaseSetKey(), lease.LeaseID)
pipe.Set(ctx, sessionLeaseKey(lease.SessionID), lease.LeaseID, ttl)
_, err = pipe.Exec(ctx)
if err != nil {
return fmt.Errorf("index lease: %w", err)
}
return nil
}
func (s *RedisStore) GetLease(ctx context.Context, leaseID string) (*workercontracts.WorkerLease, error) {
payload, err := s.client.Get(ctx, leaseKey(leaseID)).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("get lease: %w", err)
}
var lease workercontracts.WorkerLease
if err := json.Unmarshal([]byte(payload), &lease); err != nil {
return nil, fmt.Errorf("decode lease: %w", err)
}
return &lease, nil
}
func (s *RedisStore) GetLeaseBySession(ctx context.Context, sessionID string) (*workercontracts.WorkerLease, error) {
leaseID, err := s.client.Get(ctx, sessionLeaseKey(sessionID)).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("get lease by session: %w", err)
}
return s.GetLease(ctx, leaseID)
}
func (s *RedisStore) RenewLease(ctx context.Context, lease workercontracts.WorkerLease, ttl time.Duration) error {
payload, err := json.Marshal(lease)
if err != nil {
return fmt.Errorf("marshal lease renewal: %w", err)
}
pipe := s.client.TxPipeline()
pipe.Set(ctx, leaseKey(lease.LeaseID), payload, ttl)
pipe.Set(ctx, sessionLeaseKey(lease.SessionID), lease.LeaseID, ttl)
_, err = pipe.Exec(ctx)
if err != nil {
return fmt.Errorf("renew lease: %w", err)
}
return nil
}
func (s *RedisStore) ReleaseLease(ctx context.Context, leaseID string) error {
lease, err := s.GetLease(ctx, leaseID)
if err != nil {
return err
}
pipe := s.client.TxPipeline()
pipe.Del(ctx, leaseKey(leaseID))
pipe.SRem(ctx, leaseSetKey(), leaseID)
if lease != nil {
pipe.Del(ctx, sessionLeaseKey(lease.SessionID))
}
_, err = pipe.Exec(ctx)
if err != nil {
return fmt.Errorf("release lease: %w", err)
}
return nil
}
func (s *RedisStore) ListLeases(ctx context.Context) ([]workercontracts.WorkerLease, error) {
ids, err := s.client.SMembers(ctx, leaseSetKey()).Result()
if err != nil {
return nil, fmt.Errorf("list lease ids: %w", err)
}
leases := make([]workercontracts.WorkerLease, 0, len(ids))
for _, id := range ids {
lease, err := s.GetLease(ctx, id)
if err != nil {
return nil, err
}
if lease != nil {
leases = append(leases, *lease)
}
}
return leases, nil
}
func (s *RedisStore) AppendEnvelope(ctx context.Context, envelope workercontracts.RoutedEnvelope) error {
payload, err := json.Marshal(envelope)
if err != nil {
return fmt.Errorf("marshal routed envelope: %w", err)
}
key := workerQueueKey(envelope.SessionID)
if err := s.client.RPush(ctx, key, payload).Err(); err != nil {
return fmt.Errorf("append routed envelope: %w", err)
}
if envelope.Type == "input" {
correlationID, _ := envelope.Payload["correlation_id"].(string)
if correlationID != "" {
if length, err := s.client.LLen(ctx, key).Result(); err == nil {
slog.Info("worker queue length after input append",
"session_id", envelope.SessionID,
"attachment_id", envelope.AttachmentID,
"correlation_id", correlationID,
"queue_key", key,
"queue_length", length,
"trace_stage", "redis_queue_append")
}
}
}
return s.client.Expire(ctx, key, 10*time.Minute).Err()
}
func (s *RedisStore) AppendAssignment(ctx context.Context, workerID string, payload map[string]any) error {
encoded, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal worker assignment: %w", err)
}
if err := s.client.RPush(ctx, workerControlQueueKey(workerID), encoded).Err(); err != nil {
return fmt.Errorf("append worker assignment: %w", err)
}
return s.client.Expire(ctx, workerControlQueueKey(workerID), 10*time.Minute).Err()
}
func (s *RedisStore) AppendEvent(ctx context.Context, payload map[string]any) error {
encoded, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal worker event: %w", err)
}
if err := s.client.RPush(ctx, workerEventsKey(), encoded).Err(); err != nil {
return fmt.Errorf("append worker event: %w", err)
}
return s.client.Expire(ctx, workerEventsKey(), 10*time.Minute).Err()
}
func workerKey(workerID string) string {
return "worker:registration:" + workerID
}
func workerSetKey() string {
return "worker:registrations"
}
func leaseKey(leaseID string) string {
return "worker:lease:" + leaseID
}
func leaseSetKey() string {
return "worker:leases"
}
func sessionLeaseKey(sessionID string) string {
return "worker:session-lease:" + sessionID
}
func workerQueueKey(sessionID string) string {
return "worker:queue:" + sessionID
}
func workerControlQueueKey(workerID string) string {
return "worker:control:" + workerID
}
func workerEventsKey() string {
return "worker:events"
}