265 lines
7.5 KiB
Go
265 lines
7.5 KiB
Go
package worker
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"time"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
|
|
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
|
|
)
|
|
|
|
type RedisStore struct {
|
|
client *redis.Client
|
|
}
|
|
|
|
func NewRedisStore(client *redis.Client) *RedisStore {
|
|
return &RedisStore{client: client}
|
|
}
|
|
|
|
func (s *RedisStore) RegisterWorker(ctx context.Context, registration workercontracts.WorkerRegistration, ttl time.Duration) error {
|
|
payload, err := json.Marshal(registration)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal worker registration: %w", err)
|
|
}
|
|
pipe := s.client.TxPipeline()
|
|
pipe.Set(ctx, workerKey(registration.WorkerID), payload, ttl)
|
|
pipe.SAdd(ctx, workerSetKey(), registration.WorkerID)
|
|
_, err = pipe.Exec(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("register worker: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *RedisStore) TouchWorkerHeartbeat(ctx context.Context, heartbeat workercontracts.WorkerHeartbeat, ttl time.Duration) error {
|
|
registration, err := s.GetWorker(ctx, heartbeat.WorkerID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if registration == nil {
|
|
registration = &workercontracts.WorkerRegistration{
|
|
WorkerID: heartbeat.WorkerID,
|
|
Protocol: workercontracts.ProtocolRDP,
|
|
}
|
|
}
|
|
registration.Status = heartbeat.Status
|
|
registration.LastHeartbeatAt = heartbeat.LastHeartbeatAt
|
|
return s.RegisterWorker(ctx, *registration, ttl)
|
|
}
|
|
|
|
func (s *RedisStore) ListWorkers(ctx context.Context) ([]workercontracts.WorkerRegistration, error) {
|
|
ids, err := s.client.SMembers(ctx, workerSetKey()).Result()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list worker ids: %w", err)
|
|
}
|
|
workers := make([]workercontracts.WorkerRegistration, 0, len(ids))
|
|
for _, id := range ids {
|
|
worker, err := s.GetWorker(ctx, id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if worker != nil {
|
|
workers = append(workers, *worker)
|
|
}
|
|
}
|
|
return workers, nil
|
|
}
|
|
|
|
func (s *RedisStore) GetWorker(ctx context.Context, workerID string) (*workercontracts.WorkerRegistration, error) {
|
|
payload, err := s.client.Get(ctx, workerKey(workerID)).Result()
|
|
if err == redis.Nil {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get worker: %w", err)
|
|
}
|
|
var registration workercontracts.WorkerRegistration
|
|
if err := json.Unmarshal([]byte(payload), ®istration); err != nil {
|
|
return nil, fmt.Errorf("decode worker registration: %w", err)
|
|
}
|
|
return ®istration, nil
|
|
}
|
|
|
|
func (s *RedisStore) AcquireLease(ctx context.Context, lease workercontracts.WorkerLease, ttl time.Duration) error {
|
|
payload, err := json.Marshal(lease)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal lease: %w", err)
|
|
}
|
|
ok, err := s.client.SetNX(ctx, leaseKey(lease.LeaseID), payload, ttl).Result()
|
|
if err != nil {
|
|
return fmt.Errorf("acquire lease: %w", err)
|
|
}
|
|
if !ok {
|
|
return fmt.Errorf("lease already exists")
|
|
}
|
|
pipe := s.client.TxPipeline()
|
|
pipe.SAdd(ctx, leaseSetKey(), lease.LeaseID)
|
|
pipe.Set(ctx, sessionLeaseKey(lease.SessionID), lease.LeaseID, ttl)
|
|
_, err = pipe.Exec(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("index lease: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *RedisStore) GetLease(ctx context.Context, leaseID string) (*workercontracts.WorkerLease, error) {
|
|
payload, err := s.client.Get(ctx, leaseKey(leaseID)).Result()
|
|
if err == redis.Nil {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get lease: %w", err)
|
|
}
|
|
var lease workercontracts.WorkerLease
|
|
if err := json.Unmarshal([]byte(payload), &lease); err != nil {
|
|
return nil, fmt.Errorf("decode lease: %w", err)
|
|
}
|
|
return &lease, nil
|
|
}
|
|
|
|
func (s *RedisStore) GetLeaseBySession(ctx context.Context, sessionID string) (*workercontracts.WorkerLease, error) {
|
|
leaseID, err := s.client.Get(ctx, sessionLeaseKey(sessionID)).Result()
|
|
if err == redis.Nil {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get lease by session: %w", err)
|
|
}
|
|
return s.GetLease(ctx, leaseID)
|
|
}
|
|
|
|
func (s *RedisStore) RenewLease(ctx context.Context, lease workercontracts.WorkerLease, ttl time.Duration) error {
|
|
payload, err := json.Marshal(lease)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal lease renewal: %w", err)
|
|
}
|
|
pipe := s.client.TxPipeline()
|
|
pipe.Set(ctx, leaseKey(lease.LeaseID), payload, ttl)
|
|
pipe.Set(ctx, sessionLeaseKey(lease.SessionID), lease.LeaseID, ttl)
|
|
_, err = pipe.Exec(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("renew lease: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *RedisStore) ReleaseLease(ctx context.Context, leaseID string) error {
|
|
lease, err := s.GetLease(ctx, leaseID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
pipe := s.client.TxPipeline()
|
|
pipe.Del(ctx, leaseKey(leaseID))
|
|
pipe.SRem(ctx, leaseSetKey(), leaseID)
|
|
if lease != nil {
|
|
pipe.Del(ctx, sessionLeaseKey(lease.SessionID))
|
|
}
|
|
_, err = pipe.Exec(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("release lease: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *RedisStore) ListLeases(ctx context.Context) ([]workercontracts.WorkerLease, error) {
|
|
ids, err := s.client.SMembers(ctx, leaseSetKey()).Result()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list lease ids: %w", err)
|
|
}
|
|
leases := make([]workercontracts.WorkerLease, 0, len(ids))
|
|
for _, id := range ids {
|
|
lease, err := s.GetLease(ctx, id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if lease != nil {
|
|
leases = append(leases, *lease)
|
|
}
|
|
}
|
|
return leases, nil
|
|
}
|
|
|
|
func (s *RedisStore) AppendEnvelope(ctx context.Context, envelope workercontracts.RoutedEnvelope) error {
|
|
payload, err := json.Marshal(envelope)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal routed envelope: %w", err)
|
|
}
|
|
key := workerQueueKey(envelope.SessionID)
|
|
if err := s.client.RPush(ctx, key, payload).Err(); err != nil {
|
|
return fmt.Errorf("append routed envelope: %w", err)
|
|
}
|
|
if envelope.Type == "input" {
|
|
correlationID, _ := envelope.Payload["correlation_id"].(string)
|
|
if correlationID != "" {
|
|
if length, err := s.client.LLen(ctx, key).Result(); err == nil {
|
|
slog.Info("worker queue length after input append",
|
|
"session_id", envelope.SessionID,
|
|
"attachment_id", envelope.AttachmentID,
|
|
"correlation_id", correlationID,
|
|
"queue_key", key,
|
|
"queue_length", length,
|
|
"trace_stage", "redis_queue_append")
|
|
}
|
|
}
|
|
}
|
|
return s.client.Expire(ctx, key, 10*time.Minute).Err()
|
|
}
|
|
|
|
func (s *RedisStore) AppendAssignment(ctx context.Context, workerID string, payload map[string]any) error {
|
|
encoded, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal worker assignment: %w", err)
|
|
}
|
|
if err := s.client.RPush(ctx, workerControlQueueKey(workerID), encoded).Err(); err != nil {
|
|
return fmt.Errorf("append worker assignment: %w", err)
|
|
}
|
|
return s.client.Expire(ctx, workerControlQueueKey(workerID), 10*time.Minute).Err()
|
|
}
|
|
|
|
func (s *RedisStore) AppendEvent(ctx context.Context, payload map[string]any) error {
|
|
encoded, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal worker event: %w", err)
|
|
}
|
|
if err := s.client.RPush(ctx, workerEventsKey(), encoded).Err(); err != nil {
|
|
return fmt.Errorf("append worker event: %w", err)
|
|
}
|
|
return s.client.Expire(ctx, workerEventsKey(), 10*time.Minute).Err()
|
|
}
|
|
|
|
func workerKey(workerID string) string {
|
|
return "worker:registration:" + workerID
|
|
}
|
|
|
|
func workerSetKey() string {
|
|
return "worker:registrations"
|
|
}
|
|
|
|
func leaseKey(leaseID string) string {
|
|
return "worker:lease:" + leaseID
|
|
}
|
|
|
|
func leaseSetKey() string {
|
|
return "worker:leases"
|
|
}
|
|
|
|
func sessionLeaseKey(sessionID string) string {
|
|
return "worker:session-lease:" + sessionID
|
|
}
|
|
|
|
func workerQueueKey(sessionID string) string {
|
|
return "worker:queue:" + sessionID
|
|
}
|
|
|
|
func workerControlQueueKey(workerID string) string {
|
|
return "worker:control:" + workerID
|
|
}
|
|
|
|
func workerEventsKey() string {
|
|
return "worker:events"
|
|
}
|