package vpnruntime import ( "context" "fmt" "io" "log" "net" "sync" "sync/atomic" "time" "github.com/example/remote-access-platform/agents/rap-node-agent/internal/client" ) type Gateway struct { API *client.Client Transport PacketTransport ClusterID string VPNConnectionID string ServiceTunnel FabricServiceTunnel InterfaceName string AddressCIDR string RouteCIDR string PollTimeout time.Duration mu sync.Mutex running bool lastErr error cancel context.CancelFunc clientToGatewayBatches atomic.Uint64 clientToGatewayPackets atomic.Uint64 clientToGatewayBytes atomic.Uint64 gatewayToClientBatches atomic.Uint64 gatewayToClientPackets atomic.Uint64 gatewayToClientBytes atomic.Uint64 tunReadPackets atomic.Uint64 tunReadBytes atomic.Uint64 tunWritePackets atomic.Uint64 tunWriteBytes atomic.Uint64 uploadQueueDrops atomic.Uint64 downloadErrors atomic.Uint64 uploadErrors atomic.Uint64 lastClientToGatewayPacket string lastGatewayToClientPacket string lastRuntimeActivityAt time.Time } const ( vpnGatewayBatchMaxPackets = 2048 vpnGatewayBatchMaxBytes = 8 * 1024 * 1024 vpnGatewayBatchFlushTimeout = 3 * time.Millisecond vpnGatewayPriorityBatchWait = 20 * time.Millisecond ) type readWriteCloser interface { io.Reader io.Writer io.Closer } type PacketTransport interface { SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error ReceiveGatewayPacketBatch(ctx context.Context, timeout time.Duration) ([][]byte, error) } type packetTransportSnapshotter interface { Snapshot() map[string]any } type packetTransportCloser interface { Close() error } func (g *Gateway) EnsureStarted(ctx context.Context) error { g.mu.Lock() if g.running { g.mu.Unlock() return nil } g.mu.Unlock() if err := g.normalize(); err != nil { g.setStopped(err) return err } tun, err := openGatewayTun(g.InterfaceName, g.AddressCIDR, g.RouteCIDR) if err != nil { g.setStopped(err) return err } runCtx, cancel := context.WithCancel(ctx) g.mu.Lock() if g.running { g.mu.Unlock() cancel() _ = tun.Close() return nil } g.running = true g.lastErr = nil g.cancel = cancel g.mu.Unlock() go func() { if err := g.run(runCtx, tun); err != nil && runCtx.Err() == nil { log.Printf("vpn gateway runtime stopped: tunnel_id=%s error=%v", g.tunnelID(), err) g.setStopped(err) return } g.setStopped(runCtx.Err()) }() return nil } func (g *Gateway) Stop() { g.mu.Lock() cancel := g.cancel g.cancel = nil g.running = false g.mu.Unlock() if cancel != nil { cancel() } } func (g *Gateway) Status() (bool, string) { g.mu.Lock() defer g.mu.Unlock() if g.lastErr != nil { return g.running, g.lastErr.Error() } return g.running, "" } func (g *Gateway) IsReadyForConnection(vpnConnectionID string) bool { g.mu.Lock() defer g.mu.Unlock() tunnelID := g.tunnelIDLocked() return g.running && (g.VPNConnectionID == vpnConnectionID || tunnelID == vpnConnectionID) && vpnConnectionID != "" } func (g *Gateway) Snapshot() map[string]any { g.mu.Lock() running := g.running lastErr := "" if g.lastErr != nil { lastErr = g.lastErr.Error() } lastClientToGatewayPacket := g.lastClientToGatewayPacket lastGatewayToClientPacket := g.lastGatewayToClientPacket lastRuntimeActivityAt := g.lastRuntimeActivityAt g.mu.Unlock() out := map[string]any{ "running": running, "tunnel_id": g.ServiceTunnel.TunnelID, "pool_id": g.ServiceTunnel.PoolID, "service_id": g.ServiceTunnel.ServiceID, "local_service_id": g.ServiceTunnel.LocalServiceID, "remote_service_id": g.ServiceTunnel.RemoteServiceID, "service_kind": g.ServiceTunnel.ServiceKind, "service_role": firstNonEmptyTunnelString(g.ServiceTunnel.ServiceRole, DefaultFabricTunnelRole), "service_class": firstNonEmptyTunnelString(g.ServiceTunnel.ServiceClass, DefaultFabricTunnelClass), "adapter_contract": "fabric_channel_to_ipv4_nat", "transport": g.transportName(), "poll_timeout_ms": g.PollTimeout.Milliseconds(), "client_to_gateway_batches": g.clientToGatewayBatches.Load(), "client_to_gateway_packets": g.clientToGatewayPackets.Load(), "client_to_gateway_bytes": g.clientToGatewayBytes.Load(), "gateway_to_client_batches": g.gatewayToClientBatches.Load(), "gateway_to_client_packets": g.gatewayToClientPackets.Load(), "gateway_to_client_bytes": g.gatewayToClientBytes.Load(), "tun_read_packets": g.tunReadPackets.Load(), "tun_read_bytes": g.tunReadBytes.Load(), "tun_write_packets": g.tunWritePackets.Load(), "tun_write_bytes": g.tunWriteBytes.Load(), "upload_queue_drops": g.uploadQueueDrops.Load(), "download_errors": g.downloadErrors.Load(), "upload_errors": g.uploadErrors.Load(), "last_client_to_gateway": lastClientToGatewayPacket, "last_gateway_to_client": lastGatewayToClientPacket, } if lastErr != "" { out["last_error"] = lastErr } if !lastRuntimeActivityAt.IsZero() { out["last_runtime_activity_at"] = lastRuntimeActivityAt.UTC().Format(time.RFC3339Nano) } out["service_tunnel"] = g.ServiceTunnel.Snapshot() if platform := gatewayPlatformSnapshot(g.InterfaceName, g.RouteCIDR); len(platform) > 0 { out["platform"] = platform } if snapshotter, ok := g.Transport.(packetTransportSnapshotter); ok { if snapshot := snapshotter.Snapshot(); len(snapshot) > 0 { out["transport_snapshot"] = snapshot } } return out } func (g *Gateway) transportName() string { switch g.Transport.(type) { case *FabricSessionPacketTransport: return "fabric_session" case *FabricPacketTransport: return "fabric_mesh" case *LocalPacketTransport: return "local_fabric_inbox" case *AdaptivePacketTransport: return "adaptive_fabric" default: if g.Transport == nil { return "none" } return fmt.Sprintf("%T", g.Transport) } } func (g *Gateway) setStopped(err error) { g.mu.Lock() defer g.mu.Unlock() g.running = false g.lastErr = err g.cancel = nil } func (g *Gateway) normalize() error { if g.Transport == nil { return fmt.Errorf("fabric packet transport is required") } g.ServiceTunnel = NormalizeServiceTunnel(g.ServiceTunnel, g.VPNConnectionID) if g.VPNConnectionID == "" { g.VPNConnectionID = g.ServiceTunnel.TunnelID } if g.ClusterID == "" || g.VPNConnectionID == "" { return fmt.Errorf("cluster id and tunnel id are required") } if g.InterfaceName == "" { g.InterfaceName = "rapvpn0" } if g.AddressCIDR == "" { g.AddressCIDR = "10.77.0.1/24" } if g.RouteCIDR == "" { g.RouteCIDR = "10.77.0.0/24" } if g.PollTimeout <= 0 { g.PollTimeout = 25 * time.Second } return nil } func (g *Gateway) tunnelIDLocked() string { return firstNonEmptyTunnelString(g.ServiceTunnel.TunnelID, g.VPNConnectionID) } func (g *Gateway) tunnelID() string { if g == nil { return "" } g.mu.Lock() defer g.mu.Unlock() return g.tunnelIDLocked() } func (g *Gateway) run(ctx context.Context, tun readWriteCloser) error { defer tun.Close() if closer, ok := g.Transport.(packetTransportCloser); ok { defer closer.Close() } runCtx, cancel := context.WithCancel(ctx) defer cancel() errCh := make(chan error, 2) go func() { errCh <- g.copyGatewayToClient(runCtx, tun) }() go func() { errCh <- g.copyClientToGateway(runCtx, tun) }() select { case <-runCtx.Done(): return runCtx.Err() case err := <-errCh: cancel() return err } } func (g *Gateway) copyGatewayToClient(ctx context.Context, tun io.Reader) error { packets := make(chan []byte, 32768) errCh := make(chan error, 1) go func() { errCh <- g.uploadGatewayPackets(ctx, nil, packets) }() buffer := make([]byte, 65535) for { select { case <-ctx.Done(): return ctx.Err() case err := <-errCh: if err != nil { return err } default: } n, err := tun.Read(buffer) if err != nil { return err } if n <= 0 { continue } packet := append([]byte(nil), buffer[:n]...) normalizeIPv4PacketChecksums(packet) g.recordTunRead(packet) select { case packets <- packet: default: g.uploadQueueDrops.Add(1) log.Printf("vpn gateway packet upload queue full; dropping packet: tunnel_id=%s", g.tunnelID()) } } } func (g *Gateway) uploadGatewayPackets(ctx context.Context, _ <-chan []byte, packets <-chan []byte) error { batch := make([][]byte, 0, vpnGatewayBatchMaxPackets) batchBytes := 0 timer := time.NewTimer(time.Hour) if !timer.Stop() { <-timer.C } timerActive := false flush := func() { if len(batch) == 0 { return } packetCount := len(batch) byteCount := packetBytesTotal(batch) if err := g.Transport.SendGatewayPacketBatch(ctx, batch); err != nil { g.uploadErrors.Add(1) log.Printf("vpn gateway packet batch upload failed: tunnel_id=%s packets=%d error=%v", g.tunnelID(), len(batch), err) } else { g.recordGatewayToClientBatch(packetCount, byteCount, batch[0]) } for i := range batch { batch[i] = nil } batch = batch[:0] batchBytes = 0 } addPacket := func(packet []byte) bool { packetBytes := len(packet) if packetBytes <= 0 { return false } packetFrameSize := 4 + packetBytes if len(batch) > 0 { if len(batch) >= vpnGatewayBatchMaxPackets || batchBytes+packetFrameSize > vpnGatewayBatchMaxBytes { flush() } } batch = append(batch, packet) batchBytes += packetFrameSize return true } for { if len(batch) == 0 && timerActive { if !timer.Stop() { select { case <-timer.C: default: } } timerActive = false } select { case <-ctx.Done(): flush() return ctx.Err() case packet := <-packets: if !addPacket(packet) { continue } if len(batch) >= vpnGatewayBatchMaxPackets || batchBytes >= vpnGatewayBatchMaxBytes { flush() continue } if !timerActive { timer.Reset(vpnGatewayBatchFlushTimeout) timerActive = true } case <-timer.C: timerActive = false flush() } } } func (g *Gateway) copyClientToGateway(ctx context.Context, tun io.Writer) error { for { packets, err := g.Transport.ReceiveGatewayPacketBatch(ctx, g.PollTimeout) if err != nil { log.Printf("vpn gateway packet download failed: tunnel_id=%s error=%v", g.tunnelID(), err) select { case <-ctx.Done(): return ctx.Err() case <-time.After(time.Second): } continue } if len(packets) == 0 { continue } g.recordClientToGatewayBatch(len(packets), packetBytesTotal(packets), packets[0]) for _, packet := range packets { normalizeIPv4PacketChecksums(packet) if _, err := tun.Write(packet); err != nil { g.downloadErrors.Add(1) return err } g.recordTunWrite(packet) } } } func (g *Gateway) recordClientToGatewayBatch(packetCount int, byteCount int, first []byte) { next := g.clientToGatewayBatches.Add(1) g.clientToGatewayPackets.Add(uint64(packetCount)) g.clientToGatewayBytes.Add(uint64(byteCount)) summary := summarizePacket(first) g.mu.Lock() g.lastClientToGatewayPacket = summary g.lastRuntimeActivityAt = time.Now().UTC() g.mu.Unlock() if next <= 5 { log.Printf( "vpn gateway client_to_gateway batch received: tunnel_id=%s batch=%d packets=%d bytes=%d first=%s", g.tunnelID(), next, packetCount, byteCount, summary, ) } } func (g *Gateway) recordGatewayToClientBatch(packetCount int, byteCount int, first []byte) { next := g.gatewayToClientBatches.Add(1) g.gatewayToClientPackets.Add(uint64(packetCount)) g.gatewayToClientBytes.Add(uint64(byteCount)) summary := summarizePacket(first) g.mu.Lock() g.lastGatewayToClientPacket = summary g.lastRuntimeActivityAt = time.Now().UTC() g.mu.Unlock() if next <= 5 { log.Printf( "vpn gateway gateway_to_client batch uploaded: tunnel_id=%s batch=%d packets=%d bytes=%d first=%s", g.tunnelID(), next, packetCount, byteCount, summary, ) } } func (g *Gateway) recordTunWrite(packet []byte) { next := g.tunWritePackets.Add(1) g.tunWriteBytes.Add(uint64(len(packet))) if next <= 5 { log.Printf("vpn gateway packet written to tun: tunnel_id=%s packet=%d bytes=%d summary=%s", g.tunnelID(), next, len(packet), summarizePacket(packet)) } } func (g *Gateway) recordTunRead(packet []byte) { next := g.tunReadPackets.Add(1) g.tunReadBytes.Add(uint64(len(packet))) if next <= 5 { log.Printf("vpn gateway packet read from tun: tunnel_id=%s packet=%d bytes=%d summary=%s", g.tunnelID(), next, len(packet), summarizePacket(packet)) } } func packetBytesTotal(packets [][]byte) int { total := 0 for _, packet := range packets { total += len(packet) } return total } func summarizePacket(packet []byte) string { if len(packet) < 1 { return "empty" } version := packet[0] >> 4 switch version { case 4: return summarizeIPv4(packet) case 6: return summarizeIPv6(packet) default: return fmt.Sprintf("ip_version=%d bytes=%d", version, len(packet)) } } func summarizeIPv4(packet []byte) string { if len(packet) < 20 { return fmt.Sprintf("ipv4 truncated bytes=%d", len(packet)) } ihl := int(packet[0]&0x0f) * 4 if ihl < 20 || len(packet) < ihl { return fmt.Sprintf("ipv4 invalid_ihl=%d bytes=%d", ihl, len(packet)) } proto := packet[9] src := net.IP(packet[12:16]).String() dst := net.IP(packet[16:20]).String() return fmt.Sprintf("ipv4 %s -> %s proto=%d bytes=%d", src, dst, proto, len(packet)) } func summarizeIPv6(packet []byte) string { if len(packet) < 40 { return fmt.Sprintf("ipv6 truncated bytes=%d", len(packet)) } nextHeader := packet[6] src := net.IP(packet[8:24]).String() dst := net.IP(packet[24:40]).String() return fmt.Sprintf("ipv6 %s -> %s next=%d bytes=%d", src, dst, nextHeader, len(packet)) }