Files
rdp-proxy/backend/internal/modules/sessionbroker/redis_livestate.go
T
2026-04-28 22:29:50 +03:00

141 lines
4.5 KiB
Go

package sessionbroker
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
)
type RedisLiveStateStore struct {
client *redis.Client
}
func NewRedisLiveStateStore(client *redis.Client) *RedisLiveStateStore {
return &RedisLiveStateStore{client: client}
}
func (s *RedisLiveStateStore) UpsertSession(ctx context.Context, state LiveSessionState) error {
return s.setJSON(ctx, liveSessionKey(state.SessionID), state, 0)
}
func (s *RedisLiveStateStore) GetSession(ctx context.Context, sessionID string) (*LiveSessionState, error) {
var state LiveSessionState
ok, err := s.getJSON(ctx, liveSessionKey(sessionID), &state)
if err != nil || !ok {
return nil, err
}
return &state, nil
}
func (s *RedisLiveStateStore) DeleteSession(ctx context.Context, sessionID string) error {
return s.client.Del(ctx, liveSessionKey(sessionID)).Err()
}
func (s *RedisLiveStateStore) BindController(ctx context.Context, binding sessioncontracts.ControllerBinding, ttl time.Duration) error {
return s.setJSON(ctx, controllerBindingKey(binding.SessionID), binding, ttl)
}
func (s *RedisLiveStateStore) GetControllerBinding(ctx context.Context, sessionID string) (*sessioncontracts.ControllerBinding, error) {
var binding sessioncontracts.ControllerBinding
ok, err := s.getJSON(ctx, controllerBindingKey(sessionID), &binding)
if err != nil || !ok {
return nil, err
}
return &binding, nil
}
func (s *RedisLiveStateStore) ClearControllerBinding(ctx context.Context, sessionID string) error {
return s.client.Del(ctx, controllerBindingKey(sessionID)).Err()
}
func (s *RedisLiveStateStore) StoreAttachToken(ctx context.Context, claims sessioncontracts.AttachTokenClaims, ttl time.Duration) error {
return s.setJSON(ctx, attachTokenKey(claims.Token), claims, ttl)
}
func (s *RedisLiveStateStore) ConsumeAttachToken(ctx context.Context, token string) (*sessioncontracts.AttachTokenClaims, error) {
key := attachTokenKey(token)
payload, err := s.client.GetDel(ctx, key).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("consume attach token: %w", err)
}
var claims sessioncontracts.AttachTokenClaims
if err := json.Unmarshal([]byte(payload), &claims); err != nil {
return nil, fmt.Errorf("decode attach token: %w", err)
}
return &claims, nil
}
func (s *RedisLiveStateStore) TouchAttachmentHeartbeat(ctx context.Context, sessionID, attachmentID string, ttl time.Duration) error {
return s.client.Set(ctx, attachmentHeartbeatKey(sessionID, attachmentID), time.Now().UTC().Format(time.RFC3339Nano), ttl).Err()
}
func (s *RedisLiveStateStore) UpdateWorkerRoute(ctx context.Context, route WorkerRoute, ttl time.Duration) error {
return s.setJSON(ctx, workerRouteKey(route.SessionID), route, ttl)
}
func (s *RedisLiveStateStore) GetWorkerRoute(ctx context.Context, sessionID string) (*WorkerRoute, error) {
var route WorkerRoute
ok, err := s.getJSON(ctx, workerRouteKey(sessionID), &route)
if err != nil || !ok {
return nil, err
}
return &route, nil
}
func (s *RedisLiveStateStore) DeleteWorkerRoute(ctx context.Context, sessionID string) error {
return s.client.Del(ctx, workerRouteKey(sessionID)).Err()
}
func (s *RedisLiveStateStore) setJSON(ctx context.Context, key string, value any, ttl time.Duration) error {
payload, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("encode redis payload: %w", err)
}
if err := s.client.Set(ctx, key, payload, ttl).Err(); err != nil {
return fmt.Errorf("set redis key %s: %w", key, err)
}
return nil
}
func (s *RedisLiveStateStore) getJSON(ctx context.Context, key string, dest any) (bool, error) {
payload, err := s.client.Get(ctx, key).Result()
if err == redis.Nil {
return false, nil
}
if err != nil {
return false, fmt.Errorf("get redis key %s: %w", key, err)
}
if err := json.Unmarshal([]byte(payload), dest); err != nil {
return false, fmt.Errorf("decode redis key %s: %w", key, err)
}
return true, nil
}
func liveSessionKey(sessionID string) string {
return "live:session:" + sessionID
}
func controllerBindingKey(sessionID string) string {
return "live:session:" + sessionID + ":controller"
}
func attachTokenKey(token string) string {
return "live:attach:" + token
}
func attachmentHeartbeatKey(sessionID, attachmentID string) string {
return "live:session:" + sessionID + ":attachment:" + attachmentID + ":heartbeat"
}
func workerRouteKey(sessionID string) string {
return "live:session:" + sessionID + ":worker-route"
}