Shard VPN fabric session streams
This commit is contained in:
@@ -3,6 +3,7 @@ package vpnruntime
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -30,7 +31,12 @@ type FabricSessionPacketTransport struct {
|
||||
ReceiveDirection string
|
||||
TrafficClass string
|
||||
|
||||
sequence uint64
|
||||
StreamIDsByTrafficClass map[string][]uint64
|
||||
StreamIDs []uint64
|
||||
|
||||
sequence uint64
|
||||
sequenceMu sync.Mutex
|
||||
sequenceByStream map[uint64]uint64
|
||||
}
|
||||
|
||||
func (t *FabricSessionPacketTransport) SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error {
|
||||
@@ -41,19 +47,20 @@ func (t *FabricSessionPacketTransport) SendGatewayPacketBatch(ctx context.Contex
|
||||
if t == nil || t.Sender == nil {
|
||||
return mesh.ErrForwardRuntimeUnavailable
|
||||
}
|
||||
if t.StreamID == 0 || t.VPNConnectionID == "" {
|
||||
if !t.hasSendStream() || t.VPNConnectionID == "" {
|
||||
return errors.New("fabric session packet transport identity is incomplete")
|
||||
}
|
||||
direction := t.SendDirection
|
||||
if direction == "" {
|
||||
direction = FabricDirectionGatewayToClient
|
||||
}
|
||||
streamID, trafficClass := t.selectStreamForPackets(packets)
|
||||
frame, err := NewFabricVPNPacketDataFrame(FabricVPNPacketFrameInput{
|
||||
StreamID: t.StreamID,
|
||||
Sequence: atomic.AddUint64(&t.sequence, 1),
|
||||
StreamID: streamID,
|
||||
Sequence: t.nextSequence(streamID),
|
||||
VPNConnectionID: t.VPNConnectionID,
|
||||
Direction: direction,
|
||||
TrafficClass: t.TrafficClass,
|
||||
TrafficClass: trafficClass,
|
||||
Packets: packets,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -101,7 +108,7 @@ func (t *FabricSessionPacketTransport) ReceiveGatewayPacketBatch(ctx context.Con
|
||||
if !ok {
|
||||
return t.Inbox.Receive(ctx, t.VPNConnectionID, direction, 5*time.Millisecond)
|
||||
}
|
||||
if frame.Type != fabricproto.FrameData || (t.StreamID != 0 && frame.StreamID != t.StreamID) {
|
||||
if frame.Type != fabricproto.FrameData || !t.acceptsStream(frame.StreamID) {
|
||||
continue
|
||||
}
|
||||
payload, err := DecodeFabricVPNPacketDataFrame(frame)
|
||||
@@ -140,10 +147,7 @@ func (t *FabricSessionPacketTransport) RunFrameIngress(ctx context.Context) erro
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if frame.Type != fabricproto.FrameData {
|
||||
continue
|
||||
}
|
||||
if t.StreamID != 0 && frame.StreamID != t.StreamID {
|
||||
if frame.Type != fabricproto.FrameData || !t.acceptsStream(frame.StreamID) {
|
||||
continue
|
||||
}
|
||||
if err := t.Inbox.DeliverFabricSessionFrame(ctx, frame); err != nil {
|
||||
@@ -152,3 +156,95 @@ func (t *FabricSessionPacketTransport) RunFrameIngress(ctx context.Context) erro
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *FabricSessionPacketTransport) selectStreamForPackets(packets [][]byte) (uint64, string) {
|
||||
trafficClass := fabricSessionTrafficClassForPackets(t.TrafficClass, packets)
|
||||
if ids := t.streamIDsForTrafficClass(trafficClass); len(ids) > 0 {
|
||||
if len(ids) == 1 || len(packets) == 0 {
|
||||
return ids[0], trafficClass
|
||||
}
|
||||
_, shard := classifyPacketFlow(packets[0], len(ids))
|
||||
return ids[shard], trafficClass
|
||||
}
|
||||
if len(t.StreamIDs) > 0 {
|
||||
if len(t.StreamIDs) == 1 || len(packets) == 0 {
|
||||
return t.StreamIDs[0], trafficClass
|
||||
}
|
||||
_, shard := classifyPacketFlow(packets[0], len(t.StreamIDs))
|
||||
return t.StreamIDs[shard], trafficClass
|
||||
}
|
||||
return t.StreamID, trafficClass
|
||||
}
|
||||
|
||||
func (t *FabricSessionPacketTransport) hasSendStream() bool {
|
||||
if t == nil {
|
||||
return false
|
||||
}
|
||||
if t.StreamID != 0 || len(t.StreamIDs) > 0 {
|
||||
return true
|
||||
}
|
||||
for _, ids := range t.StreamIDsByTrafficClass {
|
||||
if len(ids) > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *FabricSessionPacketTransport) streamIDsForTrafficClass(trafficClass string) []uint64 {
|
||||
if t == nil || len(t.StreamIDsByTrafficClass) == 0 {
|
||||
return nil
|
||||
}
|
||||
if ids := t.StreamIDsByTrafficClass[normalizeFabricTrafficClass(trafficClass)]; len(ids) > 0 {
|
||||
return ids
|
||||
}
|
||||
if normalizeFabricTrafficClass(trafficClass) == FabricTrafficClassReliable {
|
||||
return t.StreamIDsByTrafficClass[FabricTrafficClassBulk]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *FabricSessionPacketTransport) acceptsStream(streamID uint64) bool {
|
||||
if t == nil || streamID == 0 {
|
||||
return false
|
||||
}
|
||||
if t.StreamID != 0 && streamID == t.StreamID {
|
||||
return true
|
||||
}
|
||||
for _, id := range t.StreamIDs {
|
||||
if id == streamID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, ids := range t.StreamIDsByTrafficClass {
|
||||
for _, id := range ids {
|
||||
if id == streamID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return t.StreamID == 0 && len(t.StreamIDs) == 0 && len(t.StreamIDsByTrafficClass) == 0
|
||||
}
|
||||
|
||||
func (t *FabricSessionPacketTransport) nextSequence(streamID uint64) uint64 {
|
||||
if streamID == 0 {
|
||||
return atomic.AddUint64(&t.sequence, 1)
|
||||
}
|
||||
t.sequenceMu.Lock()
|
||||
defer t.sequenceMu.Unlock()
|
||||
if t.sequenceByStream == nil {
|
||||
t.sequenceByStream = map[uint64]uint64{}
|
||||
}
|
||||
t.sequenceByStream[streamID]++
|
||||
return t.sequenceByStream[streamID]
|
||||
}
|
||||
|
||||
func fabricSessionTrafficClassForPackets(fallback string, packets [][]byte) string {
|
||||
if fallback = normalizeFabricTrafficClass(fallback); fallback != "" && fallback != FabricTrafficClassBulk {
|
||||
return fallback
|
||||
}
|
||||
if batchHasTCPControlPacket(packets) {
|
||||
return FabricTrafficClassInteractive
|
||||
}
|
||||
return FabricTrafficClassBulk
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user