Files
rdp-proxy/agents/rap-node-agent/internal/vpnruntime/gateway.go
T

541 lines
14 KiB
Go

package vpnruntime
import (
"context"
"fmt"
"io"
"log"
"net"
"sync"
"sync/atomic"
"time"
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/client"
)
type Gateway struct {
API *client.Client
Transport PacketTransport
ClusterID string
VPNConnectionID string
InterfaceName string
AddressCIDR string
RouteCIDR string
PollTimeout time.Duration
mu sync.Mutex
running bool
lastErr error
cancel context.CancelFunc
clientToGatewayBatches atomic.Uint64
clientToGatewayPackets atomic.Uint64
clientToGatewayBytes atomic.Uint64
gatewayToClientBatches atomic.Uint64
gatewayToClientPackets atomic.Uint64
gatewayToClientBytes atomic.Uint64
tunReadPackets atomic.Uint64
tunReadBytes atomic.Uint64
tunWritePackets atomic.Uint64
tunWriteBytes atomic.Uint64
uploadQueueDrops atomic.Uint64
downloadErrors atomic.Uint64
uploadErrors atomic.Uint64
lastClientToGatewayPacket string
lastGatewayToClientPacket string
lastRuntimeActivityAt time.Time
}
const (
vpnGatewayBatchMaxPackets = 2048
vpnGatewayBatchMaxBytes = 8 * 1024 * 1024
vpnGatewayBatchFlushTimeout = 3 * time.Millisecond
)
type readWriteCloser interface {
io.Reader
io.Writer
io.Closer
}
type PacketTransport interface {
SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error
ReceiveGatewayPacketBatch(ctx context.Context, timeout time.Duration) ([][]byte, error)
}
type BackendPacketTransport struct {
API *client.Client
ClusterID string
VPNConnectionID string
}
func (t BackendPacketTransport) SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error {
return t.API.SendVPNGatewayPacketBatch(ctx, t.ClusterID, t.VPNConnectionID, packets)
}
func (t BackendPacketTransport) ReceiveGatewayPacketBatch(ctx context.Context, timeout time.Duration) ([][]byte, error) {
return t.API.ReceiveVPNGatewayPacketBatch(ctx, t.ClusterID, t.VPNConnectionID, timeout)
}
func (g *Gateway) EnsureStarted(ctx context.Context) error {
g.mu.Lock()
if g.running {
g.mu.Unlock()
return nil
}
g.mu.Unlock()
if err := g.normalize(); err != nil {
g.setStopped(err)
return err
}
tun, err := openGatewayTun(g.InterfaceName, g.AddressCIDR, g.RouteCIDR)
if err != nil {
g.setStopped(err)
return err
}
runCtx, cancel := context.WithCancel(ctx)
g.mu.Lock()
if g.running {
g.mu.Unlock()
cancel()
_ = tun.Close()
return nil
}
g.running = true
g.lastErr = nil
g.cancel = cancel
g.mu.Unlock()
go func() {
if err := g.run(runCtx, tun); err != nil && runCtx.Err() == nil {
log.Printf("vpn gateway runtime stopped: vpn_connection_id=%s error=%v", g.VPNConnectionID, err)
g.setStopped(err)
return
}
g.setStopped(runCtx.Err())
}()
return nil
}
func (g *Gateway) Stop() {
g.mu.Lock()
cancel := g.cancel
g.cancel = nil
g.running = false
g.mu.Unlock()
if cancel != nil {
cancel()
}
}
func (g *Gateway) Status() (bool, string) {
g.mu.Lock()
defer g.mu.Unlock()
if g.lastErr != nil {
return g.running, g.lastErr.Error()
}
return g.running, ""
}
func (g *Gateway) IsReadyForConnection(vpnConnectionID string) bool {
g.mu.Lock()
defer g.mu.Unlock()
return g.running && g.VPNConnectionID == vpnConnectionID && vpnConnectionID != ""
}
func (g *Gateway) Snapshot() map[string]any {
g.mu.Lock()
running := g.running
lastErr := ""
if g.lastErr != nil {
lastErr = g.lastErr.Error()
}
lastClientToGatewayPacket := g.lastClientToGatewayPacket
lastGatewayToClientPacket := g.lastGatewayToClientPacket
lastRuntimeActivityAt := g.lastRuntimeActivityAt
g.mu.Unlock()
out := map[string]any{
"running": running,
"transport": g.transportName(),
"poll_timeout_ms": g.PollTimeout.Milliseconds(),
"client_to_gateway_batches": g.clientToGatewayBatches.Load(),
"client_to_gateway_packets": g.clientToGatewayPackets.Load(),
"client_to_gateway_bytes": g.clientToGatewayBytes.Load(),
"gateway_to_client_batches": g.gatewayToClientBatches.Load(),
"gateway_to_client_packets": g.gatewayToClientPackets.Load(),
"gateway_to_client_bytes": g.gatewayToClientBytes.Load(),
"tun_read_packets": g.tunReadPackets.Load(),
"tun_read_bytes": g.tunReadBytes.Load(),
"tun_write_packets": g.tunWritePackets.Load(),
"tun_write_bytes": g.tunWriteBytes.Load(),
"upload_queue_drops": g.uploadQueueDrops.Load(),
"download_errors": g.downloadErrors.Load(),
"upload_errors": g.uploadErrors.Load(),
"last_client_to_gateway": lastClientToGatewayPacket,
"last_gateway_to_client": lastGatewayToClientPacket,
}
if lastErr != "" {
out["last_error"] = lastErr
}
if !lastRuntimeActivityAt.IsZero() {
out["last_runtime_activity_at"] = lastRuntimeActivityAt.UTC().Format(time.RFC3339Nano)
}
if platform := gatewayPlatformSnapshot(g.InterfaceName, g.RouteCIDR); len(platform) > 0 {
out["platform"] = platform
}
return out
}
func (g *Gateway) transportName() string {
switch g.Transport.(type) {
case *FabricPacketTransport:
return "fabric_mesh"
case *LocalPacketTransport:
return "local_fabric_inbox"
case *AdaptivePacketTransport:
return "adaptive_fabric_backend"
case BackendPacketTransport:
return "backend_http_packet_relay"
default:
if g.Transport == nil {
return "none"
}
return fmt.Sprintf("%T", g.Transport)
}
}
func (g *Gateway) setStopped(err error) {
g.mu.Lock()
defer g.mu.Unlock()
g.running = false
g.lastErr = err
g.cancel = nil
}
func (g *Gateway) normalize() error {
if g.Transport == nil {
if g.API == nil {
return fmt.Errorf("api client or packet transport is required")
}
g.Transport = BackendPacketTransport{
API: g.API,
ClusterID: g.ClusterID,
VPNConnectionID: g.VPNConnectionID,
}
}
if g.ClusterID == "" || g.VPNConnectionID == "" {
return fmt.Errorf("cluster id and vpn connection id are required")
}
if g.InterfaceName == "" {
g.InterfaceName = "rapvpn0"
}
if g.AddressCIDR == "" {
g.AddressCIDR = "10.77.0.1/24"
}
if g.RouteCIDR == "" {
g.RouteCIDR = "10.77.0.0/24"
}
if g.PollTimeout <= 0 {
g.PollTimeout = 25 * time.Second
}
return nil
}
func (g *Gateway) run(ctx context.Context, tun readWriteCloser) error {
defer tun.Close()
errCh := make(chan error, 2)
go func() { errCh <- g.copyGatewayToClient(ctx, tun) }()
go func() { errCh <- g.copyClientToGateway(ctx, tun) }()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-errCh:
return err
}
}
func (g *Gateway) copyGatewayToClient(ctx context.Context, tun io.Reader) error {
priorityPackets := make(chan []byte, 1024)
packets := make(chan []byte, 32768)
errCh := make(chan error, 1)
go func() {
errCh <- g.uploadGatewayPackets(ctx, priorityPackets, packets)
}()
buffer := make([]byte, 65535)
for {
select {
case <-ctx.Done():
return ctx.Err()
case err := <-errCh:
if err != nil {
return err
}
default:
}
n, err := tun.Read(buffer)
if err != nil {
return err
}
if n <= 0 {
continue
}
packet := append([]byte(nil), buffer[:n]...)
normalizeIPv4PacketChecksums(packet)
g.recordTunRead(packet)
if isTCPControlPacket(packet) {
select {
case priorityPackets <- packet:
default:
g.uploadQueueDrops.Add(1)
log.Printf("vpn gateway priority packet upload queue full; dropping packet: vpn_connection_id=%s", g.VPNConnectionID)
}
continue
}
select {
case packets <- packet:
default:
g.uploadQueueDrops.Add(1)
log.Printf("vpn gateway packet upload queue full; dropping packet: vpn_connection_id=%s", g.VPNConnectionID)
}
}
}
func (g *Gateway) uploadGatewayPackets(ctx context.Context, priorityPackets <-chan []byte, packets <-chan []byte) error {
batch := make([][]byte, 0, vpnGatewayBatchMaxPackets)
batchBytes := 0
timer := time.NewTimer(time.Hour)
if !timer.Stop() {
<-timer.C
}
timerActive := false
flush := func() {
if len(batch) == 0 {
return
}
packetCount := len(batch)
byteCount := packetBytesTotal(batch)
if err := g.Transport.SendGatewayPacketBatch(ctx, batch); err != nil {
g.uploadErrors.Add(1)
log.Printf("vpn gateway packet batch upload failed: vpn_connection_id=%s packets=%d error=%v", g.VPNConnectionID, len(batch), err)
} else {
g.recordGatewayToClientBatch(packetCount, byteCount, batch[0])
}
for i := range batch {
batch[i] = nil
}
batch = batch[:0]
batchBytes = 0
}
addPacket := func(packet []byte) bool {
packetBytes := len(packet)
if packetBytes <= 0 {
return false
}
packetFrameSize := 4 + packetBytes
if len(batch) > 0 {
if len(batch) >= vpnGatewayBatchMaxPackets || batchBytes+packetFrameSize > vpnGatewayBatchMaxBytes {
flush()
}
}
batch = append(batch, packet)
batchBytes += packetFrameSize
return true
}
flushPriority := func(packet []byte) {
flush()
if addPacket(packet) {
flush()
}
}
for {
if len(batch) == 0 && timerActive {
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timerActive = false
}
select {
case packet := <-priorityPackets:
flushPriority(packet)
continue
default:
}
select {
case <-ctx.Done():
flush()
return ctx.Err()
case packet := <-priorityPackets:
flushPriority(packet)
case packet := <-packets:
if !addPacket(packet) {
continue
}
if len(batch) >= vpnGatewayBatchMaxPackets || batchBytes >= vpnGatewayBatchMaxBytes {
flush()
continue
}
if !timerActive {
timer.Reset(vpnGatewayBatchFlushTimeout)
timerActive = true
}
case <-timer.C:
timerActive = false
flush()
}
}
}
func isTCPControlPacket(packet []byte) bool {
if len(packet) < 20 || packet[0]>>4 != 4 {
return false
}
ihl := int(packet[0]&0x0f) * 4
if ihl < 20 || len(packet) < ihl+20 || packet[9] != 6 {
return false
}
flags := packet[ihl+13]
return flags&0x17 != 0
}
func (g *Gateway) copyClientToGateway(ctx context.Context, tun io.Writer) error {
for {
packets, err := g.Transport.ReceiveGatewayPacketBatch(ctx, g.PollTimeout)
if err != nil {
log.Printf("vpn gateway packet download failed: vpn_connection_id=%s error=%v", g.VPNConnectionID, err)
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(time.Second):
}
continue
}
if len(packets) == 0 {
continue
}
g.recordClientToGatewayBatch(len(packets), packetBytesTotal(packets), packets[0])
for _, packet := range packets {
normalizeIPv4PacketChecksums(packet)
if _, err := tun.Write(packet); err != nil {
g.downloadErrors.Add(1)
return err
}
g.recordTunWrite(packet)
}
}
}
func (g *Gateway) recordClientToGatewayBatch(packetCount int, byteCount int, first []byte) {
next := g.clientToGatewayBatches.Add(1)
g.clientToGatewayPackets.Add(uint64(packetCount))
g.clientToGatewayBytes.Add(uint64(byteCount))
summary := summarizePacket(first)
g.mu.Lock()
g.lastClientToGatewayPacket = summary
g.lastRuntimeActivityAt = time.Now().UTC()
g.mu.Unlock()
if next <= 5 {
log.Printf(
"vpn gateway client_to_gateway batch received: vpn_connection_id=%s batch=%d packets=%d bytes=%d first=%s",
g.VPNConnectionID,
next,
packetCount,
byteCount,
summary,
)
}
}
func (g *Gateway) recordGatewayToClientBatch(packetCount int, byteCount int, first []byte) {
next := g.gatewayToClientBatches.Add(1)
g.gatewayToClientPackets.Add(uint64(packetCount))
g.gatewayToClientBytes.Add(uint64(byteCount))
summary := summarizePacket(first)
g.mu.Lock()
g.lastGatewayToClientPacket = summary
g.lastRuntimeActivityAt = time.Now().UTC()
g.mu.Unlock()
if next <= 5 {
log.Printf(
"vpn gateway gateway_to_client batch uploaded: vpn_connection_id=%s batch=%d packets=%d bytes=%d first=%s",
g.VPNConnectionID,
next,
packetCount,
byteCount,
summary,
)
}
}
func (g *Gateway) recordTunWrite(packet []byte) {
next := g.tunWritePackets.Add(1)
g.tunWriteBytes.Add(uint64(len(packet)))
if next <= 5 {
log.Printf("vpn gateway packet written to tun: vpn_connection_id=%s packet=%d bytes=%d summary=%s", g.VPNConnectionID, next, len(packet), summarizePacket(packet))
}
}
func (g *Gateway) recordTunRead(packet []byte) {
next := g.tunReadPackets.Add(1)
g.tunReadBytes.Add(uint64(len(packet)))
if next <= 5 {
log.Printf("vpn gateway packet read from tun: vpn_connection_id=%s packet=%d bytes=%d summary=%s", g.VPNConnectionID, next, len(packet), summarizePacket(packet))
}
}
func packetBytesTotal(packets [][]byte) int {
total := 0
for _, packet := range packets {
total += len(packet)
}
return total
}
func summarizePacket(packet []byte) string {
if len(packet) < 1 {
return "empty"
}
version := packet[0] >> 4
switch version {
case 4:
return summarizeIPv4(packet)
case 6:
return summarizeIPv6(packet)
default:
return fmt.Sprintf("ip_version=%d bytes=%d", version, len(packet))
}
}
func summarizeIPv4(packet []byte) string {
if len(packet) < 20 {
return fmt.Sprintf("ipv4 truncated bytes=%d", len(packet))
}
ihl := int(packet[0]&0x0f) * 4
if ihl < 20 || len(packet) < ihl {
return fmt.Sprintf("ipv4 invalid_ihl=%d bytes=%d", ihl, len(packet))
}
proto := packet[9]
src := net.IP(packet[12:16]).String()
dst := net.IP(packet[16:20]).String()
return fmt.Sprintf("ipv4 %s -> %s proto=%d bytes=%d", src, dst, proto, len(packet))
}
func summarizeIPv6(packet []byte) string {
if len(packet) < 40 {
return fmt.Sprintf("ipv6 truncated bytes=%d", len(packet))
}
nextHeader := packet[6]
src := net.IP(packet[8:24]).String()
dst := net.IP(packet[24:40]).String()
return fmt.Sprintf("ipv6 %s -> %s next=%d bytes=%d", src, dst, nextHeader, len(packet))
}