141 lines
4.5 KiB
Go
141 lines
4.5 KiB
Go
package sessionbroker
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
|
|
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
|
|
)
|
|
|
|
type RedisLiveStateStore struct {
|
|
client *redis.Client
|
|
}
|
|
|
|
func NewRedisLiveStateStore(client *redis.Client) *RedisLiveStateStore {
|
|
return &RedisLiveStateStore{client: client}
|
|
}
|
|
|
|
func (s *RedisLiveStateStore) UpsertSession(ctx context.Context, state LiveSessionState) error {
|
|
return s.setJSON(ctx, liveSessionKey(state.SessionID), state, 0)
|
|
}
|
|
|
|
func (s *RedisLiveStateStore) GetSession(ctx context.Context, sessionID string) (*LiveSessionState, error) {
|
|
var state LiveSessionState
|
|
ok, err := s.getJSON(ctx, liveSessionKey(sessionID), &state)
|
|
if err != nil || !ok {
|
|
return nil, err
|
|
}
|
|
return &state, nil
|
|
}
|
|
|
|
func (s *RedisLiveStateStore) DeleteSession(ctx context.Context, sessionID string) error {
|
|
return s.client.Del(ctx, liveSessionKey(sessionID)).Err()
|
|
}
|
|
|
|
func (s *RedisLiveStateStore) BindController(ctx context.Context, binding sessioncontracts.ControllerBinding, ttl time.Duration) error {
|
|
return s.setJSON(ctx, controllerBindingKey(binding.SessionID), binding, ttl)
|
|
}
|
|
|
|
func (s *RedisLiveStateStore) GetControllerBinding(ctx context.Context, sessionID string) (*sessioncontracts.ControllerBinding, error) {
|
|
var binding sessioncontracts.ControllerBinding
|
|
ok, err := s.getJSON(ctx, controllerBindingKey(sessionID), &binding)
|
|
if err != nil || !ok {
|
|
return nil, err
|
|
}
|
|
return &binding, nil
|
|
}
|
|
|
|
func (s *RedisLiveStateStore) ClearControllerBinding(ctx context.Context, sessionID string) error {
|
|
return s.client.Del(ctx, controllerBindingKey(sessionID)).Err()
|
|
}
|
|
|
|
func (s *RedisLiveStateStore) StoreAttachToken(ctx context.Context, claims sessioncontracts.AttachTokenClaims, ttl time.Duration) error {
|
|
return s.setJSON(ctx, attachTokenKey(claims.Token), claims, ttl)
|
|
}
|
|
|
|
func (s *RedisLiveStateStore) ConsumeAttachToken(ctx context.Context, token string) (*sessioncontracts.AttachTokenClaims, error) {
|
|
key := attachTokenKey(token)
|
|
payload, err := s.client.GetDel(ctx, key).Result()
|
|
if err == redis.Nil {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("consume attach token: %w", err)
|
|
}
|
|
var claims sessioncontracts.AttachTokenClaims
|
|
if err := json.Unmarshal([]byte(payload), &claims); err != nil {
|
|
return nil, fmt.Errorf("decode attach token: %w", err)
|
|
}
|
|
return &claims, nil
|
|
}
|
|
|
|
func (s *RedisLiveStateStore) TouchAttachmentHeartbeat(ctx context.Context, sessionID, attachmentID string, ttl time.Duration) error {
|
|
return s.client.Set(ctx, attachmentHeartbeatKey(sessionID, attachmentID), time.Now().UTC().Format(time.RFC3339Nano), ttl).Err()
|
|
}
|
|
|
|
func (s *RedisLiveStateStore) UpdateWorkerRoute(ctx context.Context, route WorkerRoute, ttl time.Duration) error {
|
|
return s.setJSON(ctx, workerRouteKey(route.SessionID), route, ttl)
|
|
}
|
|
|
|
func (s *RedisLiveStateStore) GetWorkerRoute(ctx context.Context, sessionID string) (*WorkerRoute, error) {
|
|
var route WorkerRoute
|
|
ok, err := s.getJSON(ctx, workerRouteKey(sessionID), &route)
|
|
if err != nil || !ok {
|
|
return nil, err
|
|
}
|
|
return &route, nil
|
|
}
|
|
|
|
func (s *RedisLiveStateStore) DeleteWorkerRoute(ctx context.Context, sessionID string) error {
|
|
return s.client.Del(ctx, workerRouteKey(sessionID)).Err()
|
|
}
|
|
|
|
func (s *RedisLiveStateStore) setJSON(ctx context.Context, key string, value any, ttl time.Duration) error {
|
|
payload, err := json.Marshal(value)
|
|
if err != nil {
|
|
return fmt.Errorf("encode redis payload: %w", err)
|
|
}
|
|
if err := s.client.Set(ctx, key, payload, ttl).Err(); err != nil {
|
|
return fmt.Errorf("set redis key %s: %w", key, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *RedisLiveStateStore) getJSON(ctx context.Context, key string, dest any) (bool, error) {
|
|
payload, err := s.client.Get(ctx, key).Result()
|
|
if err == redis.Nil {
|
|
return false, nil
|
|
}
|
|
if err != nil {
|
|
return false, fmt.Errorf("get redis key %s: %w", key, err)
|
|
}
|
|
if err := json.Unmarshal([]byte(payload), dest); err != nil {
|
|
return false, fmt.Errorf("decode redis key %s: %w", key, err)
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
func liveSessionKey(sessionID string) string {
|
|
return "live:session:" + sessionID
|
|
}
|
|
|
|
func controllerBindingKey(sessionID string) string {
|
|
return "live:session:" + sessionID + ":controller"
|
|
}
|
|
|
|
func attachTokenKey(token string) string {
|
|
return "live:attach:" + token
|
|
}
|
|
|
|
func attachmentHeartbeatKey(sessionID, attachmentID string) string {
|
|
return "live:session:" + sessionID + ":attachment:" + attachmentID + ":heartbeat"
|
|
}
|
|
|
|
func workerRouteKey(sessionID string) string {
|
|
return "live:session:" + sessionID + ":worker-route"
|
|
}
|