255 lines
7.5 KiB
Go
255 lines
7.5 KiB
Go
//go:build linux
|
|
|
|
package vpnruntime
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"os/exec"
|
|
"strings"
|
|
"syscall"
|
|
"unsafe"
|
|
)
|
|
|
|
const (
|
|
tunDevicePath = "/dev/net/tun"
|
|
iffTun = 0x0001
|
|
iffNoPI = 0x1000
|
|
tunSetIFF = 0x400454ca
|
|
ifNameSize = 16
|
|
gatewayTunMTU = "1280"
|
|
gatewayTCPMSS = "1240"
|
|
)
|
|
|
|
type tunDevice struct {
|
|
file *os.File
|
|
fd int
|
|
name string
|
|
}
|
|
|
|
func openGatewayTun(name, addressCIDR, routeCIDR string) (*tunDevice, error) {
|
|
dev, err := openGatewayTunDevice(name)
|
|
if errors.Is(err, syscall.EBUSY) {
|
|
cleanupStaleGatewayInterface(name)
|
|
dev, err = openGatewayTunDevice(name)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := configureGatewayInterface(name, addressCIDR, routeCIDR); err != nil {
|
|
_ = dev.Close()
|
|
return nil, err
|
|
}
|
|
return dev, nil
|
|
}
|
|
|
|
func openGatewayTunDevice(name string) (*tunDevice, error) {
|
|
file, err := os.OpenFile(tunDevicePath, os.O_RDWR, 0)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("open %s: %w", tunDevicePath, err)
|
|
}
|
|
ifr := make([]byte, 40)
|
|
copy(ifr[:ifNameSize], []byte(name))
|
|
*(*uint16)(unsafe.Pointer(&ifr[ifNameSize])) = iffTun | iffNoPI
|
|
if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, file.Fd(), uintptr(tunSetIFF), uintptr(unsafe.Pointer(&ifr[0]))); errno != 0 {
|
|
file.Close()
|
|
return nil, fmt.Errorf("configure tun %s: %w", name, errno)
|
|
}
|
|
return &tunDevice{file: file, fd: int(file.Fd()), name: name}, nil
|
|
}
|
|
|
|
func cleanupStaleGatewayInterface(name string) {
|
|
if strings.TrimSpace(name) == "" {
|
|
return
|
|
}
|
|
_ = runCommand("ip", "link", "set", name, "down")
|
|
_ = runCommand("ip", "link", "delete", name)
|
|
}
|
|
|
|
func (d *tunDevice) Read(packet []byte) (int, error) {
|
|
return syscall.Read(d.fd, packet)
|
|
}
|
|
|
|
func (d *tunDevice) Write(packet []byte) (int, error) {
|
|
return syscall.Write(d.fd, packet)
|
|
}
|
|
|
|
func (d *tunDevice) Close() error {
|
|
_ = runCommand("ip", "link", "set", d.name, "down")
|
|
return d.file.Close()
|
|
}
|
|
|
|
func configureGatewayInterface(name, addressCIDR, routeCIDR string) error {
|
|
if _, _, err := net.ParseCIDR(addressCIDR); err != nil {
|
|
return fmt.Errorf("invalid vpn gateway address %q: %w", addressCIDR, err)
|
|
}
|
|
if err := runCommand("ip", "addr", "replace", addressCIDR, "dev", name); err != nil {
|
|
return err
|
|
}
|
|
if err := runCommand("ip", "link", "set", "dev", name, "mtu", gatewayTunMTU); err != nil {
|
|
return err
|
|
}
|
|
if err := runCommand("ip", "link", "set", name, "up"); err != nil {
|
|
return err
|
|
}
|
|
if err := enableIPv4Forwarding(); err != nil {
|
|
return err
|
|
}
|
|
if err := disableReversePathFiltering(name); err != nil {
|
|
return err
|
|
}
|
|
if err := ensureForwardingRules(name); err != nil {
|
|
return err
|
|
}
|
|
if err := ensureMasqueradeRules(routeCIDR); err != nil {
|
|
return err
|
|
}
|
|
if err := ensureMSSClampRule(name); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ensureMasqueradeRules(routeCIDR string) error {
|
|
egress, _ := defaultIPv4Interface()
|
|
if egress != "" {
|
|
if err := ensureIPTablesRule("nat", "POSTROUTING", "-s", routeCIDR, "-o", egress, "-j", "MASQUERADE"); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return ensureIPTablesRule("nat", "POSTROUTING", "-s", routeCIDR, "-j", "MASQUERADE")
|
|
}
|
|
|
|
func ensureMSSClampRule(interfaceName string) error {
|
|
if err := ensureIPTablesRule("mangle", "FORWARD", "-i", interfaceName, "-p", "tcp", "--tcp-flags", "SYN,RST", "SYN", "-j", "TCPMSS", "--set-mss", gatewayTCPMSS); err != nil {
|
|
return err
|
|
}
|
|
return ensureIPTablesRule("mangle", "FORWARD", "-o", interfaceName, "-p", "tcp", "--tcp-flags", "SYN,RST", "SYN", "-j", "TCPMSS", "--set-mss", gatewayTCPMSS)
|
|
}
|
|
|
|
func defaultIPv4Interface() (string, error) {
|
|
out, err := exec.Command("ip", "-o", "-4", "route", "show", "default").CombinedOutput()
|
|
if err != nil {
|
|
return "", fmt.Errorf("ip default route failed: %w: %s", err, string(out))
|
|
}
|
|
fields := strings.Fields(string(out))
|
|
for i := 0; i+1 < len(fields); i++ {
|
|
if fields[i] == "dev" {
|
|
return fields[i+1], nil
|
|
}
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
func ensureForwardingRules(interfaceName string) error {
|
|
if err := ensureIPTablesRule("filter", "FORWARD", "-i", interfaceName, "-j", "ACCEPT"); err != nil {
|
|
return err
|
|
}
|
|
err := ensureIPTablesRule("filter", "FORWARD", "-o", interfaceName, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT")
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
return ensureIPTablesRule("filter", "FORWARD", "-o", interfaceName, "-j", "ACCEPT")
|
|
}
|
|
|
|
func ensureIPTablesRule(table, chain string, rule ...string) error {
|
|
checkArgs := append([]string{"-t", table, "-C", chain}, rule...)
|
|
if err := runCommand("iptables", checkArgs...); err == nil {
|
|
return nil
|
|
}
|
|
addArgs := append([]string{"-t", table, "-I", chain, "1"}, rule...)
|
|
return runCommand("iptables", addArgs...)
|
|
}
|
|
|
|
func enableIPv4Forwarding() error {
|
|
if current, err := os.ReadFile("/proc/sys/net/ipv4/ip_forward"); err == nil && len(current) > 0 && current[0] == '1' {
|
|
return nil
|
|
}
|
|
if err := os.WriteFile("/proc/sys/net/ipv4/ip_forward", []byte("1\n"), 0o644); err == nil {
|
|
return nil
|
|
}
|
|
return runCommand("sysctl", "-w", "net.ipv4.ip_forward=1")
|
|
}
|
|
|
|
func disableReversePathFiltering(interfaceName string) error {
|
|
keys := []string{"all", "default", interfaceName}
|
|
if entries, err := os.ReadDir("/proc/sys/net/ipv4/conf"); err == nil {
|
|
for _, entry := range entries {
|
|
if entry.IsDir() {
|
|
keys = append(keys, entry.Name())
|
|
}
|
|
}
|
|
}
|
|
seen := make(map[string]bool)
|
|
for _, key := range keys {
|
|
if seen[key] {
|
|
continue
|
|
}
|
|
seen[key] = true
|
|
path := fmt.Sprintf("/proc/sys/net/ipv4/conf/%s/rp_filter", key)
|
|
if _, err := os.Stat(path); err != nil {
|
|
continue
|
|
}
|
|
if err := os.WriteFile(path, []byte("0\n"), 0o644); err != nil {
|
|
if sysctlErr := runCommand("sysctl", "-w", fmt.Sprintf("net.ipv4.conf.%s.rp_filter=0", key)); sysctlErr != nil {
|
|
return sysctlErr
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func runCommand(name string, args ...string) error {
|
|
cmd := exec.Command(name, args...)
|
|
if out, err := cmd.CombinedOutput(); err != nil {
|
|
return fmt.Errorf("%s %v failed: %w: %s", name, args, err, string(out))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func gatewayPlatformSnapshot(interfaceName, routeCIDR string) map[string]any {
|
|
out := map[string]any{
|
|
"os": "linux",
|
|
"interface": interfaceName,
|
|
"route_cidr": routeCIDR,
|
|
}
|
|
if value, err := readTrimmedFile("/proc/sys/net/ipv4/ip_forward"); err == nil {
|
|
out["ipv4_forward"] = value
|
|
}
|
|
for _, key := range []string{"all", "default", interfaceName} {
|
|
if strings.TrimSpace(key) == "" {
|
|
continue
|
|
}
|
|
if value, err := readTrimmedFile(fmt.Sprintf("/proc/sys/net/ipv4/conf/%s/rp_filter", key)); err == nil {
|
|
out["rp_filter_"+key] = value
|
|
}
|
|
}
|
|
if interfaceName != "" {
|
|
out["forward_in_rule"] = iptablesRulePresent("filter", "FORWARD", "-i", interfaceName, "-j", "ACCEPT")
|
|
out["forward_out_established_rule"] = iptablesRulePresent("filter", "FORWARD", "-o", interfaceName, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT")
|
|
}
|
|
if routeCIDR != "" {
|
|
out["masquerade_rule"] = iptablesRulePresent("nat", "POSTROUTING", "-s", routeCIDR, "-j", "MASQUERADE")
|
|
if egress, err := defaultIPv4Interface(); err == nil && egress != "" {
|
|
out["default_egress"] = egress
|
|
out["egress_masquerade_rule"] = iptablesRulePresent("nat", "POSTROUTING", "-s", routeCIDR, "-o", egress, "-j", "MASQUERADE")
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func readTrimmedFile(path string) (string, error) {
|
|
payload, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return strings.TrimSpace(string(payload)), nil
|
|
}
|
|
|
|
func iptablesRulePresent(table, chain string, rule ...string) bool {
|
|
checkArgs := append([]string{"-t", table, "-C", chain}, rule...)
|
|
return exec.Command("iptables", checkArgs...).Run() == nil
|
|
}
|