Files
m 20d361a886
build / backend (push) Has been cancelled
build / node-agent (push) Has been cancelled
build / worker (push) Has been cancelled
рабочий вариант, но скороть 10 МБит
2026-05-22 21:46:49 +03:00

707 lines
21 KiB
Go

package vpnruntime
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/fabricproto"
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/mesh"
)
type FabricSessionFrameSender interface {
Send(context.Context, fabricproto.Frame) error
}
type FabricSessionFrameReceiver interface {
Frames() <-chan fabricproto.Frame
Errors() <-chan error
}
type FabricSessionCloser interface {
Close() error
}
type FabricSessionPacketTransport struct {
Sender FabricSessionFrameSender
Receiver FabricSessionFrameReceiver
Inbox *FabricPacketInbox
StreamID uint64
ServiceStreams *FabricServiceStreamRegistry
ServiceTunnel FabricServiceTunnel
TunnelID string
PoolID string
ServiceID string
VPNConnectionID string
SendDirection string
ReceiveDirection string
TrafficClass string
StreamIDsByTrafficClass map[string][]uint64
StreamIDs []uint64
routeMu sync.Mutex
routeLeaseID string
routeGeneration string
routeTransitionCount uint64
routeUpdatedAt time.Time
sequence uint64
sequenceMu sync.Mutex
sequenceByStream map[uint64]uint64
statsMu sync.Mutex
sendFramesByClass map[string]uint64
sendPacketsByClass map[string]uint64
sendFramesByStream map[uint64]uint64
sendPacketsByStream map[uint64]uint64
splitBatchCount uint64
lastBatchFrameCount uint64
maxBatchFrameCount uint64
receiveFramesByClass map[string]uint64
receivePacketsByClass map[string]uint64
receiveFramesByStream map[uint64]uint64
receivePacketsByStream map[uint64]uint64
closeStreamFrames uint64
closeErrors uint64
closeOnce sync.Once
closeErr error
}
func (t *FabricSessionPacketTransport) SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error {
packets = cleanPacketBatch(packets)
if len(packets) == 0 {
return nil
}
if t == nil || t.Sender == nil {
return mesh.ErrForwardRuntimeUnavailable
}
t.normalizeServiceTunnel()
packetTunnelID := t.packetTunnelID()
if t.VPNConnectionID == "" {
t.VPNConnectionID = packetTunnelID
}
if !t.hasSendStream() || packetTunnelID == "" {
return errors.New("fabric session packet transport identity is incomplete")
}
direction := t.SendDirection
if direction == "" {
direction = FabricDirectionGatewayToClient
}
groups := t.groupPacketsByStream(packets)
for _, group := range groups {
t.registerServiceStream(group.StreamID, group.TrafficClass, direction)
frame, err := NewFabricVPNPacketDataFrame(FabricVPNPacketFrameInput{
StreamID: group.StreamID,
Sequence: t.nextSequence(group.StreamID),
VPNConnectionID: packetTunnelID,
Direction: direction,
TrafficClass: group.TrafficClass,
ServiceTunnel: t.ServiceTunnel,
Packets: group.Packets,
})
if err != nil {
return err
}
if err := t.Sender.Send(ctx, frame); err != nil {
return err
}
t.recordSend(group.StreamID, group.TrafficClass, len(group.Packets))
}
t.recordBatchFanout(len(groups))
return nil
}
func (t *FabricSessionPacketTransport) ReceiveGatewayPacketBatch(ctx context.Context, timeout time.Duration) ([][]byte, error) {
if t == nil || t.Inbox == nil {
return nil, mesh.ErrForwardRuntimeUnavailable
}
t.normalizeServiceTunnel()
packetTunnelID := t.packetTunnelID()
direction := t.ReceiveDirection
if direction == "" {
direction = FabricDirectionClientToGateway
}
if packets, err := t.Inbox.Receive(ctx, packetTunnelID, direction, 5*time.Millisecond); err != nil || len(packets) > 0 {
return packets, err
}
if t.Receiver == nil {
return t.Inbox.Receive(ctx, packetTunnelID, direction, timeout)
}
if timeout <= 0 {
timeout = 25 * time.Second
}
timer := time.NewTimer(timeout)
defer timer.Stop()
frames := t.Receiver.Frames()
errorsCh := t.Receiver.Errors()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-timer.C:
return nil, nil
case err, ok := <-errorsCh:
if !ok {
errorsCh = nil
continue
}
if err != nil {
if packets, receiveErr := t.Inbox.Receive(ctx, packetTunnelID, direction, 100*time.Millisecond); receiveErr != nil || len(packets) > 0 {
return packets, receiveErr
}
return nil, err
}
case frame, ok := <-frames:
if !ok {
return t.Inbox.Receive(ctx, packetTunnelID, direction, 100*time.Millisecond)
}
if frame.Type != fabricproto.FrameData || !t.acceptsStream(frame.StreamID) {
continue
}
payload, err := DecodeFabricVPNPacketDataFrame(frame)
if err != nil {
return nil, err
}
if payload.VPNConnectionID == packetTunnelID && payload.Direction == direction {
t.recordReceive(frame.StreamID, fabricSessionTrafficClassName(frame.TrafficClass), len(payload.Packets))
return cleanPacketBatch(payload.Packets), nil
}
if err := t.deliverDecodedFabricSessionFrame(frame, payload); err != nil {
return nil, err
}
}
}
}
func (t *FabricSessionPacketTransport) RunFrameIngress(ctx context.Context) error {
if t == nil || t.Receiver == nil || t.Inbox == nil {
return mesh.ErrForwardRuntimeUnavailable
}
frames := t.Receiver.Frames()
errorsCh := t.Receiver.Errors()
for {
select {
case <-ctx.Done():
return ctx.Err()
case err, ok := <-errorsCh:
if !ok {
errorsCh = nil
continue
}
if err != nil {
return err
}
case frame, ok := <-frames:
if !ok {
return nil
}
if frame.Type != fabricproto.FrameData || !t.acceptsStream(frame.StreamID) {
continue
}
payload, err := DecodeFabricVPNPacketDataFrame(frame)
if err != nil {
return err
}
if err := t.deliverDecodedFabricSessionFrame(frame, payload); err != nil {
return err
}
}
}
}
func (t *FabricSessionPacketTransport) deliverDecodedFabricSessionFrame(frame fabricproto.Frame, payload mesh.VPNPacketBatchPayload) error {
if t == nil || t.Inbox == nil {
return mesh.ErrForwardRuntimeUnavailable
}
payload.Packets = cleanPacketBatch(payload.Packets)
if len(payload.Packets) == 0 {
return nil
}
t.recordReceive(frame.StreamID, fabricSessionTrafficClassName(frame.TrafficClass), len(payload.Packets))
return t.Inbox.enqueue(payload)
}
func (t *FabricSessionPacketTransport) Close() error {
if t == nil {
return nil
}
t.closeOnce.Do(func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
for _, streamID := range t.allStreamIDs() {
if t.Sender != nil {
if err := t.Sender.Send(ctx, fabricproto.Frame{
Type: fabricproto.FrameCloseStream,
StreamID: streamID,
}); err != nil {
t.recordCloseError()
if t.closeErr == nil {
t.closeErr = err
}
} else {
t.markServiceStreamClosed(streamID)
t.recordCloseStream()
}
}
}
if closer, ok := t.Sender.(FabricSessionCloser); ok {
if err := closer.Close(); err != nil {
t.recordCloseError()
if t.closeErr == nil {
t.closeErr = err
}
}
}
})
return t.closeErr
}
func (t *FabricSessionPacketTransport) selectStreamForPackets(packets [][]byte) (uint64, string) {
trafficClass := fabricSessionTrafficClassForPackets(t.TrafficClass, packets)
if ids := t.streamIDsForTrafficClass(trafficClass); len(ids) > 0 {
if len(ids) == 1 || len(packets) == 0 {
return ids[0], trafficClass
}
_, shard := classifyPacketFlow(packets[0], len(ids))
return ids[shard], trafficClass
}
if len(t.StreamIDs) > 0 {
if len(t.StreamIDs) == 1 || len(packets) == 0 {
return t.StreamIDs[0], trafficClass
}
_, shard := classifyPacketFlow(packets[0], len(t.StreamIDs))
return t.StreamIDs[shard], trafficClass
}
return t.StreamID, trafficClass
}
type fabricSessionPacketGroup struct {
StreamID uint64
TrafficClass string
Packets [][]byte
}
func (t *FabricSessionPacketTransport) groupPacketsByStream(packets [][]byte) []fabricSessionPacketGroup {
groups := []fabricSessionPacketGroup{}
indexByKey := map[string]int{}
for _, packet := range packets {
streamID, trafficClass := t.selectStreamForPackets([][]byte{packet})
key := fmt.Sprintf("%d\x00%s", streamID, trafficClass)
index, ok := indexByKey[key]
if !ok {
index = len(groups)
indexByKey[key] = index
groups = append(groups, fabricSessionPacketGroup{
StreamID: streamID,
TrafficClass: trafficClass,
})
}
groups[index].Packets = append(groups[index].Packets, packet)
}
return groups
}
func (t *FabricSessionPacketTransport) allStreamIDs() []uint64 {
if t == nil {
return nil
}
seen := map[uint64]struct{}{}
var out []uint64
add := func(streamID uint64) {
if streamID == 0 {
return
}
if _, ok := seen[streamID]; ok {
return
}
seen[streamID] = struct{}{}
out = append(out, streamID)
}
add(t.StreamID)
for _, streamID := range t.StreamIDs {
add(streamID)
}
for _, ids := range t.StreamIDsByTrafficClass {
for _, streamID := range ids {
add(streamID)
}
}
return out
}
func (t *FabricSessionPacketTransport) hasSendStream() bool {
if t == nil {
return false
}
if t.StreamID != 0 || len(t.StreamIDs) > 0 {
return true
}
for _, ids := range t.StreamIDsByTrafficClass {
if len(ids) > 0 {
return true
}
}
return false
}
func (t *FabricSessionPacketTransport) streamIDsForTrafficClass(trafficClass string) []uint64 {
if t == nil || len(t.StreamIDsByTrafficClass) == 0 {
return nil
}
if ids := t.StreamIDsByTrafficClass[normalizeFabricTrafficClass(trafficClass)]; len(ids) > 0 {
return ids
}
switch normalizeFabricTrafficClass(trafficClass) {
case FabricTrafficClassDNS:
if ids := t.StreamIDsByTrafficClass[FabricTrafficClassReliable]; len(ids) > 0 {
return ids
}
return t.StreamIDsByTrafficClass[FabricTrafficClassBulk]
case FabricTrafficClassReliable:
return t.StreamIDsByTrafficClass[FabricTrafficClassBulk]
}
return nil
}
func (t *FabricSessionPacketTransport) acceptsStream(streamID uint64) bool {
if t == nil || streamID == 0 {
return false
}
if t.StreamID != 0 && streamID == t.StreamID {
return true
}
for _, id := range t.StreamIDs {
if id == streamID {
return true
}
}
for _, ids := range t.StreamIDsByTrafficClass {
for _, id := range ids {
if id == streamID {
return true
}
}
}
return t.StreamID == 0 && len(t.StreamIDs) == 0 && len(t.StreamIDsByTrafficClass) == 0
}
func (t *FabricSessionPacketTransport) nextSequence(streamID uint64) uint64 {
if streamID == 0 {
return atomic.AddUint64(&t.sequence, 1)
}
t.sequenceMu.Lock()
defer t.sequenceMu.Unlock()
if t.sequenceByStream == nil {
t.sequenceByStream = map[uint64]uint64{}
}
t.sequenceByStream[streamID]++
return t.sequenceByStream[streamID]
}
func (t *FabricSessionPacketTransport) recordSend(streamID uint64, trafficClass string, packetCount int) {
if t == nil {
return
}
trafficClass = normalizeFabricTrafficClass(trafficClass)
t.statsMu.Lock()
defer t.statsMu.Unlock()
if t.sendFramesByClass == nil {
t.sendFramesByClass = map[string]uint64{}
}
if t.sendPacketsByClass == nil {
t.sendPacketsByClass = map[string]uint64{}
}
if t.sendFramesByStream == nil {
t.sendFramesByStream = map[uint64]uint64{}
}
if t.sendPacketsByStream == nil {
t.sendPacketsByStream = map[uint64]uint64{}
}
t.sendFramesByClass[trafficClass]++
t.sendPacketsByClass[trafficClass] += uint64(packetCount)
t.sendFramesByStream[streamID]++
t.sendPacketsByStream[streamID] += uint64(packetCount)
}
func (t *FabricSessionPacketTransport) recordBatchFanout(frameCount int) {
if t == nil || frameCount <= 0 {
return
}
t.statsMu.Lock()
defer t.statsMu.Unlock()
t.lastBatchFrameCount = uint64(frameCount)
if frameCount > 1 {
t.splitBatchCount++
}
if uint64(frameCount) > t.maxBatchFrameCount {
t.maxBatchFrameCount = uint64(frameCount)
}
}
func (t *FabricSessionPacketTransport) recordReceive(streamID uint64, trafficClass string, packetCount int) {
if t == nil {
return
}
trafficClass = normalizeFabricTrafficClass(trafficClass)
t.statsMu.Lock()
defer t.statsMu.Unlock()
if t.receiveFramesByClass == nil {
t.receiveFramesByClass = map[string]uint64{}
}
if t.receivePacketsByClass == nil {
t.receivePacketsByClass = map[string]uint64{}
}
if t.receiveFramesByStream == nil {
t.receiveFramesByStream = map[uint64]uint64{}
}
if t.receivePacketsByStream == nil {
t.receivePacketsByStream = map[uint64]uint64{}
}
t.receiveFramesByClass[trafficClass]++
t.receivePacketsByClass[trafficClass] += uint64(packetCount)
t.receiveFramesByStream[streamID]++
t.receivePacketsByStream[streamID] += uint64(packetCount)
}
func (t *FabricSessionPacketTransport) Snapshot() map[string]any {
if t == nil {
return nil
}
t.normalizeServiceTunnel()
t.statsMu.Lock()
sendFramesByClass := copyStringUint64Map(t.sendFramesByClass)
sendPacketsByClass := copyStringUint64Map(t.sendPacketsByClass)
receiveFramesByClass := copyStringUint64Map(t.receiveFramesByClass)
receivePacketsByClass := copyStringUint64Map(t.receivePacketsByClass)
lastBatchFrameCount := t.lastBatchFrameCount
maxBatchFrameCount := t.maxBatchFrameCount
splitBatchCount := t.splitBatchCount
closeStreamFrames := t.closeStreamFrames
closeErrors := t.closeErrors
sendFramesByStream := make(map[string]uint64, len(t.sendFramesByStream))
for streamID, count := range t.sendFramesByStream {
sendFramesByStream[fmt.Sprintf("%d", streamID)] = count
}
sendPacketsByStream := make(map[string]uint64, len(t.sendPacketsByStream))
for streamID, count := range t.sendPacketsByStream {
sendPacketsByStream[fmt.Sprintf("%d", streamID)] = count
}
receiveFramesByStream := make(map[string]uint64, len(t.receiveFramesByStream))
for streamID, count := range t.receiveFramesByStream {
receiveFramesByStream[fmt.Sprintf("%d", streamID)] = count
}
receivePacketsByStream := make(map[string]uint64, len(t.receivePacketsByStream))
for streamID, count := range t.receivePacketsByStream {
receivePacketsByStream[fmt.Sprintf("%d", streamID)] = count
}
t.statsMu.Unlock()
t.routeMu.Lock()
routeLeaseID := firstNonEmptyTunnelString(t.routeLeaseID, t.ServiceTunnel.RouteLeaseID)
routeGeneration := firstNonEmptyTunnelString(t.routeGeneration, t.ServiceTunnel.RouteGeneration)
routeTransitionCount := t.routeTransitionCount
routeUpdatedAt := t.routeUpdatedAt
t.routeMu.Unlock()
streamIDsByClass := copyStreamIDsByTrafficClass(t.StreamIDsByTrafficClass)
out := map[string]any{
"schema_version": "rap.vpn_fabric_session_packet_transport.v1",
"tunnel_id": t.packetTunnelID(),
"pool_id": t.PoolID,
"service_id": t.ServiceID,
"route_lease_id": routeLeaseID,
"route_generation": routeGeneration,
"route_transition_count": routeTransitionCount,
"vpn_connection_id_alias": t.VPNConnectionID,
"service_tunnel": t.ServiceTunnel.Snapshot(),
"stream_id": t.StreamID,
"stream_ids_by_class": streamIDsByClass,
"stream_class_count": len(streamIDsByClass),
"stream_shard_count": countStreamIDs(streamIDsByClass) + len(t.StreamIDs),
"send_class_count": countNonZeroStringUint64Values(sendFramesByClass),
"send_stream_count": countNonZeroStringUint64Values(sendFramesByStream),
"sharding_active": len(streamIDsByClass) > 1 || countStreamIDs(streamIDsByClass)+len(t.StreamIDs) > 1,
"split_batch_count": splitBatchCount,
"last_batch_frame_count": lastBatchFrameCount,
"max_batch_frame_count": maxBatchFrameCount,
"close_stream_frames": closeStreamFrames,
"close_errors": closeErrors,
"send_frames_by_class": sendFramesByClass,
"send_packets_by_class": sendPacketsByClass,
"send_frames_by_stream_id": sendFramesByStream,
"send_packets_by_stream_id": sendPacketsByStream,
"receive_frames_by_class": receiveFramesByClass,
"receive_packets_by_class": receivePacketsByClass,
"receive_frames_by_stream_id": receiveFramesByStream,
"receive_packets_by_stream_id": receivePacketsByStream,
}
if t.ServiceStreams != nil {
out["service_stream_registry"] = t.ServiceStreams.Snapshot()
out["service_streams"] = serviceStreamsSnapshotItems(t.ServiceStreams.StreamsForTunnel(t.packetTunnelID()))
}
if !routeUpdatedAt.IsZero() {
out["route_updated_at"] = routeUpdatedAt.UTC().Format(time.RFC3339Nano)
}
return out
}
func (t *FabricSessionPacketTransport) UpdateServiceTunnel(tunnel FabricServiceTunnel) (bool, error) {
if t == nil {
return false, mesh.ErrForwardRuntimeUnavailable
}
currentID := t.packetTunnelID()
tunnel = NormalizeServiceTunnel(tunnel, currentID)
if currentID != "" && tunnel.TunnelID != "" && tunnel.TunnelID != currentID {
return false, fmt.Errorf("service tunnel id changed from %q to %q", currentID, tunnel.TunnelID)
}
t.routeMu.Lock()
defer t.routeMu.Unlock()
previousLeaseID := firstNonEmptyTunnelString(t.routeLeaseID, t.ServiceTunnel.RouteLeaseID)
previousGeneration := firstNonEmptyTunnelString(t.routeGeneration, t.ServiceTunnel.RouteGeneration)
changed := previousLeaseID != tunnel.RouteLeaseID || previousGeneration != tunnel.RouteGeneration
t.ServiceTunnel = tunnel
t.TunnelID = firstNonEmptyTunnelString(t.TunnelID, tunnel.TunnelID)
t.PoolID = firstNonEmptyTunnelString(tunnel.PoolID, t.PoolID)
t.ServiceID = firstNonEmptyTunnelString(tunnel.ServiceID, t.ServiceID)
t.routeLeaseID = tunnel.RouteLeaseID
t.routeGeneration = tunnel.RouteGeneration
if changed {
t.routeTransitionCount++
t.routeUpdatedAt = time.Now().UTC()
}
return changed, nil
}
func (t *FabricSessionPacketTransport) normalizeServiceTunnel() {
if t == nil {
return
}
fallbackID := firstNonEmptyTunnelString(t.ServiceTunnel.TunnelID, t.TunnelID, t.VPNConnectionID)
t.ServiceTunnel = NormalizeServiceTunnel(t.ServiceTunnel, fallbackID)
t.TunnelID = firstNonEmptyTunnelString(t.TunnelID, t.ServiceTunnel.TunnelID)
t.PoolID = firstNonEmptyTunnelString(t.PoolID, t.ServiceTunnel.PoolID)
t.ServiceID = firstNonEmptyTunnelString(t.ServiceID, t.ServiceTunnel.ServiceID)
t.routeMu.Lock()
if t.routeLeaseID == "" {
t.routeLeaseID = t.ServiceTunnel.RouteLeaseID
}
if t.routeGeneration == "" {
t.routeGeneration = t.ServiceTunnel.RouteGeneration
}
t.routeMu.Unlock()
}
func (t *FabricSessionPacketTransport) packetTunnelID() string {
if t == nil {
return ""
}
return firstNonEmptyTunnelString(t.ServiceTunnel.TunnelID, t.TunnelID, t.VPNConnectionID)
}
func (t *FabricSessionPacketTransport) registerServiceStream(streamID uint64, trafficClass string, direction string) {
if t == nil || t.ServiceStreams == nil || streamID == 0 {
return
}
t.normalizeServiceTunnel()
t.ServiceStreams.Register(FabricServiceStream{
TunnelID: t.packetTunnelID(),
ServiceID: t.ServiceID,
StreamID: streamID,
TrafficClass: trafficClass,
Direction: direction,
ServiceTunnel: t.ServiceTunnel,
Metadata: map[string]string{
"adapter": "vpn",
},
})
}
func (t *FabricSessionPacketTransport) markServiceStreamClosed(streamID uint64) {
if t == nil || t.ServiceStreams == nil || streamID == 0 {
return
}
t.ServiceStreams.MarkClosed(t.packetTunnelID(), streamID)
}
func (t *FabricSessionPacketTransport) recordCloseStream() {
if t == nil {
return
}
t.statsMu.Lock()
t.closeStreamFrames++
t.statsMu.Unlock()
}
func (t *FabricSessionPacketTransport) recordCloseError() {
if t == nil {
return
}
t.statsMu.Lock()
t.closeErrors++
t.statsMu.Unlock()
}
func fabricSessionTrafficClassForPackets(fallback string, packets [][]byte) string {
if fallback = normalizeFabricTrafficClass(fallback); fallback != "" {
return fallback
}
return FabricTrafficClassBulk
}
func fabricSessionTrafficClassName(value fabricproto.TrafficClass) string {
switch value {
case fabricproto.TrafficClassControl:
return FabricTrafficClassControl
case fabricproto.TrafficClassInteractive:
return FabricTrafficClassInteractive
case fabricproto.TrafficClassReliable:
return FabricTrafficClassReliable
case fabricproto.TrafficClassDroppable:
return FabricTrafficClassDroppable
default:
return FabricTrafficClassBulk
}
}
func copyStringUint64Map(values map[string]uint64) map[string]uint64 {
if len(values) == 0 {
return map[string]uint64{}
}
out := make(map[string]uint64, len(values))
for key, value := range values {
out[key] = value
}
return out
}
func copyStreamIDsByTrafficClass(values map[string][]uint64) map[string][]uint64 {
if len(values) == 0 {
return map[string][]uint64{}
}
out := make(map[string][]uint64, len(values))
for key, ids := range values {
out[key] = append([]uint64(nil), ids...)
}
return out
}
func countStreamIDs(values map[string][]uint64) int {
total := 0
for _, ids := range values {
total += len(ids)
}
return total
}
func countNonZeroStringUint64Values(values map[string]uint64) int {
total := 0
for _, value := range values {
if value > 0 {
total++
}
}
return total
}