Add tracked vpnruntime implementation for CI guard tests
This commit is contained in:
@@ -0,0 +1,77 @@
|
||||
package vpnruntime
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
func normalizeIPv4PacketChecksums(packet []byte) bool {
|
||||
if len(packet) < 20 || packet[0]>>4 != 4 {
|
||||
return false
|
||||
}
|
||||
ihl := int(packet[0]&0x0f) * 4
|
||||
if ihl < 20 || len(packet) < ihl {
|
||||
return false
|
||||
}
|
||||
totalLen := int(binary.BigEndian.Uint16(packet[2:4]))
|
||||
if totalLen <= 0 || totalLen > len(packet) {
|
||||
totalLen = len(packet)
|
||||
}
|
||||
if totalLen < ihl {
|
||||
return false
|
||||
}
|
||||
|
||||
packet[10], packet[11] = 0, 0
|
||||
binary.BigEndian.PutUint16(packet[10:12], checksum(packet[:ihl]))
|
||||
|
||||
proto := packet[9]
|
||||
payload := packet[ihl:totalLen]
|
||||
switch proto {
|
||||
case 6:
|
||||
if len(payload) < 20 {
|
||||
return true
|
||||
}
|
||||
payload[16], payload[17] = 0, 0
|
||||
binary.BigEndian.PutUint16(payload[16:18], transportChecksum(packet, payload, proto))
|
||||
case 17:
|
||||
if len(payload) < 8 {
|
||||
return true
|
||||
}
|
||||
payload[6], payload[7] = 0, 0
|
||||
sum := transportChecksum(packet, payload, proto)
|
||||
if sum == 0 {
|
||||
sum = 0xffff
|
||||
}
|
||||
binary.BigEndian.PutUint16(payload[6:8], sum)
|
||||
case 1:
|
||||
if len(payload) < 4 {
|
||||
return true
|
||||
}
|
||||
payload[2], payload[3] = 0, 0
|
||||
binary.BigEndian.PutUint16(payload[2:4], checksum(payload))
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func transportChecksum(ipHeader []byte, payload []byte, proto byte) uint16 {
|
||||
pseudo := make([]byte, 12+len(payload))
|
||||
copy(pseudo[0:4], ipHeader[12:16])
|
||||
copy(pseudo[4:8], ipHeader[16:20])
|
||||
pseudo[8] = 0
|
||||
pseudo[9] = proto
|
||||
binary.BigEndian.PutUint16(pseudo[10:12], uint16(len(payload)))
|
||||
copy(pseudo[12:], payload)
|
||||
return checksum(pseudo)
|
||||
}
|
||||
|
||||
func checksum(data []byte) uint16 {
|
||||
var sum uint32
|
||||
for len(data) >= 2 {
|
||||
sum += uint32(binary.BigEndian.Uint16(data[:2]))
|
||||
data = data[2:]
|
||||
}
|
||||
if len(data) == 1 {
|
||||
sum += uint32(data[0]) << 8
|
||||
}
|
||||
for (sum >> 16) != 0 {
|
||||
sum = (sum & 0xffff) + (sum >> 16)
|
||||
}
|
||||
return ^uint16(sum)
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package vpnruntime
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeIPv4PacketChecksumsRepairsTCP(t *testing.T) {
|
||||
packet := []byte{
|
||||
0x45, 0x00, 0x00, 0x28, 0x00, 0x00, 0x40, 0x00, 0x40, 0x06, 0x12, 0x34, 192, 168, 200, 61, 10, 77, 0, 2,
|
||||
0x46, 0xa0, 0xdd, 0x78, 0, 0, 0, 1, 0, 0, 0, 0, 0x50, 0x12, 0x72, 0x10, 0xab, 0xcd, 0, 0,
|
||||
}
|
||||
if !normalizeIPv4PacketChecksums(packet) {
|
||||
t.Fatal("normalize returned false")
|
||||
}
|
||||
if got := checksum(packet[:20]); got != 0 {
|
||||
t.Fatalf("ip checksum verification = %#x, want 0", got)
|
||||
}
|
||||
tcp := packet[20:]
|
||||
pseudo := make([]byte, 12+len(tcp))
|
||||
copy(pseudo[0:4], packet[12:16])
|
||||
copy(pseudo[4:8], packet[16:20])
|
||||
pseudo[9] = 6
|
||||
binary.BigEndian.PutUint16(pseudo[10:12], uint16(len(tcp)))
|
||||
copy(pseudo[12:], tcp)
|
||||
if got := checksum(pseudo); got != 0 {
|
||||
t.Fatalf("tcp checksum verification = %#x, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeIPv4PacketChecksumsRepairsUDP(t *testing.T) {
|
||||
packet := []byte{
|
||||
0x45, 0x00, 0x00, 0x20, 0, 0, 0x40, 0, 0x40, 0x11, 0x12, 0x34, 10, 77, 0, 2, 192, 168, 200, 210,
|
||||
0x30, 0x39, 0x00, 0x35, 0x00, 0x0c, 0xab, 0xcd, 0xde, 0xad, 0xbe, 0xef,
|
||||
}
|
||||
if !normalizeIPv4PacketChecksums(packet) {
|
||||
t.Fatal("normalize returned false")
|
||||
}
|
||||
if got := checksum(packet[:20]); got != 0 {
|
||||
t.Fatalf("ip checksum verification = %#x, want 0", got)
|
||||
}
|
||||
udp := packet[20:]
|
||||
pseudo := make([]byte, 12+len(udp))
|
||||
copy(pseudo[0:4], packet[12:16])
|
||||
copy(pseudo[4:8], packet[16:20])
|
||||
pseudo[9] = 17
|
||||
binary.BigEndian.PutUint16(pseudo[10:12], uint16(len(udp)))
|
||||
copy(pseudo[12:], udp)
|
||||
if got := checksum(pseudo); got != 0 {
|
||||
t.Fatalf("udp checksum verification = %#x, want 0", got)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,495 @@
|
||||
package vpnruntime
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/client"
|
||||
)
|
||||
|
||||
type Gateway struct {
|
||||
API *client.Client
|
||||
Transport PacketTransport
|
||||
ClusterID string
|
||||
VPNConnectionID string
|
||||
InterfaceName string
|
||||
AddressCIDR string
|
||||
RouteCIDR string
|
||||
PollTimeout time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
lastErr error
|
||||
cancel context.CancelFunc
|
||||
|
||||
clientToGatewayBatches atomic.Uint64
|
||||
clientToGatewayPackets atomic.Uint64
|
||||
clientToGatewayBytes atomic.Uint64
|
||||
gatewayToClientBatches atomic.Uint64
|
||||
gatewayToClientPackets atomic.Uint64
|
||||
gatewayToClientBytes atomic.Uint64
|
||||
tunReadPackets atomic.Uint64
|
||||
tunReadBytes atomic.Uint64
|
||||
tunWritePackets atomic.Uint64
|
||||
tunWriteBytes atomic.Uint64
|
||||
uploadQueueDrops atomic.Uint64
|
||||
downloadErrors atomic.Uint64
|
||||
uploadErrors atomic.Uint64
|
||||
|
||||
lastClientToGatewayPacket string
|
||||
lastGatewayToClientPacket string
|
||||
lastRuntimeActivityAt time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
vpnGatewayBatchMaxPackets = 2048
|
||||
vpnGatewayBatchMaxBytes = 8 * 1024 * 1024
|
||||
vpnGatewayBatchFlushTimeout = 3 * time.Millisecond
|
||||
)
|
||||
|
||||
type readWriteCloser interface {
|
||||
io.Reader
|
||||
io.Writer
|
||||
io.Closer
|
||||
}
|
||||
|
||||
type PacketTransport interface {
|
||||
SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error
|
||||
ReceiveGatewayPacketBatch(ctx context.Context, timeout time.Duration) ([][]byte, error)
|
||||
}
|
||||
|
||||
type BackendPacketTransport struct {
|
||||
API *client.Client
|
||||
ClusterID string
|
||||
VPNConnectionID string
|
||||
}
|
||||
|
||||
func (t BackendPacketTransport) SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error {
|
||||
return t.API.SendVPNGatewayPacketBatch(ctx, t.ClusterID, t.VPNConnectionID, packets)
|
||||
}
|
||||
|
||||
func (t BackendPacketTransport) ReceiveGatewayPacketBatch(ctx context.Context, timeout time.Duration) ([][]byte, error) {
|
||||
return t.API.ReceiveVPNGatewayPacketBatch(ctx, t.ClusterID, t.VPNConnectionID, timeout)
|
||||
}
|
||||
|
||||
func (g *Gateway) EnsureStarted(ctx context.Context) error {
|
||||
g.mu.Lock()
|
||||
if g.running {
|
||||
g.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
g.mu.Unlock()
|
||||
|
||||
if err := g.normalize(); err != nil {
|
||||
g.setStopped(err)
|
||||
return err
|
||||
}
|
||||
tun, err := openGatewayTun(g.InterfaceName, g.AddressCIDR, g.RouteCIDR)
|
||||
if err != nil {
|
||||
g.setStopped(err)
|
||||
return err
|
||||
}
|
||||
runCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
g.mu.Lock()
|
||||
if g.running {
|
||||
g.mu.Unlock()
|
||||
cancel()
|
||||
_ = tun.Close()
|
||||
return nil
|
||||
}
|
||||
g.running = true
|
||||
g.lastErr = nil
|
||||
g.cancel = cancel
|
||||
g.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
if err := g.run(runCtx, tun); err != nil && runCtx.Err() == nil {
|
||||
log.Printf("vpn gateway runtime stopped: vpn_connection_id=%s error=%v", g.VPNConnectionID, err)
|
||||
g.setStopped(err)
|
||||
return
|
||||
}
|
||||
g.setStopped(runCtx.Err())
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Gateway) Stop() {
|
||||
g.mu.Lock()
|
||||
cancel := g.cancel
|
||||
g.cancel = nil
|
||||
g.running = false
|
||||
g.mu.Unlock()
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Gateway) Status() (bool, string) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
if g.lastErr != nil {
|
||||
return g.running, g.lastErr.Error()
|
||||
}
|
||||
return g.running, ""
|
||||
}
|
||||
|
||||
func (g *Gateway) IsReadyForConnection(vpnConnectionID string) bool {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
return g.running && g.VPNConnectionID == vpnConnectionID && vpnConnectionID != ""
|
||||
}
|
||||
|
||||
func (g *Gateway) Snapshot() map[string]any {
|
||||
g.mu.Lock()
|
||||
running := g.running
|
||||
lastErr := ""
|
||||
if g.lastErr != nil {
|
||||
lastErr = g.lastErr.Error()
|
||||
}
|
||||
lastClientToGatewayPacket := g.lastClientToGatewayPacket
|
||||
lastGatewayToClientPacket := g.lastGatewayToClientPacket
|
||||
lastRuntimeActivityAt := g.lastRuntimeActivityAt
|
||||
g.mu.Unlock()
|
||||
|
||||
out := map[string]any{
|
||||
"running": running,
|
||||
"transport": g.transportName(),
|
||||
"poll_timeout_ms": g.PollTimeout.Milliseconds(),
|
||||
"client_to_gateway_batches": g.clientToGatewayBatches.Load(),
|
||||
"client_to_gateway_packets": g.clientToGatewayPackets.Load(),
|
||||
"client_to_gateway_bytes": g.clientToGatewayBytes.Load(),
|
||||
"gateway_to_client_batches": g.gatewayToClientBatches.Load(),
|
||||
"gateway_to_client_packets": g.gatewayToClientPackets.Load(),
|
||||
"gateway_to_client_bytes": g.gatewayToClientBytes.Load(),
|
||||
"tun_read_packets": g.tunReadPackets.Load(),
|
||||
"tun_read_bytes": g.tunReadBytes.Load(),
|
||||
"tun_write_packets": g.tunWritePackets.Load(),
|
||||
"tun_write_bytes": g.tunWriteBytes.Load(),
|
||||
"upload_queue_drops": g.uploadQueueDrops.Load(),
|
||||
"download_errors": g.downloadErrors.Load(),
|
||||
"upload_errors": g.uploadErrors.Load(),
|
||||
"last_client_to_gateway": lastClientToGatewayPacket,
|
||||
"last_gateway_to_client": lastGatewayToClientPacket,
|
||||
}
|
||||
if lastErr != "" {
|
||||
out["last_error"] = lastErr
|
||||
}
|
||||
if !lastRuntimeActivityAt.IsZero() {
|
||||
out["last_runtime_activity_at"] = lastRuntimeActivityAt.UTC().Format(time.RFC3339Nano)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (g *Gateway) transportName() string {
|
||||
switch g.Transport.(type) {
|
||||
case *FabricPacketTransport:
|
||||
return "fabric_mesh"
|
||||
case *LocalPacketTransport:
|
||||
return "local_fabric_inbox"
|
||||
case *AdaptivePacketTransport:
|
||||
return "adaptive_fabric_backend"
|
||||
case BackendPacketTransport:
|
||||
return "backend_http_packet_relay"
|
||||
default:
|
||||
if g.Transport == nil {
|
||||
return "none"
|
||||
}
|
||||
return fmt.Sprintf("%T", g.Transport)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Gateway) setStopped(err error) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
g.running = false
|
||||
g.lastErr = err
|
||||
g.cancel = nil
|
||||
}
|
||||
|
||||
func (g *Gateway) normalize() error {
|
||||
if g.Transport == nil {
|
||||
if g.API == nil {
|
||||
return fmt.Errorf("api client or packet transport is required")
|
||||
}
|
||||
g.Transport = BackendPacketTransport{
|
||||
API: g.API,
|
||||
ClusterID: g.ClusterID,
|
||||
VPNConnectionID: g.VPNConnectionID,
|
||||
}
|
||||
}
|
||||
if g.ClusterID == "" || g.VPNConnectionID == "" {
|
||||
return fmt.Errorf("cluster id and vpn connection id are required")
|
||||
}
|
||||
if g.InterfaceName == "" {
|
||||
g.InterfaceName = "rapvpn0"
|
||||
}
|
||||
if g.AddressCIDR == "" {
|
||||
g.AddressCIDR = "10.77.0.1/24"
|
||||
}
|
||||
if g.RouteCIDR == "" {
|
||||
g.RouteCIDR = "10.77.0.0/24"
|
||||
}
|
||||
if g.PollTimeout <= 0 {
|
||||
g.PollTimeout = 25 * time.Second
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Gateway) run(ctx context.Context, tun readWriteCloser) error {
|
||||
defer tun.Close()
|
||||
|
||||
errCh := make(chan error, 2)
|
||||
go func() { errCh <- g.copyGatewayToClient(ctx, tun) }()
|
||||
go func() { errCh <- g.copyClientToGateway(ctx, tun) }()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-errCh:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Gateway) copyGatewayToClient(ctx context.Context, tun io.Reader) error {
|
||||
packets := make(chan []byte, 32768)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- g.uploadGatewayPackets(ctx, packets)
|
||||
}()
|
||||
|
||||
buffer := make([]byte, 65535)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
}
|
||||
n, err := tun.Read(buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n <= 0 {
|
||||
continue
|
||||
}
|
||||
packet := append([]byte(nil), buffer[:n]...)
|
||||
normalizeIPv4PacketChecksums(packet)
|
||||
g.recordTunRead(packet)
|
||||
select {
|
||||
case packets <- packet:
|
||||
default:
|
||||
g.uploadQueueDrops.Add(1)
|
||||
log.Printf("vpn gateway packet upload queue full; dropping packet: vpn_connection_id=%s", g.VPNConnectionID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Gateway) uploadGatewayPackets(ctx context.Context, packets <-chan []byte) error {
|
||||
batch := make([][]byte, 0, vpnGatewayBatchMaxPackets)
|
||||
batchBytes := 0
|
||||
timer := time.NewTimer(time.Hour)
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
timerActive := false
|
||||
flush := func() {
|
||||
if len(batch) == 0 {
|
||||
return
|
||||
}
|
||||
packetCount := len(batch)
|
||||
byteCount := packetBytesTotal(batch)
|
||||
if err := g.Transport.SendGatewayPacketBatch(ctx, batch); err != nil {
|
||||
g.uploadErrors.Add(1)
|
||||
log.Printf("vpn gateway packet batch upload failed: vpn_connection_id=%s packets=%d error=%v", g.VPNConnectionID, len(batch), err)
|
||||
} else {
|
||||
g.recordGatewayToClientBatch(packetCount, byteCount, batch[0])
|
||||
}
|
||||
for i := range batch {
|
||||
batch[i] = nil
|
||||
}
|
||||
batch = batch[:0]
|
||||
batchBytes = 0
|
||||
}
|
||||
for {
|
||||
if len(batch) == 0 && timerActive {
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
timerActive = false
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
flush()
|
||||
return ctx.Err()
|
||||
case packet := <-packets:
|
||||
packetBytes := len(packet)
|
||||
if packetBytes <= 0 {
|
||||
continue
|
||||
}
|
||||
packetFrameSize := 4 + packetBytes
|
||||
if len(batch) > 0 {
|
||||
if len(batch) >= vpnGatewayBatchMaxPackets || batchBytes+packetFrameSize > vpnGatewayBatchMaxBytes {
|
||||
flush()
|
||||
}
|
||||
}
|
||||
batch = append(batch, packet)
|
||||
batchBytes += packetFrameSize
|
||||
if len(batch) >= vpnGatewayBatchMaxPackets || batchBytes >= vpnGatewayBatchMaxBytes {
|
||||
flush()
|
||||
continue
|
||||
}
|
||||
if !timerActive {
|
||||
timer.Reset(vpnGatewayBatchFlushTimeout)
|
||||
timerActive = true
|
||||
}
|
||||
case <-timer.C:
|
||||
timerActive = false
|
||||
flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Gateway) copyClientToGateway(ctx context.Context, tun io.Writer) error {
|
||||
for {
|
||||
packets, err := g.Transport.ReceiveGatewayPacketBatch(ctx, g.PollTimeout)
|
||||
if err != nil {
|
||||
log.Printf("vpn gateway packet download failed: vpn_connection_id=%s error=%v", g.VPNConnectionID, err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
continue
|
||||
}
|
||||
if len(packets) == 0 {
|
||||
continue
|
||||
}
|
||||
g.recordClientToGatewayBatch(len(packets), packetBytesTotal(packets), packets[0])
|
||||
for _, packet := range packets {
|
||||
normalizeIPv4PacketChecksums(packet)
|
||||
if _, err := tun.Write(packet); err != nil {
|
||||
g.downloadErrors.Add(1)
|
||||
return err
|
||||
}
|
||||
g.recordTunWrite(packet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Gateway) recordClientToGatewayBatch(packetCount int, byteCount int, first []byte) {
|
||||
next := g.clientToGatewayBatches.Add(1)
|
||||
g.clientToGatewayPackets.Add(uint64(packetCount))
|
||||
g.clientToGatewayBytes.Add(uint64(byteCount))
|
||||
summary := summarizePacket(first)
|
||||
g.mu.Lock()
|
||||
g.lastClientToGatewayPacket = summary
|
||||
g.lastRuntimeActivityAt = time.Now().UTC()
|
||||
g.mu.Unlock()
|
||||
if next <= 5 {
|
||||
log.Printf(
|
||||
"vpn gateway client_to_gateway batch received: vpn_connection_id=%s batch=%d packets=%d bytes=%d first=%s",
|
||||
g.VPNConnectionID,
|
||||
next,
|
||||
packetCount,
|
||||
byteCount,
|
||||
summary,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Gateway) recordGatewayToClientBatch(packetCount int, byteCount int, first []byte) {
|
||||
next := g.gatewayToClientBatches.Add(1)
|
||||
g.gatewayToClientPackets.Add(uint64(packetCount))
|
||||
g.gatewayToClientBytes.Add(uint64(byteCount))
|
||||
summary := summarizePacket(first)
|
||||
g.mu.Lock()
|
||||
g.lastGatewayToClientPacket = summary
|
||||
g.lastRuntimeActivityAt = time.Now().UTC()
|
||||
g.mu.Unlock()
|
||||
if next <= 5 {
|
||||
log.Printf(
|
||||
"vpn gateway gateway_to_client batch uploaded: vpn_connection_id=%s batch=%d packets=%d bytes=%d first=%s",
|
||||
g.VPNConnectionID,
|
||||
next,
|
||||
packetCount,
|
||||
byteCount,
|
||||
summary,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Gateway) recordTunWrite(packet []byte) {
|
||||
next := g.tunWritePackets.Add(1)
|
||||
g.tunWriteBytes.Add(uint64(len(packet)))
|
||||
if next <= 5 {
|
||||
log.Printf("vpn gateway packet written to tun: vpn_connection_id=%s packet=%d bytes=%d summary=%s", g.VPNConnectionID, next, len(packet), summarizePacket(packet))
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Gateway) recordTunRead(packet []byte) {
|
||||
next := g.tunReadPackets.Add(1)
|
||||
g.tunReadBytes.Add(uint64(len(packet)))
|
||||
if next <= 5 {
|
||||
log.Printf("vpn gateway packet read from tun: vpn_connection_id=%s packet=%d bytes=%d summary=%s", g.VPNConnectionID, next, len(packet), summarizePacket(packet))
|
||||
}
|
||||
}
|
||||
|
||||
func packetBytesTotal(packets [][]byte) int {
|
||||
total := 0
|
||||
for _, packet := range packets {
|
||||
total += len(packet)
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func summarizePacket(packet []byte) string {
|
||||
if len(packet) < 1 {
|
||||
return "empty"
|
||||
}
|
||||
version := packet[0] >> 4
|
||||
switch version {
|
||||
case 4:
|
||||
return summarizeIPv4(packet)
|
||||
case 6:
|
||||
return summarizeIPv6(packet)
|
||||
default:
|
||||
return fmt.Sprintf("ip_version=%d bytes=%d", version, len(packet))
|
||||
}
|
||||
}
|
||||
|
||||
func summarizeIPv4(packet []byte) string {
|
||||
if len(packet) < 20 {
|
||||
return fmt.Sprintf("ipv4 truncated bytes=%d", len(packet))
|
||||
}
|
||||
ihl := int(packet[0]&0x0f) * 4
|
||||
if ihl < 20 || len(packet) < ihl {
|
||||
return fmt.Sprintf("ipv4 invalid_ihl=%d bytes=%d", ihl, len(packet))
|
||||
}
|
||||
proto := packet[9]
|
||||
src := net.IP(packet[12:16]).String()
|
||||
dst := net.IP(packet[16:20]).String()
|
||||
return fmt.Sprintf("ipv4 %s -> %s proto=%d bytes=%d", src, dst, proto, len(packet))
|
||||
}
|
||||
|
||||
func summarizeIPv6(packet []byte) string {
|
||||
if len(packet) < 40 {
|
||||
return fmt.Sprintf("ipv6 truncated bytes=%d", len(packet))
|
||||
}
|
||||
nextHeader := packet[6]
|
||||
src := net.IP(packet[8:24]).String()
|
||||
dst := net.IP(packet[24:40]).String()
|
||||
return fmt.Sprintf("ipv6 %s -> %s next=%d bytes=%d", src, dst, nextHeader, len(packet))
|
||||
}
|
||||
@@ -0,0 +1,206 @@
|
||||
//go:build linux
|
||||
|
||||
package vpnruntime
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
tunDevicePath = "/dev/net/tun"
|
||||
iffTun = 0x0001
|
||||
iffNoPI = 0x1000
|
||||
tunSetIFF = 0x400454ca
|
||||
ifNameSize = 16
|
||||
)
|
||||
|
||||
type tunDevice struct {
|
||||
file *os.File
|
||||
fd int
|
||||
name string
|
||||
}
|
||||
|
||||
func openGatewayTun(name, addressCIDR, routeCIDR string) (*tunDevice, error) {
|
||||
dev, err := openGatewayTunDevice(name)
|
||||
if errors.Is(err, syscall.EBUSY) {
|
||||
cleanupStaleGatewayInterface(name)
|
||||
dev, err = openGatewayTunDevice(name)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := configureGatewayInterface(name, addressCIDR, routeCIDR); err != nil {
|
||||
_ = dev.Close()
|
||||
return nil, err
|
||||
}
|
||||
return dev, nil
|
||||
}
|
||||
|
||||
func openGatewayTunDevice(name string) (*tunDevice, error) {
|
||||
file, err := os.OpenFile(tunDevicePath, os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open %s: %w", tunDevicePath, err)
|
||||
}
|
||||
ifr := make([]byte, 40)
|
||||
copy(ifr[:ifNameSize], []byte(name))
|
||||
*(*uint16)(unsafe.Pointer(&ifr[ifNameSize])) = iffTun | iffNoPI
|
||||
if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, file.Fd(), uintptr(tunSetIFF), uintptr(unsafe.Pointer(&ifr[0]))); errno != 0 {
|
||||
file.Close()
|
||||
return nil, fmt.Errorf("configure tun %s: %w", name, errno)
|
||||
}
|
||||
return &tunDevice{file: file, fd: int(file.Fd()), name: name}, nil
|
||||
}
|
||||
|
||||
func cleanupStaleGatewayInterface(name string) {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return
|
||||
}
|
||||
_ = runCommand("ip", "link", "set", name, "down")
|
||||
_ = runCommand("ip", "link", "delete", name)
|
||||
}
|
||||
|
||||
func (d *tunDevice) Read(packet []byte) (int, error) {
|
||||
return syscall.Read(d.fd, packet)
|
||||
}
|
||||
|
||||
func (d *tunDevice) Write(packet []byte) (int, error) {
|
||||
return syscall.Write(d.fd, packet)
|
||||
}
|
||||
|
||||
func (d *tunDevice) Close() error {
|
||||
_ = runCommand("ip", "link", "set", d.name, "down")
|
||||
return d.file.Close()
|
||||
}
|
||||
|
||||
func configureGatewayInterface(name, addressCIDR, routeCIDR string) error {
|
||||
if _, _, err := net.ParseCIDR(addressCIDR); err != nil {
|
||||
return fmt.Errorf("invalid vpn gateway address %q: %w", addressCIDR, err)
|
||||
}
|
||||
if err := runCommand("ip", "addr", "replace", addressCIDR, "dev", name); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := runCommand("ip", "link", "set", name, "up"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := enableIPv4Forwarding(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := disableReversePathFiltering(name); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ensureForwardingRules(name); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ensureMasqueradeRules(routeCIDR); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ensureMSSClampRule(name); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureMasqueradeRules(routeCIDR string) error {
|
||||
egress, _ := defaultIPv4Interface()
|
||||
if egress != "" {
|
||||
if err := ensureIPTablesRule("nat", "POSTROUTING", "-s", routeCIDR, "-o", egress, "-j", "MASQUERADE"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return ensureIPTablesRule("nat", "POSTROUTING", "-s", routeCIDR, "-j", "MASQUERADE")
|
||||
}
|
||||
|
||||
func ensureMSSClampRule(interfaceName string) error {
|
||||
err := ensureIPTablesRule("mangle", "FORWARD", "-i", interfaceName, "-p", "tcp", "--tcp-flags", "SYN,RST", "SYN", "-j", "TCPMSS", "--clamp-mss-to-pmtu")
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func defaultIPv4Interface() (string, error) {
|
||||
out, err := exec.Command("ip", "-o", "-4", "route", "show", "default").CombinedOutput()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("ip default route failed: %w: %s", err, string(out))
|
||||
}
|
||||
fields := strings.Fields(string(out))
|
||||
for i := 0; i+1 < len(fields); i++ {
|
||||
if fields[i] == "dev" {
|
||||
return fields[i+1], nil
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func ensureForwardingRules(interfaceName string) error {
|
||||
if err := ensureIPTablesRule("filter", "FORWARD", "-i", interfaceName, "-j", "ACCEPT"); err != nil {
|
||||
return err
|
||||
}
|
||||
err := ensureIPTablesRule("filter", "FORWARD", "-o", interfaceName, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT")
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return ensureIPTablesRule("filter", "FORWARD", "-o", interfaceName, "-j", "ACCEPT")
|
||||
}
|
||||
|
||||
func ensureIPTablesRule(table, chain string, rule ...string) error {
|
||||
checkArgs := append([]string{"-t", table, "-C", chain}, rule...)
|
||||
if err := runCommand("iptables", checkArgs...); err == nil {
|
||||
return nil
|
||||
}
|
||||
addArgs := append([]string{"-t", table, "-I", chain, "1"}, rule...)
|
||||
return runCommand("iptables", addArgs...)
|
||||
}
|
||||
|
||||
func enableIPv4Forwarding() error {
|
||||
if current, err := os.ReadFile("/proc/sys/net/ipv4/ip_forward"); err == nil && len(current) > 0 && current[0] == '1' {
|
||||
return nil
|
||||
}
|
||||
if err := os.WriteFile("/proc/sys/net/ipv4/ip_forward", []byte("1\n"), 0o644); err == nil {
|
||||
return nil
|
||||
}
|
||||
return runCommand("sysctl", "-w", "net.ipv4.ip_forward=1")
|
||||
}
|
||||
|
||||
func disableReversePathFiltering(interfaceName string) error {
|
||||
keys := []string{"all", "default", interfaceName}
|
||||
if entries, err := os.ReadDir("/proc/sys/net/ipv4/conf"); err == nil {
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
keys = append(keys, entry.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
seen := make(map[string]bool)
|
||||
for _, key := range keys {
|
||||
if seen[key] {
|
||||
continue
|
||||
}
|
||||
seen[key] = true
|
||||
path := fmt.Sprintf("/proc/sys/net/ipv4/conf/%s/rp_filter", key)
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
continue
|
||||
}
|
||||
if err := os.WriteFile(path, []byte("0\n"), 0o644); err != nil {
|
||||
if sysctlErr := runCommand("sysctl", "-w", fmt.Sprintf("net.ipv4.conf.%s.rp_filter=0", key)); sysctlErr != nil {
|
||||
return sysctlErr
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runCommand(name string, args ...string) error {
|
||||
cmd := exec.Command(name, args...)
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("%s %v failed: %w: %s", name, args, err, string(out))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
//go:build !linux && !windows
|
||||
|
||||
package vpnruntime
|
||||
|
||||
import "fmt"
|
||||
|
||||
type tunDevice struct{}
|
||||
|
||||
func openGatewayTun(name, addressCIDR, routeCIDR string) (*tunDevice, error) {
|
||||
return nil, fmt.Errorf("vpn gateway runtime is currently supported only on linux")
|
||||
}
|
||||
|
||||
func (d *tunDevice) Read(packet []byte) (int, error) {
|
||||
return 0, fmt.Errorf("vpn gateway runtime is currently supported only on linux")
|
||||
}
|
||||
|
||||
func (d *tunDevice) Write(packet []byte) (int, error) {
|
||||
return 0, fmt.Errorf("vpn gateway runtime is currently supported only on linux")
|
||||
}
|
||||
|
||||
func (d *tunDevice) Close() error {
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user