Split VPN fabric batches by stream

This commit is contained in:
2026-05-16 12:33:27 +03:00
parent bbd9f8c257
commit bd70ca6342
3 changed files with 83 additions and 15 deletions
@@ -67,22 +67,24 @@ func (t *FabricSessionPacketTransport) SendGatewayPacketBatch(ctx context.Contex
if direction == "" {
direction = FabricDirectionGatewayToClient
}
streamID, trafficClass := t.selectStreamForPackets(packets)
frame, err := NewFabricVPNPacketDataFrame(FabricVPNPacketFrameInput{
StreamID: streamID,
Sequence: t.nextSequence(streamID),
VPNConnectionID: t.VPNConnectionID,
Direction: direction,
TrafficClass: trafficClass,
Packets: packets,
})
if err != nil {
return err
groups := t.groupPacketsByStream(packets)
for _, group := range groups {
frame, err := NewFabricVPNPacketDataFrame(FabricVPNPacketFrameInput{
StreamID: group.StreamID,
Sequence: t.nextSequence(group.StreamID),
VPNConnectionID: t.VPNConnectionID,
Direction: direction,
TrafficClass: group.TrafficClass,
Packets: group.Packets,
})
if err != nil {
return err
}
if err := t.Sender.Send(ctx, frame); err != nil {
return err
}
t.recordSend(group.StreamID, group.TrafficClass, len(group.Packets))
}
if err := t.Sender.Send(ctx, frame); err != nil {
return err
}
t.recordSend(streamID, trafficClass, len(packets))
return nil
}
@@ -227,6 +229,32 @@ func (t *FabricSessionPacketTransport) selectStreamForPackets(packets [][]byte)
return t.StreamID, trafficClass
}
type fabricSessionPacketGroup struct {
StreamID uint64
TrafficClass string
Packets [][]byte
}
func (t *FabricSessionPacketTransport) groupPacketsByStream(packets [][]byte) []fabricSessionPacketGroup {
groups := []fabricSessionPacketGroup{}
indexByKey := map[string]int{}
for _, packet := range packets {
streamID, trafficClass := t.selectStreamForPackets([][]byte{packet})
key := fmt.Sprintf("%d\x00%s", streamID, trafficClass)
index, ok := indexByKey[key]
if !ok {
index = len(groups)
indexByKey[key] = index
groups = append(groups, fabricSessionPacketGroup{
StreamID: streamID,
TrafficClass: trafficClass,
})
}
groups[index].Packets = append(groups[index].Packets, packet)
}
return groups
}
func (t *FabricSessionPacketTransport) allStreamIDs() []uint64 {
if t == nil {
return nil
@@ -284,6 +284,43 @@ func TestFabricSessionPacketTransportShardsStreamsByTrafficClass(t *testing.T) {
}
}
func TestFabricSessionPacketTransportSplitsMixedBatchByStream(t *testing.T) {
sender := &captureFabricSessionSender{}
transport := &FabricSessionPacketTransport{
Sender: sender,
VPNConnectionID: "vpn-1",
SendDirection: FabricDirectionClientToGateway,
StreamIDsByTrafficClass: map[string][]uint64{
FabricTrafficClassInteractive: []uint64{801},
FabricTrafficClassBulk: []uint64{901, 902},
},
}
bulkA := testIPv4TCPPacket([4]byte{10, 77, 0, 2}, [4]byte{192, 168, 200, 95}, 51000, 443)
bulkB := packetWithDifferentShard(bulkA, 2)
control := testIPv4TCPPacket([4]byte{10, 77, 0, 2}, [4]byte{192, 168, 200, 95}, 51001, 3389)
control[33] = 0x02
if err := transport.SendGatewayPacketBatch(context.Background(), [][]byte{bulkA, bulkB, control}); err != nil {
t.Fatalf("send mixed batch: %v", err)
}
if len(sender.frames) != 3 {
t.Fatalf("sent frames = %d, want 3: %+v", len(sender.frames), sender.frames)
}
streams := map[uint64]fabricproto.TrafficClass{}
for _, frame := range sender.frames {
streams[frame.StreamID] = frame.TrafficClass
}
if streams[801] != fabricproto.TrafficClassInteractive ||
streams[901] != fabricproto.TrafficClassBulk ||
streams[902] != fabricproto.TrafficClassBulk {
t.Fatalf("unexpected stream/class split: %+v", sender.frames)
}
snapshot := transport.Snapshot()
if snapshot["send_stream_count"] != 3 || snapshot["send_class_count"] != 2 {
t.Fatalf("unexpected mixed-batch shard summary: %+v", snapshot)
}
}
func TestFabricSessionPacketTransportClosesAllStreamShards(t *testing.T) {
sender := &captureFabricSessionSender{}
transport := &FabricSessionPacketTransport{