579 lines
16 KiB
Go
579 lines
16 KiB
Go
package vpnruntime
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/fabricproto"
|
|
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/mesh"
|
|
)
|
|
|
|
type FabricSessionFrameSender interface {
|
|
Send(context.Context, fabricproto.Frame) error
|
|
}
|
|
|
|
type FabricSessionFrameReceiver interface {
|
|
Frames() <-chan fabricproto.Frame
|
|
Errors() <-chan error
|
|
}
|
|
|
|
type FabricSessionCloser interface {
|
|
Close() error
|
|
}
|
|
|
|
type FabricSessionPacketTransport struct {
|
|
Sender FabricSessionFrameSender
|
|
Receiver FabricSessionFrameReceiver
|
|
Inbox *FabricPacketInbox
|
|
|
|
StreamID uint64
|
|
VPNConnectionID string
|
|
SendDirection string
|
|
ReceiveDirection string
|
|
TrafficClass string
|
|
|
|
StreamIDsByTrafficClass map[string][]uint64
|
|
StreamIDs []uint64
|
|
|
|
sequence uint64
|
|
sequenceMu sync.Mutex
|
|
sequenceByStream map[uint64]uint64
|
|
statsMu sync.Mutex
|
|
sendFramesByClass map[string]uint64
|
|
sendPacketsByClass map[string]uint64
|
|
sendFramesByStream map[uint64]uint64
|
|
sendPacketsByStream map[uint64]uint64
|
|
splitBatchCount uint64
|
|
lastBatchFrameCount uint64
|
|
maxBatchFrameCount uint64
|
|
receiveFramesByClass map[string]uint64
|
|
receivePacketsByClass map[string]uint64
|
|
receiveFramesByStream map[uint64]uint64
|
|
receivePacketsByStream map[uint64]uint64
|
|
closeStreamFrames uint64
|
|
closeErrors uint64
|
|
closeOnce sync.Once
|
|
closeErr error
|
|
}
|
|
|
|
func (t *FabricSessionPacketTransport) SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error {
|
|
packets = cleanPacketBatch(packets)
|
|
if len(packets) == 0 {
|
|
return nil
|
|
}
|
|
if t == nil || t.Sender == nil {
|
|
return mesh.ErrForwardRuntimeUnavailable
|
|
}
|
|
if !t.hasSendStream() || t.VPNConnectionID == "" {
|
|
return errors.New("fabric session packet transport identity is incomplete")
|
|
}
|
|
direction := t.SendDirection
|
|
if direction == "" {
|
|
direction = FabricDirectionGatewayToClient
|
|
}
|
|
groups := t.groupPacketsByStream(packets)
|
|
for _, group := range groups {
|
|
frame, err := NewFabricVPNPacketDataFrame(FabricVPNPacketFrameInput{
|
|
StreamID: group.StreamID,
|
|
Sequence: t.nextSequence(group.StreamID),
|
|
VPNConnectionID: t.VPNConnectionID,
|
|
Direction: direction,
|
|
TrafficClass: group.TrafficClass,
|
|
Packets: group.Packets,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := t.Sender.Send(ctx, frame); err != nil {
|
|
return err
|
|
}
|
|
t.recordSend(group.StreamID, group.TrafficClass, len(group.Packets))
|
|
}
|
|
t.recordBatchFanout(len(groups))
|
|
return nil
|
|
}
|
|
|
|
func (t *FabricSessionPacketTransport) ReceiveGatewayPacketBatch(ctx context.Context, timeout time.Duration) ([][]byte, error) {
|
|
if t == nil || t.Inbox == nil {
|
|
return nil, mesh.ErrForwardRuntimeUnavailable
|
|
}
|
|
direction := t.ReceiveDirection
|
|
if direction == "" {
|
|
direction = FabricDirectionClientToGateway
|
|
}
|
|
if packets, err := t.Inbox.Receive(ctx, t.VPNConnectionID, direction, 5*time.Millisecond); err != nil || len(packets) > 0 {
|
|
return packets, err
|
|
}
|
|
if t.Receiver == nil {
|
|
return t.Inbox.Receive(ctx, t.VPNConnectionID, direction, timeout)
|
|
}
|
|
if timeout <= 0 {
|
|
timeout = 25 * time.Second
|
|
}
|
|
timer := time.NewTimer(timeout)
|
|
defer timer.Stop()
|
|
frames := t.Receiver.Frames()
|
|
errorsCh := t.Receiver.Errors()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-timer.C:
|
|
return nil, nil
|
|
case err, ok := <-errorsCh:
|
|
if !ok {
|
|
errorsCh = nil
|
|
continue
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
case frame, ok := <-frames:
|
|
if !ok {
|
|
return t.Inbox.Receive(ctx, t.VPNConnectionID, direction, 5*time.Millisecond)
|
|
}
|
|
if frame.Type != fabricproto.FrameData || !t.acceptsStream(frame.StreamID) {
|
|
continue
|
|
}
|
|
payload, err := DecodeFabricVPNPacketDataFrame(frame)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if payload.VPNConnectionID == t.VPNConnectionID && payload.Direction == direction {
|
|
t.recordReceive(frame.StreamID, fabricSessionTrafficClassName(frame.TrafficClass), len(payload.Packets))
|
|
return cleanPacketBatch(payload.Packets), nil
|
|
}
|
|
if err := t.deliverDecodedFabricSessionFrame(frame, payload); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (t *FabricSessionPacketTransport) RunFrameIngress(ctx context.Context) error {
|
|
if t == nil || t.Receiver == nil || t.Inbox == nil {
|
|
return mesh.ErrForwardRuntimeUnavailable
|
|
}
|
|
frames := t.Receiver.Frames()
|
|
errorsCh := t.Receiver.Errors()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case err, ok := <-errorsCh:
|
|
if !ok {
|
|
errorsCh = nil
|
|
continue
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
case frame, ok := <-frames:
|
|
if !ok {
|
|
return nil
|
|
}
|
|
if frame.Type != fabricproto.FrameData || !t.acceptsStream(frame.StreamID) {
|
|
continue
|
|
}
|
|
payload, err := DecodeFabricVPNPacketDataFrame(frame)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := t.deliverDecodedFabricSessionFrame(frame, payload); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (t *FabricSessionPacketTransport) deliverDecodedFabricSessionFrame(frame fabricproto.Frame, payload mesh.VPNPacketBatchPayload) error {
|
|
if t == nil || t.Inbox == nil {
|
|
return mesh.ErrForwardRuntimeUnavailable
|
|
}
|
|
payload.Packets = cleanPacketBatch(payload.Packets)
|
|
if len(payload.Packets) == 0 {
|
|
return nil
|
|
}
|
|
t.recordReceive(frame.StreamID, fabricSessionTrafficClassName(frame.TrafficClass), len(payload.Packets))
|
|
return t.Inbox.enqueue(payload)
|
|
}
|
|
|
|
func (t *FabricSessionPacketTransport) Close() error {
|
|
if t == nil {
|
|
return nil
|
|
}
|
|
t.closeOnce.Do(func() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
for _, streamID := range t.allStreamIDs() {
|
|
if t.Sender != nil {
|
|
if err := t.Sender.Send(ctx, fabricproto.Frame{
|
|
Type: fabricproto.FrameCloseStream,
|
|
StreamID: streamID,
|
|
}); err != nil {
|
|
t.recordCloseError()
|
|
if t.closeErr == nil {
|
|
t.closeErr = err
|
|
}
|
|
} else if err == nil {
|
|
t.recordCloseStream()
|
|
}
|
|
}
|
|
}
|
|
if closer, ok := t.Sender.(FabricSessionCloser); ok {
|
|
if err := closer.Close(); err != nil {
|
|
t.recordCloseError()
|
|
if t.closeErr == nil {
|
|
t.closeErr = err
|
|
}
|
|
}
|
|
}
|
|
})
|
|
return t.closeErr
|
|
}
|
|
|
|
func (t *FabricSessionPacketTransport) selectStreamForPackets(packets [][]byte) (uint64, string) {
|
|
trafficClass := fabricSessionTrafficClassForPackets(t.TrafficClass, packets)
|
|
if ids := t.streamIDsForTrafficClass(trafficClass); len(ids) > 0 {
|
|
if len(ids) == 1 || len(packets) == 0 {
|
|
return ids[0], trafficClass
|
|
}
|
|
_, shard := classifyPacketFlow(packets[0], len(ids))
|
|
return ids[shard], trafficClass
|
|
}
|
|
if len(t.StreamIDs) > 0 {
|
|
if len(t.StreamIDs) == 1 || len(packets) == 0 {
|
|
return t.StreamIDs[0], trafficClass
|
|
}
|
|
_, shard := classifyPacketFlow(packets[0], len(t.StreamIDs))
|
|
return t.StreamIDs[shard], trafficClass
|
|
}
|
|
return t.StreamID, trafficClass
|
|
}
|
|
|
|
type fabricSessionPacketGroup struct {
|
|
StreamID uint64
|
|
TrafficClass string
|
|
Packets [][]byte
|
|
}
|
|
|
|
func (t *FabricSessionPacketTransport) groupPacketsByStream(packets [][]byte) []fabricSessionPacketGroup {
|
|
groups := []fabricSessionPacketGroup{}
|
|
indexByKey := map[string]int{}
|
|
for _, packet := range packets {
|
|
streamID, trafficClass := t.selectStreamForPackets([][]byte{packet})
|
|
key := fmt.Sprintf("%d\x00%s", streamID, trafficClass)
|
|
index, ok := indexByKey[key]
|
|
if !ok {
|
|
index = len(groups)
|
|
indexByKey[key] = index
|
|
groups = append(groups, fabricSessionPacketGroup{
|
|
StreamID: streamID,
|
|
TrafficClass: trafficClass,
|
|
})
|
|
}
|
|
groups[index].Packets = append(groups[index].Packets, packet)
|
|
}
|
|
return groups
|
|
}
|
|
|
|
func (t *FabricSessionPacketTransport) allStreamIDs() []uint64 {
|
|
if t == nil {
|
|
return nil
|
|
}
|
|
seen := map[uint64]struct{}{}
|
|
var out []uint64
|
|
add := func(streamID uint64) {
|
|
if streamID == 0 {
|
|
return
|
|
}
|
|
if _, ok := seen[streamID]; ok {
|
|
return
|
|
}
|
|
seen[streamID] = struct{}{}
|
|
out = append(out, streamID)
|
|
}
|
|
add(t.StreamID)
|
|
for _, streamID := range t.StreamIDs {
|
|
add(streamID)
|
|
}
|
|
for _, ids := range t.StreamIDsByTrafficClass {
|
|
for _, streamID := range ids {
|
|
add(streamID)
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (t *FabricSessionPacketTransport) hasSendStream() bool {
|
|
if t == nil {
|
|
return false
|
|
}
|
|
if t.StreamID != 0 || len(t.StreamIDs) > 0 {
|
|
return true
|
|
}
|
|
for _, ids := range t.StreamIDsByTrafficClass {
|
|
if len(ids) > 0 {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (t *FabricSessionPacketTransport) streamIDsForTrafficClass(trafficClass string) []uint64 {
|
|
if t == nil || len(t.StreamIDsByTrafficClass) == 0 {
|
|
return nil
|
|
}
|
|
if ids := t.StreamIDsByTrafficClass[normalizeFabricTrafficClass(trafficClass)]; len(ids) > 0 {
|
|
return ids
|
|
}
|
|
if normalizeFabricTrafficClass(trafficClass) == FabricTrafficClassReliable {
|
|
return t.StreamIDsByTrafficClass[FabricTrafficClassBulk]
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (t *FabricSessionPacketTransport) acceptsStream(streamID uint64) bool {
|
|
if t == nil || streamID == 0 {
|
|
return false
|
|
}
|
|
if t.StreamID != 0 && streamID == t.StreamID {
|
|
return true
|
|
}
|
|
for _, id := range t.StreamIDs {
|
|
if id == streamID {
|
|
return true
|
|
}
|
|
}
|
|
for _, ids := range t.StreamIDsByTrafficClass {
|
|
for _, id := range ids {
|
|
if id == streamID {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return t.StreamID == 0 && len(t.StreamIDs) == 0 && len(t.StreamIDsByTrafficClass) == 0
|
|
}
|
|
|
|
func (t *FabricSessionPacketTransport) nextSequence(streamID uint64) uint64 {
|
|
if streamID == 0 {
|
|
return atomic.AddUint64(&t.sequence, 1)
|
|
}
|
|
t.sequenceMu.Lock()
|
|
defer t.sequenceMu.Unlock()
|
|
if t.sequenceByStream == nil {
|
|
t.sequenceByStream = map[uint64]uint64{}
|
|
}
|
|
t.sequenceByStream[streamID]++
|
|
return t.sequenceByStream[streamID]
|
|
}
|
|
|
|
func (t *FabricSessionPacketTransport) recordSend(streamID uint64, trafficClass string, packetCount int) {
|
|
if t == nil {
|
|
return
|
|
}
|
|
trafficClass = normalizeFabricTrafficClass(trafficClass)
|
|
t.statsMu.Lock()
|
|
defer t.statsMu.Unlock()
|
|
if t.sendFramesByClass == nil {
|
|
t.sendFramesByClass = map[string]uint64{}
|
|
}
|
|
if t.sendPacketsByClass == nil {
|
|
t.sendPacketsByClass = map[string]uint64{}
|
|
}
|
|
if t.sendFramesByStream == nil {
|
|
t.sendFramesByStream = map[uint64]uint64{}
|
|
}
|
|
if t.sendPacketsByStream == nil {
|
|
t.sendPacketsByStream = map[uint64]uint64{}
|
|
}
|
|
t.sendFramesByClass[trafficClass]++
|
|
t.sendPacketsByClass[trafficClass] += uint64(packetCount)
|
|
t.sendFramesByStream[streamID]++
|
|
t.sendPacketsByStream[streamID] += uint64(packetCount)
|
|
}
|
|
|
|
func (t *FabricSessionPacketTransport) recordBatchFanout(frameCount int) {
|
|
if t == nil || frameCount <= 0 {
|
|
return
|
|
}
|
|
t.statsMu.Lock()
|
|
defer t.statsMu.Unlock()
|
|
t.lastBatchFrameCount = uint64(frameCount)
|
|
if frameCount > 1 {
|
|
t.splitBatchCount++
|
|
}
|
|
if uint64(frameCount) > t.maxBatchFrameCount {
|
|
t.maxBatchFrameCount = uint64(frameCount)
|
|
}
|
|
}
|
|
|
|
func (t *FabricSessionPacketTransport) recordReceive(streamID uint64, trafficClass string, packetCount int) {
|
|
if t == nil {
|
|
return
|
|
}
|
|
trafficClass = normalizeFabricTrafficClass(trafficClass)
|
|
t.statsMu.Lock()
|
|
defer t.statsMu.Unlock()
|
|
if t.receiveFramesByClass == nil {
|
|
t.receiveFramesByClass = map[string]uint64{}
|
|
}
|
|
if t.receivePacketsByClass == nil {
|
|
t.receivePacketsByClass = map[string]uint64{}
|
|
}
|
|
if t.receiveFramesByStream == nil {
|
|
t.receiveFramesByStream = map[uint64]uint64{}
|
|
}
|
|
if t.receivePacketsByStream == nil {
|
|
t.receivePacketsByStream = map[uint64]uint64{}
|
|
}
|
|
t.receiveFramesByClass[trafficClass]++
|
|
t.receivePacketsByClass[trafficClass] += uint64(packetCount)
|
|
t.receiveFramesByStream[streamID]++
|
|
t.receivePacketsByStream[streamID] += uint64(packetCount)
|
|
}
|
|
|
|
func (t *FabricSessionPacketTransport) Snapshot() map[string]any {
|
|
if t == nil {
|
|
return nil
|
|
}
|
|
t.statsMu.Lock()
|
|
sendFramesByClass := copyStringUint64Map(t.sendFramesByClass)
|
|
sendPacketsByClass := copyStringUint64Map(t.sendPacketsByClass)
|
|
receiveFramesByClass := copyStringUint64Map(t.receiveFramesByClass)
|
|
receivePacketsByClass := copyStringUint64Map(t.receivePacketsByClass)
|
|
lastBatchFrameCount := t.lastBatchFrameCount
|
|
maxBatchFrameCount := t.maxBatchFrameCount
|
|
splitBatchCount := t.splitBatchCount
|
|
closeStreamFrames := t.closeStreamFrames
|
|
closeErrors := t.closeErrors
|
|
sendFramesByStream := make(map[string]uint64, len(t.sendFramesByStream))
|
|
for streamID, count := range t.sendFramesByStream {
|
|
sendFramesByStream[fmt.Sprintf("%d", streamID)] = count
|
|
}
|
|
sendPacketsByStream := make(map[string]uint64, len(t.sendPacketsByStream))
|
|
for streamID, count := range t.sendPacketsByStream {
|
|
sendPacketsByStream[fmt.Sprintf("%d", streamID)] = count
|
|
}
|
|
receiveFramesByStream := make(map[string]uint64, len(t.receiveFramesByStream))
|
|
for streamID, count := range t.receiveFramesByStream {
|
|
receiveFramesByStream[fmt.Sprintf("%d", streamID)] = count
|
|
}
|
|
receivePacketsByStream := make(map[string]uint64, len(t.receivePacketsByStream))
|
|
for streamID, count := range t.receivePacketsByStream {
|
|
receivePacketsByStream[fmt.Sprintf("%d", streamID)] = count
|
|
}
|
|
t.statsMu.Unlock()
|
|
streamIDsByClass := copyStreamIDsByTrafficClass(t.StreamIDsByTrafficClass)
|
|
return map[string]any{
|
|
"schema_version": "rap.vpn_fabric_session_packet_transport.v1",
|
|
"stream_id": t.StreamID,
|
|
"stream_ids_by_class": streamIDsByClass,
|
|
"stream_class_count": len(streamIDsByClass),
|
|
"stream_shard_count": countStreamIDs(streamIDsByClass) + len(t.StreamIDs),
|
|
"send_class_count": countNonZeroStringUint64Values(sendFramesByClass),
|
|
"send_stream_count": countNonZeroStringUint64Values(sendFramesByStream),
|
|
"sharding_active": len(streamIDsByClass) > 1 || countStreamIDs(streamIDsByClass)+len(t.StreamIDs) > 1,
|
|
"split_batch_count": splitBatchCount,
|
|
"last_batch_frame_count": lastBatchFrameCount,
|
|
"max_batch_frame_count": maxBatchFrameCount,
|
|
"close_stream_frames": closeStreamFrames,
|
|
"close_errors": closeErrors,
|
|
"send_frames_by_class": sendFramesByClass,
|
|
"send_packets_by_class": sendPacketsByClass,
|
|
"send_frames_by_stream_id": sendFramesByStream,
|
|
"send_packets_by_stream_id": sendPacketsByStream,
|
|
"receive_frames_by_class": receiveFramesByClass,
|
|
"receive_packets_by_class": receivePacketsByClass,
|
|
"receive_frames_by_stream_id": receiveFramesByStream,
|
|
"receive_packets_by_stream_id": receivePacketsByStream,
|
|
}
|
|
}
|
|
|
|
func (t *FabricSessionPacketTransport) recordCloseStream() {
|
|
if t == nil {
|
|
return
|
|
}
|
|
t.statsMu.Lock()
|
|
t.closeStreamFrames++
|
|
t.statsMu.Unlock()
|
|
}
|
|
|
|
func (t *FabricSessionPacketTransport) recordCloseError() {
|
|
if t == nil {
|
|
return
|
|
}
|
|
t.statsMu.Lock()
|
|
t.closeErrors++
|
|
t.statsMu.Unlock()
|
|
}
|
|
|
|
func fabricSessionTrafficClassForPackets(fallback string, packets [][]byte) string {
|
|
if fallback = normalizeFabricTrafficClass(fallback); fallback != "" && fallback != FabricTrafficClassBulk {
|
|
return fallback
|
|
}
|
|
if batchHasTCPControlPacket(packets) {
|
|
return FabricTrafficClassInteractive
|
|
}
|
|
return FabricTrafficClassBulk
|
|
}
|
|
|
|
func fabricSessionTrafficClassName(value fabricproto.TrafficClass) string {
|
|
switch value {
|
|
case fabricproto.TrafficClassControl:
|
|
return FabricTrafficClassControl
|
|
case fabricproto.TrafficClassInteractive:
|
|
return FabricTrafficClassInteractive
|
|
case fabricproto.TrafficClassReliable:
|
|
return FabricTrafficClassReliable
|
|
case fabricproto.TrafficClassDroppable:
|
|
return FabricTrafficClassDroppable
|
|
default:
|
|
return FabricTrafficClassBulk
|
|
}
|
|
}
|
|
|
|
func copyStringUint64Map(values map[string]uint64) map[string]uint64 {
|
|
if len(values) == 0 {
|
|
return map[string]uint64{}
|
|
}
|
|
out := make(map[string]uint64, len(values))
|
|
for key, value := range values {
|
|
out[key] = value
|
|
}
|
|
return out
|
|
}
|
|
|
|
func copyStreamIDsByTrafficClass(values map[string][]uint64) map[string][]uint64 {
|
|
if len(values) == 0 {
|
|
return map[string][]uint64{}
|
|
}
|
|
out := make(map[string][]uint64, len(values))
|
|
for key, ids := range values {
|
|
out[key] = append([]uint64(nil), ids...)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func countStreamIDs(values map[string][]uint64) int {
|
|
total := 0
|
|
for _, ids := range values {
|
|
total += len(ids)
|
|
}
|
|
return total
|
|
}
|
|
|
|
func countNonZeroStringUint64Values(values map[string]uint64) int {
|
|
total := 0
|
|
for _, value := range values {
|
|
if value > 0 {
|
|
total++
|
|
}
|
|
}
|
|
return total
|
|
}
|