Files
rdp-proxy/agents/rap-node-agent/internal/vpnruntime/gateway.go
T
m 20d361a886
build / backend (push) Has been cancelled
build / node-agent (push) Has been cancelled
build / worker (push) Has been cancelled
рабочий вариант, но скороть 10 МБит
2026-05-22 21:46:49 +03:00

533 lines
14 KiB
Go

package vpnruntime
import (
"context"
"fmt"
"io"
"log"
"net"
"sync"
"sync/atomic"
"time"
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/client"
)
type Gateway struct {
API *client.Client
Transport PacketTransport
ClusterID string
VPNConnectionID string
ServiceTunnel FabricServiceTunnel
InterfaceName string
AddressCIDR string
RouteCIDR string
PollTimeout time.Duration
mu sync.Mutex
running bool
lastErr error
cancel context.CancelFunc
clientToGatewayBatches atomic.Uint64
clientToGatewayPackets atomic.Uint64
clientToGatewayBytes atomic.Uint64
gatewayToClientBatches atomic.Uint64
gatewayToClientPackets atomic.Uint64
gatewayToClientBytes atomic.Uint64
tunReadPackets atomic.Uint64
tunReadBytes atomic.Uint64
tunWritePackets atomic.Uint64
tunWriteBytes atomic.Uint64
uploadQueueDrops atomic.Uint64
downloadErrors atomic.Uint64
uploadErrors atomic.Uint64
lastClientToGatewayPacket string
lastGatewayToClientPacket string
lastRuntimeActivityAt time.Time
}
const (
vpnGatewayBatchMaxPackets = 2048
vpnGatewayBatchMaxBytes = 8 * 1024 * 1024
vpnGatewayBatchFlushTimeout = 3 * time.Millisecond
vpnGatewayPriorityBatchWait = 20 * time.Millisecond
)
type readWriteCloser interface {
io.Reader
io.Writer
io.Closer
}
type PacketTransport interface {
SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error
ReceiveGatewayPacketBatch(ctx context.Context, timeout time.Duration) ([][]byte, error)
}
type packetTransportSnapshotter interface {
Snapshot() map[string]any
}
type packetTransportCloser interface {
Close() error
}
func (g *Gateway) EnsureStarted(ctx context.Context) error {
g.mu.Lock()
if g.running {
g.mu.Unlock()
return nil
}
g.mu.Unlock()
if err := g.normalize(); err != nil {
g.setStopped(err)
return err
}
tun, err := openGatewayTun(g.InterfaceName, g.AddressCIDR, g.RouteCIDR)
if err != nil {
g.setStopped(err)
return err
}
runCtx, cancel := context.WithCancel(ctx)
g.mu.Lock()
if g.running {
g.mu.Unlock()
cancel()
_ = tun.Close()
return nil
}
g.running = true
g.lastErr = nil
g.cancel = cancel
g.mu.Unlock()
go func() {
if err := g.run(runCtx, tun); err != nil && runCtx.Err() == nil {
log.Printf("vpn gateway runtime stopped: tunnel_id=%s error=%v", g.tunnelID(), err)
g.setStopped(err)
return
}
g.setStopped(runCtx.Err())
}()
return nil
}
func (g *Gateway) Stop() {
g.mu.Lock()
cancel := g.cancel
g.cancel = nil
g.running = false
g.mu.Unlock()
if cancel != nil {
cancel()
}
}
func (g *Gateway) Status() (bool, string) {
g.mu.Lock()
defer g.mu.Unlock()
if g.lastErr != nil {
return g.running, g.lastErr.Error()
}
return g.running, ""
}
func (g *Gateway) IsReadyForConnection(vpnConnectionID string) bool {
g.mu.Lock()
defer g.mu.Unlock()
tunnelID := g.tunnelIDLocked()
return g.running && (g.VPNConnectionID == vpnConnectionID || tunnelID == vpnConnectionID) && vpnConnectionID != ""
}
func (g *Gateway) Snapshot() map[string]any {
g.mu.Lock()
running := g.running
lastErr := ""
if g.lastErr != nil {
lastErr = g.lastErr.Error()
}
lastClientToGatewayPacket := g.lastClientToGatewayPacket
lastGatewayToClientPacket := g.lastGatewayToClientPacket
lastRuntimeActivityAt := g.lastRuntimeActivityAt
g.mu.Unlock()
out := map[string]any{
"running": running,
"tunnel_id": g.ServiceTunnel.TunnelID,
"pool_id": g.ServiceTunnel.PoolID,
"service_id": g.ServiceTunnel.ServiceID,
"local_service_id": g.ServiceTunnel.LocalServiceID,
"remote_service_id": g.ServiceTunnel.RemoteServiceID,
"service_kind": g.ServiceTunnel.ServiceKind,
"service_role": firstNonEmptyTunnelString(g.ServiceTunnel.ServiceRole, DefaultFabricTunnelRole),
"service_class": firstNonEmptyTunnelString(g.ServiceTunnel.ServiceClass, DefaultFabricTunnelClass),
"adapter_contract": "fabric_channel_to_ipv4_nat",
"transport": g.transportName(),
"poll_timeout_ms": g.PollTimeout.Milliseconds(),
"client_to_gateway_batches": g.clientToGatewayBatches.Load(),
"client_to_gateway_packets": g.clientToGatewayPackets.Load(),
"client_to_gateway_bytes": g.clientToGatewayBytes.Load(),
"gateway_to_client_batches": g.gatewayToClientBatches.Load(),
"gateway_to_client_packets": g.gatewayToClientPackets.Load(),
"gateway_to_client_bytes": g.gatewayToClientBytes.Load(),
"tun_read_packets": g.tunReadPackets.Load(),
"tun_read_bytes": g.tunReadBytes.Load(),
"tun_write_packets": g.tunWritePackets.Load(),
"tun_write_bytes": g.tunWriteBytes.Load(),
"upload_queue_drops": g.uploadQueueDrops.Load(),
"download_errors": g.downloadErrors.Load(),
"upload_errors": g.uploadErrors.Load(),
"last_client_to_gateway": lastClientToGatewayPacket,
"last_gateway_to_client": lastGatewayToClientPacket,
}
if lastErr != "" {
out["last_error"] = lastErr
}
if !lastRuntimeActivityAt.IsZero() {
out["last_runtime_activity_at"] = lastRuntimeActivityAt.UTC().Format(time.RFC3339Nano)
}
out["service_tunnel"] = g.ServiceTunnel.Snapshot()
if platform := gatewayPlatformSnapshot(g.InterfaceName, g.RouteCIDR); len(platform) > 0 {
out["platform"] = platform
}
if snapshotter, ok := g.Transport.(packetTransportSnapshotter); ok {
if snapshot := snapshotter.Snapshot(); len(snapshot) > 0 {
out["transport_snapshot"] = snapshot
}
}
return out
}
func (g *Gateway) transportName() string {
switch g.Transport.(type) {
case *FabricSessionPacketTransport:
return "fabric_session"
case *FabricPacketTransport:
return "fabric_mesh"
case *LocalPacketTransport:
return "local_fabric_inbox"
case *AdaptivePacketTransport:
return "adaptive_fabric"
default:
if g.Transport == nil {
return "none"
}
return fmt.Sprintf("%T", g.Transport)
}
}
func (g *Gateway) setStopped(err error) {
g.mu.Lock()
defer g.mu.Unlock()
g.running = false
g.lastErr = err
g.cancel = nil
}
func (g *Gateway) normalize() error {
if g.Transport == nil {
return fmt.Errorf("fabric packet transport is required")
}
g.ServiceTunnel = NormalizeServiceTunnel(g.ServiceTunnel, g.VPNConnectionID)
if g.VPNConnectionID == "" {
g.VPNConnectionID = g.ServiceTunnel.TunnelID
}
if g.ClusterID == "" || g.VPNConnectionID == "" {
return fmt.Errorf("cluster id and tunnel id are required")
}
if g.InterfaceName == "" {
g.InterfaceName = "rapvpn0"
}
if g.AddressCIDR == "" {
g.AddressCIDR = "10.77.0.1/24"
}
if g.RouteCIDR == "" {
g.RouteCIDR = "10.77.0.0/24"
}
if g.PollTimeout <= 0 {
g.PollTimeout = 25 * time.Second
}
return nil
}
func (g *Gateway) tunnelIDLocked() string {
return firstNonEmptyTunnelString(g.ServiceTunnel.TunnelID, g.VPNConnectionID)
}
func (g *Gateway) tunnelID() string {
if g == nil {
return ""
}
g.mu.Lock()
defer g.mu.Unlock()
return g.tunnelIDLocked()
}
func (g *Gateway) run(ctx context.Context, tun readWriteCloser) error {
defer tun.Close()
if closer, ok := g.Transport.(packetTransportCloser); ok {
defer closer.Close()
}
runCtx, cancel := context.WithCancel(ctx)
defer cancel()
errCh := make(chan error, 2)
go func() { errCh <- g.copyGatewayToClient(runCtx, tun) }()
go func() { errCh <- g.copyClientToGateway(runCtx, tun) }()
select {
case <-runCtx.Done():
return runCtx.Err()
case err := <-errCh:
cancel()
return err
}
}
func (g *Gateway) copyGatewayToClient(ctx context.Context, tun io.Reader) error {
packets := make(chan []byte, 32768)
errCh := make(chan error, 1)
go func() {
errCh <- g.uploadGatewayPackets(ctx, nil, packets)
}()
buffer := make([]byte, 65535)
for {
select {
case <-ctx.Done():
return ctx.Err()
case err := <-errCh:
if err != nil {
return err
}
default:
}
n, err := tun.Read(buffer)
if err != nil {
return err
}
if n <= 0 {
continue
}
packet := append([]byte(nil), buffer[:n]...)
normalizeIPv4PacketChecksums(packet)
g.recordTunRead(packet)
select {
case packets <- packet:
default:
g.uploadQueueDrops.Add(1)
log.Printf("vpn gateway packet upload queue full; dropping packet: tunnel_id=%s", g.tunnelID())
}
}
}
func (g *Gateway) uploadGatewayPackets(ctx context.Context, _ <-chan []byte, packets <-chan []byte) error {
batch := make([][]byte, 0, vpnGatewayBatchMaxPackets)
batchBytes := 0
timer := time.NewTimer(time.Hour)
if !timer.Stop() {
<-timer.C
}
timerActive := false
flush := func() {
if len(batch) == 0 {
return
}
packetCount := len(batch)
byteCount := packetBytesTotal(batch)
if err := g.Transport.SendGatewayPacketBatch(ctx, batch); err != nil {
g.uploadErrors.Add(1)
log.Printf("vpn gateway packet batch upload failed: tunnel_id=%s packets=%d error=%v", g.tunnelID(), len(batch), err)
} else {
g.recordGatewayToClientBatch(packetCount, byteCount, batch[0])
}
for i := range batch {
batch[i] = nil
}
batch = batch[:0]
batchBytes = 0
}
addPacket := func(packet []byte) bool {
packetBytes := len(packet)
if packetBytes <= 0 {
return false
}
packetFrameSize := 4 + packetBytes
if len(batch) > 0 {
if len(batch) >= vpnGatewayBatchMaxPackets || batchBytes+packetFrameSize > vpnGatewayBatchMaxBytes {
flush()
}
}
batch = append(batch, packet)
batchBytes += packetFrameSize
return true
}
for {
if len(batch) == 0 && timerActive {
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timerActive = false
}
select {
case <-ctx.Done():
flush()
return ctx.Err()
case packet := <-packets:
if !addPacket(packet) {
continue
}
if len(batch) >= vpnGatewayBatchMaxPackets || batchBytes >= vpnGatewayBatchMaxBytes {
flush()
continue
}
if !timerActive {
timer.Reset(vpnGatewayBatchFlushTimeout)
timerActive = true
}
case <-timer.C:
timerActive = false
flush()
}
}
}
func (g *Gateway) copyClientToGateway(ctx context.Context, tun io.Writer) error {
for {
packets, err := g.Transport.ReceiveGatewayPacketBatch(ctx, g.PollTimeout)
if err != nil {
log.Printf("vpn gateway packet download failed: tunnel_id=%s error=%v", g.tunnelID(), err)
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(time.Second):
}
continue
}
if len(packets) == 0 {
continue
}
g.recordClientToGatewayBatch(len(packets), packetBytesTotal(packets), packets[0])
for _, packet := range packets {
normalizeIPv4PacketChecksums(packet)
if _, err := tun.Write(packet); err != nil {
g.downloadErrors.Add(1)
return err
}
g.recordTunWrite(packet)
}
}
}
func (g *Gateway) recordClientToGatewayBatch(packetCount int, byteCount int, first []byte) {
next := g.clientToGatewayBatches.Add(1)
g.clientToGatewayPackets.Add(uint64(packetCount))
g.clientToGatewayBytes.Add(uint64(byteCount))
summary := summarizePacket(first)
g.mu.Lock()
g.lastClientToGatewayPacket = summary
g.lastRuntimeActivityAt = time.Now().UTC()
g.mu.Unlock()
if next <= 5 {
log.Printf(
"vpn gateway client_to_gateway batch received: tunnel_id=%s batch=%d packets=%d bytes=%d first=%s",
g.tunnelID(),
next,
packetCount,
byteCount,
summary,
)
}
}
func (g *Gateway) recordGatewayToClientBatch(packetCount int, byteCount int, first []byte) {
next := g.gatewayToClientBatches.Add(1)
g.gatewayToClientPackets.Add(uint64(packetCount))
g.gatewayToClientBytes.Add(uint64(byteCount))
summary := summarizePacket(first)
g.mu.Lock()
g.lastGatewayToClientPacket = summary
g.lastRuntimeActivityAt = time.Now().UTC()
g.mu.Unlock()
if next <= 5 {
log.Printf(
"vpn gateway gateway_to_client batch uploaded: tunnel_id=%s batch=%d packets=%d bytes=%d first=%s",
g.tunnelID(),
next,
packetCount,
byteCount,
summary,
)
}
}
func (g *Gateway) recordTunWrite(packet []byte) {
next := g.tunWritePackets.Add(1)
g.tunWriteBytes.Add(uint64(len(packet)))
if next <= 5 {
log.Printf("vpn gateway packet written to tun: tunnel_id=%s packet=%d bytes=%d summary=%s", g.tunnelID(), next, len(packet), summarizePacket(packet))
}
}
func (g *Gateway) recordTunRead(packet []byte) {
next := g.tunReadPackets.Add(1)
g.tunReadBytes.Add(uint64(len(packet)))
if next <= 5 {
log.Printf("vpn gateway packet read from tun: tunnel_id=%s packet=%d bytes=%d summary=%s", g.tunnelID(), next, len(packet), summarizePacket(packet))
}
}
func packetBytesTotal(packets [][]byte) int {
total := 0
for _, packet := range packets {
total += len(packet)
}
return total
}
func summarizePacket(packet []byte) string {
if len(packet) < 1 {
return "empty"
}
version := packet[0] >> 4
switch version {
case 4:
return summarizeIPv4(packet)
case 6:
return summarizeIPv6(packet)
default:
return fmt.Sprintf("ip_version=%d bytes=%d", version, len(packet))
}
}
func summarizeIPv4(packet []byte) string {
if len(packet) < 20 {
return fmt.Sprintf("ipv4 truncated bytes=%d", len(packet))
}
ihl := int(packet[0]&0x0f) * 4
if ihl < 20 || len(packet) < ihl {
return fmt.Sprintf("ipv4 invalid_ihl=%d bytes=%d", ihl, len(packet))
}
proto := packet[9]
src := net.IP(packet[12:16]).String()
dst := net.IP(packet[16:20]).String()
return fmt.Sprintf("ipv4 %s -> %s proto=%d bytes=%d", src, dst, proto, len(packet))
}
func summarizeIPv6(packet []byte) string {
if len(packet) < 40 {
return fmt.Sprintf("ipv6 truncated bytes=%d", len(packet))
}
nextHeader := packet[6]
src := net.IP(packet[8:24]).String()
dst := net.IP(packet[24:40]).String()
return fmt.Sprintf("ipv6 %s -> %s next=%d bytes=%d", src, dst, nextHeader, len(packet))
}