1119 lines
37 KiB
Go
1119 lines
37 KiB
Go
package sessiongateway
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"hash/fnv"
|
|
"log/slog"
|
|
"net/http"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/websocket"
|
|
|
|
"github.com/example/remote-access-platform/backend/internal/modules/sessionbroker"
|
|
"github.com/example/remote-access-platform/backend/internal/modules/worker"
|
|
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
|
|
"github.com/example/remote-access-platform/backend/internal/platform/module"
|
|
messagecontracts "github.com/example/remote-access-platform/backend/pkg/contracts/message"
|
|
sessioncontracts "github.com/example/remote-access-platform/backend/pkg/contracts/session"
|
|
workercontracts "github.com/example/remote-access-platform/backend/pkg/contracts/worker"
|
|
)
|
|
|
|
const maxClipboardTextBytes = 1024 * 1024
|
|
const maxFileUploadBytes = 25 * 1024 * 1024
|
|
const maxFileUploadChunkBytes = 256 * 1024
|
|
const liveSyncInterval = time.Second
|
|
|
|
var safeTransferIDPattern = regexp.MustCompile(`^[a-fA-F0-9-]{36}$`)
|
|
var safeFileIDPattern = regexp.MustCompile(`^[a-zA-Z0-9._:-]{1,160}$`)
|
|
|
|
type gatewayConnection struct {
|
|
conn *websocket.Conn
|
|
writeMu sync.Mutex
|
|
uploads map[string]*fileUploadState
|
|
}
|
|
|
|
type fileUploadState struct {
|
|
TransferID string
|
|
FileName string
|
|
FileSize int64
|
|
TotalChunks int64
|
|
Received int64
|
|
NextIndex int64
|
|
ContentHash string
|
|
}
|
|
|
|
func (c *gatewayConnection) writeJSON(timeout time.Duration, value any) error {
|
|
c.writeMu.Lock()
|
|
defer c.writeMu.Unlock()
|
|
_ = c.conn.SetWriteDeadline(time.Now().Add(timeout))
|
|
return c.conn.WriteJSON(value)
|
|
}
|
|
|
|
func (c *gatewayConnection) writeMessage(timeout time.Duration, messageType int, data []byte) error {
|
|
c.writeMu.Lock()
|
|
defer c.writeMu.Unlock()
|
|
_ = c.conn.SetWriteDeadline(time.Now().Add(timeout))
|
|
return c.conn.WriteMessage(messageType, data)
|
|
}
|
|
|
|
type Module struct {
|
|
cfg module.Config
|
|
logger *slog.Logger
|
|
broker *sessionbroker.Service
|
|
workers *worker.Service
|
|
upgrader websocket.Upgrader
|
|
}
|
|
|
|
func NewModule(deps module.Dependencies, broker *sessionbroker.Service, workers *worker.Service) *Module {
|
|
return &Module{
|
|
cfg: deps.Config,
|
|
logger: deps.Infra.Logger,
|
|
broker: broker,
|
|
workers: workers,
|
|
upgrader: websocket.Upgrader{
|
|
CheckOrigin: func(_ *http.Request) bool { return true },
|
|
},
|
|
}
|
|
}
|
|
|
|
func (m *Module) Name() string {
|
|
return "session-gateway"
|
|
}
|
|
|
|
func (m *Module) RegisterRoutes(router chi.Router) {
|
|
router.Route("/gateway", func(r chi.Router) {
|
|
r.Get("/status", m.status)
|
|
r.Get("/ws", m.handleWebSocket)
|
|
})
|
|
}
|
|
|
|
func (m *Module) status(w http.ResponseWriter, _ *http.Request) {
|
|
httpx.WriteJSON(w, http.StatusOK, map[string]any{
|
|
"module": m.Name(),
|
|
"transport": "websocket",
|
|
"attach_handshake": "query.attach_token",
|
|
"ping_every": m.cfg.WebSocket.PingInterval.String(),
|
|
"pong_deadby": m.cfg.WebSocket.PongWait.String(),
|
|
})
|
|
}
|
|
|
|
func (m *Module) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|
token := r.URL.Query().Get("attach_token")
|
|
if token == "" {
|
|
httpx.WriteError(w, http.StatusBadRequest, "attach_token is required")
|
|
return
|
|
}
|
|
claims, state, err := m.broker.ConsumeAttachToken(r.Context(), token)
|
|
if err != nil {
|
|
httpx.WriteError(w, http.StatusUnauthorized, err.Error())
|
|
return
|
|
}
|
|
conn, err := m.upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
httpx.WriteError(w, http.StatusBadRequest, "websocket upgrade failed")
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
gatewayConn := &gatewayConnection{conn: conn, uploads: make(map[string]*fileUploadState)}
|
|
|
|
sessionCtx, cancel := context.WithCancel(r.Context())
|
|
defer cancel()
|
|
|
|
_ = conn.SetReadDeadline(time.Now().Add(m.cfg.WebSocket.PongWait))
|
|
conn.SetPongHandler(func(string) error {
|
|
if err := m.broker.TouchAttachmentHeartbeat(sessionCtx, claims.SessionID, claims.AttachmentID); err != nil {
|
|
return err
|
|
}
|
|
return conn.SetReadDeadline(time.Now().Add(m.cfg.WebSocket.PongWait))
|
|
})
|
|
|
|
if err := m.writeEnvelope(gatewayConn, sessioncontracts.TransportEnvelope{
|
|
Type: "session.state",
|
|
SessionID: claims.SessionID,
|
|
Payload: map[string]any{
|
|
"state": state.State,
|
|
"attachment_id": claims.AttachmentID,
|
|
"takeover_version": state.TakeoverVersion,
|
|
"reconnectable": claims.Reconnectable,
|
|
"worker_routing_id": claims.WorkerID,
|
|
"render": renderPayloadFromLiveState(state),
|
|
},
|
|
Event: m.newEventMessage(
|
|
"session.state."+string(state.State),
|
|
"events.session.state."+string(state.State),
|
|
"Session state updated.",
|
|
map[string]any{
|
|
"state": state.State,
|
|
},
|
|
),
|
|
}); err != nil {
|
|
return
|
|
}
|
|
lastFrameSequence := int64(0)
|
|
lastClipboardSequence := int64(0)
|
|
lastFileDownloadSequence := int64(0)
|
|
if state != nil {
|
|
lastFrameSequence = state.RenderFrameSequence
|
|
lastClipboardSequence = state.ClipboardSequence
|
|
if state.RenderFrameSequence > 0 && state.RenderFrameData != "" {
|
|
_ = m.writeFrameEnvelope(gatewayConn, claims.SessionID, *state)
|
|
}
|
|
if state.ClipboardSequence > 0 && state.ClipboardText != "" {
|
|
if err := m.ensureClipboardAllowed(sessionCtx, claims.SessionID, "server_to_client"); err != nil {
|
|
lastClipboardSequence = state.ClipboardSequence
|
|
} else {
|
|
_ = m.writeClipboardEnvelope(gatewayConn, claims.SessionID, *state)
|
|
}
|
|
}
|
|
}
|
|
|
|
readErr := make(chan error, 1)
|
|
go func() {
|
|
readErr <- m.readLoop(sessionCtx, gatewayConn, claims)
|
|
}()
|
|
|
|
pingTicker := time.NewTicker(m.cfg.WebSocket.PingInterval)
|
|
defer pingTicker.Stop()
|
|
syncTicker := time.NewTicker(liveSyncInterval)
|
|
defer syncTicker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case err := <-readErr:
|
|
if err == nil {
|
|
return
|
|
}
|
|
_ = m.writeEnvelope(gatewayConn, sessioncontracts.TransportEnvelope{
|
|
Type: "transport.closed",
|
|
SessionID: claims.SessionID,
|
|
Payload: map[string]any{
|
|
"reason": err.Error(),
|
|
},
|
|
Event: m.newEventMessage(
|
|
"transport.closed",
|
|
"events.session.transport_closed",
|
|
"The session transport closed.",
|
|
map[string]any{
|
|
"reason": err.Error(),
|
|
},
|
|
),
|
|
})
|
|
return
|
|
case <-syncTicker.C:
|
|
if err := m.syncConnection(sessionCtx, gatewayConn, claims, &lastFrameSequence, &lastClipboardSequence, &lastFileDownloadSequence); err != nil {
|
|
return
|
|
}
|
|
case <-pingTicker.C:
|
|
if err := gatewayConn.writeMessage(m.cfg.WebSocket.WriteTimeout, websocket.PingMessage, []byte("ping")); err != nil {
|
|
return
|
|
}
|
|
case <-sessionCtx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *Module) readLoop(ctx context.Context, conn *gatewayConnection, claims *sessioncontracts.AttachTokenClaims) error {
|
|
for {
|
|
var envelope sessioncontracts.TransportEnvelope
|
|
if err := conn.conn.ReadJSON(&envelope); err != nil {
|
|
return err
|
|
}
|
|
if err := m.handleEnvelope(ctx, conn, claims, envelope); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *Module) handleEnvelope(ctx context.Context, conn *gatewayConnection, claims *sessioncontracts.AttachTokenClaims, envelope sessioncontracts.TransportEnvelope) error {
|
|
switch envelope.Type {
|
|
case "heartbeat":
|
|
if err := m.broker.TouchAttachmentHeartbeat(ctx, claims.SessionID, claims.AttachmentID); err != nil {
|
|
return err
|
|
}
|
|
return m.writeEnvelope(conn, sessioncontracts.TransportEnvelope{
|
|
Type: "heartbeat.ack",
|
|
SessionID: claims.SessionID,
|
|
Event: m.newEventMessage(
|
|
"session.heartbeat_ack",
|
|
"events.session.heartbeat_ack",
|
|
"Session heartbeat acknowledged.",
|
|
nil,
|
|
),
|
|
})
|
|
case "control":
|
|
action, _ := envelope.Payload["action"].(string)
|
|
if action == "detach" {
|
|
_, err := m.broker.DetachFromSession(ctx, sessionbroker.DetachFromSessionCommand{
|
|
SessionID: claims.SessionID,
|
|
AttachmentID: claims.AttachmentID,
|
|
UserID: claims.UserID,
|
|
Reason: "client_requested_detach",
|
|
})
|
|
return err
|
|
}
|
|
return m.workers.PublishControl(ctx, workercontracts.RoutedEnvelope{
|
|
SessionID: claims.SessionID,
|
|
AttachmentID: claims.AttachmentID,
|
|
Type: "control",
|
|
Payload: envelope.Payload,
|
|
})
|
|
case "input":
|
|
session, err := m.broker.GetSessionSnapshot(ctx, claims.SessionID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if session == nil {
|
|
return sessionbroker.ErrSessionNotFound
|
|
}
|
|
if session.State != sessioncontracts.StateActive {
|
|
m.logger.Warn("session gateway input envelope rejected because session is not active",
|
|
"session_id", claims.SessionID,
|
|
"attachment_id", claims.AttachmentID,
|
|
"state", session.State)
|
|
return sessionbroker.ErrSessionNotAttachable
|
|
}
|
|
action, _ := envelope.Payload["action"].(string)
|
|
kind, _ := envelope.Payload["kind"].(string)
|
|
correlationID, _ := envelope.Payload["correlation_id"].(string)
|
|
m.logger.Info("session gateway input envelope received",
|
|
"session_id", claims.SessionID,
|
|
"attachment_id", claims.AttachmentID,
|
|
"takeover_version", claims.TakeoverVersion,
|
|
"kind", kind,
|
|
"action", action,
|
|
"correlation_id", correlationID,
|
|
"trace_stage", "backend_receive")
|
|
err = m.workers.PublishInput(ctx, workercontracts.RoutedEnvelope{
|
|
SessionID: claims.SessionID,
|
|
AttachmentID: claims.AttachmentID,
|
|
Type: "input",
|
|
Payload: envelope.Payload,
|
|
})
|
|
if err != nil {
|
|
m.logger.Warn("session gateway input envelope rejected",
|
|
"session_id", claims.SessionID,
|
|
"attachment_id", claims.AttachmentID,
|
|
"kind", kind,
|
|
"action", action,
|
|
"reason", err.Error())
|
|
return err
|
|
}
|
|
m.logger.Info("session gateway input envelope routed to worker",
|
|
"session_id", claims.SessionID,
|
|
"attachment_id", claims.AttachmentID,
|
|
"kind", kind,
|
|
"action", action,
|
|
"correlation_id", correlationID,
|
|
"trace_stage", "backend_route")
|
|
return nil
|
|
case "clipboard":
|
|
if ok, err := m.isCurrentController(ctx, claims); err != nil {
|
|
return err
|
|
} else if !ok {
|
|
return m.writeTakenOver(conn, claims.SessionID)
|
|
}
|
|
if err := m.ensureClipboardAllowed(ctx, claims.SessionID, "client_to_server"); err != nil {
|
|
return m.writeClipboardBlocked(conn, claims.SessionID, err.Error())
|
|
}
|
|
text, _ := envelope.Payload["text"].(string)
|
|
if text == "" || len([]byte(text)) > maxClipboardTextBytes {
|
|
return m.writeClipboardBlocked(conn, claims.SessionID, "clipboard text is required")
|
|
}
|
|
contentHash, _ := envelope.Payload["content_hash"].(string)
|
|
if contentHash == "" {
|
|
contentHash = clipboardContentHash(text)
|
|
}
|
|
return m.workers.PublishControl(ctx, workercontracts.RoutedEnvelope{
|
|
SessionID: claims.SessionID,
|
|
AttachmentID: claims.AttachmentID,
|
|
Type: "clipboard",
|
|
Payload: map[string]any{
|
|
"direction": "client_to_server",
|
|
"text": text,
|
|
"sequence_id": time.Now().UTC().UnixNano(),
|
|
"origin": claims.AttachmentID,
|
|
"content_hash": contentHash,
|
|
"attachment_id": claims.AttachmentID,
|
|
},
|
|
})
|
|
case "file_upload.start":
|
|
return m.handleFileUploadStart(ctx, conn, claims, envelope)
|
|
case "file_upload.chunk":
|
|
return m.handleFileUploadChunk(ctx, conn, claims, envelope)
|
|
case "file_upload.cancel":
|
|
return m.handleFileUploadCancel(ctx, conn, claims, envelope)
|
|
case "file_download.start":
|
|
return m.handleFileDownloadControl(ctx, conn, claims, envelope, "start")
|
|
case "file_download.ack":
|
|
return m.handleFileDownloadControl(ctx, conn, claims, envelope, "ack")
|
|
case "file_download.cancel":
|
|
return m.handleFileDownloadControl(ctx, conn, claims, envelope, "cancel")
|
|
case "gui":
|
|
return m.workers.PublishControl(ctx, workercontracts.RoutedEnvelope{
|
|
SessionID: claims.SessionID,
|
|
AttachmentID: claims.AttachmentID,
|
|
Type: "gui",
|
|
Payload: envelope.Payload,
|
|
})
|
|
default:
|
|
return m.writeEnvelope(conn, sessioncontracts.TransportEnvelope{
|
|
Type: "error",
|
|
SessionID: claims.SessionID,
|
|
Payload: map[string]any{
|
|
"message": "unsupported envelope type",
|
|
"type": envelope.Type,
|
|
},
|
|
Event: m.newEventMessage(
|
|
"session.unsupported_envelope",
|
|
"events.session.unsupported_envelope",
|
|
"Unsupported session envelope type.",
|
|
map[string]any{
|
|
"type": envelope.Type,
|
|
},
|
|
),
|
|
})
|
|
}
|
|
}
|
|
|
|
func (m *Module) handleFileUploadStart(ctx context.Context, conn *gatewayConnection, claims *sessioncontracts.AttachTokenClaims, envelope sessioncontracts.TransportEnvelope) error {
|
|
if ok, err := m.isCurrentController(ctx, claims); err != nil {
|
|
return err
|
|
} else if !ok {
|
|
return m.writeTakenOver(conn, claims.SessionID)
|
|
}
|
|
if err := m.ensureFileUploadAllowed(ctx, claims.SessionID); err != nil {
|
|
return m.writeFileUploadBlocked(conn, claims.SessionID, stringValue(envelope.Payload, "transfer_id"), err.Error())
|
|
}
|
|
transferID, err := validateTransferID(stringValue(envelope.Payload, "transfer_id"))
|
|
if err != nil {
|
|
return m.writeFileUploadBlocked(conn, claims.SessionID, "", err.Error())
|
|
}
|
|
fileName, err := sanitizeUploadFileName(stringValue(envelope.Payload, "file_name"))
|
|
if err != nil {
|
|
return m.writeFileUploadBlocked(conn, claims.SessionID, transferID, err.Error())
|
|
}
|
|
fileSize, ok := numberValue(envelope.Payload, "file_size")
|
|
if !ok || fileSize <= 0 || fileSize > maxFileUploadBytes {
|
|
return m.writeFileUploadBlocked(conn, claims.SessionID, transferID, "file size exceeds policy")
|
|
}
|
|
totalChunks, ok := numberValue(envelope.Payload, "total_chunks")
|
|
if !ok || totalChunks <= 0 || totalChunks > (maxFileUploadBytes/maxFileUploadChunkBytes)+1 {
|
|
return m.writeFileUploadBlocked(conn, claims.SessionID, transferID, "invalid upload chunk count")
|
|
}
|
|
if _, exists := conn.uploads[transferID]; exists {
|
|
return m.writeFileUploadBlocked(conn, claims.SessionID, transferID, "duplicate transfer_id")
|
|
}
|
|
contentHash := stringValue(envelope.Payload, "content_hash")
|
|
conn.uploads[transferID] = &fileUploadState{
|
|
TransferID: transferID,
|
|
FileName: fileName,
|
|
FileSize: fileSize,
|
|
TotalChunks: totalChunks,
|
|
ContentHash: contentHash,
|
|
}
|
|
m.logger.Info("session gateway file upload start accepted",
|
|
"session_id", claims.SessionID,
|
|
"attachment_id", claims.AttachmentID,
|
|
"transfer_id", transferID,
|
|
"file_name", fileName,
|
|
"file_size", fileSize,
|
|
"total_chunks", totalChunks)
|
|
if err := m.workers.PublishControl(ctx, workercontracts.RoutedEnvelope{
|
|
SessionID: claims.SessionID,
|
|
AttachmentID: claims.AttachmentID,
|
|
Type: "file_upload",
|
|
Payload: map[string]any{
|
|
"action": "start",
|
|
"direction": "client_to_server",
|
|
"transfer_id": transferID,
|
|
"file_name": fileName,
|
|
"file_size": fileSize,
|
|
"total_chunks": totalChunks,
|
|
"content_hash": contentHash,
|
|
"attachment_id": claims.AttachmentID,
|
|
},
|
|
}); err != nil {
|
|
delete(conn.uploads, transferID)
|
|
return err
|
|
}
|
|
return m.writeFileUploadProgress(conn, claims.SessionID, transferID, 0, fileSize, "started")
|
|
}
|
|
|
|
func (m *Module) handleFileUploadChunk(ctx context.Context, conn *gatewayConnection, claims *sessioncontracts.AttachTokenClaims, envelope sessioncontracts.TransportEnvelope) error {
|
|
if ok, err := m.isCurrentController(ctx, claims); err != nil {
|
|
return err
|
|
} else if !ok {
|
|
return m.writeTakenOver(conn, claims.SessionID)
|
|
}
|
|
if err := m.ensureFileUploadAllowed(ctx, claims.SessionID); err != nil {
|
|
return m.writeFileUploadBlocked(conn, claims.SessionID, stringValue(envelope.Payload, "transfer_id"), err.Error())
|
|
}
|
|
transferID, err := validateTransferID(stringValue(envelope.Payload, "transfer_id"))
|
|
if err != nil {
|
|
return m.writeFileUploadBlocked(conn, claims.SessionID, "", err.Error())
|
|
}
|
|
state := conn.uploads[transferID]
|
|
if state == nil {
|
|
return m.writeFileUploadBlocked(conn, claims.SessionID, transferID, "unknown transfer_id")
|
|
}
|
|
chunkIndex, ok := numberValue(envelope.Payload, "chunk_index")
|
|
if !ok || chunkIndex != state.NextIndex || chunkIndex >= state.TotalChunks {
|
|
return m.writeFileUploadBlocked(conn, claims.SessionID, transferID, "invalid chunk index")
|
|
}
|
|
offset, ok := numberValue(envelope.Payload, "offset")
|
|
if !ok || offset != state.Received {
|
|
return m.writeFileUploadBlocked(conn, claims.SessionID, transferID, "invalid chunk offset")
|
|
}
|
|
chunkBytes := stringValue(envelope.Payload, "chunk_bytes")
|
|
decoded, err := base64.StdEncoding.DecodeString(chunkBytes)
|
|
if err != nil || len(decoded) == 0 || len(decoded) > maxFileUploadChunkBytes {
|
|
return m.writeFileUploadBlocked(conn, claims.SessionID, transferID, "invalid chunk payload")
|
|
}
|
|
if state.Received+int64(len(decoded)) > state.FileSize {
|
|
return m.writeFileUploadBlocked(conn, claims.SessionID, transferID, "chunk exceeds declared file size")
|
|
}
|
|
state.Received += int64(len(decoded))
|
|
state.NextIndex++
|
|
status := "transferring"
|
|
if state.Received == state.FileSize && state.NextIndex == state.TotalChunks {
|
|
status = "completed"
|
|
}
|
|
m.logger.Info("session gateway file upload chunk accepted",
|
|
"session_id", claims.SessionID,
|
|
"attachment_id", claims.AttachmentID,
|
|
"transfer_id", transferID,
|
|
"chunk_index", chunkIndex,
|
|
"chunk_size", len(decoded),
|
|
"received", state.Received,
|
|
"file_size", state.FileSize)
|
|
if err := m.workers.PublishControl(ctx, workercontracts.RoutedEnvelope{
|
|
SessionID: claims.SessionID,
|
|
AttachmentID: claims.AttachmentID,
|
|
Type: "file_upload",
|
|
Payload: map[string]any{
|
|
"action": "chunk",
|
|
"direction": "client_to_server",
|
|
"transfer_id": transferID,
|
|
"chunk_index": chunkIndex,
|
|
"offset": offset,
|
|
"chunk_size": len(decoded),
|
|
"chunk_bytes": chunkBytes,
|
|
"attachment_id": claims.AttachmentID,
|
|
},
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
if status == "completed" {
|
|
delete(conn.uploads, transferID)
|
|
}
|
|
return m.writeFileUploadProgress(conn, claims.SessionID, transferID, state.Received, state.FileSize, status)
|
|
}
|
|
|
|
func (m *Module) handleFileUploadCancel(ctx context.Context, conn *gatewayConnection, claims *sessioncontracts.AttachTokenClaims, envelope sessioncontracts.TransportEnvelope) error {
|
|
transferID, err := validateTransferID(stringValue(envelope.Payload, "transfer_id"))
|
|
if err != nil {
|
|
return m.writeFileUploadBlocked(conn, claims.SessionID, "", err.Error())
|
|
}
|
|
delete(conn.uploads, transferID)
|
|
if err := m.workers.PublishControl(ctx, workercontracts.RoutedEnvelope{
|
|
SessionID: claims.SessionID,
|
|
AttachmentID: claims.AttachmentID,
|
|
Type: "file_upload",
|
|
Payload: map[string]any{
|
|
"action": "cancel",
|
|
"direction": "client_to_server",
|
|
"transfer_id": transferID,
|
|
"attachment_id": claims.AttachmentID,
|
|
},
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
return m.writeFileUploadProgress(conn, claims.SessionID, transferID, 0, 0, "cancelled")
|
|
}
|
|
|
|
func (m *Module) handleFileDownloadControl(ctx context.Context, conn *gatewayConnection, claims *sessioncontracts.AttachTokenClaims, envelope sessioncontracts.TransportEnvelope, action string) error {
|
|
if ok, err := m.isCurrentController(ctx, claims); err != nil {
|
|
return err
|
|
} else if !ok {
|
|
return m.writeTakenOver(conn, claims.SessionID)
|
|
}
|
|
if err := m.ensureFileDownloadAllowed(ctx, claims.SessionID); err != nil {
|
|
return m.writeFileDownloadBlocked(conn, claims.SessionID, stringValue(envelope.Payload, "transfer_id"), err.Error())
|
|
}
|
|
transferID, err := validateTransferID(stringValue(envelope.Payload, "transfer_id"))
|
|
if err != nil {
|
|
return m.writeFileDownloadBlocked(conn, claims.SessionID, "", err.Error())
|
|
}
|
|
payload := map[string]any{
|
|
"action": action,
|
|
"direction": "server_to_client",
|
|
"transfer_id": transferID,
|
|
"attachment_id": claims.AttachmentID,
|
|
}
|
|
switch action {
|
|
case "start":
|
|
fileID := strings.TrimSpace(stringValue(envelope.Payload, "file_id"))
|
|
if fileID == "" || !safeFileIDPattern.MatchString(fileID) {
|
|
return m.writeFileDownloadBlocked(conn, claims.SessionID, transferID, "invalid file_id")
|
|
}
|
|
payload["file_id"] = fileID
|
|
case "ack":
|
|
if sequence, ok := numberValue(envelope.Payload, "sequence"); ok && sequence >= 0 {
|
|
payload["sequence"] = sequence
|
|
}
|
|
if offset, ok := numberValue(envelope.Payload, "offset"); ok && offset >= 0 {
|
|
payload["offset"] = offset
|
|
}
|
|
case "cancel":
|
|
default:
|
|
return m.writeFileDownloadBlocked(conn, claims.SessionID, transferID, "invalid download action")
|
|
}
|
|
m.logger.Info("session gateway file download control accepted",
|
|
"session_id", claims.SessionID,
|
|
"attachment_id", claims.AttachmentID,
|
|
"transfer_id", transferID,
|
|
"action", action,
|
|
"file_id", payload["file_id"])
|
|
return m.workers.PublishControl(ctx, workercontracts.RoutedEnvelope{
|
|
SessionID: claims.SessionID,
|
|
AttachmentID: claims.AttachmentID,
|
|
Type: "file_download",
|
|
Payload: payload,
|
|
})
|
|
}
|
|
|
|
func (m *Module) syncConnection(ctx context.Context, conn *gatewayConnection, claims *sessioncontracts.AttachTokenClaims, lastFrameSequence *int64, lastClipboardSequence *int64, lastFileDownloadSequence *int64) error {
|
|
binding, err := m.broker.GetControllerBinding(ctx, claims.SessionID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if binding == nil {
|
|
session, err := m.broker.GetSessionSnapshot(ctx, claims.SessionID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if session != nil {
|
|
switch session.State {
|
|
case sessioncontracts.StateFailed, sessioncontracts.StateTerminated:
|
|
if err := m.writeEnvelope(conn, sessioncontracts.TransportEnvelope{
|
|
Type: "session.state",
|
|
SessionID: claims.SessionID,
|
|
Payload: map[string]any{
|
|
"state": session.State,
|
|
"takeover_version": session.TakeoverVersion,
|
|
"attachment_id": claims.AttachmentID,
|
|
"render": renderPayloadFromSessionSnapshot(session),
|
|
},
|
|
Event: m.newEventMessage(
|
|
"session.state."+string(session.State),
|
|
"events.session.state."+string(session.State),
|
|
"Session state updated.",
|
|
map[string]any{
|
|
"state": session.State,
|
|
},
|
|
),
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
_ = m.writeEnvelope(conn, sessioncontracts.TransportEnvelope{
|
|
Type: "transport.closed",
|
|
SessionID: claims.SessionID,
|
|
Payload: map[string]any{
|
|
"reason": string(session.State),
|
|
},
|
|
Event: m.newEventMessage(
|
|
"transport.closed",
|
|
"events.session.transport_closed",
|
|
"The session transport closed.",
|
|
map[string]any{
|
|
"reason": session.State,
|
|
},
|
|
),
|
|
})
|
|
return websocket.ErrCloseSent
|
|
}
|
|
}
|
|
_ = m.writeTakenOver(conn, claims.SessionID)
|
|
return websocket.ErrCloseSent
|
|
}
|
|
if binding.AttachmentID != claims.AttachmentID || binding.TakeoverVersion != claims.TakeoverVersion {
|
|
_ = m.writeTakenOver(conn, claims.SessionID)
|
|
return websocket.ErrCloseSent
|
|
}
|
|
state, err := m.broker.GetLiveSession(ctx, claims.SessionID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if state != nil {
|
|
if err := m.writeEnvelope(conn, sessioncontracts.TransportEnvelope{
|
|
Type: "session.state",
|
|
SessionID: claims.SessionID,
|
|
Payload: map[string]any{
|
|
"state": state.State,
|
|
"takeover_version": state.TakeoverVersion,
|
|
"attachment_id": state.AttachmentID,
|
|
"render": renderPayloadFromLiveState(state),
|
|
},
|
|
Event: m.newEventMessage(
|
|
"session.state."+string(state.State),
|
|
"events.session.state."+string(state.State),
|
|
"Session state updated.",
|
|
map[string]any{
|
|
"state": state.State,
|
|
},
|
|
),
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
if state.RenderFrameSequence > 0 && state.RenderFrameSequence != *lastFrameSequence && state.RenderFrameData != "" {
|
|
if err := m.writeFrameEnvelope(conn, claims.SessionID, *state); err != nil {
|
|
return err
|
|
}
|
|
*lastFrameSequence = state.RenderFrameSequence
|
|
}
|
|
m.logger.Info("session gateway clipboard sync evaluation",
|
|
"session_id", claims.SessionID,
|
|
"attachment_id", claims.AttachmentID,
|
|
"clipboard_sequence", state.ClipboardSequence,
|
|
"last_clipboard_sequence", *lastClipboardSequence,
|
|
"clipboard_text_bytes", len(state.ClipboardText),
|
|
"clipboard_origin", state.ClipboardOrigin)
|
|
if state.ClipboardSequence > 0 && state.ClipboardSequence != *lastClipboardSequence && state.ClipboardText != "" {
|
|
if err := m.ensureClipboardAllowed(ctx, claims.SessionID, "server_to_client"); err == nil {
|
|
m.logger.Info("session gateway clipboard sync allowed",
|
|
"session_id", claims.SessionID,
|
|
"attachment_id", claims.AttachmentID,
|
|
"sequence_id", state.ClipboardSequence)
|
|
if err := m.writeClipboardEnvelope(conn, claims.SessionID, *state); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
m.logger.Warn("session gateway clipboard sync blocked",
|
|
"session_id", claims.SessionID,
|
|
"attachment_id", claims.AttachmentID,
|
|
"sequence_id", state.ClipboardSequence,
|
|
"reason", err.Error())
|
|
}
|
|
*lastClipboardSequence = state.ClipboardSequence
|
|
}
|
|
if state.FileDownloadSequence > 0 && state.FileDownloadSequence != *lastFileDownloadSequence && state.FileDownloadType != "" {
|
|
if err := m.ensureFileDownloadAllowed(ctx, claims.SessionID); err == nil {
|
|
if err := m.writeFileDownloadEnvelope(conn, claims.SessionID, *state); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
m.logger.Warn("session gateway file download sync blocked",
|
|
"session_id", claims.SessionID,
|
|
"attachment_id", claims.AttachmentID,
|
|
"sequence", state.FileDownloadSequence,
|
|
"reason", err.Error())
|
|
}
|
|
*lastFileDownloadSequence = state.FileDownloadSequence
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Module) isCurrentController(ctx context.Context, claims *sessioncontracts.AttachTokenClaims) (bool, error) {
|
|
binding, err := m.broker.GetControllerBinding(ctx, claims.SessionID)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return binding != nil && binding.AttachmentID == claims.AttachmentID && binding.TakeoverVersion == claims.TakeoverVersion, nil
|
|
}
|
|
|
|
func (m *Module) writeTakenOver(conn *gatewayConnection, sessionID string) error {
|
|
return m.writeEnvelope(conn, sessioncontracts.TransportEnvelope{
|
|
Type: "session.taken_over",
|
|
SessionID: sessionID,
|
|
Payload: map[string]any{
|
|
"message": "controller binding changed",
|
|
},
|
|
Event: m.newEventMessage(
|
|
"session.taken_over",
|
|
"events.session.taken_over",
|
|
"This session was taken over from another device.",
|
|
map[string]any{
|
|
"reason": "controller_binding_changed",
|
|
},
|
|
),
|
|
})
|
|
}
|
|
|
|
func (m *Module) writeClipboardEnvelope(conn *gatewayConnection, sessionID string, state sessionbroker.LiveSessionState) error {
|
|
m.logger.Info("session gateway writing clipboard envelope",
|
|
"session_id", sessionID,
|
|
"sequence_id", state.ClipboardSequence,
|
|
"origin", state.ClipboardOrigin,
|
|
"content_hash", state.ClipboardContentHash,
|
|
"text_bytes", len(state.ClipboardText))
|
|
return m.writeEnvelope(conn, sessioncontracts.TransportEnvelope{
|
|
Type: "clipboard.text",
|
|
SessionID: sessionID,
|
|
Payload: map[string]any{
|
|
"direction": "server_to_client",
|
|
"text": state.ClipboardText,
|
|
"clipboard_sequence": state.ClipboardSequence,
|
|
"sequence_id": state.ClipboardSequence,
|
|
"origin": state.ClipboardOrigin,
|
|
"content_hash": state.ClipboardContentHash,
|
|
},
|
|
Event: m.newEventMessage(
|
|
"clipboard.text",
|
|
"events.clipboard.text_received",
|
|
"Clipboard text received from the remote session.",
|
|
map[string]any{"direction": "server_to_client"},
|
|
),
|
|
})
|
|
}
|
|
|
|
func clipboardContentHash(text string) string {
|
|
hash := fnv.New64a()
|
|
_, _ = hash.Write([]byte(text))
|
|
return fmt.Sprintf("%016x", hash.Sum64())
|
|
}
|
|
|
|
func (m *Module) writeClipboardBlocked(conn *gatewayConnection, sessionID, reason string) error {
|
|
return m.writeEnvelope(conn, sessioncontracts.TransportEnvelope{
|
|
Type: "clipboard.blocked",
|
|
SessionID: sessionID,
|
|
Payload: map[string]any{
|
|
"reason": reason,
|
|
},
|
|
Event: m.newEventMessage(
|
|
"clipboard.blocked",
|
|
"events.clipboard.blocked",
|
|
"Clipboard transfer is blocked by session state or resource policy.",
|
|
map[string]any{"reason": reason},
|
|
),
|
|
})
|
|
}
|
|
|
|
func (m *Module) writeFileUploadProgress(conn *gatewayConnection, sessionID, transferID string, received, total int64, status string) error {
|
|
return m.writeEnvelope(conn, sessioncontracts.TransportEnvelope{
|
|
Type: "file_upload.progress",
|
|
SessionID: sessionID,
|
|
Payload: map[string]any{
|
|
"transfer_id": transferID,
|
|
"received": received,
|
|
"total": total,
|
|
"status": status,
|
|
},
|
|
Event: m.newEventMessage(
|
|
"file_upload."+status,
|
|
"events.file_upload."+status,
|
|
"File upload status updated.",
|
|
map[string]any{"transfer_id": transferID, "status": status},
|
|
),
|
|
})
|
|
}
|
|
|
|
func (m *Module) writeFileUploadBlocked(conn *gatewayConnection, sessionID, transferID, reason string) error {
|
|
return m.writeEnvelope(conn, sessioncontracts.TransportEnvelope{
|
|
Type: "file_upload.blocked",
|
|
SessionID: sessionID,
|
|
Payload: map[string]any{
|
|
"transfer_id": transferID,
|
|
"reason": reason,
|
|
},
|
|
Event: m.newEventMessage(
|
|
"file_upload.blocked",
|
|
"events.file_upload.blocked",
|
|
"File upload is blocked by session state or resource policy.",
|
|
map[string]any{"reason": reason, "transfer_id": transferID},
|
|
),
|
|
})
|
|
}
|
|
|
|
func (m *Module) writeFileDownloadEnvelope(conn *gatewayConnection, sessionID string, state sessionbroker.LiveSessionState) error {
|
|
payload := map[string]any{}
|
|
for key, value := range state.FileDownloadPayload {
|
|
payload[key] = value
|
|
}
|
|
payload["sequence"] = state.FileDownloadSequence
|
|
payload["direction"] = "server_to_client"
|
|
envelopeType := fileDownloadEnvelopeType(state.FileDownloadType)
|
|
m.logger.Info("session gateway writing file download envelope",
|
|
"session_id", sessionID,
|
|
"type", envelopeType,
|
|
"sequence", state.FileDownloadSequence,
|
|
"transfer_id", payload["transfer_id"],
|
|
"file_id", payload["file_id"],
|
|
"file_name", payload["file_name"],
|
|
"status", payload["status"])
|
|
return m.writeEnvelope(conn, sessioncontracts.TransportEnvelope{
|
|
Type: envelopeType,
|
|
SessionID: sessionID,
|
|
Payload: payload,
|
|
Event: m.newEventMessage(
|
|
fileDownloadEventCode(envelopeType),
|
|
fileDownloadMessageKey(envelopeType),
|
|
fileDownloadFallback(envelopeType),
|
|
map[string]any{
|
|
"transfer_id": payload["transfer_id"],
|
|
"file_id": payload["file_id"],
|
|
"file_name": payload["file_name"],
|
|
"status": payload["status"],
|
|
},
|
|
),
|
|
})
|
|
}
|
|
|
|
func (m *Module) writeFileDownloadBlocked(conn *gatewayConnection, sessionID, transferID, reason string) error {
|
|
return m.writeEnvelope(conn, sessioncontracts.TransportEnvelope{
|
|
Type: "file_download.blocked",
|
|
SessionID: sessionID,
|
|
Payload: map[string]any{
|
|
"transfer_id": transferID,
|
|
"reason": reason,
|
|
"direction": "server_to_client",
|
|
},
|
|
Event: m.newEventMessage(
|
|
"file_download.blocked",
|
|
"events.file_download.blocked",
|
|
"File download is blocked by session state or resource policy.",
|
|
map[string]any{"reason": reason, "transfer_id": transferID},
|
|
),
|
|
})
|
|
}
|
|
|
|
func (m *Module) ensureClipboardAllowed(ctx context.Context, sessionID, direction string) error {
|
|
mode, state, err := m.broker.GetSessionClipboardPolicy(ctx, sessionID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if state != sessioncontracts.StateActive {
|
|
return sessionbroker.ErrSessionNotAttachable
|
|
}
|
|
switch direction {
|
|
case "client_to_server":
|
|
if mode == sessionbroker.ResourceClipboardModeClientToServer || mode == sessionbroker.ResourceClipboardModeBidirectional {
|
|
return nil
|
|
}
|
|
case "server_to_client":
|
|
if mode == sessionbroker.ResourceClipboardModeServerToClient || mode == sessionbroker.ResourceClipboardModeBidirectional {
|
|
return nil
|
|
}
|
|
}
|
|
return sessionbroker.ErrAccessDenied
|
|
}
|
|
|
|
func (m *Module) ensureFileUploadAllowed(ctx context.Context, sessionID string) error {
|
|
mode, state, err := m.broker.GetSessionFileTransferPolicy(ctx, sessionID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if state != sessioncontracts.StateActive {
|
|
return sessionbroker.ErrSessionNotAttachable
|
|
}
|
|
if mode == sessionbroker.ResourceFileTransferModeClientToServer || mode == sessionbroker.ResourceFileTransferModeBidirectional {
|
|
return nil
|
|
}
|
|
return sessionbroker.ErrAccessDenied
|
|
}
|
|
|
|
func (m *Module) ensureFileDownloadAllowed(ctx context.Context, sessionID string) error {
|
|
mode, state, err := m.broker.GetSessionFileTransferPolicy(ctx, sessionID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if state != sessioncontracts.StateActive {
|
|
return sessionbroker.ErrSessionNotAttachable
|
|
}
|
|
if mode == sessionbroker.ResourceFileTransferModeServerToClient || mode == sessionbroker.ResourceFileTransferModeBidirectional {
|
|
return nil
|
|
}
|
|
return sessionbroker.ErrAccessDenied
|
|
}
|
|
|
|
func fileDownloadEnvelopeType(workerEventType string) string {
|
|
switch workerEventType {
|
|
case worker.SessionEventFileDownloadAvailable:
|
|
return "file_download.available"
|
|
case worker.SessionEventFileDownloadChunk:
|
|
return "file_download.chunk"
|
|
case worker.SessionEventFileDownloadCompleted:
|
|
return "file_download.completed"
|
|
case worker.SessionEventFileDownloadFailed:
|
|
return "file_download.failed"
|
|
case worker.SessionEventFileDownloadBlocked:
|
|
return "file_download.blocked"
|
|
default:
|
|
return "file_download.progress"
|
|
}
|
|
}
|
|
|
|
func fileDownloadEventCode(envelopeType string) string {
|
|
return strings.ReplaceAll(envelopeType, ".", "_")
|
|
}
|
|
|
|
func fileDownloadMessageKey(envelopeType string) string {
|
|
return "events." + envelopeType
|
|
}
|
|
|
|
func fileDownloadFallback(envelopeType string) string {
|
|
switch envelopeType {
|
|
case "file_download.available":
|
|
return "A file is available for download from the remote session."
|
|
case "file_download.chunk":
|
|
return "File download chunk received."
|
|
case "file_download.completed":
|
|
return "File download completed."
|
|
case "file_download.failed":
|
|
return "File download failed."
|
|
case "file_download.blocked":
|
|
return "File download is blocked by session state or resource policy."
|
|
default:
|
|
return "File download status updated."
|
|
}
|
|
}
|
|
|
|
func validateTransferID(value string) (string, error) {
|
|
if value == "" || !safeTransferIDPattern.MatchString(value) {
|
|
return "", errors.New("invalid transfer_id")
|
|
}
|
|
parsed, err := uuid.Parse(value)
|
|
if err != nil {
|
|
return "", errors.New("invalid transfer_id")
|
|
}
|
|
return parsed.String(), nil
|
|
}
|
|
|
|
func sanitizeUploadFileName(value string) (string, error) {
|
|
name := strings.TrimSpace(value)
|
|
if name == "" || name == "." || name == ".." {
|
|
return "", errors.New("invalid file name")
|
|
}
|
|
if strings.Contains(name, "..") || strings.ContainsAny(name, `/\:`) || filepath.IsAbs(name) {
|
|
return "", errors.New("invalid file name")
|
|
}
|
|
if len([]byte(name)) > 255 {
|
|
return "", errors.New("file name is too long")
|
|
}
|
|
return name, nil
|
|
}
|
|
|
|
func stringValue(payload map[string]any, key string) string {
|
|
if payload == nil {
|
|
return ""
|
|
}
|
|
value, _ := payload[key].(string)
|
|
return value
|
|
}
|
|
|
|
func numberValue(payload map[string]any, key string) (int64, bool) {
|
|
if payload == nil {
|
|
return 0, false
|
|
}
|
|
switch value := payload[key].(type) {
|
|
case int:
|
|
return int64(value), true
|
|
case int64:
|
|
return value, true
|
|
case float64:
|
|
return int64(value), value == float64(int64(value))
|
|
default:
|
|
return 0, false
|
|
}
|
|
}
|
|
|
|
func (m *Module) writeFrameEnvelope(conn *gatewayConnection, sessionID string, state sessionbroker.LiveSessionState) error {
|
|
if state.LastInputCorrelationID != "" {
|
|
m.logger.Info("session gateway frame envelope writing",
|
|
"session_id", sessionID,
|
|
"frame_sequence", state.RenderFrameSequence,
|
|
"correlation_id", state.LastInputCorrelationID,
|
|
"worker_frame_captured_at", state.WorkerFrameCapturedAt,
|
|
"trace_stage", "backend_frame_to_client")
|
|
}
|
|
return m.writeEnvelope(conn, sessioncontracts.TransportEnvelope{
|
|
Type: "session.frame",
|
|
SessionID: sessionID,
|
|
Payload: map[string]any{
|
|
"frame_sequence": state.RenderFrameSequence,
|
|
"frame_format": state.RenderFrameFormat,
|
|
"frame_data": state.RenderFrameData,
|
|
"frame_width": state.RenderWidth,
|
|
"frame_height": state.RenderHeight,
|
|
"frame_stride": state.RenderWidth * 4,
|
|
"frame_update_kind": "full",
|
|
"desktop_width": state.RenderWidth,
|
|
"desktop_height": state.RenderHeight,
|
|
"region_x": 0,
|
|
"region_y": 0,
|
|
"region_width": state.RenderWidth,
|
|
"region_height": state.RenderHeight,
|
|
"input_correlation_id": state.LastInputCorrelationID,
|
|
"worker_frame_captured_at": state.WorkerFrameCapturedAt,
|
|
"render": renderPayloadFromLiveState(&state),
|
|
},
|
|
Event: m.newEventMessage(
|
|
"session.frame",
|
|
"events.session.frame",
|
|
"The session desktop frame updated.",
|
|
map[string]any{
|
|
"frame_sequence": state.RenderFrameSequence,
|
|
},
|
|
),
|
|
})
|
|
}
|
|
|
|
func renderPayloadFromLiveState(state *sessionbroker.LiveSessionState) map[string]any {
|
|
if state == nil {
|
|
return map[string]any{}
|
|
}
|
|
return map[string]any{
|
|
"quality_profile": state.RenderQualityProfile,
|
|
"render_quality_profile": state.RenderQualityProfile,
|
|
"state": state.RenderState,
|
|
"render_state": state.RenderState,
|
|
"width": state.RenderWidth,
|
|
"height": state.RenderHeight,
|
|
"frame_sequence": state.RenderFrameSequence,
|
|
"frame_format": state.RenderFrameFormat,
|
|
"cursor_x": state.CursorX,
|
|
"cursor_y": state.CursorY,
|
|
"cursor_visible": state.CursorVisible,
|
|
"dirty_rectangles": state.DirtyRectangles,
|
|
"last_render_at": state.LastRenderAt,
|
|
}
|
|
}
|
|
|
|
func renderPayloadFromSessionSnapshot(session *sessionbroker.RemoteSession) map[string]any {
|
|
if session == nil {
|
|
return map[string]any{}
|
|
}
|
|
return map[string]any{
|
|
"quality_profile": session.RenderQualityProfile,
|
|
"render_quality_profile": session.RenderQualityProfile,
|
|
"state": session.State,
|
|
"render_state": session.State,
|
|
"width": 0,
|
|
"height": 0,
|
|
"cursor_x": 0,
|
|
"cursor_y": 0,
|
|
"cursor_visible": true,
|
|
"dirty_rectangles": 0,
|
|
"last_render_at": session.LastHeartbeatAt,
|
|
}
|
|
}
|
|
|
|
func (m *Module) writeEnvelope(conn *gatewayConnection, envelope sessioncontracts.TransportEnvelope) error {
|
|
return conn.writeJSON(m.cfg.WebSocket.WriteTimeout, envelope)
|
|
}
|
|
|
|
func (m *Module) newEventMessage(code, messageKey, fallback string, details map[string]any) *messagecontracts.Message {
|
|
event := httpx.NewMessage(code, messageKey, fallback, details, "")
|
|
return &event
|
|
}
|