diff --git a/agents/rap-node-agent/internal/vpnruntime/fabric_session_transport.go b/agents/rap-node-agent/internal/vpnruntime/fabric_session_transport.go index 2ff995a..95c51a2 100644 --- a/agents/rap-node-agent/internal/vpnruntime/fabric_session_transport.go +++ b/agents/rap-node-agent/internal/vpnruntime/fabric_session_transport.go @@ -21,6 +21,10 @@ type FabricSessionFrameReceiver interface { Errors() <-chan error } +type FabricSessionCloser interface { + Close() error +} + type FabricSessionPacketTransport struct { Sender FabricSessionFrameSender Receiver FabricSessionFrameReceiver @@ -42,6 +46,8 @@ type FabricSessionPacketTransport struct { sendFramesByClass map[string]uint64 sendPacketsByClass map[string]uint64 sendFramesByStream map[uint64]uint64 + closeOnce sync.Once + closeErr error } func (t *FabricSessionPacketTransport) SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error { @@ -166,6 +172,32 @@ func (t *FabricSessionPacketTransport) RunFrameIngress(ctx context.Context) erro } } +func (t *FabricSessionPacketTransport) Close() error { + if t == nil { + return nil + } + t.closeOnce.Do(func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + for _, streamID := range t.allStreamIDs() { + if t.Sender != nil { + if err := t.Sender.Send(ctx, fabricproto.Frame{ + Type: fabricproto.FrameCloseStream, + StreamID: streamID, + }); err != nil && t.closeErr == nil { + t.closeErr = err + } + } + } + if closer, ok := t.Sender.(FabricSessionCloser); ok { + if err := closer.Close(); err != nil && t.closeErr == nil { + t.closeErr = err + } + } + }) + return t.closeErr +} + func (t *FabricSessionPacketTransport) selectStreamForPackets(packets [][]byte) (uint64, string) { trafficClass := fabricSessionTrafficClassForPackets(t.TrafficClass, packets) if ids := t.streamIDsForTrafficClass(trafficClass); len(ids) > 0 { @@ -185,6 +217,34 @@ func (t *FabricSessionPacketTransport) selectStreamForPackets(packets [][]byte) return t.StreamID, trafficClass } +func (t *FabricSessionPacketTransport) allStreamIDs() []uint64 { + if t == nil { + return nil + } + seen := map[uint64]struct{}{} + var out []uint64 + add := func(streamID uint64) { + if streamID == 0 { + return + } + if _, ok := seen[streamID]; ok { + return + } + seen[streamID] = struct{}{} + out = append(out, streamID) + } + add(t.StreamID) + for _, streamID := range t.StreamIDs { + add(streamID) + } + for _, ids := range t.StreamIDsByTrafficClass { + for _, streamID := range ids { + add(streamID) + } + } + return out +} + func (t *FabricSessionPacketTransport) hasSendStream() bool { if t == nil { return false diff --git a/agents/rap-node-agent/internal/vpnruntime/fabric_transport_test.go b/agents/rap-node-agent/internal/vpnruntime/fabric_transport_test.go index 9907368..aeaa7f1 100644 --- a/agents/rap-node-agent/internal/vpnruntime/fabric_transport_test.go +++ b/agents/rap-node-agent/internal/vpnruntime/fabric_transport_test.go @@ -126,6 +126,7 @@ type memoryPacketTransport struct { type captureFabricSessionSender struct { err error frames []fabricproto.Frame + closed bool } func (s *captureFabricSessionSender) Send(_ context.Context, frame fabricproto.Frame) error { @@ -136,6 +137,11 @@ func (s *captureFabricSessionSender) Send(_ context.Context, frame fabricproto.F return nil } +func (s *captureFabricSessionSender) Close() error { + s.closed = true + return nil +} + type memoryFabricSessionReceiver struct { frames chan fabricproto.Frame errors chan error @@ -278,6 +284,35 @@ func TestFabricSessionPacketTransportShardsStreamsByTrafficClass(t *testing.T) { } } +func TestFabricSessionPacketTransportClosesAllStreamShards(t *testing.T) { + sender := &captureFabricSessionSender{} + transport := &FabricSessionPacketTransport{ + Sender: sender, + StreamID: 700, + StreamIDsByTrafficClass: map[string][]uint64{ + FabricTrafficClassInteractive: []uint64{801, 802}, + FabricTrafficClassBulk: []uint64{901, 902}, + }, + } + if err := transport.Close(); err != nil { + t.Fatalf("close transport: %v", err) + } + if !sender.closed { + t.Fatal("underlying fabric session was not closed") + } + closed := map[uint64]bool{} + for _, frame := range sender.frames { + if frame.Type == fabricproto.FrameCloseStream { + closed[frame.StreamID] = true + } + } + for _, streamID := range []uint64{700, 801, 802, 901, 902} { + if !closed[streamID] { + t.Fatalf("stream %d was not closed; frames=%+v", streamID, sender.frames) + } + } +} + func TestFabricSessionPacketTransportRunFrameIngressDeliversInbox(t *testing.T) { inbox := NewFabricPacketInbox(4) receiver := memoryFabricSessionReceiver{ diff --git a/agents/rap-node-agent/internal/vpnruntime/gateway.go b/agents/rap-node-agent/internal/vpnruntime/gateway.go index 2bb3c20..35efbc5 100644 --- a/agents/rap-node-agent/internal/vpnruntime/gateway.go +++ b/agents/rap-node-agent/internal/vpnruntime/gateway.go @@ -69,6 +69,10 @@ type packetTransportSnapshotter interface { Snapshot() map[string]any } +type packetTransportCloser interface { + Close() error +} + type BackendPacketTransport struct { API *client.Client ClusterID string @@ -259,6 +263,9 @@ func (g *Gateway) normalize() error { func (g *Gateway) run(ctx context.Context, tun readWriteCloser) error { defer tun.Close() + if closer, ok := g.Transport.(packetTransportCloser); ok { + defer closer.Close() + } errCh := make(chan error, 2) go func() { errCh <- g.copyGatewayToClient(ctx, tun) }() diff --git a/docs/architecture/DISTRIBUTED_FABRIC_NODE_PROTOCOL_PLAN.md b/docs/architecture/DISTRIBUTED_FABRIC_NODE_PROTOCOL_PLAN.md index 8063669..578ee5c 100644 --- a/docs/architecture/DISTRIBUTED_FABRIC_NODE_PROTOCOL_PLAN.md +++ b/docs/architecture/DISTRIBUTED_FABRIC_NODE_PROTOCOL_PLAN.md @@ -391,6 +391,9 @@ layout and send counters by traffic class/stream id for load-test diagnosis. Those snapshots also summarize configured stream class/shard counts and active send class/stream counts, making sharding health visible without expanding per-stream maps. +Gateway shutdown now closes all VPN fabric-session stream shards and then the +underlying fabric session, preventing stale logical streams from consuming QUIC +carrier capacity after reconnects or rollout restarts. Endpoint ranking treats `capacity_limited` observations as a soft pressure penalty instead of a hard recent failure, enabling load spreading without marking the carrier unhealthy.