Shard VPN fabric session streams

This commit is contained in:
2026-05-16 12:13:59 +03:00
parent 9a170c83c2
commit e50070c005
12 changed files with 242 additions and 36 deletions
@@ -3,6 +3,7 @@ package vpnruntime
import (
"context"
"errors"
"sync"
"sync/atomic"
"time"
@@ -30,7 +31,12 @@ type FabricSessionPacketTransport struct {
ReceiveDirection string
TrafficClass string
sequence uint64
StreamIDsByTrafficClass map[string][]uint64
StreamIDs []uint64
sequence uint64
sequenceMu sync.Mutex
sequenceByStream map[uint64]uint64
}
func (t *FabricSessionPacketTransport) SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error {
@@ -41,19 +47,20 @@ func (t *FabricSessionPacketTransport) SendGatewayPacketBatch(ctx context.Contex
if t == nil || t.Sender == nil {
return mesh.ErrForwardRuntimeUnavailable
}
if t.StreamID == 0 || t.VPNConnectionID == "" {
if !t.hasSendStream() || t.VPNConnectionID == "" {
return errors.New("fabric session packet transport identity is incomplete")
}
direction := t.SendDirection
if direction == "" {
direction = FabricDirectionGatewayToClient
}
streamID, trafficClass := t.selectStreamForPackets(packets)
frame, err := NewFabricVPNPacketDataFrame(FabricVPNPacketFrameInput{
StreamID: t.StreamID,
Sequence: atomic.AddUint64(&t.sequence, 1),
StreamID: streamID,
Sequence: t.nextSequence(streamID),
VPNConnectionID: t.VPNConnectionID,
Direction: direction,
TrafficClass: t.TrafficClass,
TrafficClass: trafficClass,
Packets: packets,
})
if err != nil {
@@ -101,7 +108,7 @@ func (t *FabricSessionPacketTransport) ReceiveGatewayPacketBatch(ctx context.Con
if !ok {
return t.Inbox.Receive(ctx, t.VPNConnectionID, direction, 5*time.Millisecond)
}
if frame.Type != fabricproto.FrameData || (t.StreamID != 0 && frame.StreamID != t.StreamID) {
if frame.Type != fabricproto.FrameData || !t.acceptsStream(frame.StreamID) {
continue
}
payload, err := DecodeFabricVPNPacketDataFrame(frame)
@@ -140,10 +147,7 @@ func (t *FabricSessionPacketTransport) RunFrameIngress(ctx context.Context) erro
if !ok {
return nil
}
if frame.Type != fabricproto.FrameData {
continue
}
if t.StreamID != 0 && frame.StreamID != t.StreamID {
if frame.Type != fabricproto.FrameData || !t.acceptsStream(frame.StreamID) {
continue
}
if err := t.Inbox.DeliverFabricSessionFrame(ctx, frame); err != nil {
@@ -152,3 +156,95 @@ func (t *FabricSessionPacketTransport) RunFrameIngress(ctx context.Context) erro
}
}
}
func (t *FabricSessionPacketTransport) selectStreamForPackets(packets [][]byte) (uint64, string) {
trafficClass := fabricSessionTrafficClassForPackets(t.TrafficClass, packets)
if ids := t.streamIDsForTrafficClass(trafficClass); len(ids) > 0 {
if len(ids) == 1 || len(packets) == 0 {
return ids[0], trafficClass
}
_, shard := classifyPacketFlow(packets[0], len(ids))
return ids[shard], trafficClass
}
if len(t.StreamIDs) > 0 {
if len(t.StreamIDs) == 1 || len(packets) == 0 {
return t.StreamIDs[0], trafficClass
}
_, shard := classifyPacketFlow(packets[0], len(t.StreamIDs))
return t.StreamIDs[shard], trafficClass
}
return t.StreamID, trafficClass
}
func (t *FabricSessionPacketTransport) hasSendStream() bool {
if t == nil {
return false
}
if t.StreamID != 0 || len(t.StreamIDs) > 0 {
return true
}
for _, ids := range t.StreamIDsByTrafficClass {
if len(ids) > 0 {
return true
}
}
return false
}
func (t *FabricSessionPacketTransport) streamIDsForTrafficClass(trafficClass string) []uint64 {
if t == nil || len(t.StreamIDsByTrafficClass) == 0 {
return nil
}
if ids := t.StreamIDsByTrafficClass[normalizeFabricTrafficClass(trafficClass)]; len(ids) > 0 {
return ids
}
if normalizeFabricTrafficClass(trafficClass) == FabricTrafficClassReliable {
return t.StreamIDsByTrafficClass[FabricTrafficClassBulk]
}
return nil
}
func (t *FabricSessionPacketTransport) acceptsStream(streamID uint64) bool {
if t == nil || streamID == 0 {
return false
}
if t.StreamID != 0 && streamID == t.StreamID {
return true
}
for _, id := range t.StreamIDs {
if id == streamID {
return true
}
}
for _, ids := range t.StreamIDsByTrafficClass {
for _, id := range ids {
if id == streamID {
return true
}
}
}
return t.StreamID == 0 && len(t.StreamIDs) == 0 && len(t.StreamIDsByTrafficClass) == 0
}
func (t *FabricSessionPacketTransport) nextSequence(streamID uint64) uint64 {
if streamID == 0 {
return atomic.AddUint64(&t.sequence, 1)
}
t.sequenceMu.Lock()
defer t.sequenceMu.Unlock()
if t.sequenceByStream == nil {
t.sequenceByStream = map[uint64]uint64{}
}
t.sequenceByStream[streamID]++
return t.sequenceByStream[streamID]
}
func fabricSessionTrafficClassForPackets(fallback string, packets [][]byte) string {
if fallback = normalizeFabricTrafficClass(fallback); fallback != "" && fallback != FabricTrafficClassBulk {
return fallback
}
if batchHasTCPControlPacket(packets) {
return FabricTrafficClassInteractive
}
return FabricTrafficClassBulk
}
@@ -227,6 +227,42 @@ func TestFabricSessionPacketTransportSendsDataFrame(t *testing.T) {
}
}
func TestFabricSessionPacketTransportShardsStreamsByTrafficClass(t *testing.T) {
sender := &captureFabricSessionSender{}
transport := &FabricSessionPacketTransport{
Sender: sender,
StreamID: 700,
VPNConnectionID: "vpn-1",
SendDirection: FabricDirectionClientToGateway,
StreamIDsByTrafficClass: map[string][]uint64{
FabricTrafficClassInteractive: []uint64{801, 802},
FabricTrafficClassBulk: []uint64{901, 902},
},
}
bulkPacket := testIPv4TCPPacket([4]byte{10, 77, 0, 2}, [4]byte{192, 168, 200, 95}, 51000, 443)
controlPacket := testIPv4TCPPacket([4]byte{10, 77, 0, 2}, [4]byte{192, 168, 200, 95}, 51001, 3389)
controlPacket[33] = 0x02
if err := transport.SendGatewayPacketBatch(context.Background(), [][]byte{bulkPacket}); err != nil {
t.Fatalf("send bulk packet: %v", err)
}
if err := transport.SendGatewayPacketBatch(context.Background(), [][]byte{controlPacket}); err != nil {
t.Fatalf("send control packet: %v", err)
}
if len(sender.frames) != 2 {
t.Fatalf("sent frames = %d, want 2", len(sender.frames))
}
if sender.frames[0].TrafficClass != fabricproto.TrafficClassBulk || sender.frames[0].StreamID < 901 || sender.frames[0].StreamID > 902 {
t.Fatalf("bulk frame did not use bulk shard: %+v", sender.frames[0])
}
if sender.frames[1].TrafficClass != fabricproto.TrafficClassInteractive || sender.frames[1].StreamID < 801 || sender.frames[1].StreamID > 802 {
t.Fatalf("control frame did not use interactive shard: %+v", sender.frames[1])
}
if sender.frames[0].Sequence != 1 || sender.frames[1].Sequence != 1 {
t.Fatalf("per-stream sequences = %d/%d, want 1/1", sender.frames[0].Sequence, sender.frames[1].Sequence)
}
}
func TestFabricSessionPacketTransportRunFrameIngressDeliversInbox(t *testing.T) {
inbox := NewFabricPacketInbox(4)
receiver := memoryFabricSessionReceiver{