From af85f6e309d8b53b9f075af7cf5fc57ec227f93b Mon Sep 17 00:00:00 2001 From: Mikhail Date: Fri, 15 May 2026 16:38:00 +0300 Subject: [PATCH] Infer VPN packet traffic class --- .../rap-node-agent/internal/agent/payload.go | 2 +- agents/rap-node-agent/internal/mesh/server.go | 4 +- .../internal/mesh/server_test.go | 67 +++++++++++++++++++ .../internal/mesh/vpn_packet.go | 26 +++++++ 4 files changed, 96 insertions(+), 3 deletions(-) diff --git a/agents/rap-node-agent/internal/agent/payload.go b/agents/rap-node-agent/internal/agent/payload.go index 2d6864f..628632e 100644 --- a/agents/rap-node-agent/internal/agent/payload.go +++ b/agents/rap-node-agent/internal/agent/payload.go @@ -7,7 +7,7 @@ import ( "github.com/example/remote-access-platform/agents/rap-node-agent/internal/state" ) -const Version = "0.2.272-vpnprio" +const Version = "0.2.273-vpnclass" func EnrollmentPayload(clusterID, joinToken string, identity state.Identity) client.EnrollRequest { return client.EnrollRequest{ diff --git a/agents/rap-node-agent/internal/mesh/server.go b/agents/rap-node-agent/internal/mesh/server.go index 8bd1a47..fe20730 100644 --- a/agents/rap-node-agent/internal/mesh/server.go +++ b/agents/rap-node-agent/internal/mesh/server.go @@ -833,7 +833,7 @@ func (s Server) handleVPNPacketHTTP(w http.ResponseWriter, r *http.Request, clus http.Error(w, ErrRouteNotFound.Error(), vpnIngressStatusCode(ErrRouteNotFound)) return true } - trafficClass := r.Header.Get("X-RAP-Traffic-Class") + trafficClass := inferVPNPacketTrafficClass(r.Header.Get("X-RAP-Traffic-Class"), packets) var sendErr error if classIngress, ok := s.VPNPacketIngress.(VPNPacketIngressTrafficClass); ok { sendErr = classIngress.SendClientPacketBatchWithTrafficClass(r.Context(), clusterID, vpnConnectionID, trafficClass, packets) @@ -951,7 +951,7 @@ func (s Server) readVPNPacketWebSocket(ctx context.Context, conn *websocket.Conn } continue } - sendErr := s.sendVPNPacketWebSocketBatch(ctx, clusterID, vpnConnectionID, trafficClass, packets, !backendFallbackAllowed) + sendErr := s.sendVPNPacketWebSocketBatch(ctx, clusterID, vpnConnectionID, inferVPNPacketTrafficClass(trafficClass, packets), packets, !backendFallbackAllowed) if sendErr != nil { if !backendFallbackAllowed { s.logFabricServiceChannelViolation(nil, clusterID, channelID, vpnConnectionID, backendRelayPolicy, "fabric_route_send_failed_backend_fallback_blocked", sendErr.Error()) diff --git a/agents/rap-node-agent/internal/mesh/server_test.go b/agents/rap-node-agent/internal/mesh/server_test.go index 26790c3..eb44e60 100644 --- a/agents/rap-node-agent/internal/mesh/server_test.go +++ b/agents/rap-node-agent/internal/mesh/server_test.go @@ -4523,6 +4523,55 @@ func TestFabricServiceChannelVPNPacketWebSocketPreservesTrafficClass(t *testing. } } +func TestFabricServiceChannelVPNPacketWebSocketInfersInteractiveTrafficClass(t *testing.T) { + ingress := &recordingVPNPacketIngress{ + receive: [][]byte{[]byte("reply")}, + } + server := httptest.NewServer(Server{ + Local: PeerIdentity{ClusterID: "cluster-1", NodeID: "entry-1"}, + VPNPacketIngress: ingress, + }.Handler()) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/api/v1/clusters/cluster-1/fabric/service-channels/channel-1/vpn-connections/vpn-1/packets/ws" + headers := http.Header{} + headers.Set("Authorization", "Bearer rap_fsc_testtoken") + headers.Set("X-RAP-Service-Class", FabricServiceClassVPNPackets) + headers.Set("X-RAP-Channel-Class", ProductionChannelVPNPacket) + headers.Set("X-RAP-Fabric-Channel-ID", "channel-1") + conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + defer conn.Close() + + packet := testVPNIPv4TCPPacket([4]byte{10, 77, 0, 2}, [4]byte{192, 168, 200, 95}, 51000, 3389, 0x02) + if err := conn.WriteMessage(websocket.BinaryMessage, encodeVPNIngressPacketBatch([][]byte{packet})); err != nil { + t.Fatalf("write packet batch: %v", err) + } + if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatalf("set read deadline: %v", err) + } + if _, _, err := conn.ReadMessage(); err != nil { + t.Fatalf("read packet batch: %v", err) + } + + deadline := time.Now().Add(2 * time.Second) + for { + ingress.mu.Lock() + trafficClass := ingress.trafficClass + sent := append([][]byte(nil), ingress.sent...) + ingress.mu.Unlock() + if trafficClass == "interactive" && len(sent) == 1 { + break + } + if time.Now().After(deadline) { + t.Fatalf("traffic class = %q sent packets = %#v, want inferred interactive packet", trafficClass, sent) + } + time.Sleep(10 * time.Millisecond) + } +} + func TestVPNPacketIngressWebSocketFallsBackToBackendRelay(t *testing.T) { var backendBody []byte postSeen := make(chan struct{}, 1) @@ -4707,6 +4756,24 @@ func (i *recordingVPNPacketIngress) ReceiveClientPacketBatch(_ context.Context, return packets, nil } +func testVPNIPv4TCPPacket(src [4]byte, dst [4]byte, srcPort uint16, dstPort uint16, flags byte) []byte { + packet := make([]byte, 40) + packet[0] = 0x45 + packet[2] = 0 + packet[3] = 40 + packet[8] = 64 + packet[9] = 6 + copy(packet[12:16], src[:]) + copy(packet[16:20], dst[:]) + packet[20] = byte(srcPort >> 8) + packet[21] = byte(srcPort) + packet[22] = byte(dstPort >> 8) + packet[23] = byte(dstPort) + packet[32] = 0x50 + packet[33] = flags + return packet +} + func hasProductionForwardEvent(events []ProductionForwardLogEntry, event string) bool { for _, item := range events { if item.Event == event { diff --git a/agents/rap-node-agent/internal/mesh/vpn_packet.go b/agents/rap-node-agent/internal/mesh/vpn_packet.go index d19bf78..1e33162 100644 --- a/agents/rap-node-agent/internal/mesh/vpn_packet.go +++ b/agents/rap-node-agent/internal/mesh/vpn_packet.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "encoding/json" "fmt" + "strings" "time" ) @@ -119,3 +120,28 @@ func cleanProductionVPNPacketBatch(packets [][]byte) [][]byte { } return cleaned } + +func inferVPNPacketTrafficClass(explicit string, packets [][]byte) string { + explicit = strings.TrimSpace(strings.ToLower(explicit)) + if explicit != "" && explicit != "bulk" { + return explicit + } + for _, packet := range packets { + if isVPNPacketTCPControl(packet) { + return "interactive" + } + } + return explicit +} + +func isVPNPacketTCPControl(packet []byte) bool { + if len(packet) < 20 || packet[0]>>4 != 4 { + return false + } + ihl := int(packet[0]&0x0f) * 4 + if ihl < 20 || len(packet) < ihl+20 || packet[9] != 6 { + return false + } + flags := packet[ihl+13] + return flags&0x17 != 0 +}