package fabricvpn import ( "os" "testing" "time" "github.com/example/remote-access-platform/agents/rap-node-agent/internal/vpnruntime" ) func TestFabricRuntimeEndpointsPreferRouteBundle(t *testing.T) { cfg := runtimeConfig{ Endpoints: []endpointConfig{{EndpointID: "compat", Address: "quic://compat.example:19131"}}, RouteBundle: routeBundleConfig{ EndpointCandidates: []endpointConfig{{EndpointID: "bundle", Address: "quic://bundle.example:19131"}}, }, } got := fabricRuntimeEndpoints(cfg) if len(got) != 1 || got[0].EndpointID != "bundle" { t.Fatalf("endpoints = %+v, want route bundle endpoint", got) } } func TestFabricRuntimeEndpointsPreferRouteLease(t *testing.T) { cfg := runtimeConfig{ Endpoints: []endpointConfig{{EndpointID: "compat", Address: "quic://compat.example:19131"}}, RouteBundle: routeBundleConfig{ EndpointCandidates: []endpointConfig{{EndpointID: "bundle", Address: "quic://bundle.example:19131"}}, RouteLease: routeLeaseConfig{ SelectedTargetNode: "exit-1", PrimaryPath: routeLeasePath{ TargetNodeID: "exit-1", EndpointCandidates: []endpointConfig{{EndpointID: "lease-primary", Address: "quic://lease.example:19131"}}, }, }, }, } got := fabricRuntimeEndpoints(cfg) if len(got) != 1 || got[0].EndpointID != "lease-primary" { t.Fatalf("endpoints = %+v, want route lease primary endpoint", got) } if target := fabricRuntimeTargetNodeID(cfg); target != "exit-1" { t.Fatalf("target = %q, want exit-1", target) } } func TestFabricRuntimePacketTargetIsLongLived(t *testing.T) { cfg := runtimeConfig{ RouteBundle: routeBundleConfig{RouteLease: routeLeaseConfig{ PrimaryPath: routeLeasePath{TargetNodeID: "exit-1"}, }}, } target := fabricRuntimePacketTarget(cfg, endpointConfig{ EndpointID: "exit-public", NodeID: "exit-1", Address: "quic://203.0.113.10:19131", Transport: "direct_quic", PeerCertSHA256: "abc123", }) if target.Timeout != 0 { t.Fatalf("packet target timeout = %s, want 0 for long-lived vpn stream", target.Timeout) } if target.PeerID != "exit-1" || target.Endpoint != "quic://203.0.113.10:19131" || target.PeerCertSHA256 != "abc123" { t.Fatalf("unexpected packet target: %+v", target) } } func TestServiceTunnelFromRuntimeConfigCarriesRouteEpoch(t *testing.T) { cfg := runtimeConfig{ TunnelID: "tunnel-1", PoolID: "home-ipv4", ServiceID: "svc-1", ServiceKind: "ipv4-tunnel", ServiceClass: "vpn_packets", RouteLeaseID: "lease-1", RouteGeneration: "route-gen-1", StreamShards: 8, } tunnel := serviceTunnelFromRuntimeConfig(cfg) if tunnel.RouteLeaseID != "lease-1" || tunnel.RouteGeneration != "route-gen-1" || tunnel.StreamShards != 8 { t.Fatalf("service tunnel route epoch = %+v", tunnel) } } func TestManagerUpdateRuntimeConfigKeepsTunnelAndUpdatesRouteEpoch(t *testing.T) { manager := NewManager() manager.cfg = runtimeConfig{ ClusterID: "cluster-1", LocalNodeID: "android-1", TunnelID: "tunnel-1", PoolID: "home-ipv4", ServiceID: "svc-1", ServiceKind: "ipv4-tunnel", ServiceClass: "vpn_packets", RouteLeaseID: "lease-1", RouteGeneration: "route-gen-1", StreamShards: 4, } manager.packet = &vpnruntime.FabricSessionPacketTransport{ TunnelID: "tunnel-1", ServiceTunnel: vpnruntime.FabricServiceTunnel{ TunnelID: "tunnel-1", PoolID: "home-ipv4", ServiceID: "svc-1", RouteLeaseID: "lease-1", RouteGeneration: "route-gen-1", }, } err := manager.UpdateRuntimeConfig(`{ "cluster_id":"cluster-1", "local_node_id":"android-1", "tunnel_id":"tunnel-1", "pool_id":"home-ipv4", "service_id":"svc-1", "service_kind":"ipv4-tunnel", "service_class":"vpn_packets", "route_lease_id":"lease-2", "route_generation":"route-gen-2", "stream_shards":4, "service_channel_request":{"schema_version":"rap.fabric_service_channel_request.v1"} }`) if err != nil { t.Fatalf("update runtime config: %v", err) } snapshot := manager.packet.Snapshot() if snapshot["route_lease_id"] != "lease-2" || snapshot["route_generation"] != "route-gen-2" || snapshot["route_transition_count"] != uint64(1) { t.Fatalf("packet route epoch not updated: %+v", snapshot) } if err := manager.UpdateRuntimeConfig(`{"tunnel_id":"other-tunnel"}`); err == nil { t.Fatal("expected changed tunnel id to be rejected") } } func TestRuntimeRouteReconnectDecisionTracksTargetAndEndpoints(t *testing.T) { current := runtimeConfig{ TunnelID: "tunnel-1", Endpoints: []endpointConfig{{EndpointID: "exit-a", NodeID: "node-a", Address: "quic://node-a:19131", Transport: "direct_quic"}}, RouteBundle: routeBundleConfig{RouteLease: routeLeaseConfig{ PrimaryPath: routeLeasePath{TargetNodeID: "node-a"}, }}, } sameLeaseNewGeneration := current sameLeaseNewGeneration.RouteLeaseID = "lease-2" sameLeaseNewGeneration.RouteGeneration = "route-gen-2" if shouldReconnectForRuntimeRoute(current, sameLeaseNewGeneration) { t.Fatal("same target/endpoints should update route epoch without reconnect") } newTarget := current newTarget.RouteBundle.RouteLease.PrimaryPath.TargetNodeID = "node-b" if !shouldReconnectForRuntimeRoute(current, newTarget) { t.Fatal("changed target node should reconnect fabric session") } newEndpoint := current newEndpoint.Endpoints = []endpointConfig{{EndpointID: "exit-b", NodeID: "node-b", Address: "quic://node-b:19131", Transport: "direct_quic"}} if !shouldReconnectForRuntimeRoute(current, newEndpoint) { t.Fatal("changed endpoint candidates should reconnect fabric session") } } func TestPacketBatchSendTimeoutScalesWithPayload(t *testing.T) { small := packetBatchSendTimeout([][]byte{make([]byte, 1200)}) large := packetBatchSendTimeout([][]byte{make([]byte, 4*1024*1024)}) if small != minPacketBatchSendTimeout { t.Fatalf("small timeout = %s, want %s", small, minPacketBatchSendTimeout) } if large <= small { t.Fatalf("large timeout = %s, want greater than %s", large, small) } many := make([][]byte, 2048) for i := range many { many[i] = make([]byte, 1200) } if got := packetBatchSendTimeout(many); got <= small { t.Fatalf("many-packet timeout = %s, want greater than %s", got, small) } if got := packetBatchSendTimeout([][]byte{make([]byte, 100*1024*1024)}); got != maxPacketBatchSendTimeout { t.Fatalf("capped timeout = %s, want %s", got, maxPacketBatchSendTimeout) } } func TestFabricRuntimeEndpointsFallbackToDisallowedEndpoints(t *testing.T) { cfg := runtimeConfig{ Endpoints: []endpointConfig{{EndpointID: "compat", Address: "quic://compat.example:19131"}}, } got := fabricRuntimeEndpoints(cfg) if len(got) != 1 || got[0].EndpointID != "compat" { t.Fatalf("endpoints = %+v, want compat endpoint fallback", got) } } func TestLiveFabricVPNRuntimeStartsFromRouteLease(t *testing.T) { raw := os.Getenv("RAP_LIVE_FABRICVPN_CONFIG") if raw == "" { t.Skip("RAP_LIVE_FABRICVPN_CONFIG is not set") } manager := NewManager() if err := manager.Start(raw); err != nil { t.Fatalf("start live fabric vpn runtime: %v", err) } defer manager.Stop() if snapshot := manager.SnapshotJSON(); snapshot == "" { t.Fatal("empty live fabric vpn snapshot") } if os.Getenv("RAP_LIVE_FABRICVPN_PACKET_PROBE") == "" { return } if err := manager.SendPacket(testDNSIPv4Packet()); err != nil { t.Fatalf("send live dns packet: %v", err) } for i := 0; i < 20; i++ { packet, err := manager.ReceivePacket(500) if err != nil { t.Fatalf("receive live dns packet: %v", err) } if len(packet) > 0 { if len(packet) >= 20 && packet[9] == 17 && packet[12] == 1 && packet[13] == 1 && packet[14] == 1 && packet[15] == 1 { return } } } t.Fatal("timed out waiting for live dns response through fabric vpn") } func testDNSIPv4Packet() []byte { nonce := uint16(time.Now().UnixNano()) dns := []byte{ byte(nonce >> 8), byte(nonce), 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0x01, 0x00, 0x01, } udpLen := 8 + len(dns) totalLen := 20 + udpLen packet := make([]byte, totalLen) packet[0] = 0x45 packet[2] = byte(totalLen >> 8) packet[3] = byte(totalLen) packet[8] = 64 packet[9] = 17 copy(packet[12:16], []byte{10, 77, 0, 2}) copy(packet[16:20], []byte{1, 1, 1, 1}) packet[20] = byte(0xc0 | ((nonce >> 8) & 0x3f)) packet[21] = byte(nonce) packet[22] = 0x00 packet[23] = 0x35 packet[24] = byte(udpLen >> 8) packet[25] = byte(udpLen) copy(packet[28:], dns) sum := ipv4HeaderChecksum(packet[:20]) packet[10] = byte(sum >> 8) packet[11] = byte(sum) return packet } func ipv4HeaderChecksum(header []byte) uint16 { var sum uint32 for i := 0; i+1 < len(header); i += 2 { if i == 10 { continue } sum += uint32(header[i])<<8 | uint32(header[i+1]) } for sum > 0xffff { sum = (sum & 0xffff) + (sum >> 16) } return ^uint16(sum) } func min(a, b int) int { if a < b { return a } return b }