Refactor RDP proxy handling and update related tests
This commit is contained in:
@@ -0,0 +1,504 @@
|
||||
package fabricvpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/fabricproto"
|
||||
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/mesh"
|
||||
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/vpnruntime"
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
type endpointConfig struct {
|
||||
EndpointID string `json:"endpoint_id"`
|
||||
NodeID string `json:"node_id"`
|
||||
Transport string `json:"transport"`
|
||||
Address string `json:"address"`
|
||||
PeerCertSHA256 string `json:"peer_cert_sha256"`
|
||||
TLSCertSHA256 string `json:"tls_cert_sha256"`
|
||||
Priority int `json:"priority"`
|
||||
}
|
||||
|
||||
type runtimeConfig struct {
|
||||
ClusterID string `json:"cluster_id"`
|
||||
LocalNodeID string `json:"local_node_id"`
|
||||
ExitNodeID string `json:"exit_node_id"`
|
||||
VPNConnectionID string `json:"vpn_connection_id"`
|
||||
Endpoints []endpointConfig `json:"endpoints"`
|
||||
RouteBundle routeBundleConfig `json:"route_bundle"`
|
||||
ServiceChannelRequest serviceChannelRequest `json:"service_channel_request"`
|
||||
StreamShards int `json:"stream_shards"`
|
||||
}
|
||||
|
||||
type routeBundleConfig struct {
|
||||
SchemaVersion string `json:"schema_version"`
|
||||
RouteAuthority string `json:"route_authority"`
|
||||
SelectedTargetNode string `json:"selected_target_node_id"`
|
||||
EndpointCandidates []endpointConfig `json:"endpoint_candidates"`
|
||||
TargetCandidates []endpointConfig `json:"target_candidates"`
|
||||
RouteLease routeLeaseConfig `json:"route_lease"`
|
||||
}
|
||||
|
||||
type routeLeaseConfig struct {
|
||||
SchemaVersion string `json:"schema_version"`
|
||||
LeaseID string `json:"lease_id"`
|
||||
SelectedTargetNode string `json:"selected_target_node"`
|
||||
PrimaryPath routeLeasePath `json:"primary_path"`
|
||||
WarmStandbyPaths []routeLeasePath `json:"warm_standby_paths"`
|
||||
Multipath map[string]any `json:"multipath"`
|
||||
RebuildPolicy map[string]any `json:"rebuild_policy"`
|
||||
}
|
||||
|
||||
type routeLeasePath struct {
|
||||
PathID string `json:"path_id"`
|
||||
TargetNodeID string `json:"target_node_id"`
|
||||
Status string `json:"status"`
|
||||
EndpointCandidates []endpointConfig `json:"endpoint_candidates"`
|
||||
}
|
||||
|
||||
type serviceChannelRequest struct {
|
||||
SchemaVersion string `json:"schema_version"`
|
||||
ChannelID string `json:"channel_id"`
|
||||
ServiceClass string `json:"service_class"`
|
||||
SourceRole string `json:"source_role"`
|
||||
}
|
||||
|
||||
type SocketProtector interface {
|
||||
Protect(fd int64) bool
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
opMu sync.Mutex
|
||||
mu sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
transport *mesh.QUICFabricTransport
|
||||
session mesh.FabricTransportSession
|
||||
packet *vpnruntime.FabricSessionPacketTransport
|
||||
inbox *vpnruntime.FabricPacketInbox
|
||||
cfg runtimeConfig
|
||||
lastErr string
|
||||
endpoint string
|
||||
protector SocketProtector
|
||||
|
||||
uplinkPackets atomic.Uint64
|
||||
uplinkBytes atomic.Uint64
|
||||
downlinkPackets atomic.Uint64
|
||||
downlinkBytes atomic.Uint64
|
||||
}
|
||||
|
||||
func NewManager() *Manager {
|
||||
return &Manager{}
|
||||
}
|
||||
|
||||
func (m *Manager) SetSocketProtector(protector SocketProtector) {
|
||||
m.mu.Lock()
|
||||
m.protector = protector
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *Manager) Start(configJSON string) error {
|
||||
var cfg runtimeConfig
|
||||
if err := json.Unmarshal([]byte(configJSON), &cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.ClusterID = strings.TrimSpace(cfg.ClusterID)
|
||||
cfg.LocalNodeID = strings.TrimSpace(cfg.LocalNodeID)
|
||||
cfg.ExitNodeID = strings.TrimSpace(cfg.ExitNodeID)
|
||||
cfg.VPNConnectionID = strings.TrimSpace(cfg.VPNConnectionID)
|
||||
cfg.Endpoints = fabricRuntimeEndpoints(cfg)
|
||||
cfg.ExitNodeID = firstNonEmpty(cfg.ExitNodeID, fabricRuntimeTargetNodeID(cfg))
|
||||
if cfg.ClusterID == "" || cfg.LocalNodeID == "" || cfg.VPNConnectionID == "" {
|
||||
return fmt.Errorf("cluster, local node and vpn connection id are required")
|
||||
}
|
||||
if strings.TrimSpace(cfg.ServiceChannelRequest.SchemaVersion) == "" {
|
||||
return fmt.Errorf("fabric service channel request is required")
|
||||
}
|
||||
if len(cfg.Endpoints) == 0 {
|
||||
return fmt.Errorf("fabric route lease has no QUIC candidates")
|
||||
}
|
||||
if cfg.StreamShards <= 0 {
|
||||
cfg.StreamShards = 4
|
||||
}
|
||||
if cfg.StreamShards > 32 {
|
||||
cfg.StreamShards = 32
|
||||
}
|
||||
|
||||
m.Stop()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
if err := m.connect(ctx, cfg, cancel); err != nil {
|
||||
cancel()
|
||||
m.setErr(err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func fabricRuntimeEndpoints(cfg runtimeConfig) []endpointConfig {
|
||||
if len(cfg.RouteBundle.RouteLease.PrimaryPath.EndpointCandidates) > 0 {
|
||||
return cfg.RouteBundle.RouteLease.PrimaryPath.EndpointCandidates
|
||||
}
|
||||
for _, path := range cfg.RouteBundle.RouteLease.WarmStandbyPaths {
|
||||
if len(path.EndpointCandidates) > 0 {
|
||||
return path.EndpointCandidates
|
||||
}
|
||||
}
|
||||
if len(cfg.RouteBundle.EndpointCandidates) > 0 {
|
||||
return cfg.RouteBundle.EndpointCandidates
|
||||
}
|
||||
if len(cfg.RouteBundle.TargetCandidates) > 0 {
|
||||
return cfg.RouteBundle.TargetCandidates
|
||||
}
|
||||
return cfg.Endpoints
|
||||
}
|
||||
|
||||
func fabricRuntimeTargetNodeID(cfg runtimeConfig) string {
|
||||
if cfg.RouteBundle.RouteLease.PrimaryPath.TargetNodeID != "" {
|
||||
return cfg.RouteBundle.RouteLease.PrimaryPath.TargetNodeID
|
||||
}
|
||||
if cfg.RouteBundle.RouteLease.SelectedTargetNode != "" {
|
||||
return cfg.RouteBundle.RouteLease.SelectedTargetNode
|
||||
}
|
||||
return cfg.RouteBundle.SelectedTargetNode
|
||||
}
|
||||
|
||||
func (m *Manager) connect(ctx context.Context, cfg runtimeConfig, cancel context.CancelFunc) error {
|
||||
quicTransport := mesh.NewQUICFabricTransport(nil)
|
||||
quicTransport.SetLocalPeerID(cfg.LocalNodeID)
|
||||
quicTransport.DialAddr = m.protectedQUICDialer()
|
||||
inbox := vpnruntime.NewFabricPacketInbox(4096)
|
||||
quicTransport.SetInboundHandlers(func(ctx context.Context, envelope mesh.ProductionEnvelope) (mesh.ProductionForwardResult, error) {
|
||||
if err := inbox.DeliverProductionEnvelope(ctx, envelope); err != nil {
|
||||
return mesh.ProductionForwardResult{}, err
|
||||
}
|
||||
return mesh.ProductionForwardResult{Delivered: true, MessageID: envelope.MessageID}, nil
|
||||
}, nil, nil)
|
||||
|
||||
var lastErr error
|
||||
for _, endpoint := range cfg.Endpoints {
|
||||
target := mesh.FabricTransportTarget{
|
||||
EndpointID: firstNonEmpty(endpoint.EndpointID, endpoint.Address),
|
||||
PeerID: firstNonEmpty(endpoint.NodeID, cfg.ExitNodeID),
|
||||
Endpoint: endpoint.Address,
|
||||
Transport: firstNonEmpty(endpoint.Transport, "direct_quic"),
|
||||
PeerCertSHA256: firstNonEmpty(endpoint.PeerCertSHA256, endpoint.TLSCertSHA256),
|
||||
Timeout: 5 * time.Second,
|
||||
OutboundBuffer: 512,
|
||||
InboundBuffer: 512,
|
||||
ErrorBuffer: 32,
|
||||
}
|
||||
carrier, selected, err := mesh.FabricTransportForTarget(target, quicTransport)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
dialCtx, dialCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
session, err := carrier.Connect(dialCtx, selected)
|
||||
if err != nil {
|
||||
dialCancel()
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
streamIDs, streamID, err := openStreams(dialCtx, session, cfg.StreamShards)
|
||||
dialCancel()
|
||||
if err != nil {
|
||||
_ = session.Close()
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.cancel = cancel
|
||||
m.transport = quicTransport
|
||||
m.session = session
|
||||
m.inbox = inbox
|
||||
m.cfg = cfg
|
||||
m.endpoint = endpoint.Address
|
||||
m.lastErr = ""
|
||||
m.packet = &vpnruntime.FabricSessionPacketTransport{
|
||||
Sender: session,
|
||||
Receiver: session,
|
||||
Inbox: inbox,
|
||||
StreamID: streamID,
|
||||
StreamIDsByTrafficClass: streamIDs,
|
||||
VPNConnectionID: cfg.VPNConnectionID,
|
||||
SendDirection: vpnruntime.FabricDirectionClientToGateway,
|
||||
ReceiveDirection: vpnruntime.FabricDirectionGatewayToClient,
|
||||
}
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
if lastErr == nil {
|
||||
lastErr = fmt.Errorf("no QUIC exit endpoints available")
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func (m *Manager) protectedQUICDialer() func(context.Context, string, *tls.Config, *quic.Config) (*quic.Conn, error) {
|
||||
m.mu.Lock()
|
||||
protector := m.protector
|
||||
m.mu.Unlock()
|
||||
if protector == nil {
|
||||
return nil
|
||||
}
|
||||
return func(ctx context.Context, endpoint string, tlsConfig *tls.Config, config *quic.Config) (*quic.Conn, error) {
|
||||
network := "udp4"
|
||||
if strings.Contains(endpoint, "[") {
|
||||
network = "udp6"
|
||||
}
|
||||
conn, err := net.ListenPacket(network, ":0")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raw, ok := conn.(interface {
|
||||
SyscallConn() (syscall.RawConn, error)
|
||||
})
|
||||
if !ok {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("udp socket does not expose raw connection for vpn protection")
|
||||
}
|
||||
rawConn, err := raw.SyscallConn()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
var protectErr error
|
||||
if err := rawConn.Control(func(fd uintptr) {
|
||||
if !protector.Protect(int64(fd)) {
|
||||
protectErr = fmt.Errorf("android vpn socket protect failed")
|
||||
}
|
||||
}); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
if protectErr != nil {
|
||||
_ = conn.Close()
|
||||
return nil, protectErr
|
||||
}
|
||||
return mesh.DialQUICAddrWithPacketConn(ctx, endpoint, conn, tlsConfig, config)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Stop() {
|
||||
m.opMu.Lock()
|
||||
defer m.opMu.Unlock()
|
||||
m.stopLocked()
|
||||
}
|
||||
|
||||
func (m *Manager) stopLocked() {
|
||||
m.mu.Lock()
|
||||
cancel := m.cancel
|
||||
session := m.session
|
||||
transport := m.transport
|
||||
m.cancel = nil
|
||||
m.session = nil
|
||||
m.transport = nil
|
||||
m.packet = nil
|
||||
m.mu.Unlock()
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
if session != nil {
|
||||
_ = session.Close()
|
||||
}
|
||||
if transport != nil {
|
||||
_ = transport.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) SendPacket(packet []byte) error {
|
||||
if len(packet) == 0 {
|
||||
return nil
|
||||
}
|
||||
m.opMu.Lock()
|
||||
defer m.opMu.Unlock()
|
||||
if err := m.ensureConnectedLocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
m.mu.Lock()
|
||||
transport := m.packet
|
||||
m.mu.Unlock()
|
||||
if transport == nil {
|
||||
return fmt.Errorf("fabric vpn runtime is not connected")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := transport.SendGatewayPacketBatch(ctx, [][]byte{append([]byte(nil), packet...)}); err != nil {
|
||||
m.setErr(err)
|
||||
if reconnectErr := m.reconnectLocked(); reconnectErr != nil {
|
||||
return err
|
||||
}
|
||||
m.mu.Lock()
|
||||
transport = m.packet
|
||||
m.mu.Unlock()
|
||||
if transport == nil {
|
||||
return err
|
||||
}
|
||||
retryCtx, retryCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer retryCancel()
|
||||
if retryErr := transport.SendGatewayPacketBatch(retryCtx, [][]byte{append([]byte(nil), packet...)}); retryErr != nil {
|
||||
m.setErr(retryErr)
|
||||
return retryErr
|
||||
}
|
||||
}
|
||||
m.uplinkPackets.Add(1)
|
||||
m.uplinkBytes.Add(uint64(len(packet)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) ReceivePacket(timeoutMillis int) ([]byte, error) {
|
||||
m.opMu.Lock()
|
||||
defer m.opMu.Unlock()
|
||||
if err := m.ensureConnectedLocked(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.mu.Lock()
|
||||
transport := m.packet
|
||||
m.mu.Unlock()
|
||||
if transport == nil {
|
||||
return nil, fmt.Errorf("fabric vpn runtime is not connected")
|
||||
}
|
||||
timeout := time.Duration(timeoutMillis) * time.Millisecond
|
||||
if timeout <= 0 {
|
||||
timeout = 100 * time.Millisecond
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout+time.Second)
|
||||
defer cancel()
|
||||
packets, err := transport.ReceiveGatewayPacketBatch(ctx, timeout)
|
||||
if err != nil {
|
||||
m.setErr(err)
|
||||
_ = m.reconnectLocked()
|
||||
return nil, err
|
||||
}
|
||||
if len(packets) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
packet := append([]byte(nil), packets[0]...)
|
||||
m.downlinkPackets.Add(1)
|
||||
m.downlinkBytes.Add(uint64(len(packet)))
|
||||
return packet, nil
|
||||
}
|
||||
|
||||
func (m *Manager) Reconnect() error {
|
||||
m.opMu.Lock()
|
||||
defer m.opMu.Unlock()
|
||||
return m.reconnectLocked()
|
||||
}
|
||||
|
||||
func (m *Manager) ensureConnectedLocked() error {
|
||||
m.mu.Lock()
|
||||
connected := m.packet != nil
|
||||
cancel := m.cancel
|
||||
m.mu.Unlock()
|
||||
if connected {
|
||||
return nil
|
||||
}
|
||||
if cancel == nil {
|
||||
return fmt.Errorf("fabric vpn runtime is stopped")
|
||||
}
|
||||
return m.reconnectLocked()
|
||||
}
|
||||
|
||||
func (m *Manager) reconnectLocked() error {
|
||||
m.mu.Lock()
|
||||
cfg := m.cfg
|
||||
oldSession := m.session
|
||||
oldTransport := m.transport
|
||||
cancel := m.cancel
|
||||
m.session = nil
|
||||
m.transport = nil
|
||||
m.packet = nil
|
||||
m.mu.Unlock()
|
||||
if oldSession != nil {
|
||||
_ = oldSession.Close()
|
||||
}
|
||||
if oldTransport != nil {
|
||||
_ = oldTransport.Close()
|
||||
}
|
||||
if cancel == nil {
|
||||
return fmt.Errorf("fabric vpn runtime is stopped")
|
||||
}
|
||||
ctx, ctxCancel := context.WithTimeout(context.Background(), 8*time.Second)
|
||||
defer ctxCancel()
|
||||
if err := m.connect(ctx, cfg, cancel); err != nil {
|
||||
m.setErr(err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) SnapshotJSON() string {
|
||||
m.mu.Lock()
|
||||
connected := m.packet != nil
|
||||
endpoint := m.endpoint
|
||||
lastErr := m.lastErr
|
||||
vpnConnectionID := m.cfg.VPNConnectionID
|
||||
localNodeID := m.cfg.LocalNodeID
|
||||
exitNodeID := m.cfg.ExitNodeID
|
||||
m.mu.Unlock()
|
||||
payload, _ := json.Marshal(map[string]any{
|
||||
"schema_version": "rap.android_fabric_vpn_runtime.v1",
|
||||
"connected": connected,
|
||||
"endpoint": endpoint,
|
||||
"last_error": lastErr,
|
||||
"vpn_connection": vpnConnectionID,
|
||||
"local_node_id": localNodeID,
|
||||
"exit_node_id": exitNodeID,
|
||||
"uplink_packets": m.uplinkPackets.Load(),
|
||||
"uplink_bytes": m.uplinkBytes.Load(),
|
||||
"downlink_packets": m.downlinkPackets.Load(),
|
||||
"downlink_bytes": m.downlinkBytes.Load(),
|
||||
})
|
||||
return string(payload)
|
||||
}
|
||||
|
||||
func (m *Manager) setErr(err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.lastErr = err.Error()
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func openStreams(ctx context.Context, session mesh.FabricTransportSession, shards int) (map[string][]uint64, uint64, error) {
|
||||
base := uint64(time.Now().UnixNano())
|
||||
classes := []struct {
|
||||
name string
|
||||
trafficClass fabricproto.TrafficClass
|
||||
}{
|
||||
{name: vpnruntime.FabricTrafficClassInteractive, trafficClass: fabricproto.TrafficClassInteractive},
|
||||
{name: vpnruntime.FabricTrafficClassBulk, trafficClass: fabricproto.TrafficClassBulk},
|
||||
}
|
||||
out := make(map[string][]uint64, len(classes))
|
||||
var primary uint64
|
||||
for classIndex, class := range classes {
|
||||
for shard := 0; shard < shards; shard++ {
|
||||
streamID := base + uint64(classIndex*shards+shard)
|
||||
if err := session.Send(ctx, fabricproto.Frame{Type: fabricproto.FrameOpenStream, StreamID: streamID, TrafficClass: class.trafficClass}); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if primary == 0 {
|
||||
primary = streamID
|
||||
}
|
||||
out[class.name] = append(out[class.name], streamID)
|
||||
}
|
||||
}
|
||||
return out, primary, nil
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
package fabricvpn
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFabricRuntimeEndpointsPreferRouteBundle(t *testing.T) {
|
||||
cfg := runtimeConfig{
|
||||
Endpoints: []endpointConfig{{EndpointID: "legacy", Address: "quic://legacy.example:19131"}},
|
||||
RouteBundle: routeBundleConfig{
|
||||
EndpointCandidates: []endpointConfig{{EndpointID: "bundle", Address: "quic://bundle.example:19131"}},
|
||||
},
|
||||
}
|
||||
got := fabricRuntimeEndpoints(cfg)
|
||||
if len(got) != 1 || got[0].EndpointID != "bundle" {
|
||||
t.Fatalf("endpoints = %+v, want route bundle endpoint", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFabricRuntimeEndpointsPreferRouteLease(t *testing.T) {
|
||||
cfg := runtimeConfig{
|
||||
Endpoints: []endpointConfig{{EndpointID: "legacy", Address: "quic://legacy.example:19131"}},
|
||||
RouteBundle: routeBundleConfig{
|
||||
EndpointCandidates: []endpointConfig{{EndpointID: "bundle", Address: "quic://bundle.example:19131"}},
|
||||
RouteLease: routeLeaseConfig{
|
||||
SelectedTargetNode: "exit-1",
|
||||
PrimaryPath: routeLeasePath{
|
||||
TargetNodeID: "exit-1",
|
||||
EndpointCandidates: []endpointConfig{{EndpointID: "lease-primary", Address: "quic://lease.example:19131"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
got := fabricRuntimeEndpoints(cfg)
|
||||
if len(got) != 1 || got[0].EndpointID != "lease-primary" {
|
||||
t.Fatalf("endpoints = %+v, want route lease primary endpoint", got)
|
||||
}
|
||||
if target := fabricRuntimeTargetNodeID(cfg); target != "exit-1" {
|
||||
t.Fatalf("target = %q, want exit-1", target)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFabricRuntimeEndpointsFallbackToLegacyEndpoints(t *testing.T) {
|
||||
cfg := runtimeConfig{
|
||||
Endpoints: []endpointConfig{{EndpointID: "legacy", Address: "quic://legacy.example:19131"}},
|
||||
}
|
||||
got := fabricRuntimeEndpoints(cfg)
|
||||
if len(got) != 1 || got[0].EndpointID != "legacy" {
|
||||
t.Fatalf("endpoints = %+v, want legacy endpoint fallback", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLiveFabricVPNRuntimeStartsFromRouteLease(t *testing.T) {
|
||||
raw := os.Getenv("RAP_LIVE_FABRICVPN_CONFIG")
|
||||
if raw == "" {
|
||||
t.Skip("RAP_LIVE_FABRICVPN_CONFIG is not set")
|
||||
}
|
||||
manager := NewManager()
|
||||
if err := manager.Start(raw); err != nil {
|
||||
t.Fatalf("start live fabric vpn runtime: %v", err)
|
||||
}
|
||||
defer manager.Stop()
|
||||
if snapshot := manager.SnapshotJSON(); snapshot == "" {
|
||||
t.Fatal("empty live fabric vpn snapshot")
|
||||
}
|
||||
if os.Getenv("RAP_LIVE_FABRICVPN_PACKET_PROBE") == "" {
|
||||
return
|
||||
}
|
||||
if err := manager.SendPacket(testDNSIPv4Packet()); err != nil {
|
||||
t.Fatalf("send live dns packet: %v", err)
|
||||
}
|
||||
for i := 0; i < 20; i++ {
|
||||
packet, err := manager.ReceivePacket(500)
|
||||
if err != nil {
|
||||
t.Fatalf("receive live dns packet: %v", err)
|
||||
}
|
||||
if len(packet) > 0 {
|
||||
if packet[9] != 17 || packet[12] != 1 || packet[13] != 1 || packet[14] != 1 || packet[15] != 1 {
|
||||
t.Fatalf("unexpected response packet header: %v", packet[:min(20, len(packet))])
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatal("timed out waiting for live dns response through fabric vpn")
|
||||
}
|
||||
|
||||
func testDNSIPv4Packet() []byte {
|
||||
dns := []byte{
|
||||
0x12, 0x34, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x07, 'e', 'x', 'a',
|
||||
'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00,
|
||||
0x00, 0x01, 0x00, 0x01,
|
||||
}
|
||||
udpLen := 8 + len(dns)
|
||||
totalLen := 20 + udpLen
|
||||
packet := make([]byte, totalLen)
|
||||
packet[0] = 0x45
|
||||
packet[2] = byte(totalLen >> 8)
|
||||
packet[3] = byte(totalLen)
|
||||
packet[8] = 64
|
||||
packet[9] = 17
|
||||
copy(packet[12:16], []byte{10, 77, 0, 2})
|
||||
copy(packet[16:20], []byte{1, 1, 1, 1})
|
||||
packet[20] = 0xcf
|
||||
packet[21] = 0x08
|
||||
packet[22] = 0x00
|
||||
packet[23] = 0x35
|
||||
packet[24] = byte(udpLen >> 8)
|
||||
packet[25] = byte(udpLen)
|
||||
copy(packet[28:], dns)
|
||||
sum := ipv4HeaderChecksum(packet[:20])
|
||||
packet[10] = byte(sum >> 8)
|
||||
packet[11] = byte(sum)
|
||||
return packet
|
||||
}
|
||||
|
||||
func ipv4HeaderChecksum(header []byte) uint16 {
|
||||
var sum uint32
|
||||
for i := 0; i+1 < len(header); i += 2 {
|
||||
if i == 10 {
|
||||
continue
|
||||
}
|
||||
sum += uint32(header[i])<<8 | uint32(header[i+1])
|
||||
}
|
||||
for sum > 0xffff {
|
||||
sum = (sum & 0xffff) + (sum >> 16)
|
||||
}
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
Reference in New Issue
Block a user