package fabricvpn import ( "context" "crypto/tls" "encoding/json" "fmt" "net" "strings" "sync" "sync/atomic" "syscall" "time" "github.com/example/remote-access-platform/agents/rap-node-agent/internal/fabricproto" "github.com/example/remote-access-platform/agents/rap-node-agent/internal/mesh" "github.com/example/remote-access-platform/agents/rap-node-agent/internal/vpnruntime" "github.com/quic-go/quic-go" ) type endpointConfig struct { EndpointID string `json:"endpoint_id"` NodeID string `json:"node_id"` Transport string `json:"transport"` Address string `json:"address"` PeerCertSHA256 string `json:"peer_cert_sha256"` TLSCertSHA256 string `json:"tls_cert_sha256"` Priority int `json:"priority"` } type runtimeConfig struct { ClusterID string `json:"cluster_id"` LocalNodeID string `json:"local_node_id"` ExitNodeID string `json:"exit_node_id"` VPNConnectionID string `json:"vpn_connection_id"` Endpoints []endpointConfig `json:"endpoints"` RouteBundle routeBundleConfig `json:"route_bundle"` ServiceChannelRequest serviceChannelRequest `json:"service_channel_request"` StreamShards int `json:"stream_shards"` } type controlForwardResponse struct { Payload json.RawMessage `json:"payload,omitempty"` Error string `json:"error,omitempty"` } type routeBundleConfig struct { SchemaVersion string `json:"schema_version"` RouteAuthority string `json:"route_authority"` SelectedTargetNode string `json:"selected_target_node_id"` EndpointCandidates []endpointConfig `json:"endpoint_candidates"` TargetCandidates []endpointConfig `json:"target_candidates"` RouteLease routeLeaseConfig `json:"route_lease"` } type routeLeaseConfig struct { SchemaVersion string `json:"schema_version"` LeaseID string `json:"lease_id"` SelectedTargetNode string `json:"selected_target_node"` PrimaryPath routeLeasePath `json:"primary_path"` WarmStandbyPaths []routeLeasePath `json:"warm_standby_paths"` Multipath map[string]any `json:"multipath"` RebuildPolicy map[string]any `json:"rebuild_policy"` } type routeLeasePath struct { PathID string `json:"path_id"` TargetNodeID string `json:"target_node_id"` Status string `json:"status"` EndpointCandidates []endpointConfig `json:"endpoint_candidates"` } type serviceChannelRequest struct { SchemaVersion string `json:"schema_version"` ChannelID string `json:"channel_id"` ServiceClass string `json:"service_class"` SourceRole string `json:"source_role"` } type SocketProtector interface { Protect(fd int64) bool } type Manager struct { opMu sync.Mutex mu sync.Mutex cancel context.CancelFunc transport *mesh.QUICFabricTransport session mesh.FabricTransportSession packet *vpnruntime.FabricSessionPacketTransport inbox *vpnruntime.FabricPacketInbox cfg runtimeConfig lastErr string endpoint string protector SocketProtector uplinkPackets atomic.Uint64 uplinkBytes atomic.Uint64 downlinkPackets atomic.Uint64 downlinkBytes atomic.Uint64 } func NewManager() *Manager { return &Manager{} } func (m *Manager) SetSocketProtector(protector SocketProtector) { m.mu.Lock() m.protector = protector m.mu.Unlock() } func (m *Manager) Start(configJSON string) error { var cfg runtimeConfig if err := json.Unmarshal([]byte(configJSON), &cfg); err != nil { return err } cfg.ClusterID = strings.TrimSpace(cfg.ClusterID) cfg.LocalNodeID = strings.TrimSpace(cfg.LocalNodeID) cfg.ExitNodeID = strings.TrimSpace(cfg.ExitNodeID) cfg.VPNConnectionID = strings.TrimSpace(cfg.VPNConnectionID) cfg.Endpoints = fabricRuntimeEndpoints(cfg) cfg.ExitNodeID = firstNonEmpty(cfg.ExitNodeID, fabricRuntimeTargetNodeID(cfg)) if cfg.ClusterID == "" || cfg.LocalNodeID == "" || cfg.VPNConnectionID == "" { return fmt.Errorf("cluster, local node and vpn connection id are required") } if strings.TrimSpace(cfg.ServiceChannelRequest.SchemaVersion) == "" { return fmt.Errorf("fabric service channel request is required") } if len(cfg.Endpoints) == 0 { return fmt.Errorf("fabric route lease has no QUIC candidates") } if cfg.StreamShards <= 0 { cfg.StreamShards = 4 } if cfg.StreamShards > 32 { cfg.StreamShards = 32 } m.Stop() ctx, cancel := context.WithCancel(context.Background()) if err := m.connect(ctx, cfg, cancel); err != nil { cancel() m.setErr(err) return err } return nil } func fabricRuntimeEndpoints(cfg runtimeConfig) []endpointConfig { if len(cfg.RouteBundle.RouteLease.PrimaryPath.EndpointCandidates) > 0 { return cfg.RouteBundle.RouteLease.PrimaryPath.EndpointCandidates } for _, path := range cfg.RouteBundle.RouteLease.WarmStandbyPaths { if len(path.EndpointCandidates) > 0 { return path.EndpointCandidates } } if len(cfg.RouteBundle.EndpointCandidates) > 0 { return cfg.RouteBundle.EndpointCandidates } if len(cfg.RouteBundle.TargetCandidates) > 0 { return cfg.RouteBundle.TargetCandidates } return cfg.Endpoints } func fabricRuntimeTargetNodeID(cfg runtimeConfig) string { if cfg.RouteBundle.RouteLease.PrimaryPath.TargetNodeID != "" { return cfg.RouteBundle.RouteLease.PrimaryPath.TargetNodeID } if cfg.RouteBundle.RouteLease.SelectedTargetNode != "" { return cfg.RouteBundle.RouteLease.SelectedTargetNode } return cfg.RouteBundle.SelectedTargetNode } func (m *Manager) connect(ctx context.Context, cfg runtimeConfig, cancel context.CancelFunc) error { quicTransport := mesh.NewQUICFabricTransport(nil) quicTransport.SetLocalPeerID(cfg.LocalNodeID) quicTransport.DialAddr = m.protectedQUICDialer() inbox := vpnruntime.NewFabricPacketInbox(4096) quicTransport.SetInboundHandlers(func(ctx context.Context, envelope mesh.ProductionEnvelope) (mesh.ProductionForwardResult, error) { if err := inbox.DeliverProductionEnvelope(ctx, envelope); err != nil { return mesh.ProductionForwardResult{}, err } return mesh.ProductionForwardResult{Delivered: true, MessageID: envelope.MessageID}, nil }, nil, nil) var lastErr error for _, endpoint := range cfg.Endpoints { target := mesh.FabricTransportTarget{ EndpointID: firstNonEmpty(endpoint.EndpointID, endpoint.Address), PeerID: firstNonEmpty(endpoint.NodeID, cfg.ExitNodeID), Endpoint: endpoint.Address, Transport: firstNonEmpty(endpoint.Transport, "direct_quic"), PeerCertSHA256: firstNonEmpty(endpoint.PeerCertSHA256, endpoint.TLSCertSHA256), Timeout: 5 * time.Second, OutboundBuffer: 512, InboundBuffer: 512, ErrorBuffer: 32, } carrier, selected, err := mesh.FabricTransportForTarget(target, quicTransport) if err != nil { lastErr = err continue } dialCtx, dialCancel := context.WithTimeout(ctx, 5*time.Second) session, err := carrier.Connect(dialCtx, selected) if err != nil { dialCancel() lastErr = err continue } streamIDs, streamID, err := openStreams(dialCtx, session, cfg.StreamShards) dialCancel() if err != nil { _ = session.Close() lastErr = err continue } m.mu.Lock() m.cancel = cancel m.transport = quicTransport m.session = session m.inbox = inbox m.cfg = cfg m.endpoint = endpoint.Address m.lastErr = "" m.packet = &vpnruntime.FabricSessionPacketTransport{ Sender: session, Receiver: session, Inbox: inbox, StreamID: streamID, StreamIDsByTrafficClass: streamIDs, VPNConnectionID: cfg.VPNConnectionID, SendDirection: vpnruntime.FabricDirectionClientToGateway, ReceiveDirection: vpnruntime.FabricDirectionGatewayToClient, } m.mu.Unlock() return nil } if lastErr == nil { lastErr = fmt.Errorf("no QUIC exit endpoints available") } return lastErr } func (m *Manager) protectedQUICDialer() func(context.Context, string, *tls.Config, *quic.Config) (*quic.Conn, error) { m.mu.Lock() protector := m.protector m.mu.Unlock() if protector == nil { return nil } return func(ctx context.Context, endpoint string, tlsConfig *tls.Config, config *quic.Config) (*quic.Conn, error) { network := "udp4" if strings.Contains(endpoint, "[") { network = "udp6" } conn, err := net.ListenPacket(network, ":0") if err != nil { return nil, err } raw, ok := conn.(interface { SyscallConn() (syscall.RawConn, error) }) if !ok { _ = conn.Close() return nil, fmt.Errorf("udp socket does not expose raw connection for vpn protection") } rawConn, err := raw.SyscallConn() if err != nil { _ = conn.Close() return nil, err } var protectErr error if err := rawConn.Control(func(fd uintptr) { if !protector.Protect(int64(fd)) { protectErr = fmt.Errorf("android vpn socket protect failed") } }); err != nil { _ = conn.Close() return nil, err } if protectErr != nil { _ = conn.Close() return nil, protectErr } return mesh.DialQUICAddrWithPacketConn(ctx, endpoint, conn, tlsConfig, config) } } func (m *Manager) Stop() { m.opMu.Lock() defer m.opMu.Unlock() m.stopLocked() } func (m *Manager) stopLocked() { m.mu.Lock() cancel := m.cancel session := m.session transport := m.transport m.cancel = nil m.session = nil m.transport = nil m.packet = nil m.mu.Unlock() if cancel != nil { cancel() } if session != nil { _ = session.Close() } if transport != nil { _ = transport.Close() } } func (m *Manager) SendPacket(packet []byte) error { if len(packet) == 0 { return nil } m.opMu.Lock() defer m.opMu.Unlock() if err := m.ensureConnectedLocked(); err != nil { return err } m.mu.Lock() transport := m.packet m.mu.Unlock() if transport == nil { return fmt.Errorf("fabric vpn runtime is not connected") } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := transport.SendGatewayPacketBatch(ctx, [][]byte{append([]byte(nil), packet...)}); err != nil { m.setErr(err) if reconnectErr := m.reconnectLocked(); reconnectErr != nil { return err } m.mu.Lock() transport = m.packet m.mu.Unlock() if transport == nil { return err } retryCtx, retryCancel := context.WithTimeout(context.Background(), 5*time.Second) defer retryCancel() if retryErr := transport.SendGatewayPacketBatch(retryCtx, [][]byte{append([]byte(nil), packet...)}); retryErr != nil { m.setErr(retryErr) return retryErr } } m.uplinkPackets.Add(1) m.uplinkBytes.Add(uint64(len(packet))) return nil } func (m *Manager) ReceivePacket(timeoutMillis int) ([]byte, error) { m.opMu.Lock() defer m.opMu.Unlock() if err := m.ensureConnectedLocked(); err != nil { return nil, err } m.mu.Lock() transport := m.packet m.mu.Unlock() if transport == nil { return nil, fmt.Errorf("fabric vpn runtime is not connected") } timeout := time.Duration(timeoutMillis) * time.Millisecond if timeout <= 0 { timeout = 100 * time.Millisecond } ctx, cancel := context.WithTimeout(context.Background(), timeout+time.Second) defer cancel() packets, err := transport.ReceiveGatewayPacketBatch(ctx, timeout) if err != nil { m.setErr(err) _ = m.reconnectLocked() return nil, err } if len(packets) == 0 { return nil, nil } packet := append([]byte(nil), packets[0]...) m.downlinkPackets.Add(1) m.downlinkBytes.Add(uint64(len(packet))) return packet, nil } func (m *Manager) ControlRequest(payloadJSON string) (string, error) { m.opMu.Lock() defer m.opMu.Unlock() if err := m.ensureConnectedLocked(); err != nil { return "", err } m.mu.Lock() transport := m.transport cfg := m.cfg endpointAddress := m.endpoint m.mu.Unlock() if transport == nil || endpointAddress == "" { return "", fmt.Errorf("fabric control runtime is not connected") } endpoint := endpointConfig{Address: endpointAddress} for _, candidate := range cfg.Endpoints { if strings.TrimSpace(candidate.Address) == endpointAddress { endpoint = candidate break } } target := mesh.FabricTransportTarget{ EndpointID: firstNonEmpty(endpoint.EndpointID, endpoint.Address), PeerID: firstNonEmpty(endpoint.NodeID, cfg.ExitNodeID), Endpoint: endpoint.Address, Transport: firstNonEmpty(endpoint.Transport, "direct_quic"), PeerCertSHA256: firstNonEmpty(endpoint.PeerCertSHA256, endpoint.TLSCertSHA256), Timeout: 8 * time.Second, OutboundBuffer: 16, InboundBuffer: 16, ErrorBuffer: 8, } carrier, selected, err := mesh.FabricTransportForTarget(target, transport) if err != nil { return "", err } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() session, err := carrier.Connect(ctx, selected) if err != nil { return "", err } defer session.Close() if err := session.Send(ctx, fabricproto.Frame{ Type: fabricproto.FrameData, TrafficClass: fabricproto.TrafficClassReliable, StreamID: mesh.FabricControlForwardQUICStreamID, Sequence: uint64(time.Now().UnixNano()), Payload: []byte(payloadJSON), }); err != nil { return "", err } for { select { case <-ctx.Done(): return "", ctx.Err() case err := <-session.Errors(): if err != nil { return "", err } case frame := <-session.Frames(): if frame.Type != fabricproto.FrameData || frame.StreamID != mesh.FabricControlForwardQUICStreamID { continue } var response controlForwardResponse if err := json.Unmarshal(frame.Payload, &response); err != nil { return "", err } if response.Error != "" { return "", fmt.Errorf(response.Error) } return string(response.Payload), nil } } } func (m *Manager) Reconnect() error { m.opMu.Lock() defer m.opMu.Unlock() return m.reconnectLocked() } func (m *Manager) ensureConnectedLocked() error { m.mu.Lock() connected := m.packet != nil cancel := m.cancel m.mu.Unlock() if connected { return nil } if cancel == nil { return fmt.Errorf("fabric vpn runtime is stopped") } return m.reconnectLocked() } func (m *Manager) reconnectLocked() error { m.mu.Lock() cfg := m.cfg oldSession := m.session oldTransport := m.transport cancel := m.cancel m.session = nil m.transport = nil m.packet = nil m.mu.Unlock() if oldSession != nil { _ = oldSession.Close() } if oldTransport != nil { _ = oldTransport.Close() } if cancel == nil { return fmt.Errorf("fabric vpn runtime is stopped") } ctx, ctxCancel := context.WithTimeout(context.Background(), 8*time.Second) defer ctxCancel() if err := m.connect(ctx, cfg, cancel); err != nil { m.setErr(err) return err } return nil } func (m *Manager) SnapshotJSON() string { m.mu.Lock() connected := m.packet != nil endpoint := m.endpoint lastErr := m.lastErr vpnConnectionID := m.cfg.VPNConnectionID localNodeID := m.cfg.LocalNodeID exitNodeID := m.cfg.ExitNodeID m.mu.Unlock() payload, _ := json.Marshal(map[string]any{ "schema_version": "rap.android_fabric_vpn_runtime.v1", "connected": connected, "endpoint": endpoint, "last_error": lastErr, "vpn_connection": vpnConnectionID, "local_node_id": localNodeID, "exit_node_id": exitNodeID, "uplink_packets": m.uplinkPackets.Load(), "uplink_bytes": m.uplinkBytes.Load(), "downlink_packets": m.downlinkPackets.Load(), "downlink_bytes": m.downlinkBytes.Load(), }) return string(payload) } func (m *Manager) setErr(err error) { if err == nil { return } m.mu.Lock() m.lastErr = err.Error() m.mu.Unlock() } func openStreams(ctx context.Context, session mesh.FabricTransportSession, shards int) (map[string][]uint64, uint64, error) { base := uint64(time.Now().UnixNano()) classes := []struct { name string trafficClass fabricproto.TrafficClass }{ {name: vpnruntime.FabricTrafficClassInteractive, trafficClass: fabricproto.TrafficClassInteractive}, {name: vpnruntime.FabricTrafficClassBulk, trafficClass: fabricproto.TrafficClassBulk}, } out := make(map[string][]uint64, len(classes)) var primary uint64 for classIndex, class := range classes { for shard := 0; shard < shards; shard++ { streamID := base + uint64(classIndex*shards+shard) if err := session.Send(ctx, fabricproto.Frame{Type: fabricproto.FrameOpenStream, StreamID: streamID, TrafficClass: class.trafficClass}); err != nil { return nil, 0, err } if primary == 0 { primary = streamID } out[class.name] = append(out[class.name], streamID) } } return out, primary, nil } func firstNonEmpty(values ...string) string { for _, value := range values { if strings.TrimSpace(value) != "" { return strings.TrimSpace(value) } } return "" }