Add tracked vpnruntime implementation for CI guard tests
This commit is contained in:
@@ -0,0 +1,206 @@
|
||||
//go:build linux
|
||||
|
||||
package vpnruntime
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
tunDevicePath = "/dev/net/tun"
|
||||
iffTun = 0x0001
|
||||
iffNoPI = 0x1000
|
||||
tunSetIFF = 0x400454ca
|
||||
ifNameSize = 16
|
||||
)
|
||||
|
||||
type tunDevice struct {
|
||||
file *os.File
|
||||
fd int
|
||||
name string
|
||||
}
|
||||
|
||||
func openGatewayTun(name, addressCIDR, routeCIDR string) (*tunDevice, error) {
|
||||
dev, err := openGatewayTunDevice(name)
|
||||
if errors.Is(err, syscall.EBUSY) {
|
||||
cleanupStaleGatewayInterface(name)
|
||||
dev, err = openGatewayTunDevice(name)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := configureGatewayInterface(name, addressCIDR, routeCIDR); err != nil {
|
||||
_ = dev.Close()
|
||||
return nil, err
|
||||
}
|
||||
return dev, nil
|
||||
}
|
||||
|
||||
func openGatewayTunDevice(name string) (*tunDevice, error) {
|
||||
file, err := os.OpenFile(tunDevicePath, os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open %s: %w", tunDevicePath, err)
|
||||
}
|
||||
ifr := make([]byte, 40)
|
||||
copy(ifr[:ifNameSize], []byte(name))
|
||||
*(*uint16)(unsafe.Pointer(&ifr[ifNameSize])) = iffTun | iffNoPI
|
||||
if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, file.Fd(), uintptr(tunSetIFF), uintptr(unsafe.Pointer(&ifr[0]))); errno != 0 {
|
||||
file.Close()
|
||||
return nil, fmt.Errorf("configure tun %s: %w", name, errno)
|
||||
}
|
||||
return &tunDevice{file: file, fd: int(file.Fd()), name: name}, nil
|
||||
}
|
||||
|
||||
func cleanupStaleGatewayInterface(name string) {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return
|
||||
}
|
||||
_ = runCommand("ip", "link", "set", name, "down")
|
||||
_ = runCommand("ip", "link", "delete", name)
|
||||
}
|
||||
|
||||
func (d *tunDevice) Read(packet []byte) (int, error) {
|
||||
return syscall.Read(d.fd, packet)
|
||||
}
|
||||
|
||||
func (d *tunDevice) Write(packet []byte) (int, error) {
|
||||
return syscall.Write(d.fd, packet)
|
||||
}
|
||||
|
||||
func (d *tunDevice) Close() error {
|
||||
_ = runCommand("ip", "link", "set", d.name, "down")
|
||||
return d.file.Close()
|
||||
}
|
||||
|
||||
func configureGatewayInterface(name, addressCIDR, routeCIDR string) error {
|
||||
if _, _, err := net.ParseCIDR(addressCIDR); err != nil {
|
||||
return fmt.Errorf("invalid vpn gateway address %q: %w", addressCIDR, err)
|
||||
}
|
||||
if err := runCommand("ip", "addr", "replace", addressCIDR, "dev", name); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := runCommand("ip", "link", "set", name, "up"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := enableIPv4Forwarding(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := disableReversePathFiltering(name); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ensureForwardingRules(name); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ensureMasqueradeRules(routeCIDR); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ensureMSSClampRule(name); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureMasqueradeRules(routeCIDR string) error {
|
||||
egress, _ := defaultIPv4Interface()
|
||||
if egress != "" {
|
||||
if err := ensureIPTablesRule("nat", "POSTROUTING", "-s", routeCIDR, "-o", egress, "-j", "MASQUERADE"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return ensureIPTablesRule("nat", "POSTROUTING", "-s", routeCIDR, "-j", "MASQUERADE")
|
||||
}
|
||||
|
||||
func ensureMSSClampRule(interfaceName string) error {
|
||||
err := ensureIPTablesRule("mangle", "FORWARD", "-i", interfaceName, "-p", "tcp", "--tcp-flags", "SYN,RST", "SYN", "-j", "TCPMSS", "--clamp-mss-to-pmtu")
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func defaultIPv4Interface() (string, error) {
|
||||
out, err := exec.Command("ip", "-o", "-4", "route", "show", "default").CombinedOutput()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("ip default route failed: %w: %s", err, string(out))
|
||||
}
|
||||
fields := strings.Fields(string(out))
|
||||
for i := 0; i+1 < len(fields); i++ {
|
||||
if fields[i] == "dev" {
|
||||
return fields[i+1], nil
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func ensureForwardingRules(interfaceName string) error {
|
||||
if err := ensureIPTablesRule("filter", "FORWARD", "-i", interfaceName, "-j", "ACCEPT"); err != nil {
|
||||
return err
|
||||
}
|
||||
err := ensureIPTablesRule("filter", "FORWARD", "-o", interfaceName, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT")
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return ensureIPTablesRule("filter", "FORWARD", "-o", interfaceName, "-j", "ACCEPT")
|
||||
}
|
||||
|
||||
func ensureIPTablesRule(table, chain string, rule ...string) error {
|
||||
checkArgs := append([]string{"-t", table, "-C", chain}, rule...)
|
||||
if err := runCommand("iptables", checkArgs...); err == nil {
|
||||
return nil
|
||||
}
|
||||
addArgs := append([]string{"-t", table, "-I", chain, "1"}, rule...)
|
||||
return runCommand("iptables", addArgs...)
|
||||
}
|
||||
|
||||
func enableIPv4Forwarding() error {
|
||||
if current, err := os.ReadFile("/proc/sys/net/ipv4/ip_forward"); err == nil && len(current) > 0 && current[0] == '1' {
|
||||
return nil
|
||||
}
|
||||
if err := os.WriteFile("/proc/sys/net/ipv4/ip_forward", []byte("1\n"), 0o644); err == nil {
|
||||
return nil
|
||||
}
|
||||
return runCommand("sysctl", "-w", "net.ipv4.ip_forward=1")
|
||||
}
|
||||
|
||||
func disableReversePathFiltering(interfaceName string) error {
|
||||
keys := []string{"all", "default", interfaceName}
|
||||
if entries, err := os.ReadDir("/proc/sys/net/ipv4/conf"); err == nil {
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
keys = append(keys, entry.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
seen := make(map[string]bool)
|
||||
for _, key := range keys {
|
||||
if seen[key] {
|
||||
continue
|
||||
}
|
||||
seen[key] = true
|
||||
path := fmt.Sprintf("/proc/sys/net/ipv4/conf/%s/rp_filter", key)
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
continue
|
||||
}
|
||||
if err := os.WriteFile(path, []byte("0\n"), 0o644); err != nil {
|
||||
if sysctlErr := runCommand("sysctl", "-w", fmt.Sprintf("net.ipv4.conf.%s.rp_filter=0", key)); sysctlErr != nil {
|
||||
return sysctlErr
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runCommand(name string, args ...string) error {
|
||||
cmd := exec.Command(name, args...)
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("%s %v failed: %w: %s", name, args, err, string(out))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user