Refactor RDP proxy handling and update related tests

This commit is contained in:
2026-05-17 20:38:35 +03:00
parent 8e9402580f
commit d551e57fd5
172 changed files with 22117 additions and 2509 deletions
@@ -0,0 +1,189 @@
package vpnruntime
import (
"context"
"sync"
"time"
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/fabricproto"
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/mesh"
)
type FabricSessionFrameWriter interface {
SendFrame(context.Context, fabricproto.Frame) error
}
type FabricSessionPacketPeerRegistry struct {
mu sync.RWMutex
peers map[string]FabricSessionPacketPeer
}
type FabricSessionPacketPeer struct {
VPNConnectionID string
Sender FabricSessionFrameWriter
StreamID uint64
StreamIDsByTrafficClass map[string][]uint64
RegisteredAt time.Time
LastPacketAt time.Time
}
type FabricSessionPacketPeerTransport struct {
Registry *FabricSessionPacketPeerRegistry
Inbox *FabricPacketInbox
VPNConnectionID string
}
func NewFabricSessionPacketPeerRegistry() *FabricSessionPacketPeerRegistry {
return &FabricSessionPacketPeerRegistry{peers: map[string]FabricSessionPacketPeer{}}
}
func (r *FabricSessionPacketPeerRegistry) RegisterFrame(ctx context.Context, sender FabricSessionFrameWriter, frame fabricproto.Frame) (bool, error) {
if r == nil || sender == nil || frame.Type != fabricproto.FrameData || frame.StreamID == 0 {
return false, nil
}
payload, err := DecodeFabricVPNPacketDataFrame(frame)
if err != nil {
return false, nil
}
if payload.VPNConnectionID == "" {
return false, nil
}
now := time.Now().UTC()
r.mu.Lock()
if r.peers == nil {
r.peers = map[string]FabricSessionPacketPeer{}
}
peer := r.peers[payload.VPNConnectionID]
if peer.RegisteredAt.IsZero() {
peer.RegisteredAt = now
}
peer.VPNConnectionID = payload.VPNConnectionID
peer.Sender = sender
peer.StreamID = frame.StreamID
peer.LastPacketAt = now
if peer.StreamIDsByTrafficClass == nil {
peer.StreamIDsByTrafficClass = map[string][]uint64{}
}
trafficClass := fabricSessionTrafficClassName(frame.TrafficClass)
if trafficClass != "" && !containsUint64(peer.StreamIDsByTrafficClass[trafficClass], frame.StreamID) {
peer.StreamIDsByTrafficClass[trafficClass] = append(peer.StreamIDsByTrafficClass[trafficClass], frame.StreamID)
}
r.peers[payload.VPNConnectionID] = peer
r.mu.Unlock()
return true, nil
}
func (r *FabricSessionPacketPeerRegistry) TransportFor(vpnConnectionID string, inbox *FabricPacketInbox) PacketTransport {
if r == nil || inbox == nil || vpnConnectionID == "" {
return nil
}
r.mu.RLock()
peer, ok := r.peers[vpnConnectionID]
r.mu.RUnlock()
if !ok || peer.Sender == nil || peer.StreamID == 0 {
return nil
}
return &FabricSessionPacketTransport{
Sender: fabricSessionFrameWriterAdapter{writer: peer.Sender},
Inbox: inbox,
StreamID: peer.StreamID,
StreamIDsByTrafficClass: copyStreamIDsByClass(peer.StreamIDsByTrafficClass),
VPNConnectionID: vpnConnectionID,
SendDirection: FabricDirectionGatewayToClient,
ReceiveDirection: FabricDirectionClientToGateway,
}
}
func (t *FabricSessionPacketPeerTransport) SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error {
if t == nil || t.Registry == nil || t.Inbox == nil || t.VPNConnectionID == "" {
return mesh.ErrForwardRuntimeUnavailable
}
transport := t.Registry.TransportFor(t.VPNConnectionID, t.Inbox)
if transport == nil {
return mesh.ErrForwardRuntimeUnavailable
}
return transport.SendGatewayPacketBatch(ctx, packets)
}
func (t *FabricSessionPacketPeerTransport) ReceiveGatewayPacketBatch(ctx context.Context, timeout time.Duration) ([][]byte, error) {
if t == nil || t.Inbox == nil || t.VPNConnectionID == "" {
return nil, mesh.ErrForwardRuntimeUnavailable
}
return t.Inbox.Receive(ctx, t.VPNConnectionID, FabricDirectionClientToGateway, timeout)
}
func (t *FabricSessionPacketPeerTransport) Snapshot() map[string]any {
if t == nil {
return map[string]any{
"transport": "fabric_session_peer_dynamic",
"peer_ready": false,
}
}
ready := 0
if t.Registry != nil {
if transport := t.Registry.TransportFor(t.VPNConnectionID, t.Inbox); transport != nil {
ready = 1
}
}
return map[string]any{
"transport": "fabric_session_peer_dynamic",
"vpn_connection_id": t.VPNConnectionID,
"peer_ready": ready == 1,
}
}
func (r *FabricSessionPacketPeerRegistry) Snapshot() map[string]any {
if r == nil {
return map[string]any{"ready": 0}
}
r.mu.RLock()
defer r.mu.RUnlock()
out := map[string]any{"ready": len(r.peers)}
items := make([]map[string]any, 0, len(r.peers))
for _, peer := range r.peers {
item := map[string]any{
"vpn_connection_id": peer.VPNConnectionID,
"stream_id": peer.StreamID,
}
if !peer.RegisteredAt.IsZero() {
item["registered_at"] = peer.RegisteredAt.Format(time.RFC3339Nano)
}
if !peer.LastPacketAt.IsZero() {
item["last_packet_at"] = peer.LastPacketAt.Format(time.RFC3339Nano)
}
items = append(items, item)
}
out["peers"] = items
return out
}
type fabricSessionFrameWriterAdapter struct {
writer FabricSessionFrameWriter
}
func (a fabricSessionFrameWriterAdapter) Send(ctx context.Context, frame fabricproto.Frame) error {
if a.writer == nil {
return mesh.ErrForwardRuntimeUnavailable
}
return a.writer.SendFrame(ctx, frame)
}
func containsUint64(values []uint64, value uint64) bool {
for _, item := range values {
if item == value {
return true
}
}
return false
}
func copyStreamIDsByClass(values map[string][]uint64) map[string][]uint64 {
if len(values) == 0 {
return nil
}
out := make(map[string][]uint64, len(values))
for key, ids := range values {
out[key] = append([]uint64(nil), ids...)
}
return out
}
@@ -130,11 +130,14 @@ func (t *FabricSessionPacketTransport) ReceiveGatewayPacketBatch(ctx context.Con
continue
}
if err != nil {
if packets, receiveErr := t.Inbox.Receive(ctx, t.VPNConnectionID, direction, 100*time.Millisecond); receiveErr != nil || len(packets) > 0 {
return packets, receiveErr
}
return nil, err
}
case frame, ok := <-frames:
if !ok {
return t.Inbox.Receive(ctx, t.VPNConnectionID, direction, 5*time.Millisecond)
return t.Inbox.Receive(ctx, t.VPNConnectionID, direction, 100*time.Millisecond)
}
if frame.Type != fabricproto.FrameData || !t.acceptsStream(frame.StreamID) {
continue
@@ -426,6 +426,59 @@ func TestFabricSessionPacketTransportRunFrameIngressDeliversInbox(t *testing.T)
}
}
func TestFabricSessionPacketPeerTransportSendsReplyToLatestRegisteredPeer(t *testing.T) {
inbox := NewFabricPacketInbox(4)
registry := NewFabricSessionPacketPeerRegistry()
sender := &recordingFrameSender{}
frame, err := NewFabricVPNPacketDataFrame(FabricVPNPacketFrameInput{
StreamID: 7,
VPNConnectionID: "vpn-1",
Direction: FabricDirectionClientToGateway,
Packets: [][]byte{[]byte("request")},
})
if err != nil {
t.Fatalf("frame: %v", err)
}
handled, err := registry.RegisterFrame(context.Background(), sender, frame)
if err != nil || !handled {
t.Fatalf("register frame handled=%v err=%v", handled, err)
}
if err := inbox.DeliverFabricSessionFrame(context.Background(), frame); err != nil {
t.Fatalf("deliver frame: %v", err)
}
transport := &FabricSessionPacketPeerTransport{
Registry: registry,
Inbox: inbox,
VPNConnectionID: "vpn-1",
}
requests, err := transport.ReceiveGatewayPacketBatch(context.Background(), time.Second)
if err != nil || len(requests) != 1 || string(requests[0]) != "request" {
t.Fatalf("requests=%q err=%v", requests, err)
}
if err := transport.SendGatewayPacketBatch(context.Background(), [][]byte{[]byte("reply")}); err != nil {
t.Fatalf("send reply: %v", err)
}
if len(sender.frames) != 1 {
t.Fatalf("sent frames = %d, want 1", len(sender.frames))
}
payload, err := DecodeFabricVPNPacketDataFrame(sender.frames[0])
if err != nil {
t.Fatalf("decode reply: %v", err)
}
if payload.Direction != FabricDirectionGatewayToClient || string(payload.Packets[0]) != "reply" {
t.Fatalf("reply payload = %+v", payload)
}
}
type recordingFrameSender struct {
frames []fabricproto.Frame
}
func (s *recordingFrameSender) SendFrame(_ context.Context, frame fabricproto.Frame) error {
s.frames = append(s.frames, frame)
return nil
}
func TestFabricSessionPacketTransportReceiveReadsPumpFrames(t *testing.T) {
inbox := NewFabricPacketInbox(4)
receiver := memoryFabricSessionReceiver{
@@ -169,6 +169,9 @@ func (g *Gateway) Snapshot() map[string]any {
out := map[string]any{
"running": running,
"service_role": "ipv4-egress",
"service_class": "vpn_packets",
"adapter_contract": "fabric_channel_to_ipv4_nat",
"transport": g.transportName(),
"poll_timeout_ms": g.PollTimeout.Milliseconds(),
"client_to_gateway_batches": g.clientToGatewayBatches.Load(),
@@ -234,14 +237,7 @@ func (g *Gateway) setStopped(err error) {
func (g *Gateway) normalize() error {
if g.Transport == nil {
if g.API == nil {
return fmt.Errorf("api client or packet transport is required")
}
g.Transport = BackendPacketTransport{
API: g.API,
ClusterID: g.ClusterID,
VPNConnectionID: g.VPNConnectionID,
}
return fmt.Errorf("fabric packet transport is required; backend packet relay fallback is disabled")
}
if g.ClusterID == "" || g.VPNConnectionID == "" {
return fmt.Errorf("cluster id and vpn connection id are required")
@@ -95,6 +95,30 @@ func TestGatewayRunClosesPacketTransportOnRuntimeError(t *testing.T) {
}
}
func TestGatewayNormalizeRejectsBackendPacketRelayFallback(t *testing.T) {
gateway := &Gateway{
API: nil,
ClusterID: "cluster-1",
VPNConnectionID: "vpn-1",
}
err := gateway.normalize()
if err == nil {
t.Fatal("normalize succeeded without a fabric packet transport")
}
if got, want := err.Error(), "fabric packet transport is required; backend packet relay fallback is disabled"; got != want {
t.Fatalf("normalize error = %q, want %q", got, want)
}
}
func TestGatewaySnapshotReportsIPv4EgressServiceAdapter(t *testing.T) {
gateway := &Gateway{Transport: &recordingGatewayTransport{}, VPNConnectionID: "vpn-1"}
snapshot := gateway.Snapshot()
if snapshot["service_role"] != "ipv4-egress" || snapshot["service_class"] != "vpn_packets" || snapshot["adapter_contract"] != "fabric_channel_to_ipv4_nat" {
t.Fatalf("unexpected gateway service snapshot: %#v", snapshot)
}
}
func TestGatewayUploadPrioritizesTCPControlPackets(t *testing.T) {
transport := &recordingGatewayTransport{}
gateway := &Gateway{Transport: transport, VPNConnectionID: "vpn-1"}