package vpnruntime import ( "context" "errors" "fmt" "sync" "sync/atomic" "time" "github.com/example/remote-access-platform/agents/rap-node-agent/internal/fabricproto" "github.com/example/remote-access-platform/agents/rap-node-agent/internal/mesh" ) type FabricSessionFrameSender interface { Send(context.Context, fabricproto.Frame) error } type FabricSessionFrameReceiver interface { Frames() <-chan fabricproto.Frame Errors() <-chan error } type FabricSessionCloser interface { Close() error } type FabricSessionPacketTransport struct { Sender FabricSessionFrameSender Receiver FabricSessionFrameReceiver Inbox *FabricPacketInbox StreamID uint64 VPNConnectionID string SendDirection string ReceiveDirection string TrafficClass string StreamIDsByTrafficClass map[string][]uint64 StreamIDs []uint64 sequence uint64 sequenceMu sync.Mutex sequenceByStream map[uint64]uint64 statsMu sync.Mutex sendFramesByClass map[string]uint64 sendPacketsByClass map[string]uint64 sendFramesByStream map[uint64]uint64 sendPacketsByStream map[uint64]uint64 splitBatchCount uint64 lastBatchFrameCount uint64 maxBatchFrameCount uint64 receiveFramesByClass map[string]uint64 receivePacketsByClass map[string]uint64 receiveFramesByStream map[uint64]uint64 receivePacketsByStream map[uint64]uint64 closeStreamFrames uint64 closeErrors uint64 closeOnce sync.Once closeErr error } func (t *FabricSessionPacketTransport) SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error { packets = cleanPacketBatch(packets) if len(packets) == 0 { return nil } if t == nil || t.Sender == nil { return mesh.ErrForwardRuntimeUnavailable } if !t.hasSendStream() || t.VPNConnectionID == "" { return errors.New("fabric session packet transport identity is incomplete") } direction := t.SendDirection if direction == "" { direction = FabricDirectionGatewayToClient } groups := t.groupPacketsByStream(packets) for _, group := range groups { frame, err := NewFabricVPNPacketDataFrame(FabricVPNPacketFrameInput{ StreamID: group.StreamID, Sequence: t.nextSequence(group.StreamID), VPNConnectionID: t.VPNConnectionID, Direction: direction, TrafficClass: group.TrafficClass, Packets: group.Packets, }) if err != nil { return err } if err := t.Sender.Send(ctx, frame); err != nil { return err } t.recordSend(group.StreamID, group.TrafficClass, len(group.Packets)) } t.recordBatchFanout(len(groups)) return nil } func (t *FabricSessionPacketTransport) ReceiveGatewayPacketBatch(ctx context.Context, timeout time.Duration) ([][]byte, error) { if t == nil || t.Inbox == nil { return nil, mesh.ErrForwardRuntimeUnavailable } direction := t.ReceiveDirection if direction == "" { direction = FabricDirectionClientToGateway } if packets, err := t.Inbox.Receive(ctx, t.VPNConnectionID, direction, 5*time.Millisecond); err != nil || len(packets) > 0 { return packets, err } if t.Receiver == nil { return t.Inbox.Receive(ctx, t.VPNConnectionID, direction, timeout) } if timeout <= 0 { timeout = 25 * time.Second } timer := time.NewTimer(timeout) defer timer.Stop() frames := t.Receiver.Frames() errorsCh := t.Receiver.Errors() for { select { case <-ctx.Done(): return nil, ctx.Err() case <-timer.C: return nil, nil case err, ok := <-errorsCh: if !ok { errorsCh = nil continue } if err != nil { return nil, err } case frame, ok := <-frames: if !ok { return t.Inbox.Receive(ctx, t.VPNConnectionID, direction, 5*time.Millisecond) } if frame.Type != fabricproto.FrameData || !t.acceptsStream(frame.StreamID) { continue } payload, err := DecodeFabricVPNPacketDataFrame(frame) if err != nil { return nil, err } if payload.VPNConnectionID == t.VPNConnectionID && payload.Direction == direction { t.recordReceive(frame.StreamID, fabricSessionTrafficClassName(frame.TrafficClass), len(payload.Packets)) return cleanPacketBatch(payload.Packets), nil } if err := t.deliverDecodedFabricSessionFrame(frame, payload); err != nil { return nil, err } } } } func (t *FabricSessionPacketTransport) RunFrameIngress(ctx context.Context) error { if t == nil || t.Receiver == nil || t.Inbox == nil { return mesh.ErrForwardRuntimeUnavailable } frames := t.Receiver.Frames() errorsCh := t.Receiver.Errors() for { select { case <-ctx.Done(): return ctx.Err() case err, ok := <-errorsCh: if !ok { errorsCh = nil continue } if err != nil { return err } case frame, ok := <-frames: if !ok { return nil } if frame.Type != fabricproto.FrameData || !t.acceptsStream(frame.StreamID) { continue } payload, err := DecodeFabricVPNPacketDataFrame(frame) if err != nil { return err } if err := t.deliverDecodedFabricSessionFrame(frame, payload); err != nil { return err } } } } func (t *FabricSessionPacketTransport) deliverDecodedFabricSessionFrame(frame fabricproto.Frame, payload mesh.VPNPacketBatchPayload) error { if t == nil || t.Inbox == nil { return mesh.ErrForwardRuntimeUnavailable } payload.Packets = cleanPacketBatch(payload.Packets) if len(payload.Packets) == 0 { return nil } t.recordReceive(frame.StreamID, fabricSessionTrafficClassName(frame.TrafficClass), len(payload.Packets)) return t.Inbox.enqueue(payload) } func (t *FabricSessionPacketTransport) Close() error { if t == nil { return nil } t.closeOnce.Do(func() { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() for _, streamID := range t.allStreamIDs() { if t.Sender != nil { if err := t.Sender.Send(ctx, fabricproto.Frame{ Type: fabricproto.FrameCloseStream, StreamID: streamID, }); err != nil { t.recordCloseError() if t.closeErr == nil { t.closeErr = err } } else if err == nil { t.recordCloseStream() } } } if closer, ok := t.Sender.(FabricSessionCloser); ok { if err := closer.Close(); err != nil { t.recordCloseError() if t.closeErr == nil { t.closeErr = err } } } }) return t.closeErr } func (t *FabricSessionPacketTransport) selectStreamForPackets(packets [][]byte) (uint64, string) { trafficClass := fabricSessionTrafficClassForPackets(t.TrafficClass, packets) if ids := t.streamIDsForTrafficClass(trafficClass); len(ids) > 0 { if len(ids) == 1 || len(packets) == 0 { return ids[0], trafficClass } _, shard := classifyPacketFlow(packets[0], len(ids)) return ids[shard], trafficClass } if len(t.StreamIDs) > 0 { if len(t.StreamIDs) == 1 || len(packets) == 0 { return t.StreamIDs[0], trafficClass } _, shard := classifyPacketFlow(packets[0], len(t.StreamIDs)) return t.StreamIDs[shard], trafficClass } return t.StreamID, trafficClass } type fabricSessionPacketGroup struct { StreamID uint64 TrafficClass string Packets [][]byte } func (t *FabricSessionPacketTransport) groupPacketsByStream(packets [][]byte) []fabricSessionPacketGroup { groups := []fabricSessionPacketGroup{} indexByKey := map[string]int{} for _, packet := range packets { streamID, trafficClass := t.selectStreamForPackets([][]byte{packet}) key := fmt.Sprintf("%d\x00%s", streamID, trafficClass) index, ok := indexByKey[key] if !ok { index = len(groups) indexByKey[key] = index groups = append(groups, fabricSessionPacketGroup{ StreamID: streamID, TrafficClass: trafficClass, }) } groups[index].Packets = append(groups[index].Packets, packet) } return groups } func (t *FabricSessionPacketTransport) allStreamIDs() []uint64 { if t == nil { return nil } seen := map[uint64]struct{}{} var out []uint64 add := func(streamID uint64) { if streamID == 0 { return } if _, ok := seen[streamID]; ok { return } seen[streamID] = struct{}{} out = append(out, streamID) } add(t.StreamID) for _, streamID := range t.StreamIDs { add(streamID) } for _, ids := range t.StreamIDsByTrafficClass { for _, streamID := range ids { add(streamID) } } return out } func (t *FabricSessionPacketTransport) hasSendStream() bool { if t == nil { return false } if t.StreamID != 0 || len(t.StreamIDs) > 0 { return true } for _, ids := range t.StreamIDsByTrafficClass { if len(ids) > 0 { return true } } return false } func (t *FabricSessionPacketTransport) streamIDsForTrafficClass(trafficClass string) []uint64 { if t == nil || len(t.StreamIDsByTrafficClass) == 0 { return nil } if ids := t.StreamIDsByTrafficClass[normalizeFabricTrafficClass(trafficClass)]; len(ids) > 0 { return ids } if normalizeFabricTrafficClass(trafficClass) == FabricTrafficClassReliable { return t.StreamIDsByTrafficClass[FabricTrafficClassBulk] } return nil } func (t *FabricSessionPacketTransport) acceptsStream(streamID uint64) bool { if t == nil || streamID == 0 { return false } if t.StreamID != 0 && streamID == t.StreamID { return true } for _, id := range t.StreamIDs { if id == streamID { return true } } for _, ids := range t.StreamIDsByTrafficClass { for _, id := range ids { if id == streamID { return true } } } return t.StreamID == 0 && len(t.StreamIDs) == 0 && len(t.StreamIDsByTrafficClass) == 0 } func (t *FabricSessionPacketTransport) nextSequence(streamID uint64) uint64 { if streamID == 0 { return atomic.AddUint64(&t.sequence, 1) } t.sequenceMu.Lock() defer t.sequenceMu.Unlock() if t.sequenceByStream == nil { t.sequenceByStream = map[uint64]uint64{} } t.sequenceByStream[streamID]++ return t.sequenceByStream[streamID] } func (t *FabricSessionPacketTransport) recordSend(streamID uint64, trafficClass string, packetCount int) { if t == nil { return } trafficClass = normalizeFabricTrafficClass(trafficClass) t.statsMu.Lock() defer t.statsMu.Unlock() if t.sendFramesByClass == nil { t.sendFramesByClass = map[string]uint64{} } if t.sendPacketsByClass == nil { t.sendPacketsByClass = map[string]uint64{} } if t.sendFramesByStream == nil { t.sendFramesByStream = map[uint64]uint64{} } if t.sendPacketsByStream == nil { t.sendPacketsByStream = map[uint64]uint64{} } t.sendFramesByClass[trafficClass]++ t.sendPacketsByClass[trafficClass] += uint64(packetCount) t.sendFramesByStream[streamID]++ t.sendPacketsByStream[streamID] += uint64(packetCount) } func (t *FabricSessionPacketTransport) recordBatchFanout(frameCount int) { if t == nil || frameCount <= 0 { return } t.statsMu.Lock() defer t.statsMu.Unlock() t.lastBatchFrameCount = uint64(frameCount) if frameCount > 1 { t.splitBatchCount++ } if uint64(frameCount) > t.maxBatchFrameCount { t.maxBatchFrameCount = uint64(frameCount) } } func (t *FabricSessionPacketTransport) recordReceive(streamID uint64, trafficClass string, packetCount int) { if t == nil { return } trafficClass = normalizeFabricTrafficClass(trafficClass) t.statsMu.Lock() defer t.statsMu.Unlock() if t.receiveFramesByClass == nil { t.receiveFramesByClass = map[string]uint64{} } if t.receivePacketsByClass == nil { t.receivePacketsByClass = map[string]uint64{} } if t.receiveFramesByStream == nil { t.receiveFramesByStream = map[uint64]uint64{} } if t.receivePacketsByStream == nil { t.receivePacketsByStream = map[uint64]uint64{} } t.receiveFramesByClass[trafficClass]++ t.receivePacketsByClass[trafficClass] += uint64(packetCount) t.receiveFramesByStream[streamID]++ t.receivePacketsByStream[streamID] += uint64(packetCount) } func (t *FabricSessionPacketTransport) Snapshot() map[string]any { if t == nil { return nil } t.statsMu.Lock() sendFramesByClass := copyStringUint64Map(t.sendFramesByClass) sendPacketsByClass := copyStringUint64Map(t.sendPacketsByClass) receiveFramesByClass := copyStringUint64Map(t.receiveFramesByClass) receivePacketsByClass := copyStringUint64Map(t.receivePacketsByClass) lastBatchFrameCount := t.lastBatchFrameCount maxBatchFrameCount := t.maxBatchFrameCount splitBatchCount := t.splitBatchCount closeStreamFrames := t.closeStreamFrames closeErrors := t.closeErrors sendFramesByStream := make(map[string]uint64, len(t.sendFramesByStream)) for streamID, count := range t.sendFramesByStream { sendFramesByStream[fmt.Sprintf("%d", streamID)] = count } sendPacketsByStream := make(map[string]uint64, len(t.sendPacketsByStream)) for streamID, count := range t.sendPacketsByStream { sendPacketsByStream[fmt.Sprintf("%d", streamID)] = count } receiveFramesByStream := make(map[string]uint64, len(t.receiveFramesByStream)) for streamID, count := range t.receiveFramesByStream { receiveFramesByStream[fmt.Sprintf("%d", streamID)] = count } receivePacketsByStream := make(map[string]uint64, len(t.receivePacketsByStream)) for streamID, count := range t.receivePacketsByStream { receivePacketsByStream[fmt.Sprintf("%d", streamID)] = count } t.statsMu.Unlock() streamIDsByClass := copyStreamIDsByTrafficClass(t.StreamIDsByTrafficClass) return map[string]any{ "schema_version": "rap.vpn_fabric_session_packet_transport.v1", "stream_id": t.StreamID, "stream_ids_by_class": streamIDsByClass, "stream_class_count": len(streamIDsByClass), "stream_shard_count": countStreamIDs(streamIDsByClass) + len(t.StreamIDs), "send_class_count": countNonZeroStringUint64Values(sendFramesByClass), "send_stream_count": countNonZeroStringUint64Values(sendFramesByStream), "sharding_active": len(streamIDsByClass) > 1 || countStreamIDs(streamIDsByClass)+len(t.StreamIDs) > 1, "split_batch_count": splitBatchCount, "last_batch_frame_count": lastBatchFrameCount, "max_batch_frame_count": maxBatchFrameCount, "close_stream_frames": closeStreamFrames, "close_errors": closeErrors, "send_frames_by_class": sendFramesByClass, "send_packets_by_class": sendPacketsByClass, "send_frames_by_stream_id": sendFramesByStream, "send_packets_by_stream_id": sendPacketsByStream, "receive_frames_by_class": receiveFramesByClass, "receive_packets_by_class": receivePacketsByClass, "receive_frames_by_stream_id": receiveFramesByStream, "receive_packets_by_stream_id": receivePacketsByStream, } } func (t *FabricSessionPacketTransport) recordCloseStream() { if t == nil { return } t.statsMu.Lock() t.closeStreamFrames++ t.statsMu.Unlock() } func (t *FabricSessionPacketTransport) recordCloseError() { if t == nil { return } t.statsMu.Lock() t.closeErrors++ t.statsMu.Unlock() } func fabricSessionTrafficClassForPackets(fallback string, packets [][]byte) string { if fallback = normalizeFabricTrafficClass(fallback); fallback != "" && fallback != FabricTrafficClassBulk { return fallback } if batchHasTCPControlPacket(packets) { return FabricTrafficClassInteractive } return FabricTrafficClassBulk } func fabricSessionTrafficClassName(value fabricproto.TrafficClass) string { switch value { case fabricproto.TrafficClassControl: return FabricTrafficClassControl case fabricproto.TrafficClassInteractive: return FabricTrafficClassInteractive case fabricproto.TrafficClassReliable: return FabricTrafficClassReliable case fabricproto.TrafficClassDroppable: return FabricTrafficClassDroppable default: return FabricTrafficClassBulk } } func copyStringUint64Map(values map[string]uint64) map[string]uint64 { if len(values) == 0 { return map[string]uint64{} } out := make(map[string]uint64, len(values)) for key, value := range values { out[key] = value } return out } func copyStreamIDsByTrafficClass(values map[string][]uint64) map[string][]uint64 { if len(values) == 0 { return map[string][]uint64{} } out := make(map[string][]uint64, len(values)) for key, ids := range values { out[key] = append([]uint64(nil), ids...) } return out } func countStreamIDs(values map[string][]uint64) int { total := 0 for _, ids := range values { total += len(ids) } return total } func countNonZeroStringUint64Values(values map[string]uint64) int { total := 0 for _, value := range values { if value > 0 { total++ } } return total }