541 lines
14 KiB
Go
541 lines
14 KiB
Go
package vpnruntime
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/client"
|
|
)
|
|
|
|
type Gateway struct {
|
|
API *client.Client
|
|
Transport PacketTransport
|
|
ClusterID string
|
|
VPNConnectionID string
|
|
InterfaceName string
|
|
AddressCIDR string
|
|
RouteCIDR string
|
|
PollTimeout time.Duration
|
|
|
|
mu sync.Mutex
|
|
running bool
|
|
lastErr error
|
|
cancel context.CancelFunc
|
|
|
|
clientToGatewayBatches atomic.Uint64
|
|
clientToGatewayPackets atomic.Uint64
|
|
clientToGatewayBytes atomic.Uint64
|
|
gatewayToClientBatches atomic.Uint64
|
|
gatewayToClientPackets atomic.Uint64
|
|
gatewayToClientBytes atomic.Uint64
|
|
tunReadPackets atomic.Uint64
|
|
tunReadBytes atomic.Uint64
|
|
tunWritePackets atomic.Uint64
|
|
tunWriteBytes atomic.Uint64
|
|
uploadQueueDrops atomic.Uint64
|
|
downloadErrors atomic.Uint64
|
|
uploadErrors atomic.Uint64
|
|
|
|
lastClientToGatewayPacket string
|
|
lastGatewayToClientPacket string
|
|
lastRuntimeActivityAt time.Time
|
|
}
|
|
|
|
const (
|
|
vpnGatewayBatchMaxPackets = 2048
|
|
vpnGatewayBatchMaxBytes = 8 * 1024 * 1024
|
|
vpnGatewayBatchFlushTimeout = 3 * time.Millisecond
|
|
)
|
|
|
|
type readWriteCloser interface {
|
|
io.Reader
|
|
io.Writer
|
|
io.Closer
|
|
}
|
|
|
|
type PacketTransport interface {
|
|
SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error
|
|
ReceiveGatewayPacketBatch(ctx context.Context, timeout time.Duration) ([][]byte, error)
|
|
}
|
|
|
|
type BackendPacketTransport struct {
|
|
API *client.Client
|
|
ClusterID string
|
|
VPNConnectionID string
|
|
}
|
|
|
|
func (t BackendPacketTransport) SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error {
|
|
return t.API.SendVPNGatewayPacketBatch(ctx, t.ClusterID, t.VPNConnectionID, packets)
|
|
}
|
|
|
|
func (t BackendPacketTransport) ReceiveGatewayPacketBatch(ctx context.Context, timeout time.Duration) ([][]byte, error) {
|
|
return t.API.ReceiveVPNGatewayPacketBatch(ctx, t.ClusterID, t.VPNConnectionID, timeout)
|
|
}
|
|
|
|
func (g *Gateway) EnsureStarted(ctx context.Context) error {
|
|
g.mu.Lock()
|
|
if g.running {
|
|
g.mu.Unlock()
|
|
return nil
|
|
}
|
|
g.mu.Unlock()
|
|
|
|
if err := g.normalize(); err != nil {
|
|
g.setStopped(err)
|
|
return err
|
|
}
|
|
tun, err := openGatewayTun(g.InterfaceName, g.AddressCIDR, g.RouteCIDR)
|
|
if err != nil {
|
|
g.setStopped(err)
|
|
return err
|
|
}
|
|
runCtx, cancel := context.WithCancel(ctx)
|
|
|
|
g.mu.Lock()
|
|
if g.running {
|
|
g.mu.Unlock()
|
|
cancel()
|
|
_ = tun.Close()
|
|
return nil
|
|
}
|
|
g.running = true
|
|
g.lastErr = nil
|
|
g.cancel = cancel
|
|
g.mu.Unlock()
|
|
|
|
go func() {
|
|
if err := g.run(runCtx, tun); err != nil && runCtx.Err() == nil {
|
|
log.Printf("vpn gateway runtime stopped: vpn_connection_id=%s error=%v", g.VPNConnectionID, err)
|
|
g.setStopped(err)
|
|
return
|
|
}
|
|
g.setStopped(runCtx.Err())
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
func (g *Gateway) Stop() {
|
|
g.mu.Lock()
|
|
cancel := g.cancel
|
|
g.cancel = nil
|
|
g.running = false
|
|
g.mu.Unlock()
|
|
if cancel != nil {
|
|
cancel()
|
|
}
|
|
}
|
|
|
|
func (g *Gateway) Status() (bool, string) {
|
|
g.mu.Lock()
|
|
defer g.mu.Unlock()
|
|
if g.lastErr != nil {
|
|
return g.running, g.lastErr.Error()
|
|
}
|
|
return g.running, ""
|
|
}
|
|
|
|
func (g *Gateway) IsReadyForConnection(vpnConnectionID string) bool {
|
|
g.mu.Lock()
|
|
defer g.mu.Unlock()
|
|
return g.running && g.VPNConnectionID == vpnConnectionID && vpnConnectionID != ""
|
|
}
|
|
|
|
func (g *Gateway) Snapshot() map[string]any {
|
|
g.mu.Lock()
|
|
running := g.running
|
|
lastErr := ""
|
|
if g.lastErr != nil {
|
|
lastErr = g.lastErr.Error()
|
|
}
|
|
lastClientToGatewayPacket := g.lastClientToGatewayPacket
|
|
lastGatewayToClientPacket := g.lastGatewayToClientPacket
|
|
lastRuntimeActivityAt := g.lastRuntimeActivityAt
|
|
g.mu.Unlock()
|
|
|
|
out := map[string]any{
|
|
"running": running,
|
|
"transport": g.transportName(),
|
|
"poll_timeout_ms": g.PollTimeout.Milliseconds(),
|
|
"client_to_gateway_batches": g.clientToGatewayBatches.Load(),
|
|
"client_to_gateway_packets": g.clientToGatewayPackets.Load(),
|
|
"client_to_gateway_bytes": g.clientToGatewayBytes.Load(),
|
|
"gateway_to_client_batches": g.gatewayToClientBatches.Load(),
|
|
"gateway_to_client_packets": g.gatewayToClientPackets.Load(),
|
|
"gateway_to_client_bytes": g.gatewayToClientBytes.Load(),
|
|
"tun_read_packets": g.tunReadPackets.Load(),
|
|
"tun_read_bytes": g.tunReadBytes.Load(),
|
|
"tun_write_packets": g.tunWritePackets.Load(),
|
|
"tun_write_bytes": g.tunWriteBytes.Load(),
|
|
"upload_queue_drops": g.uploadQueueDrops.Load(),
|
|
"download_errors": g.downloadErrors.Load(),
|
|
"upload_errors": g.uploadErrors.Load(),
|
|
"last_client_to_gateway": lastClientToGatewayPacket,
|
|
"last_gateway_to_client": lastGatewayToClientPacket,
|
|
}
|
|
if lastErr != "" {
|
|
out["last_error"] = lastErr
|
|
}
|
|
if !lastRuntimeActivityAt.IsZero() {
|
|
out["last_runtime_activity_at"] = lastRuntimeActivityAt.UTC().Format(time.RFC3339Nano)
|
|
}
|
|
if platform := gatewayPlatformSnapshot(g.InterfaceName, g.RouteCIDR); len(platform) > 0 {
|
|
out["platform"] = platform
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (g *Gateway) transportName() string {
|
|
switch g.Transport.(type) {
|
|
case *FabricPacketTransport:
|
|
return "fabric_mesh"
|
|
case *LocalPacketTransport:
|
|
return "local_fabric_inbox"
|
|
case *AdaptivePacketTransport:
|
|
return "adaptive_fabric_backend"
|
|
case BackendPacketTransport:
|
|
return "backend_http_packet_relay"
|
|
default:
|
|
if g.Transport == nil {
|
|
return "none"
|
|
}
|
|
return fmt.Sprintf("%T", g.Transport)
|
|
}
|
|
}
|
|
|
|
func (g *Gateway) setStopped(err error) {
|
|
g.mu.Lock()
|
|
defer g.mu.Unlock()
|
|
g.running = false
|
|
g.lastErr = err
|
|
g.cancel = nil
|
|
}
|
|
|
|
func (g *Gateway) normalize() error {
|
|
if g.Transport == nil {
|
|
if g.API == nil {
|
|
return fmt.Errorf("api client or packet transport is required")
|
|
}
|
|
g.Transport = BackendPacketTransport{
|
|
API: g.API,
|
|
ClusterID: g.ClusterID,
|
|
VPNConnectionID: g.VPNConnectionID,
|
|
}
|
|
}
|
|
if g.ClusterID == "" || g.VPNConnectionID == "" {
|
|
return fmt.Errorf("cluster id and vpn connection id are required")
|
|
}
|
|
if g.InterfaceName == "" {
|
|
g.InterfaceName = "rapvpn0"
|
|
}
|
|
if g.AddressCIDR == "" {
|
|
g.AddressCIDR = "10.77.0.1/24"
|
|
}
|
|
if g.RouteCIDR == "" {
|
|
g.RouteCIDR = "10.77.0.0/24"
|
|
}
|
|
if g.PollTimeout <= 0 {
|
|
g.PollTimeout = 25 * time.Second
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (g *Gateway) run(ctx context.Context, tun readWriteCloser) error {
|
|
defer tun.Close()
|
|
|
|
errCh := make(chan error, 2)
|
|
go func() { errCh <- g.copyGatewayToClient(ctx, tun) }()
|
|
go func() { errCh <- g.copyClientToGateway(ctx, tun) }()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case err := <-errCh:
|
|
return err
|
|
}
|
|
}
|
|
|
|
func (g *Gateway) copyGatewayToClient(ctx context.Context, tun io.Reader) error {
|
|
priorityPackets := make(chan []byte, 1024)
|
|
packets := make(chan []byte, 32768)
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
errCh <- g.uploadGatewayPackets(ctx, priorityPackets, packets)
|
|
}()
|
|
|
|
buffer := make([]byte, 65535)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case err := <-errCh:
|
|
if err != nil {
|
|
return err
|
|
}
|
|
default:
|
|
}
|
|
n, err := tun.Read(buffer)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if n <= 0 {
|
|
continue
|
|
}
|
|
packet := append([]byte(nil), buffer[:n]...)
|
|
normalizeIPv4PacketChecksums(packet)
|
|
g.recordTunRead(packet)
|
|
if isTCPControlPacket(packet) {
|
|
select {
|
|
case priorityPackets <- packet:
|
|
default:
|
|
g.uploadQueueDrops.Add(1)
|
|
log.Printf("vpn gateway priority packet upload queue full; dropping packet: vpn_connection_id=%s", g.VPNConnectionID)
|
|
}
|
|
continue
|
|
}
|
|
select {
|
|
case packets <- packet:
|
|
default:
|
|
g.uploadQueueDrops.Add(1)
|
|
log.Printf("vpn gateway packet upload queue full; dropping packet: vpn_connection_id=%s", g.VPNConnectionID)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (g *Gateway) uploadGatewayPackets(ctx context.Context, priorityPackets <-chan []byte, packets <-chan []byte) error {
|
|
batch := make([][]byte, 0, vpnGatewayBatchMaxPackets)
|
|
batchBytes := 0
|
|
timer := time.NewTimer(time.Hour)
|
|
if !timer.Stop() {
|
|
<-timer.C
|
|
}
|
|
timerActive := false
|
|
flush := func() {
|
|
if len(batch) == 0 {
|
|
return
|
|
}
|
|
packetCount := len(batch)
|
|
byteCount := packetBytesTotal(batch)
|
|
if err := g.Transport.SendGatewayPacketBatch(ctx, batch); err != nil {
|
|
g.uploadErrors.Add(1)
|
|
log.Printf("vpn gateway packet batch upload failed: vpn_connection_id=%s packets=%d error=%v", g.VPNConnectionID, len(batch), err)
|
|
} else {
|
|
g.recordGatewayToClientBatch(packetCount, byteCount, batch[0])
|
|
}
|
|
for i := range batch {
|
|
batch[i] = nil
|
|
}
|
|
batch = batch[:0]
|
|
batchBytes = 0
|
|
}
|
|
addPacket := func(packet []byte) bool {
|
|
packetBytes := len(packet)
|
|
if packetBytes <= 0 {
|
|
return false
|
|
}
|
|
packetFrameSize := 4 + packetBytes
|
|
if len(batch) > 0 {
|
|
if len(batch) >= vpnGatewayBatchMaxPackets || batchBytes+packetFrameSize > vpnGatewayBatchMaxBytes {
|
|
flush()
|
|
}
|
|
}
|
|
batch = append(batch, packet)
|
|
batchBytes += packetFrameSize
|
|
return true
|
|
}
|
|
flushPriority := func(packet []byte) {
|
|
flush()
|
|
if addPacket(packet) {
|
|
flush()
|
|
}
|
|
}
|
|
for {
|
|
if len(batch) == 0 && timerActive {
|
|
if !timer.Stop() {
|
|
select {
|
|
case <-timer.C:
|
|
default:
|
|
}
|
|
}
|
|
timerActive = false
|
|
}
|
|
select {
|
|
case packet := <-priorityPackets:
|
|
flushPriority(packet)
|
|
continue
|
|
default:
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
flush()
|
|
return ctx.Err()
|
|
case packet := <-priorityPackets:
|
|
flushPriority(packet)
|
|
case packet := <-packets:
|
|
if !addPacket(packet) {
|
|
continue
|
|
}
|
|
if len(batch) >= vpnGatewayBatchMaxPackets || batchBytes >= vpnGatewayBatchMaxBytes {
|
|
flush()
|
|
continue
|
|
}
|
|
if !timerActive {
|
|
timer.Reset(vpnGatewayBatchFlushTimeout)
|
|
timerActive = true
|
|
}
|
|
case <-timer.C:
|
|
timerActive = false
|
|
flush()
|
|
}
|
|
}
|
|
}
|
|
|
|
func isTCPControlPacket(packet []byte) bool {
|
|
if len(packet) < 20 || packet[0]>>4 != 4 {
|
|
return false
|
|
}
|
|
ihl := int(packet[0]&0x0f) * 4
|
|
if ihl < 20 || len(packet) < ihl+20 || packet[9] != 6 {
|
|
return false
|
|
}
|
|
flags := packet[ihl+13]
|
|
return flags&0x17 != 0
|
|
}
|
|
|
|
func (g *Gateway) copyClientToGateway(ctx context.Context, tun io.Writer) error {
|
|
for {
|
|
packets, err := g.Transport.ReceiveGatewayPacketBatch(ctx, g.PollTimeout)
|
|
if err != nil {
|
|
log.Printf("vpn gateway packet download failed: vpn_connection_id=%s error=%v", g.VPNConnectionID, err)
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-time.After(time.Second):
|
|
}
|
|
continue
|
|
}
|
|
if len(packets) == 0 {
|
|
continue
|
|
}
|
|
g.recordClientToGatewayBatch(len(packets), packetBytesTotal(packets), packets[0])
|
|
for _, packet := range packets {
|
|
normalizeIPv4PacketChecksums(packet)
|
|
if _, err := tun.Write(packet); err != nil {
|
|
g.downloadErrors.Add(1)
|
|
return err
|
|
}
|
|
g.recordTunWrite(packet)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (g *Gateway) recordClientToGatewayBatch(packetCount int, byteCount int, first []byte) {
|
|
next := g.clientToGatewayBatches.Add(1)
|
|
g.clientToGatewayPackets.Add(uint64(packetCount))
|
|
g.clientToGatewayBytes.Add(uint64(byteCount))
|
|
summary := summarizePacket(first)
|
|
g.mu.Lock()
|
|
g.lastClientToGatewayPacket = summary
|
|
g.lastRuntimeActivityAt = time.Now().UTC()
|
|
g.mu.Unlock()
|
|
if next <= 5 {
|
|
log.Printf(
|
|
"vpn gateway client_to_gateway batch received: vpn_connection_id=%s batch=%d packets=%d bytes=%d first=%s",
|
|
g.VPNConnectionID,
|
|
next,
|
|
packetCount,
|
|
byteCount,
|
|
summary,
|
|
)
|
|
}
|
|
}
|
|
|
|
func (g *Gateway) recordGatewayToClientBatch(packetCount int, byteCount int, first []byte) {
|
|
next := g.gatewayToClientBatches.Add(1)
|
|
g.gatewayToClientPackets.Add(uint64(packetCount))
|
|
g.gatewayToClientBytes.Add(uint64(byteCount))
|
|
summary := summarizePacket(first)
|
|
g.mu.Lock()
|
|
g.lastGatewayToClientPacket = summary
|
|
g.lastRuntimeActivityAt = time.Now().UTC()
|
|
g.mu.Unlock()
|
|
if next <= 5 {
|
|
log.Printf(
|
|
"vpn gateway gateway_to_client batch uploaded: vpn_connection_id=%s batch=%d packets=%d bytes=%d first=%s",
|
|
g.VPNConnectionID,
|
|
next,
|
|
packetCount,
|
|
byteCount,
|
|
summary,
|
|
)
|
|
}
|
|
}
|
|
|
|
func (g *Gateway) recordTunWrite(packet []byte) {
|
|
next := g.tunWritePackets.Add(1)
|
|
g.tunWriteBytes.Add(uint64(len(packet)))
|
|
if next <= 5 {
|
|
log.Printf("vpn gateway packet written to tun: vpn_connection_id=%s packet=%d bytes=%d summary=%s", g.VPNConnectionID, next, len(packet), summarizePacket(packet))
|
|
}
|
|
}
|
|
|
|
func (g *Gateway) recordTunRead(packet []byte) {
|
|
next := g.tunReadPackets.Add(1)
|
|
g.tunReadBytes.Add(uint64(len(packet)))
|
|
if next <= 5 {
|
|
log.Printf("vpn gateway packet read from tun: vpn_connection_id=%s packet=%d bytes=%d summary=%s", g.VPNConnectionID, next, len(packet), summarizePacket(packet))
|
|
}
|
|
}
|
|
|
|
func packetBytesTotal(packets [][]byte) int {
|
|
total := 0
|
|
for _, packet := range packets {
|
|
total += len(packet)
|
|
}
|
|
return total
|
|
}
|
|
|
|
func summarizePacket(packet []byte) string {
|
|
if len(packet) < 1 {
|
|
return "empty"
|
|
}
|
|
version := packet[0] >> 4
|
|
switch version {
|
|
case 4:
|
|
return summarizeIPv4(packet)
|
|
case 6:
|
|
return summarizeIPv6(packet)
|
|
default:
|
|
return fmt.Sprintf("ip_version=%d bytes=%d", version, len(packet))
|
|
}
|
|
}
|
|
|
|
func summarizeIPv4(packet []byte) string {
|
|
if len(packet) < 20 {
|
|
return fmt.Sprintf("ipv4 truncated bytes=%d", len(packet))
|
|
}
|
|
ihl := int(packet[0]&0x0f) * 4
|
|
if ihl < 20 || len(packet) < ihl {
|
|
return fmt.Sprintf("ipv4 invalid_ihl=%d bytes=%d", ihl, len(packet))
|
|
}
|
|
proto := packet[9]
|
|
src := net.IP(packet[12:16]).String()
|
|
dst := net.IP(packet[16:20]).String()
|
|
return fmt.Sprintf("ipv4 %s -> %s proto=%d bytes=%d", src, dst, proto, len(packet))
|
|
}
|
|
|
|
func summarizeIPv6(packet []byte) string {
|
|
if len(packet) < 40 {
|
|
return fmt.Sprintf("ipv6 truncated bytes=%d", len(packet))
|
|
}
|
|
nextHeader := packet[6]
|
|
src := net.IP(packet[8:24]).String()
|
|
dst := net.IP(packet[24:40]).String()
|
|
return fmt.Sprintf("ipv6 %s -> %s next=%d bytes=%d", src, dst, nextHeader, len(packet))
|
|
}
|