Close VPN fabric session stream shards

This commit is contained in:
2026-05-16 12:26:51 +03:00
parent a5b91113bf
commit d170820445
4 changed files with 105 additions and 0 deletions
@@ -21,6 +21,10 @@ type FabricSessionFrameReceiver interface {
Errors() <-chan error
}
type FabricSessionCloser interface {
Close() error
}
type FabricSessionPacketTransport struct {
Sender FabricSessionFrameSender
Receiver FabricSessionFrameReceiver
@@ -42,6 +46,8 @@ type FabricSessionPacketTransport struct {
sendFramesByClass map[string]uint64
sendPacketsByClass map[string]uint64
sendFramesByStream map[uint64]uint64
closeOnce sync.Once
closeErr error
}
func (t *FabricSessionPacketTransport) SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error {
@@ -166,6 +172,32 @@ func (t *FabricSessionPacketTransport) RunFrameIngress(ctx context.Context) erro
}
}
func (t *FabricSessionPacketTransport) Close() error {
if t == nil {
return nil
}
t.closeOnce.Do(func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
for _, streamID := range t.allStreamIDs() {
if t.Sender != nil {
if err := t.Sender.Send(ctx, fabricproto.Frame{
Type: fabricproto.FrameCloseStream,
StreamID: streamID,
}); err != nil && t.closeErr == nil {
t.closeErr = err
}
}
}
if closer, ok := t.Sender.(FabricSessionCloser); ok {
if err := closer.Close(); err != nil && t.closeErr == nil {
t.closeErr = err
}
}
})
return t.closeErr
}
func (t *FabricSessionPacketTransport) selectStreamForPackets(packets [][]byte) (uint64, string) {
trafficClass := fabricSessionTrafficClassForPackets(t.TrafficClass, packets)
if ids := t.streamIDsForTrafficClass(trafficClass); len(ids) > 0 {
@@ -185,6 +217,34 @@ func (t *FabricSessionPacketTransport) selectStreamForPackets(packets [][]byte)
return t.StreamID, trafficClass
}
func (t *FabricSessionPacketTransport) allStreamIDs() []uint64 {
if t == nil {
return nil
}
seen := map[uint64]struct{}{}
var out []uint64
add := func(streamID uint64) {
if streamID == 0 {
return
}
if _, ok := seen[streamID]; ok {
return
}
seen[streamID] = struct{}{}
out = append(out, streamID)
}
add(t.StreamID)
for _, streamID := range t.StreamIDs {
add(streamID)
}
for _, ids := range t.StreamIDsByTrafficClass {
for _, streamID := range ids {
add(streamID)
}
}
return out
}
func (t *FabricSessionPacketTransport) hasSendStream() bool {
if t == nil {
return false