Infer VPN packet traffic class

This commit is contained in:
2026-05-15 16:38:00 +03:00
parent 50db5e7a0d
commit af85f6e309
4 changed files with 96 additions and 3 deletions
@@ -7,7 +7,7 @@ import (
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/state"
)
const Version = "0.2.272-vpnprio"
const Version = "0.2.273-vpnclass"
func EnrollmentPayload(clusterID, joinToken string, identity state.Identity) client.EnrollRequest {
return client.EnrollRequest{
@@ -833,7 +833,7 @@ func (s Server) handleVPNPacketHTTP(w http.ResponseWriter, r *http.Request, clus
http.Error(w, ErrRouteNotFound.Error(), vpnIngressStatusCode(ErrRouteNotFound))
return true
}
trafficClass := r.Header.Get("X-RAP-Traffic-Class")
trafficClass := inferVPNPacketTrafficClass(r.Header.Get("X-RAP-Traffic-Class"), packets)
var sendErr error
if classIngress, ok := s.VPNPacketIngress.(VPNPacketIngressTrafficClass); ok {
sendErr = classIngress.SendClientPacketBatchWithTrafficClass(r.Context(), clusterID, vpnConnectionID, trafficClass, packets)
@@ -951,7 +951,7 @@ func (s Server) readVPNPacketWebSocket(ctx context.Context, conn *websocket.Conn
}
continue
}
sendErr := s.sendVPNPacketWebSocketBatch(ctx, clusterID, vpnConnectionID, trafficClass, packets, !backendFallbackAllowed)
sendErr := s.sendVPNPacketWebSocketBatch(ctx, clusterID, vpnConnectionID, inferVPNPacketTrafficClass(trafficClass, packets), packets, !backendFallbackAllowed)
if sendErr != nil {
if !backendFallbackAllowed {
s.logFabricServiceChannelViolation(nil, clusterID, channelID, vpnConnectionID, backendRelayPolicy, "fabric_route_send_failed_backend_fallback_blocked", sendErr.Error())
@@ -4523,6 +4523,55 @@ func TestFabricServiceChannelVPNPacketWebSocketPreservesTrafficClass(t *testing.
}
}
func TestFabricServiceChannelVPNPacketWebSocketInfersInteractiveTrafficClass(t *testing.T) {
ingress := &recordingVPNPacketIngress{
receive: [][]byte{[]byte("reply")},
}
server := httptest.NewServer(Server{
Local: PeerIdentity{ClusterID: "cluster-1", NodeID: "entry-1"},
VPNPacketIngress: ingress,
}.Handler())
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/api/v1/clusters/cluster-1/fabric/service-channels/channel-1/vpn-connections/vpn-1/packets/ws"
headers := http.Header{}
headers.Set("Authorization", "Bearer rap_fsc_testtoken")
headers.Set("X-RAP-Service-Class", FabricServiceClassVPNPackets)
headers.Set("X-RAP-Channel-Class", ProductionChannelVPNPacket)
headers.Set("X-RAP-Fabric-Channel-ID", "channel-1")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
defer conn.Close()
packet := testVPNIPv4TCPPacket([4]byte{10, 77, 0, 2}, [4]byte{192, 168, 200, 95}, 51000, 3389, 0x02)
if err := conn.WriteMessage(websocket.BinaryMessage, encodeVPNIngressPacketBatch([][]byte{packet})); err != nil {
t.Fatalf("write packet batch: %v", err)
}
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("set read deadline: %v", err)
}
if _, _, err := conn.ReadMessage(); err != nil {
t.Fatalf("read packet batch: %v", err)
}
deadline := time.Now().Add(2 * time.Second)
for {
ingress.mu.Lock()
trafficClass := ingress.trafficClass
sent := append([][]byte(nil), ingress.sent...)
ingress.mu.Unlock()
if trafficClass == "interactive" && len(sent) == 1 {
break
}
if time.Now().After(deadline) {
t.Fatalf("traffic class = %q sent packets = %#v, want inferred interactive packet", trafficClass, sent)
}
time.Sleep(10 * time.Millisecond)
}
}
func TestVPNPacketIngressWebSocketFallsBackToBackendRelay(t *testing.T) {
var backendBody []byte
postSeen := make(chan struct{}, 1)
@@ -4707,6 +4756,24 @@ func (i *recordingVPNPacketIngress) ReceiveClientPacketBatch(_ context.Context,
return packets, nil
}
func testVPNIPv4TCPPacket(src [4]byte, dst [4]byte, srcPort uint16, dstPort uint16, flags byte) []byte {
packet := make([]byte, 40)
packet[0] = 0x45
packet[2] = 0
packet[3] = 40
packet[8] = 64
packet[9] = 6
copy(packet[12:16], src[:])
copy(packet[16:20], dst[:])
packet[20] = byte(srcPort >> 8)
packet[21] = byte(srcPort)
packet[22] = byte(dstPort >> 8)
packet[23] = byte(dstPort)
packet[32] = 0x50
packet[33] = flags
return packet
}
func hasProductionForwardEvent(events []ProductionForwardLogEntry, event string) bool {
for _, item := range events {
if item.Event == event {
@@ -5,6 +5,7 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
"strings"
"time"
)
@@ -119,3 +120,28 @@ func cleanProductionVPNPacketBatch(packets [][]byte) [][]byte {
}
return cleaned
}
func inferVPNPacketTrafficClass(explicit string, packets [][]byte) string {
explicit = strings.TrimSpace(strings.ToLower(explicit))
if explicit != "" && explicit != "bulk" {
return explicit
}
for _, packet := range packets {
if isVPNPacketTCPControl(packet) {
return "interactive"
}
}
return explicit
}
func isVPNPacketTCPControl(packet []byte) bool {
if len(packet) < 20 || packet[0]>>4 != 4 {
return false
}
ihl := int(packet[0]&0x0f) * 4
if ihl < 20 || len(packet) < ihl+20 || packet[9] != 6 {
return false
}
flags := packet[ihl+13]
return flags&0x17 != 0
}