Require auth for fabric session websocket

This commit is contained in:
2026-05-16 00:19:38 +03:00
parent 8a972ea68f
commit be31798d7c
3 changed files with 238 additions and 4 deletions
@@ -198,6 +198,8 @@ type FabricSessionEventLogEntry struct {
Event string `json:"event"` Event string `json:"event"`
ClusterID string `json:"cluster_id,omitempty"` ClusterID string `json:"cluster_id,omitempty"`
NodeID string `json:"node_id,omitempty"` NodeID string `json:"node_id,omitempty"`
AcceptedBy string `json:"accepted_by,omitempty"`
SessionID string `json:"session_id,omitempty"`
SessionEvent fabricproto.SessionEventType `json:"session_event,omitempty"` SessionEvent fabricproto.SessionEventType `json:"session_event,omitempty"`
StreamID uint64 `json:"stream_id,omitempty"` StreamID uint64 `json:"stream_id,omitempty"`
Sequence uint64 `json:"sequence,omitempty"` Sequence uint64 `json:"sequence,omitempty"`
@@ -207,11 +209,31 @@ type FabricSessionEventLogEntry struct {
ObservedAt time.Time `json:"observed_at"` ObservedAt time.Time `json:"observed_at"`
} }
type fabricSessionAuthorityPayload struct {
SchemaVersion string `json:"schema_version"`
ClusterID string `json:"cluster_id"`
SessionID string `json:"session_id"`
SourceNodeID string `json:"source_node_id,omitempty"`
SelectedEntryNodeID string `json:"selected_entry_node_id,omitempty"`
TokenHash string `json:"token_hash"`
IssuedAt time.Time `json:"issued_at"`
ExpiresAt time.Time `json:"expires_at"`
}
type fabricSessionAuthDecision struct {
AcceptedBy string
SessionID string
}
func (s Server) handleFabricSessionWebSocket(w http.ResponseWriter, r *http.Request) { func (s Server) handleFabricSessionWebSocket(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
w.WriteHeader(http.StatusMethodNotAllowed) w.WriteHeader(http.StatusMethodNotAllowed)
return return
} }
decision, ok := s.validateFabricSessionRequest(w, r)
if !ok {
return
}
upgrader := websocket.Upgrader{ upgrader := websocket.Upgrader{
CheckOrigin: func(_ *http.Request) bool { return true }, CheckOrigin: func(_ *http.Request) bool { return true },
} }
@@ -225,6 +247,8 @@ func (s Server) handleFabricSessionWebSocket(w http.ResponseWriter, r *http.Requ
Event: "fabric_session_websocket_opened", Event: "fabric_session_websocket_opened",
ClusterID: s.Local.ClusterID, ClusterID: s.Local.ClusterID,
NodeID: s.Local.NodeID, NodeID: s.Local.NodeID,
AcceptedBy: decision.AcceptedBy,
SessionID: decision.SessionID,
RemoteAddr: r.RemoteAddr, RemoteAddr: r.RemoteAddr,
ObservedAt: time.Now().UTC(), ObservedAt: time.Now().UTC(),
}) })
@@ -235,6 +259,8 @@ func (s Server) handleFabricSessionWebSocket(w http.ResponseWriter, r *http.Requ
Event: "fabric_session_event", Event: "fabric_session_event",
ClusterID: s.Local.ClusterID, ClusterID: s.Local.ClusterID,
NodeID: s.Local.NodeID, NodeID: s.Local.NodeID,
AcceptedBy: decision.AcceptedBy,
SessionID: decision.SessionID,
SessionEvent: event.Type, SessionEvent: event.Type,
StreamID: event.StreamID, StreamID: event.StreamID,
Sequence: event.Sequence, Sequence: event.Sequence,
@@ -251,6 +277,8 @@ func (s Server) handleFabricSessionWebSocket(w http.ResponseWriter, r *http.Requ
Event: "fabric_session_websocket_closed", Event: "fabric_session_websocket_closed",
ClusterID: s.Local.ClusterID, ClusterID: s.Local.ClusterID,
NodeID: s.Local.NodeID, NodeID: s.Local.NodeID,
AcceptedBy: decision.AcceptedBy,
SessionID: decision.SessionID,
RemoteAddr: r.RemoteAddr, RemoteAddr: r.RemoteAddr,
Reason: err.Error(), Reason: err.Error(),
ObservedAt: time.Now().UTC(), ObservedAt: time.Now().UTC(),
@@ -261,11 +289,83 @@ func (s Server) handleFabricSessionWebSocket(w http.ResponseWriter, r *http.Requ
Event: "fabric_session_websocket_closed", Event: "fabric_session_websocket_closed",
ClusterID: s.Local.ClusterID, ClusterID: s.Local.ClusterID,
NodeID: s.Local.NodeID, NodeID: s.Local.NodeID,
AcceptedBy: decision.AcceptedBy,
SessionID: decision.SessionID,
RemoteAddr: r.RemoteAddr, RemoteAddr: r.RemoteAddr,
ObservedAt: time.Now().UTC(), ObservedAt: time.Now().UTC(),
}) })
} }
func (s Server) validateFabricSessionRequest(w http.ResponseWriter, r *http.Request) (fabricSessionAuthDecision, bool) {
var decision fabricSessionAuthDecision
token := fabricSessionBearerToken(r)
if !strings.HasPrefix(token, "rap_fsn_") {
http.Error(w, "fabric session token is required", http.StatusUnauthorized)
return decision, false
}
payload, err := s.verifyFabricSessionAuthority(r, token)
if err != nil {
http.Error(w, err.Error(), http.StatusForbidden)
return decision, false
}
decision.AcceptedBy = "legacy_unsigned"
if payload != nil {
decision.AcceptedBy = "signed"
decision.SessionID = strings.TrimSpace(payload.SessionID)
}
return decision, true
}
func (s Server) verifyFabricSessionAuthority(r *http.Request, token string) (*fabricSessionAuthorityPayload, error) {
publicKey := strings.TrimSpace(s.ClusterAuthorityPublicKey)
payloadHeader := strings.TrimSpace(r.Header.Get("X-RAP-Fabric-Session-Authority-Payload"))
signatureHeader := strings.TrimSpace(r.Header.Get("X-RAP-Fabric-Session-Authority-Signature"))
if payloadHeader == "" && signatureHeader == "" {
if publicKey != "" {
return nil, fmt.Errorf("%w: signed fabric session authority is required", ErrUnauthorizedChannel)
}
return nil, nil
}
if publicKey == "" {
return nil, ErrUnauthorizedChannel
}
if payloadHeader == "" || signatureHeader == "" {
return nil, fmt.Errorf("%w: fabric session authority payload and signature are required together", ErrUnauthorizedChannel)
}
payloadRaw, err := decodeHeaderJSON(payloadHeader)
if err != nil {
return nil, fmt.Errorf("%w: invalid fabric session authority payload", ErrUnauthorizedChannel)
}
signatureRaw, err := decodeHeaderJSON(signatureHeader)
if err != nil {
return nil, fmt.Errorf("%w: invalid fabric session authority signature", ErrUnauthorizedChannel)
}
var signature authority.Signature
if err := json.Unmarshal(signatureRaw, &signature); err != nil {
return nil, fmt.Errorf("%w: invalid fabric session authority signature", ErrUnauthorizedChannel)
}
if err := authority.VerifyRaw(publicKey, payloadRaw, signature); err != nil {
return nil, fmt.Errorf("%w: fabric session authority signature rejected", ErrUnauthorizedChannel)
}
var payload fabricSessionAuthorityPayload
if err := json.Unmarshal(payloadRaw, &payload); err != nil {
return nil, fmt.Errorf("%w: invalid fabric session authority payload", ErrUnauthorizedChannel)
}
if payload.SchemaVersion != "rap.fabric_session_authority.v1" ||
payload.ClusterID != s.Local.ClusterID ||
payload.TokenHash != fabricSessionTokenHash(token) ||
strings.TrimSpace(payload.SessionID) == "" {
return nil, fmt.Errorf("%w: fabric session authority payload mismatch", ErrUnauthorizedChannel)
}
if payload.SelectedEntryNodeID != "" && s.Local.NodeID != "" && payload.SelectedEntryNodeID != s.Local.NodeID {
return nil, fmt.Errorf("%w: fabric session entry node mismatch", ErrUnauthorizedChannel)
}
if !payload.ExpiresAt.IsZero() && !payload.ExpiresAt.After(time.Now().UTC()) {
return nil, fmt.Errorf("%w: fabric session lease expired", ErrUnauthorizedChannel)
}
return &payload, nil
}
func (s Server) logFabricSession(entry FabricSessionEventLogEntry) { func (s Server) logFabricSession(entry FabricSessionEventLogEntry) {
if s.FabricSessionLogger != nil { if s.FabricSessionLogger != nil {
s.FabricSessionLogger(entry) s.FabricSessionLogger(entry)
@@ -1693,6 +1793,25 @@ func fabricServiceChannelBearerToken(r *http.Request) string {
return strings.TrimSpace(r.URL.Query().Get("service_channel_token")) return strings.TrimSpace(r.URL.Query().Get("service_channel_token"))
} }
func fabricSessionTokenHash(token string) string {
sum := sha256.Sum256([]byte(strings.TrimSpace(token)))
return hex.EncodeToString(sum[:])
}
func fabricSessionBearerToken(r *http.Request) string {
if r == nil {
return ""
}
if token := strings.TrimSpace(r.Header.Get("X-RAP-Fabric-Session-Token")); token != "" {
return token
}
auth := strings.TrimSpace(r.Header.Get("Authorization"))
if len(auth) > len("Bearer ") && strings.EqualFold(auth[:len("Bearer ")], "Bearer ") {
return strings.TrimSpace(auth[len("Bearer "):])
}
return strings.TrimSpace(r.URL.Query().Get("fabric_session_token"))
}
func isAllowedFabricServiceVPNChannel(channel string) bool { func isAllowedFabricServiceVPNChannel(channel string) bool {
return isAllowedFabricServiceChannelForClass(FabricServiceClassVPNPackets, channel) return isAllowedFabricServiceChannelForClass(FabricServiceClassVPNPackets, channel)
} }
@@ -101,7 +101,7 @@ func TestFabricSessionWebSocketPingPongAndEvents(t *testing.T) {
defer server.Close() defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/mesh/v1/fabric/session/ws" wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/mesh/v1/fabric/session/ws"
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) conn, _, err := websocket.DefaultDialer.Dial(wsURL, fabricSessionTestHeaders("rap_fsn_smoke"))
if err != nil { if err != nil {
t.Fatalf("dial fabric session websocket: %v", err) t.Fatalf("dial fabric session websocket: %v", err)
} }
@@ -112,7 +112,7 @@ func TestFabricSessionWebSocketPingPongAndEvents(t *testing.T) {
if pong.Type != fabricproto.FramePong || pong.Sequence != 17 || string(pong.Payload) != "probe" { if pong.Type != fabricproto.FramePong || pong.Sequence != 17 || string(pong.Payload) != "probe" {
t.Fatalf("pong = %+v", pong) t.Fatalf("pong = %+v", pong)
} }
if len(events) < 2 || events[0].Event != "fabric_session_websocket_opened" || events[1].SessionEvent != fabricproto.SessionEventPing { if len(events) < 2 || events[0].Event != "fabric_session_websocket_opened" || events[0].AcceptedBy != "legacy_unsigned" || events[1].SessionEvent != fabricproto.SessionEventPing {
t.Fatalf("events = %+v", events) t.Fatalf("events = %+v", events)
} }
} }
@@ -125,7 +125,7 @@ func TestFabricSessionWebSocketOpenStreamDataAck(t *testing.T) {
defer server.Close() defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/mesh/v1/fabric/session/ws" wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/mesh/v1/fabric/session/ws"
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) conn, _, err := websocket.DefaultDialer.Dial(wsURL, fabricSessionTestHeaders("rap_fsn_smoke"))
if err != nil { if err != nil {
t.Fatalf("dial fabric session websocket: %v", err) t.Fatalf("dial fabric session websocket: %v", err)
} }
@@ -149,6 +149,89 @@ func TestFabricSessionWebSocketOpenStreamDataAck(t *testing.T) {
} }
} }
func TestFabricSessionWebSocketRequiresToken(t *testing.T) {
server := httptest.NewServer(Server{
Local: PeerIdentity{ClusterID: "cluster-1", NodeID: "node-a"},
FabricSessionEnabled: true,
}.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/mesh/v1/fabric/session/ws"
_, resp, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err == nil {
t.Fatal("dial fabric session without token unexpectedly succeeded")
}
if resp == nil || resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("status = %v err=%v, want 401", resp, err)
}
}
func TestFabricSessionWebSocketRequiresSignedAuthorityWhenConfigured(t *testing.T) {
publicKey, _, err := ed25519.GenerateKey(nil)
if err != nil {
t.Fatalf("generate key: %v", err)
}
server := httptest.NewServer(Server{
Local: PeerIdentity{ClusterID: "cluster-1", NodeID: "node-a"},
FabricSessionEnabled: true,
ClusterAuthorityPublicKey: base64.StdEncoding.EncodeToString(publicKey),
}.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/mesh/v1/fabric/session/ws"
_, resp, err := websocket.DefaultDialer.Dial(wsURL, fabricSessionTestHeaders("rap_fsn_unsigned"))
if err == nil {
t.Fatal("dial unsigned fabric session unexpectedly succeeded")
}
if resp == nil || resp.StatusCode != http.StatusForbidden {
t.Fatalf("status = %v err=%v, want 403", resp, err)
}
}
func TestFabricSessionWebSocketAcceptsSignedAuthority(t *testing.T) {
publicKey, privateKey, err := ed25519.GenerateKey(nil)
if err != nil {
t.Fatalf("generate key: %v", err)
}
token := "rap_fsn_signedtest"
var events []FabricSessionEventLogEntry
server := httptest.NewServer(Server{
Local: PeerIdentity{ClusterID: "cluster-1", NodeID: "node-a"},
FabricSessionEnabled: true,
ClusterAuthorityPublicKey: base64.StdEncoding.EncodeToString(publicKey),
FabricSessionLogger: func(entry FabricSessionEventLogEntry) {
events = append(events, entry)
},
}.Handler())
defer server.Close()
headers := signedFabricSessionHeaders(t, token, publicKey, privateKey, fabricSessionAuthorityPayload{
SchemaVersion: "rap.fabric_session_authority.v1",
ClusterID: "cluster-1",
SessionID: "session-1",
SourceNodeID: "phone-1",
SelectedEntryNodeID: "node-a",
TokenHash: fabricSessionTokenHash(token),
IssuedAt: time.Now().UTC().Add(-time.Minute),
ExpiresAt: time.Now().UTC().Add(time.Minute),
})
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/mesh/v1/fabric/session/ws"
conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers)
if err != nil {
t.Fatalf("dial signed fabric session websocket: %v", err)
}
defer conn.Close()
writeMeshFabricFrame(t, conn, fabricproto.Frame{Type: fabricproto.FramePing, Sequence: 23})
pong := readMeshFabricFrame(t, conn)
if pong.Type != fabricproto.FramePong || pong.Sequence != 23 {
t.Fatalf("pong = %+v", pong)
}
if len(events) < 2 || events[0].AcceptedBy != "signed" || events[0].SessionID != "session-1" {
t.Fatalf("events = %+v", events)
}
}
func TestMeshForwardingGateEnabledStillHasNoProductionRuntime(t *testing.T) { func TestMeshForwardingGateEnabledStillHasNoProductionRuntime(t *testing.T) {
local := PeerIdentity{ClusterID: "cluster-1", NodeID: "node-b"} local := PeerIdentity{ClusterID: "cluster-1", NodeID: "node-b"}
server := httptest.NewServer(Server{ server := httptest.NewServer(Server{
@@ -182,6 +265,38 @@ func writeMeshFabricFrame(t *testing.T, conn *websocket.Conn, frame fabricproto.
} }
} }
func fabricSessionTestHeaders(token string) http.Header {
headers := http.Header{}
headers.Set("X-RAP-Fabric-Session-Token", token)
return headers
}
func signedFabricSessionHeaders(t *testing.T, token string, publicKey ed25519.PublicKey, privateKey ed25519.PrivateKey, payload fabricSessionAuthorityPayload) http.Header {
t.Helper()
headers := fabricSessionTestHeaders(token)
rawPayload, err := json.Marshal(payload)
if err != nil {
t.Fatalf("marshal fabric session authority payload: %v", err)
}
canonical, err := authority.CanonicalJSON(rawPayload)
if err != nil {
t.Fatalf("canonical fabric session authority payload: %v", err)
}
signature := authority.Signature{
SchemaVersion: authority.SignatureSchemaVersion,
Algorithm: authority.AlgorithmEd25519,
KeyFingerprint: authority.Fingerprint(publicKey),
Signature: base64.StdEncoding.EncodeToString(ed25519.Sign(privateKey, canonical)),
}
rawSignature, err := json.Marshal(signature)
if err != nil {
t.Fatalf("marshal fabric session authority signature: %v", err)
}
headers.Set("X-RAP-Fabric-Session-Authority-Payload", base64.StdEncoding.EncodeToString(rawPayload))
headers.Set("X-RAP-Fabric-Session-Authority-Signature", base64.StdEncoding.EncodeToString(rawSignature))
return headers
}
func readMeshFabricFrame(t *testing.T, conn *websocket.Conn) fabricproto.Frame { func readMeshFabricFrame(t *testing.T, conn *websocket.Conn) fabricproto.Frame {
t.Helper() t.Helper()
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
@@ -258,7 +258,7 @@ Deliverables:
Status: started with a transport-neutral `io.Reader`/`io.Writer` frame loop, Status: started with a transport-neutral `io.Reader`/`io.Writer` frame loop,
WebSocket frame adapter in `agents/rap-node-agent/internal/fabricproto`, and a WebSocket frame adapter in `agents/rap-node-agent/internal/fabricproto`, and a
gated mesh smoke endpoint at `/mesh/v1/fabric/session/ws`. gated/authenticated mesh smoke endpoint at `/mesh/v1/fabric/session/ws`.
Deliverables: Deliverables: