Close VPN fabric session stream shards
This commit is contained in:
@@ -21,6 +21,10 @@ type FabricSessionFrameReceiver interface {
|
||||
Errors() <-chan error
|
||||
}
|
||||
|
||||
type FabricSessionCloser interface {
|
||||
Close() error
|
||||
}
|
||||
|
||||
type FabricSessionPacketTransport struct {
|
||||
Sender FabricSessionFrameSender
|
||||
Receiver FabricSessionFrameReceiver
|
||||
@@ -42,6 +46,8 @@ type FabricSessionPacketTransport struct {
|
||||
sendFramesByClass map[string]uint64
|
||||
sendPacketsByClass map[string]uint64
|
||||
sendFramesByStream map[uint64]uint64
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func (t *FabricSessionPacketTransport) SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error {
|
||||
@@ -166,6 +172,32 @@ func (t *FabricSessionPacketTransport) RunFrameIngress(ctx context.Context) erro
|
||||
}
|
||||
}
|
||||
|
||||
func (t *FabricSessionPacketTransport) Close() error {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
t.closeOnce.Do(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
for _, streamID := range t.allStreamIDs() {
|
||||
if t.Sender != nil {
|
||||
if err := t.Sender.Send(ctx, fabricproto.Frame{
|
||||
Type: fabricproto.FrameCloseStream,
|
||||
StreamID: streamID,
|
||||
}); err != nil && t.closeErr == nil {
|
||||
t.closeErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
if closer, ok := t.Sender.(FabricSessionCloser); ok {
|
||||
if err := closer.Close(); err != nil && t.closeErr == nil {
|
||||
t.closeErr = err
|
||||
}
|
||||
}
|
||||
})
|
||||
return t.closeErr
|
||||
}
|
||||
|
||||
func (t *FabricSessionPacketTransport) selectStreamForPackets(packets [][]byte) (uint64, string) {
|
||||
trafficClass := fabricSessionTrafficClassForPackets(t.TrafficClass, packets)
|
||||
if ids := t.streamIDsForTrafficClass(trafficClass); len(ids) > 0 {
|
||||
@@ -185,6 +217,34 @@ func (t *FabricSessionPacketTransport) selectStreamForPackets(packets [][]byte)
|
||||
return t.StreamID, trafficClass
|
||||
}
|
||||
|
||||
func (t *FabricSessionPacketTransport) allStreamIDs() []uint64 {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
seen := map[uint64]struct{}{}
|
||||
var out []uint64
|
||||
add := func(streamID uint64) {
|
||||
if streamID == 0 {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[streamID]; ok {
|
||||
return
|
||||
}
|
||||
seen[streamID] = struct{}{}
|
||||
out = append(out, streamID)
|
||||
}
|
||||
add(t.StreamID)
|
||||
for _, streamID := range t.StreamIDs {
|
||||
add(streamID)
|
||||
}
|
||||
for _, ids := range t.StreamIDsByTrafficClass {
|
||||
for _, streamID := range ids {
|
||||
add(streamID)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *FabricSessionPacketTransport) hasSendStream() bool {
|
||||
if t == nil {
|
||||
return false
|
||||
|
||||
Reference in New Issue
Block a user