рабочий вариант, но скороть 10 МБит
This commit is contained in:
@@ -3,6 +3,7 @@ package fabricvpn
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -18,6 +19,13 @@ import (
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultRuntimeStreamShards = 8
|
||||
maxRuntimeStreamShards = 128
|
||||
minPacketBatchSendTimeout = 5 * time.Second
|
||||
maxPacketBatchSendTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type endpointConfig struct {
|
||||
EndpointID string `json:"endpoint_id"`
|
||||
NodeID string `json:"node_id"`
|
||||
@@ -31,8 +39,15 @@ type endpointConfig struct {
|
||||
type runtimeConfig struct {
|
||||
ClusterID string `json:"cluster_id"`
|
||||
LocalNodeID string `json:"local_node_id"`
|
||||
ExitNodeID string `json:"exit_node_id"`
|
||||
VPNConnectionID string `json:"vpn_connection_id"`
|
||||
TunnelID string `json:"tunnel_id"`
|
||||
PoolID string `json:"pool_id"`
|
||||
ServiceID string `json:"service_id"`
|
||||
LocalServiceID string `json:"local_service_id"`
|
||||
RemoteServiceID string `json:"remote_service_id"`
|
||||
ServiceKind string `json:"service_kind"`
|
||||
ServiceClass string `json:"service_class"`
|
||||
RouteLeaseID string `json:"route_lease_id"`
|
||||
RouteGeneration string `json:"route_generation"`
|
||||
Endpoints []endpointConfig `json:"endpoints"`
|
||||
RouteBundle routeBundleConfig `json:"route_bundle"`
|
||||
ServiceChannelRequest serviceChannelRequest `json:"service_channel_request"`
|
||||
@@ -56,6 +71,7 @@ type routeBundleConfig struct {
|
||||
type routeLeaseConfig struct {
|
||||
SchemaVersion string `json:"schema_version"`
|
||||
LeaseID string `json:"lease_id"`
|
||||
Generation string `json:"generation"`
|
||||
SelectedTargetNode string `json:"selected_target_node"`
|
||||
PrimaryPath routeLeasePath `json:"primary_path"`
|
||||
WarmStandbyPaths []routeLeasePath `json:"warm_standby_paths"`
|
||||
@@ -82,17 +98,19 @@ type SocketProtector interface {
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
opMu sync.Mutex
|
||||
mu sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
transport *mesh.QUICFabricTransport
|
||||
session mesh.FabricTransportSession
|
||||
packet *vpnruntime.FabricSessionPacketTransport
|
||||
inbox *vpnruntime.FabricPacketInbox
|
||||
cfg runtimeConfig
|
||||
lastErr string
|
||||
endpoint string
|
||||
protector SocketProtector
|
||||
opMu sync.Mutex
|
||||
mu sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
heartbeatCancel context.CancelFunc
|
||||
transport *mesh.QUICFabricTransport
|
||||
session mesh.FabricTransportSession
|
||||
packet *vpnruntime.FabricSessionPacketTransport
|
||||
inbox *vpnruntime.FabricPacketInbox
|
||||
serviceStreams *vpnruntime.FabricServiceStreamRegistry
|
||||
cfg runtimeConfig
|
||||
lastErr string
|
||||
endpoint string
|
||||
protector SocketProtector
|
||||
|
||||
uplinkPackets atomic.Uint64
|
||||
uplinkBytes atomic.Uint64
|
||||
@@ -100,6 +118,14 @@ type Manager struct {
|
||||
downlinkBytes atomic.Uint64
|
||||
}
|
||||
|
||||
type fabricEndpointConnectResult struct {
|
||||
endpoint endpointConfig
|
||||
session mesh.FabricTransportSession
|
||||
streamIDs map[string][]uint64
|
||||
streamID uint64
|
||||
err error
|
||||
}
|
||||
|
||||
func NewManager() *Manager {
|
||||
return &Manager{}
|
||||
}
|
||||
@@ -117,12 +143,28 @@ func (m *Manager) Start(configJSON string) error {
|
||||
}
|
||||
cfg.ClusterID = strings.TrimSpace(cfg.ClusterID)
|
||||
cfg.LocalNodeID = strings.TrimSpace(cfg.LocalNodeID)
|
||||
cfg.ExitNodeID = strings.TrimSpace(cfg.ExitNodeID)
|
||||
cfg.VPNConnectionID = strings.TrimSpace(cfg.VPNConnectionID)
|
||||
cfg.TunnelID = strings.TrimSpace(cfg.TunnelID)
|
||||
cfg.PoolID = strings.TrimSpace(cfg.PoolID)
|
||||
cfg.ServiceID = strings.TrimSpace(cfg.ServiceID)
|
||||
cfg.LocalServiceID = strings.TrimSpace(cfg.LocalServiceID)
|
||||
cfg.RemoteServiceID = strings.TrimSpace(cfg.RemoteServiceID)
|
||||
cfg.ServiceKind = strings.TrimSpace(cfg.ServiceKind)
|
||||
cfg.ServiceClass = strings.TrimSpace(cfg.ServiceClass)
|
||||
cfg.RouteLeaseID = strings.TrimSpace(firstNonEmpty(cfg.RouteLeaseID, cfg.RouteBundle.RouteLease.LeaseID))
|
||||
cfg.RouteGeneration = strings.TrimSpace(firstNonEmpty(cfg.RouteGeneration, cfg.RouteBundle.RouteLease.Generation, cfg.RouteBundle.RouteLease.LeaseID))
|
||||
cfg.TunnelID = firstNonEmpty(cfg.TunnelID)
|
||||
if cfg.PoolID == "" {
|
||||
cfg.PoolID = vpnruntime.DefaultFabricTunnelPoolID
|
||||
}
|
||||
if cfg.ServiceClass == "" {
|
||||
cfg.ServiceClass = vpnruntime.DefaultFabricTunnelClass
|
||||
}
|
||||
if cfg.ServiceKind == "" {
|
||||
cfg.ServiceKind = vpnruntime.DefaultFabricTunnelServiceKind
|
||||
}
|
||||
cfg.Endpoints = fabricRuntimeEndpoints(cfg)
|
||||
cfg.ExitNodeID = firstNonEmpty(cfg.ExitNodeID, fabricRuntimeTargetNodeID(cfg))
|
||||
if cfg.ClusterID == "" || cfg.LocalNodeID == "" || cfg.VPNConnectionID == "" {
|
||||
return fmt.Errorf("cluster, local node and vpn connection id are required")
|
||||
if cfg.ClusterID == "" || cfg.LocalNodeID == "" || cfg.TunnelID == "" {
|
||||
return fmt.Errorf("cluster, local node and fabric tunnel id are required")
|
||||
}
|
||||
if strings.TrimSpace(cfg.ServiceChannelRequest.SchemaVersion) == "" {
|
||||
return fmt.Errorf("fabric service channel request is required")
|
||||
@@ -131,10 +173,10 @@ func (m *Manager) Start(configJSON string) error {
|
||||
return fmt.Errorf("fabric route lease has no QUIC candidates")
|
||||
}
|
||||
if cfg.StreamShards <= 0 {
|
||||
cfg.StreamShards = 4
|
||||
cfg.StreamShards = defaultRuntimeStreamShards
|
||||
}
|
||||
if cfg.StreamShards > 32 {
|
||||
cfg.StreamShards = 32
|
||||
if cfg.StreamShards > maxRuntimeStreamShards {
|
||||
cfg.StreamShards = maxRuntimeStreamShards
|
||||
}
|
||||
|
||||
m.Stop()
|
||||
@@ -187,43 +229,22 @@ func (m *Manager) connect(ctx context.Context, cfg runtimeConfig, cancel context
|
||||
return mesh.ProductionForwardResult{Delivered: true, MessageID: envelope.MessageID}, nil
|
||||
}, nil, nil)
|
||||
|
||||
var lastErr error
|
||||
for _, endpoint := range cfg.Endpoints {
|
||||
target := mesh.FabricTransportTarget{
|
||||
EndpointID: firstNonEmpty(endpoint.EndpointID, endpoint.Address),
|
||||
PeerID: firstNonEmpty(endpoint.NodeID, cfg.ExitNodeID),
|
||||
Endpoint: endpoint.Address,
|
||||
Transport: firstNonEmpty(endpoint.Transport, "direct_quic"),
|
||||
PeerCertSHA256: firstNonEmpty(endpoint.PeerCertSHA256, endpoint.TLSCertSHA256),
|
||||
Timeout: 5 * time.Second,
|
||||
OutboundBuffer: 512,
|
||||
InboundBuffer: 512,
|
||||
ErrorBuffer: 32,
|
||||
}
|
||||
carrier, selected, err := mesh.FabricTransportForTarget(target, quicTransport)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
dialCtx, dialCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
session, err := carrier.Connect(dialCtx, selected)
|
||||
if err != nil {
|
||||
dialCancel()
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
streamIDs, streamID, err := openStreams(dialCtx, session, cfg.StreamShards)
|
||||
dialCancel()
|
||||
if err != nil {
|
||||
_ = session.Close()
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
result, err := m.connectFastestEndpoint(ctx, cfg, quicTransport)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
endpoint := result.endpoint
|
||||
session := result.session
|
||||
streamIDs := result.streamIDs
|
||||
streamID := result.streamID
|
||||
heartbeatCtx, heartbeatCancel := context.WithCancel(context.Background())
|
||||
m.mu.Lock()
|
||||
m.cancel = cancel
|
||||
m.heartbeatCancel = heartbeatCancel
|
||||
m.transport = quicTransport
|
||||
m.session = session
|
||||
m.inbox = inbox
|
||||
m.serviceStreams = vpnruntime.NewFabricServiceStreamRegistry()
|
||||
m.cfg = cfg
|
||||
m.endpoint = endpoint.Address
|
||||
m.lastErr = ""
|
||||
@@ -232,18 +253,219 @@ func (m *Manager) connect(ctx context.Context, cfg runtimeConfig, cancel context
|
||||
Receiver: session,
|
||||
Inbox: inbox,
|
||||
StreamID: streamID,
|
||||
ServiceStreams: m.serviceStreams,
|
||||
ServiceTunnel: serviceTunnelFromRuntimeConfig(cfg),
|
||||
StreamIDsByTrafficClass: streamIDs,
|
||||
VPNConnectionID: cfg.VPNConnectionID,
|
||||
TunnelID: cfg.TunnelID,
|
||||
PoolID: cfg.PoolID,
|
||||
ServiceID: cfg.ServiceID,
|
||||
VPNConnectionID: cfg.TunnelID,
|
||||
SendDirection: vpnruntime.FabricDirectionClientToGateway,
|
||||
ReceiveDirection: vpnruntime.FabricDirectionGatewayToClient,
|
||||
}
|
||||
m.mu.Unlock()
|
||||
announceCtx, announceCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
announceErr := announceVPNSessionStreams(announceCtx, session, serviceTunnelFromRuntimeConfig(cfg), streamIDs, streamID)
|
||||
announceCancel()
|
||||
if announceErr != nil {
|
||||
m.setErr(announceErr)
|
||||
}
|
||||
go m.runVPNSessionHeartbeat(heartbeatCtx, session, streamIDs, streamID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) connectFastestEndpoint(ctx context.Context, cfg runtimeConfig, quicTransport *mesh.QUICFabricTransport) (fabricEndpointConnectResult, error) {
|
||||
if len(cfg.Endpoints) == 0 {
|
||||
return fabricEndpointConnectResult{}, fmt.Errorf("no QUIC exit endpoints available")
|
||||
}
|
||||
connectCtx, connectCancel := context.WithCancel(ctx)
|
||||
defer connectCancel()
|
||||
endpointGroups := groupEndpointsByPeer(cfg)
|
||||
results := make(chan fabricEndpointConnectResult, len(endpointGroups))
|
||||
attempts := 0
|
||||
for _, group := range endpointGroups {
|
||||
attempts++
|
||||
go func(group []endpointConfig) {
|
||||
var last fabricEndpointConnectResult
|
||||
for _, endpoint := range group {
|
||||
target := fabricRuntimePacketTarget(cfg, endpoint)
|
||||
carrier, selected, err := mesh.FabricTransportForTarget(target, quicTransport)
|
||||
if err != nil {
|
||||
last = fabricEndpointConnectResult{endpoint: endpoint, err: err}
|
||||
continue
|
||||
}
|
||||
dialCtx, dialCancel := context.WithTimeout(connectCtx, 5*time.Second)
|
||||
session, err := carrier.Connect(dialCtx, selected)
|
||||
if err != nil {
|
||||
dialCancel()
|
||||
last = fabricEndpointConnectResult{endpoint: endpoint, err: err}
|
||||
continue
|
||||
}
|
||||
streamIDs, streamID, err := openStreams(dialCtx, session, cfg.StreamShards)
|
||||
dialCancel()
|
||||
if err != nil {
|
||||
_ = session.Close()
|
||||
last = fabricEndpointConnectResult{endpoint: endpoint, err: err}
|
||||
continue
|
||||
}
|
||||
results <- fabricEndpointConnectResult{
|
||||
endpoint: endpoint,
|
||||
session: session,
|
||||
streamIDs: streamIDs,
|
||||
streamID: streamID,
|
||||
}
|
||||
return
|
||||
}
|
||||
if last.err == nil {
|
||||
last.err = fmt.Errorf("no endpoint attempt completed for peer")
|
||||
}
|
||||
results <- last
|
||||
}(group)
|
||||
}
|
||||
var lastErr error
|
||||
for index := 0; index < attempts; index++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if lastErr != nil {
|
||||
return fabricEndpointConnectResult{}, lastErr
|
||||
}
|
||||
return fabricEndpointConnectResult{}, ctx.Err()
|
||||
case result := <-results:
|
||||
if result.err != nil {
|
||||
lastErr = result.err
|
||||
continue
|
||||
}
|
||||
connectCancel()
|
||||
go closeLateFabricSessions(results, attempts-index-1)
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
if lastErr == nil {
|
||||
lastErr = fmt.Errorf("no QUIC exit endpoints available")
|
||||
lastErr = fmt.Errorf("no endpoint attempt completed")
|
||||
}
|
||||
return fmt.Errorf("fabric bootstrap failed after %d endpoint candidates: %w", len(cfg.Endpoints), lastErr)
|
||||
return fabricEndpointConnectResult{}, fmt.Errorf("fabric bootstrap failed after %d endpoint candidates: %w", len(cfg.Endpoints), lastErr)
|
||||
}
|
||||
|
||||
func groupEndpointsByPeer(cfg runtimeConfig) [][]endpointConfig {
|
||||
groups := make([][]endpointConfig, 0, len(cfg.Endpoints))
|
||||
indexByPeer := map[string]int{}
|
||||
for _, endpoint := range cfg.Endpoints {
|
||||
peer := endpointPeerKey(cfg, endpoint)
|
||||
if index, ok := indexByPeer[peer]; ok {
|
||||
groups[index] = append(groups[index], endpoint)
|
||||
continue
|
||||
}
|
||||
indexByPeer[peer] = len(groups)
|
||||
groups = append(groups, []endpointConfig{endpoint})
|
||||
}
|
||||
return groups
|
||||
}
|
||||
|
||||
func endpointPeerKey(cfg runtimeConfig, endpoint endpointConfig) string {
|
||||
if value := strings.TrimSpace(endpoint.NodeID); value != "" {
|
||||
return value
|
||||
}
|
||||
if value := strings.TrimSpace(fabricRuntimeTargetNodeID(cfg)); value != "" {
|
||||
return value
|
||||
}
|
||||
return firstNonEmpty(endpoint.EndpointID, endpoint.Address)
|
||||
}
|
||||
|
||||
func closeLateFabricSessions(results <-chan fabricEndpointConnectResult, remaining int) {
|
||||
for index := 0; index < remaining; index++ {
|
||||
result := <-results
|
||||
if result.session != nil {
|
||||
_ = result.session.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func fabricRuntimePacketTarget(cfg runtimeConfig, endpoint endpointConfig) mesh.FabricTransportTarget {
|
||||
return mesh.FabricTransportTarget{
|
||||
EndpointID: firstNonEmpty(endpoint.EndpointID, endpoint.Address),
|
||||
PeerID: firstNonEmpty(endpoint.NodeID, fabricRuntimeTargetNodeID(cfg)),
|
||||
Endpoint: endpoint.Address,
|
||||
Transport: firstNonEmpty(endpoint.Transport, "direct_quic"),
|
||||
PeerCertSHA256: firstNonEmpty(endpoint.PeerCertSHA256, endpoint.TLSCertSHA256),
|
||||
OutboundBuffer: 4096,
|
||||
InboundBuffer: 4096,
|
||||
ErrorBuffer: 128,
|
||||
}
|
||||
}
|
||||
|
||||
func announceVPNSessionStreams(ctx context.Context, session mesh.FabricTransportSession, serviceTunnel vpnruntime.FabricServiceTunnel, streamIDsByClass map[string][]uint64, fallbackStreamID uint64) error {
|
||||
serviceTunnel = vpnruntime.NormalizeServiceTunnel(serviceTunnel, serviceTunnel.TunnelID)
|
||||
if session == nil || strings.TrimSpace(serviceTunnel.TunnelID) == "" {
|
||||
return fmt.Errorf("fabric vpn session announce requires an active session")
|
||||
}
|
||||
announced := map[uint64]bool{}
|
||||
sequence := uint64(time.Now().UnixNano())
|
||||
for trafficClass, streamIDs := range streamIDsByClass {
|
||||
for _, streamID := range streamIDs {
|
||||
if streamID == 0 || announced[streamID] {
|
||||
continue
|
||||
}
|
||||
sequence++
|
||||
frame, err := vpnruntime.NewFabricVPNSessionHelloFrame(vpnruntime.FabricVPNPacketFrameInput{
|
||||
StreamID: streamID,
|
||||
Sequence: sequence,
|
||||
VPNConnectionID: serviceTunnel.TunnelID,
|
||||
Direction: vpnruntime.FabricDirectionClientToGateway,
|
||||
TrafficClass: trafficClass,
|
||||
ServiceTunnel: serviceTunnel,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := session.Send(ctx, frame); err != nil {
|
||||
return err
|
||||
}
|
||||
announced[streamID] = true
|
||||
}
|
||||
}
|
||||
if len(announced) == 0 && fallbackStreamID != 0 {
|
||||
frame, err := vpnruntime.NewFabricVPNSessionHelloFrame(vpnruntime.FabricVPNPacketFrameInput{
|
||||
StreamID: fallbackStreamID,
|
||||
Sequence: sequence + 1,
|
||||
VPNConnectionID: serviceTunnel.TunnelID,
|
||||
Direction: vpnruntime.FabricDirectionClientToGateway,
|
||||
TrafficClass: vpnruntime.FabricTrafficClassBulk,
|
||||
ServiceTunnel: serviceTunnel,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := session.Send(ctx, frame); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) runVPNSessionHeartbeat(ctx context.Context, session mesh.FabricTransportSession, streamIDsByClass map[string][]uint64, fallbackStreamID uint64) {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
heartbeatCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||
err := announceVPNSessionStreams(heartbeatCtx, session, m.currentServiceTunnel(), streamIDsByClass, fallbackStreamID)
|
||||
cancel()
|
||||
if err != nil {
|
||||
m.setErr(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) currentServiceTunnel() vpnruntime.FabricServiceTunnel {
|
||||
m.mu.Lock()
|
||||
cfg := m.cfg
|
||||
m.mu.Unlock()
|
||||
return serviceTunnelFromRuntimeConfig(cfg)
|
||||
}
|
||||
|
||||
func (m *Manager) protectedQUICDialer() func(context.Context, string, *tls.Config, *quic.Config) (*quic.Conn, error) {
|
||||
@@ -300,16 +522,22 @@ func (m *Manager) Stop() {
|
||||
func (m *Manager) stopLocked() {
|
||||
m.mu.Lock()
|
||||
cancel := m.cancel
|
||||
heartbeatCancel := m.heartbeatCancel
|
||||
session := m.session
|
||||
transport := m.transport
|
||||
m.cancel = nil
|
||||
m.heartbeatCancel = nil
|
||||
m.session = nil
|
||||
m.transport = nil
|
||||
m.packet = nil
|
||||
m.serviceStreams = nil
|
||||
m.mu.Unlock()
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
if heartbeatCancel != nil {
|
||||
heartbeatCancel()
|
||||
}
|
||||
if session != nil {
|
||||
_ = session.Close()
|
||||
}
|
||||
@@ -322,33 +550,30 @@ func (m *Manager) SendPacket(packet []byte) error {
|
||||
if len(packet) == 0 {
|
||||
return nil
|
||||
}
|
||||
m.opMu.Lock()
|
||||
defer m.opMu.Unlock()
|
||||
if err := m.ensureConnectedLocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
m.mu.Lock()
|
||||
transport := m.packet
|
||||
m.mu.Unlock()
|
||||
transport := m.packetTransport()
|
||||
if transport == nil {
|
||||
return fmt.Errorf("fabric vpn runtime is not connected")
|
||||
var err error
|
||||
transport, err = m.reconnectPacketTransport()
|
||||
if err != nil || transport == nil {
|
||||
return fmt.Errorf("fabric vpn runtime is not connected")
|
||||
}
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
packetBatch := [][]byte{append([]byte(nil), packet...)}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), packetBatchSendTimeout(packetBatch))
|
||||
defer cancel()
|
||||
if err := transport.SendGatewayPacketBatch(ctx, [][]byte{append([]byte(nil), packet...)}); err != nil {
|
||||
if err := transport.SendGatewayPacketBatch(ctx, packetBatch); err != nil {
|
||||
m.setErr(err)
|
||||
if reconnectErr := m.reconnectLocked(); reconnectErr != nil {
|
||||
transport, reconnectErr := m.reconnectPacketTransport()
|
||||
if reconnectErr != nil {
|
||||
return err
|
||||
}
|
||||
m.mu.Lock()
|
||||
transport = m.packet
|
||||
m.mu.Unlock()
|
||||
if transport == nil {
|
||||
return err
|
||||
}
|
||||
retryCtx, retryCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
retryPacketBatch := [][]byte{append([]byte(nil), packet...)}
|
||||
retryCtx, retryCancel := context.WithTimeout(context.Background(), packetBatchSendTimeout(retryPacketBatch))
|
||||
defer retryCancel()
|
||||
if retryErr := transport.SendGatewayPacketBatch(retryCtx, [][]byte{append([]byte(nil), packet...)}); retryErr != nil {
|
||||
if retryErr := transport.SendGatewayPacketBatch(retryCtx, retryPacketBatch); retryErr != nil {
|
||||
m.setErr(retryErr)
|
||||
return retryErr
|
||||
}
|
||||
@@ -358,17 +583,94 @@ func (m *Manager) SendPacket(packet []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) SendPacketBatchPayload(payload []byte) error {
|
||||
packets, err := decodePacketBatchPayload(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(packets) == 0 {
|
||||
return nil
|
||||
}
|
||||
transport := m.packetTransport()
|
||||
if transport == nil {
|
||||
var err error
|
||||
transport, err = m.reconnectPacketTransport()
|
||||
if err != nil || transport == nil {
|
||||
return fmt.Errorf("fabric vpn runtime is not connected")
|
||||
}
|
||||
}
|
||||
sendTimeout := packetBatchSendTimeout(packets)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sendTimeout)
|
||||
defer cancel()
|
||||
if err := transport.SendGatewayPacketBatch(ctx, clonePacketBatch(packets)); err != nil {
|
||||
m.setErr(err)
|
||||
transport, reconnectErr := m.reconnectPacketTransport()
|
||||
if reconnectErr != nil {
|
||||
return err
|
||||
}
|
||||
if transport == nil {
|
||||
return err
|
||||
}
|
||||
retryCtx, retryCancel := context.WithTimeout(context.Background(), sendTimeout)
|
||||
defer retryCancel()
|
||||
if retryErr := transport.SendGatewayPacketBatch(retryCtx, clonePacketBatch(packets)); retryErr != nil {
|
||||
m.setErr(retryErr)
|
||||
return retryErr
|
||||
}
|
||||
}
|
||||
var bytes uint64
|
||||
for _, packet := range packets {
|
||||
bytes += uint64(len(packet))
|
||||
}
|
||||
m.uplinkPackets.Add(uint64(len(packets)))
|
||||
m.uplinkBytes.Add(bytes)
|
||||
return nil
|
||||
}
|
||||
|
||||
func packetBatchSendTimeout(packets [][]byte) time.Duration {
|
||||
if len(packets) == 0 {
|
||||
return minPacketBatchSendTimeout
|
||||
}
|
||||
var bytes int
|
||||
for _, packet := range packets {
|
||||
bytes += len(packet)
|
||||
}
|
||||
timeout := minPacketBatchSendTimeout
|
||||
if bytes > 0 {
|
||||
timeout += time.Duration(bytes/(512*1024)) * time.Second
|
||||
}
|
||||
if len(packets) > 512 {
|
||||
timeout += time.Duration(len(packets)/512) * time.Second
|
||||
}
|
||||
if timeout > maxPacketBatchSendTimeout {
|
||||
return maxPacketBatchSendTimeout
|
||||
}
|
||||
return timeout
|
||||
}
|
||||
|
||||
func (m *Manager) ReceivePacket(timeoutMillis int) ([]byte, error) {
|
||||
m.opMu.Lock()
|
||||
defer m.opMu.Unlock()
|
||||
if err := m.ensureConnectedLocked(); err != nil {
|
||||
payload, err := m.ReceivePacketBatchPayload(timeoutMillis)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.mu.Lock()
|
||||
transport := m.packet
|
||||
m.mu.Unlock()
|
||||
packets, err := decodePacketBatchPayload(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(packets) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return append([]byte(nil), packets[0]...), nil
|
||||
}
|
||||
|
||||
func (m *Manager) ReceivePacketBatchPayload(timeoutMillis int) ([]byte, error) {
|
||||
transport := m.packetTransport()
|
||||
if transport == nil {
|
||||
return nil, fmt.Errorf("fabric vpn runtime is not connected")
|
||||
var err error
|
||||
transport, err = m.reconnectPacketTransport()
|
||||
if err != nil || transport == nil {
|
||||
return nil, fmt.Errorf("fabric vpn runtime is not connected")
|
||||
}
|
||||
}
|
||||
timeout := time.Duration(timeoutMillis) * time.Millisecond
|
||||
if timeout <= 0 {
|
||||
@@ -379,16 +681,19 @@ func (m *Manager) ReceivePacket(timeoutMillis int) ([]byte, error) {
|
||||
packets, err := transport.ReceiveGatewayPacketBatch(ctx, timeout)
|
||||
if err != nil {
|
||||
m.setErr(err)
|
||||
_ = m.reconnectLocked()
|
||||
_, _ = m.reconnectPacketTransport()
|
||||
return nil, err
|
||||
}
|
||||
if len(packets) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
packet := append([]byte(nil), packets[0]...)
|
||||
m.downlinkPackets.Add(1)
|
||||
m.downlinkBytes.Add(uint64(len(packet)))
|
||||
return packet, nil
|
||||
var bytes uint64
|
||||
for _, packet := range packets {
|
||||
bytes += uint64(len(packet))
|
||||
}
|
||||
m.downlinkPackets.Add(uint64(len(packets)))
|
||||
m.downlinkBytes.Add(bytes)
|
||||
return encodePacketBatchPayload(packets), nil
|
||||
}
|
||||
|
||||
func (m *Manager) ControlRequest(payloadJSON string) (string, error) {
|
||||
@@ -402,19 +707,63 @@ func (m *Manager) ControlRequest(payloadJSON string) (string, error) {
|
||||
cfg := m.cfg
|
||||
endpointAddress := m.endpoint
|
||||
m.mu.Unlock()
|
||||
if transport == nil || endpointAddress == "" {
|
||||
if transport == nil {
|
||||
return "", fmt.Errorf("fabric control runtime is not connected")
|
||||
}
|
||||
endpoint := endpointConfig{Address: endpointAddress}
|
||||
for _, candidate := range cfg.Endpoints {
|
||||
if strings.TrimSpace(candidate.Address) == endpointAddress {
|
||||
endpoint = candidate
|
||||
break
|
||||
}
|
||||
candidates := prioritizeControlEndpoints(cfg.Endpoints, endpointAddress)
|
||||
if len(candidates) == 0 {
|
||||
return "", fmt.Errorf("fabric control runtime has no bootstrap endpoints")
|
||||
}
|
||||
var lastErr error
|
||||
for _, endpoint := range candidates {
|
||||
response, err := m.controlRequestToEndpoint(transport, cfg, endpoint, payloadJSON)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(endpoint.Address) != "" && strings.TrimSpace(endpoint.Address) != endpointAddress {
|
||||
m.mu.Lock()
|
||||
m.endpoint = strings.TrimSpace(endpoint.Address)
|
||||
m.mu.Unlock()
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
if lastErr != nil {
|
||||
return "", lastErr
|
||||
}
|
||||
return "", fmt.Errorf("fabric control route unavailable")
|
||||
}
|
||||
|
||||
func prioritizeControlEndpoints(endpoints []endpointConfig, activeAddress string) []endpointConfig {
|
||||
activeAddress = strings.TrimSpace(activeAddress)
|
||||
out := make([]endpointConfig, 0, len(endpoints)+1)
|
||||
seen := map[string]bool{}
|
||||
for _, endpoint := range endpoints {
|
||||
address := strings.TrimSpace(endpoint.Address)
|
||||
if address == "" || address != activeAddress {
|
||||
continue
|
||||
}
|
||||
out = append(out, endpoint)
|
||||
seen[address] = true
|
||||
}
|
||||
for _, endpoint := range endpoints {
|
||||
address := strings.TrimSpace(endpoint.Address)
|
||||
if address == "" || seen[address] {
|
||||
continue
|
||||
}
|
||||
out = append(out, endpoint)
|
||||
seen[address] = true
|
||||
}
|
||||
if len(out) == 0 && activeAddress != "" {
|
||||
out = append(out, endpointConfig{Address: activeAddress})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *Manager) controlRequestToEndpoint(transport *mesh.QUICFabricTransport, cfg runtimeConfig, endpoint endpointConfig, payloadJSON string) (string, error) {
|
||||
target := mesh.FabricTransportTarget{
|
||||
EndpointID: firstNonEmpty(endpoint.EndpointID, endpoint.Address),
|
||||
PeerID: firstNonEmpty(endpoint.NodeID, cfg.ExitNodeID),
|
||||
PeerID: firstNonEmpty(endpoint.NodeID, fabricRuntimeTargetNodeID(cfg)),
|
||||
Endpoint: endpoint.Address,
|
||||
Transport: firstNonEmpty(endpoint.Transport, "direct_quic"),
|
||||
PeerCertSHA256: firstNonEmpty(endpoint.PeerCertSHA256, endpoint.TLSCertSHA256),
|
||||
@@ -479,6 +828,130 @@ func (m *Manager) Reconnect() error {
|
||||
return m.reconnectLocked()
|
||||
}
|
||||
|
||||
func (m *Manager) UpdateRuntimeConfig(configJSON string) error {
|
||||
var next runtimeConfig
|
||||
if err := json.Unmarshal([]byte(configJSON), &next); err != nil {
|
||||
return err
|
||||
}
|
||||
next.ClusterID = strings.TrimSpace(next.ClusterID)
|
||||
next.LocalNodeID = strings.TrimSpace(next.LocalNodeID)
|
||||
next.TunnelID = strings.TrimSpace(next.TunnelID)
|
||||
next.PoolID = strings.TrimSpace(next.PoolID)
|
||||
next.ServiceID = strings.TrimSpace(next.ServiceID)
|
||||
next.LocalServiceID = strings.TrimSpace(next.LocalServiceID)
|
||||
next.RemoteServiceID = strings.TrimSpace(next.RemoteServiceID)
|
||||
next.ServiceKind = strings.TrimSpace(next.ServiceKind)
|
||||
next.ServiceClass = strings.TrimSpace(next.ServiceClass)
|
||||
next.RouteLeaseID = strings.TrimSpace(firstNonEmpty(next.RouteLeaseID, next.RouteBundle.RouteLease.LeaseID))
|
||||
next.RouteGeneration = strings.TrimSpace(firstNonEmpty(next.RouteGeneration, next.RouteBundle.RouteLease.Generation, next.RouteBundle.RouteLease.LeaseID))
|
||||
next.Endpoints = fabricRuntimeEndpoints(next)
|
||||
if next.StreamShards <= 0 {
|
||||
next.StreamShards = defaultRuntimeStreamShards
|
||||
}
|
||||
if next.StreamShards > maxRuntimeStreamShards {
|
||||
next.StreamShards = maxRuntimeStreamShards
|
||||
}
|
||||
|
||||
m.opMu.Lock()
|
||||
defer m.opMu.Unlock()
|
||||
m.mu.Lock()
|
||||
current := m.cfg
|
||||
packet := m.packet
|
||||
m.mu.Unlock()
|
||||
if current.TunnelID != "" && next.TunnelID != "" && current.TunnelID != next.TunnelID {
|
||||
return fmt.Errorf("fabric runtime config tunnel id changed from %q to %q", current.TunnelID, next.TunnelID)
|
||||
}
|
||||
if next.ClusterID == "" {
|
||||
next.ClusterID = current.ClusterID
|
||||
}
|
||||
if next.LocalNodeID == "" {
|
||||
next.LocalNodeID = current.LocalNodeID
|
||||
}
|
||||
if next.TunnelID == "" {
|
||||
next.TunnelID = current.TunnelID
|
||||
}
|
||||
if next.PoolID == "" {
|
||||
next.PoolID = current.PoolID
|
||||
}
|
||||
if next.ServiceID == "" {
|
||||
next.ServiceID = current.ServiceID
|
||||
}
|
||||
if next.ServiceKind == "" {
|
||||
next.ServiceKind = current.ServiceKind
|
||||
}
|
||||
if next.ServiceClass == "" {
|
||||
next.ServiceClass = current.ServiceClass
|
||||
}
|
||||
if len(next.Endpoints) == 0 {
|
||||
next.Endpoints = current.Endpoints
|
||||
}
|
||||
reconnectForRoute := shouldReconnectForRuntimeRoute(current, next)
|
||||
if packet != nil {
|
||||
if _, err := packet.UpdateServiceTunnel(serviceTunnelFromRuntimeConfig(next)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.cfg = next
|
||||
m.lastErr = ""
|
||||
m.mu.Unlock()
|
||||
if reconnectForRoute {
|
||||
if err := m.reconnectLocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func shouldReconnectForRuntimeRoute(current runtimeConfig, next runtimeConfig) bool {
|
||||
if current.TunnelID == "" || next.TunnelID == "" || current.TunnelID != next.TunnelID {
|
||||
return false
|
||||
}
|
||||
if fabricRuntimeTargetNodeID(current) != fabricRuntimeTargetNodeID(next) {
|
||||
return true
|
||||
}
|
||||
return endpointListSignature(current.Endpoints) != endpointListSignature(next.Endpoints)
|
||||
}
|
||||
|
||||
func endpointListSignature(endpoints []endpointConfig) string {
|
||||
if len(endpoints) == 0 {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for _, endpoint := range endpoints {
|
||||
b.WriteString(endpoint.EndpointID)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(endpoint.NodeID)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(endpoint.Transport)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(endpoint.Address)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(endpoint.PeerCertSHA256)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(endpoint.TLSCertSHA256)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(fmt.Sprintf("%d", endpoint.Priority))
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m *Manager) packetTransport() *vpnruntime.FabricSessionPacketTransport {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.packet
|
||||
}
|
||||
|
||||
func (m *Manager) reconnectPacketTransport() (*vpnruntime.FabricSessionPacketTransport, error) {
|
||||
m.opMu.Lock()
|
||||
defer m.opMu.Unlock()
|
||||
if err := m.reconnectLocked(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m.packetTransport(), nil
|
||||
}
|
||||
|
||||
func (m *Manager) ensureConnectedLocked() error {
|
||||
m.mu.Lock()
|
||||
connected := m.packet != nil
|
||||
@@ -498,11 +971,16 @@ func (m *Manager) reconnectLocked() error {
|
||||
cfg := m.cfg
|
||||
oldSession := m.session
|
||||
oldTransport := m.transport
|
||||
oldHeartbeatCancel := m.heartbeatCancel
|
||||
cancel := m.cancel
|
||||
m.session = nil
|
||||
m.transport = nil
|
||||
m.packet = nil
|
||||
m.heartbeatCancel = nil
|
||||
m.mu.Unlock()
|
||||
if oldHeartbeatCancel != nil {
|
||||
oldHeartbeatCancel()
|
||||
}
|
||||
if oldSession != nil {
|
||||
_ = oldSession.Close()
|
||||
}
|
||||
@@ -521,27 +999,105 @@ func (m *Manager) reconnectLocked() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodePacketBatchPayload(payload []byte) ([][]byte, error) {
|
||||
if len(payload) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
packets := make([][]byte, 0, 16)
|
||||
for offset := 0; offset < len(payload); {
|
||||
if len(payload)-offset < 4 {
|
||||
return nil, fmt.Errorf("invalid packet batch payload: truncated length")
|
||||
}
|
||||
size := int(binary.BigEndian.Uint32(payload[offset : offset+4]))
|
||||
offset += 4
|
||||
if size <= 0 || size > 65535 {
|
||||
return nil, fmt.Errorf("invalid packet batch payload: packet size %d", size)
|
||||
}
|
||||
if len(payload)-offset < size {
|
||||
return nil, fmt.Errorf("invalid packet batch payload: truncated packet")
|
||||
}
|
||||
packet := append([]byte(nil), payload[offset:offset+size]...)
|
||||
packets = append(packets, packet)
|
||||
offset += size
|
||||
}
|
||||
return packets, nil
|
||||
}
|
||||
|
||||
func encodePacketBatchPayload(packets [][]byte) []byte {
|
||||
if len(packets) == 0 {
|
||||
return nil
|
||||
}
|
||||
total := 0
|
||||
for _, packet := range packets {
|
||||
if len(packet) == 0 {
|
||||
continue
|
||||
}
|
||||
total += 4 + len(packet)
|
||||
}
|
||||
if total == 0 {
|
||||
return nil
|
||||
}
|
||||
payload := make([]byte, 0, total)
|
||||
var size [4]byte
|
||||
for _, packet := range packets {
|
||||
if len(packet) == 0 {
|
||||
continue
|
||||
}
|
||||
binary.BigEndian.PutUint32(size[:], uint32(len(packet)))
|
||||
payload = append(payload, size[:]...)
|
||||
payload = append(payload, packet...)
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func clonePacketBatch(packets [][]byte) [][]byte {
|
||||
out := make([][]byte, 0, len(packets))
|
||||
for _, packet := range packets {
|
||||
if len(packet) == 0 {
|
||||
continue
|
||||
}
|
||||
out = append(out, append([]byte(nil), packet...))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *Manager) SnapshotJSON() string {
|
||||
m.mu.Lock()
|
||||
connected := m.packet != nil
|
||||
endpoint := m.endpoint
|
||||
lastErr := m.lastErr
|
||||
vpnConnectionID := m.cfg.VPNConnectionID
|
||||
tunnelID := m.cfg.TunnelID
|
||||
poolID := m.cfg.PoolID
|
||||
serviceID := m.cfg.ServiceID
|
||||
localNodeID := m.cfg.LocalNodeID
|
||||
exitNodeID := m.cfg.ExitNodeID
|
||||
serviceKind := m.cfg.ServiceKind
|
||||
serviceClass := m.cfg.ServiceClass
|
||||
routeLeaseID := m.cfg.RouteLeaseID
|
||||
routeGeneration := m.cfg.RouteGeneration
|
||||
var serviceStreamSnapshot map[string]any
|
||||
if m.serviceStreams != nil {
|
||||
serviceStreamSnapshot = m.serviceStreams.Snapshot()
|
||||
}
|
||||
m.mu.Unlock()
|
||||
payload, _ := json.Marshal(map[string]any{
|
||||
"schema_version": "rap.android_fabric_vpn_runtime.v1",
|
||||
"schema_version": "rap.ipv4_tunnel_fabric_runtime.v1",
|
||||
"platform_adapter": "android_vpnservice_tun",
|
||||
"connected": connected,
|
||||
"endpoint": endpoint,
|
||||
"last_error": lastErr,
|
||||
"vpn_connection": vpnConnectionID,
|
||||
"tunnel_id": tunnelID,
|
||||
"pool_id": poolID,
|
||||
"service_id": serviceID,
|
||||
"service_kind": serviceKind,
|
||||
"service_class": serviceClass,
|
||||
"route_lease_id": routeLeaseID,
|
||||
"route_generation": routeGeneration,
|
||||
"local_node_id": localNodeID,
|
||||
"exit_node_id": exitNodeID,
|
||||
"uplink_packets": m.uplinkPackets.Load(),
|
||||
"uplink_bytes": m.uplinkBytes.Load(),
|
||||
"downlink_packets": m.downlinkPackets.Load(),
|
||||
"downlink_bytes": m.downlinkBytes.Load(),
|
||||
"service_streams": serviceStreamSnapshot,
|
||||
})
|
||||
return string(payload)
|
||||
}
|
||||
@@ -560,15 +1116,23 @@ func openStreams(ctx context.Context, session mesh.FabricTransportSession, shard
|
||||
classes := []struct {
|
||||
name string
|
||||
trafficClass fabricproto.TrafficClass
|
||||
shards int
|
||||
}{
|
||||
{name: vpnruntime.FabricTrafficClassInteractive, trafficClass: fabricproto.TrafficClassInteractive},
|
||||
{name: vpnruntime.FabricTrafficClassBulk, trafficClass: fabricproto.TrafficClassBulk},
|
||||
{name: vpnruntime.FabricTrafficClassControl, trafficClass: fabricproto.TrafficClassControl, shards: 1},
|
||||
{name: vpnruntime.FabricTrafficClassDNS, trafficClass: fabricproto.TrafficClassReliable, shards: 1},
|
||||
{name: vpnruntime.FabricTrafficClassInteractive, trafficClass: fabricproto.TrafficClassInteractive, shards: shards},
|
||||
{name: vpnruntime.FabricTrafficClassReliable, trafficClass: fabricproto.TrafficClassReliable, shards: maxInt(1, shards/2)},
|
||||
{name: vpnruntime.FabricTrafficClassBulk, trafficClass: fabricproto.TrafficClassBulk, shards: shards},
|
||||
{name: vpnruntime.FabricTrafficClassDroppable, trafficClass: fabricproto.TrafficClassDroppable, shards: maxInt(1, shards/2)},
|
||||
}
|
||||
out := make(map[string][]uint64, len(classes))
|
||||
var primary uint64
|
||||
var ordinal uint64
|
||||
for classIndex, class := range classes {
|
||||
for shard := 0; shard < shards; shard++ {
|
||||
streamID := base + uint64(classIndex*shards+shard)
|
||||
_ = classIndex
|
||||
for shard := 0; shard < class.shards; shard++ {
|
||||
ordinal++
|
||||
streamID := base + ordinal
|
||||
if err := session.Send(ctx, fabricproto.Frame{Type: fabricproto.FrameOpenStream, StreamID: streamID, TrafficClass: class.trafficClass}); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@@ -581,6 +1145,29 @@ func openStreams(ctx context.Context, session mesh.FabricTransportSession, shard
|
||||
return out, primary, nil
|
||||
}
|
||||
|
||||
func serviceTunnelFromRuntimeConfig(cfg runtimeConfig) vpnruntime.FabricServiceTunnel {
|
||||
return vpnruntime.NormalizeServiceTunnel(vpnruntime.FabricServiceTunnel{
|
||||
TunnelID: cfg.TunnelID,
|
||||
PoolID: cfg.PoolID,
|
||||
ServiceID: cfg.ServiceID,
|
||||
LocalServiceID: cfg.LocalServiceID,
|
||||
RemoteServiceID: cfg.RemoteServiceID,
|
||||
ServiceKind: cfg.ServiceKind,
|
||||
ServiceClass: cfg.ServiceClass,
|
||||
ServiceRole: vpnruntime.DefaultFabricTunnelRole,
|
||||
RouteLeaseID: cfg.RouteLeaseID,
|
||||
RouteGeneration: cfg.RouteGeneration,
|
||||
StreamShards: cfg.StreamShards,
|
||||
}, cfg.TunnelID)
|
||||
}
|
||||
|
||||
func maxInt(left, right int) int {
|
||||
if left > right {
|
||||
return left
|
||||
}
|
||||
return right
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
if strings.TrimSpace(value) != "" {
|
||||
|
||||
@@ -3,11 +3,14 @@ package fabricvpn
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/vpnruntime"
|
||||
)
|
||||
|
||||
func TestFabricRuntimeEndpointsPreferRouteBundle(t *testing.T) {
|
||||
cfg := runtimeConfig{
|
||||
Endpoints: []endpointConfig{{EndpointID: "legacy", Address: "quic://legacy.example:19131"}},
|
||||
Endpoints: []endpointConfig{{EndpointID: "compat", Address: "quic://compat.example:19131"}},
|
||||
RouteBundle: routeBundleConfig{
|
||||
EndpointCandidates: []endpointConfig{{EndpointID: "bundle", Address: "quic://bundle.example:19131"}},
|
||||
},
|
||||
@@ -20,7 +23,7 @@ func TestFabricRuntimeEndpointsPreferRouteBundle(t *testing.T) {
|
||||
|
||||
func TestFabricRuntimeEndpointsPreferRouteLease(t *testing.T) {
|
||||
cfg := runtimeConfig{
|
||||
Endpoints: []endpointConfig{{EndpointID: "legacy", Address: "quic://legacy.example:19131"}},
|
||||
Endpoints: []endpointConfig{{EndpointID: "compat", Address: "quic://compat.example:19131"}},
|
||||
RouteBundle: routeBundleConfig{
|
||||
EndpointCandidates: []endpointConfig{{EndpointID: "bundle", Address: "quic://bundle.example:19131"}},
|
||||
RouteLease: routeLeaseConfig{
|
||||
@@ -41,13 +44,148 @@ func TestFabricRuntimeEndpointsPreferRouteLease(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFabricRuntimeEndpointsFallbackToLegacyEndpoints(t *testing.T) {
|
||||
func TestFabricRuntimePacketTargetIsLongLived(t *testing.T) {
|
||||
cfg := runtimeConfig{
|
||||
Endpoints: []endpointConfig{{EndpointID: "legacy", Address: "quic://legacy.example:19131"}},
|
||||
RouteBundle: routeBundleConfig{RouteLease: routeLeaseConfig{
|
||||
PrimaryPath: routeLeasePath{TargetNodeID: "exit-1"},
|
||||
}},
|
||||
}
|
||||
target := fabricRuntimePacketTarget(cfg, endpointConfig{
|
||||
EndpointID: "exit-public",
|
||||
NodeID: "exit-1",
|
||||
Address: "quic://203.0.113.10:19131",
|
||||
Transport: "direct_quic",
|
||||
PeerCertSHA256: "abc123",
|
||||
})
|
||||
if target.Timeout != 0 {
|
||||
t.Fatalf("packet target timeout = %s, want 0 for long-lived vpn stream", target.Timeout)
|
||||
}
|
||||
if target.PeerID != "exit-1" || target.Endpoint != "quic://203.0.113.10:19131" || target.PeerCertSHA256 != "abc123" {
|
||||
t.Fatalf("unexpected packet target: %+v", target)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceTunnelFromRuntimeConfigCarriesRouteEpoch(t *testing.T) {
|
||||
cfg := runtimeConfig{
|
||||
TunnelID: "tunnel-1",
|
||||
PoolID: "home-ipv4",
|
||||
ServiceID: "svc-1",
|
||||
ServiceKind: "ipv4-tunnel",
|
||||
ServiceClass: "vpn_packets",
|
||||
RouteLeaseID: "lease-1",
|
||||
RouteGeneration: "route-gen-1",
|
||||
StreamShards: 8,
|
||||
}
|
||||
tunnel := serviceTunnelFromRuntimeConfig(cfg)
|
||||
if tunnel.RouteLeaseID != "lease-1" || tunnel.RouteGeneration != "route-gen-1" || tunnel.StreamShards != 8 {
|
||||
t.Fatalf("service tunnel route epoch = %+v", tunnel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerUpdateRuntimeConfigKeepsTunnelAndUpdatesRouteEpoch(t *testing.T) {
|
||||
manager := NewManager()
|
||||
manager.cfg = runtimeConfig{
|
||||
ClusterID: "cluster-1",
|
||||
LocalNodeID: "android-1",
|
||||
TunnelID: "tunnel-1",
|
||||
PoolID: "home-ipv4",
|
||||
ServiceID: "svc-1",
|
||||
ServiceKind: "ipv4-tunnel",
|
||||
ServiceClass: "vpn_packets",
|
||||
RouteLeaseID: "lease-1",
|
||||
RouteGeneration: "route-gen-1",
|
||||
StreamShards: 4,
|
||||
}
|
||||
manager.packet = &vpnruntime.FabricSessionPacketTransport{
|
||||
TunnelID: "tunnel-1",
|
||||
ServiceTunnel: vpnruntime.FabricServiceTunnel{
|
||||
TunnelID: "tunnel-1",
|
||||
PoolID: "home-ipv4",
|
||||
ServiceID: "svc-1",
|
||||
RouteLeaseID: "lease-1",
|
||||
RouteGeneration: "route-gen-1",
|
||||
},
|
||||
}
|
||||
|
||||
err := manager.UpdateRuntimeConfig(`{
|
||||
"cluster_id":"cluster-1",
|
||||
"local_node_id":"android-1",
|
||||
"tunnel_id":"tunnel-1",
|
||||
"pool_id":"home-ipv4",
|
||||
"service_id":"svc-1",
|
||||
"service_kind":"ipv4-tunnel",
|
||||
"service_class":"vpn_packets",
|
||||
"route_lease_id":"lease-2",
|
||||
"route_generation":"route-gen-2",
|
||||
"stream_shards":4,
|
||||
"service_channel_request":{"schema_version":"rap.fabric_service_channel_request.v1"}
|
||||
}`)
|
||||
if err != nil {
|
||||
t.Fatalf("update runtime config: %v", err)
|
||||
}
|
||||
snapshot := manager.packet.Snapshot()
|
||||
if snapshot["route_lease_id"] != "lease-2" || snapshot["route_generation"] != "route-gen-2" || snapshot["route_transition_count"] != uint64(1) {
|
||||
t.Fatalf("packet route epoch not updated: %+v", snapshot)
|
||||
}
|
||||
if err := manager.UpdateRuntimeConfig(`{"tunnel_id":"other-tunnel"}`); err == nil {
|
||||
t.Fatal("expected changed tunnel id to be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeRouteReconnectDecisionTracksTargetAndEndpoints(t *testing.T) {
|
||||
current := runtimeConfig{
|
||||
TunnelID: "tunnel-1",
|
||||
Endpoints: []endpointConfig{{EndpointID: "exit-a", NodeID: "node-a", Address: "quic://node-a:19131", Transport: "direct_quic"}},
|
||||
RouteBundle: routeBundleConfig{RouteLease: routeLeaseConfig{
|
||||
PrimaryPath: routeLeasePath{TargetNodeID: "node-a"},
|
||||
}},
|
||||
}
|
||||
sameLeaseNewGeneration := current
|
||||
sameLeaseNewGeneration.RouteLeaseID = "lease-2"
|
||||
sameLeaseNewGeneration.RouteGeneration = "route-gen-2"
|
||||
if shouldReconnectForRuntimeRoute(current, sameLeaseNewGeneration) {
|
||||
t.Fatal("same target/endpoints should update route epoch without reconnect")
|
||||
}
|
||||
newTarget := current
|
||||
newTarget.RouteBundle.RouteLease.PrimaryPath.TargetNodeID = "node-b"
|
||||
if !shouldReconnectForRuntimeRoute(current, newTarget) {
|
||||
t.Fatal("changed target node should reconnect fabric session")
|
||||
}
|
||||
newEndpoint := current
|
||||
newEndpoint.Endpoints = []endpointConfig{{EndpointID: "exit-b", NodeID: "node-b", Address: "quic://node-b:19131", Transport: "direct_quic"}}
|
||||
if !shouldReconnectForRuntimeRoute(current, newEndpoint) {
|
||||
t.Fatal("changed endpoint candidates should reconnect fabric session")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPacketBatchSendTimeoutScalesWithPayload(t *testing.T) {
|
||||
small := packetBatchSendTimeout([][]byte{make([]byte, 1200)})
|
||||
large := packetBatchSendTimeout([][]byte{make([]byte, 4*1024*1024)})
|
||||
if small != minPacketBatchSendTimeout {
|
||||
t.Fatalf("small timeout = %s, want %s", small, minPacketBatchSendTimeout)
|
||||
}
|
||||
if large <= small {
|
||||
t.Fatalf("large timeout = %s, want greater than %s", large, small)
|
||||
}
|
||||
many := make([][]byte, 2048)
|
||||
for i := range many {
|
||||
many[i] = make([]byte, 1200)
|
||||
}
|
||||
if got := packetBatchSendTimeout(many); got <= small {
|
||||
t.Fatalf("many-packet timeout = %s, want greater than %s", got, small)
|
||||
}
|
||||
if got := packetBatchSendTimeout([][]byte{make([]byte, 100*1024*1024)}); got != maxPacketBatchSendTimeout {
|
||||
t.Fatalf("capped timeout = %s, want %s", got, maxPacketBatchSendTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFabricRuntimeEndpointsFallbackToDisallowedEndpoints(t *testing.T) {
|
||||
cfg := runtimeConfig{
|
||||
Endpoints: []endpointConfig{{EndpointID: "compat", Address: "quic://compat.example:19131"}},
|
||||
}
|
||||
got := fabricRuntimeEndpoints(cfg)
|
||||
if len(got) != 1 || got[0].EndpointID != "legacy" {
|
||||
t.Fatalf("endpoints = %+v, want legacy endpoint fallback", got)
|
||||
if len(got) != 1 || got[0].EndpointID != "compat" {
|
||||
t.Fatalf("endpoints = %+v, want compat endpoint fallback", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,18 +214,18 @@ func TestLiveFabricVPNRuntimeStartsFromRouteLease(t *testing.T) {
|
||||
t.Fatalf("receive live dns packet: %v", err)
|
||||
}
|
||||
if len(packet) > 0 {
|
||||
if packet[9] != 17 || packet[12] != 1 || packet[13] != 1 || packet[14] != 1 || packet[15] != 1 {
|
||||
t.Fatalf("unexpected response packet header: %v", packet[:min(20, len(packet))])
|
||||
if len(packet) >= 20 && packet[9] == 17 && packet[12] == 1 && packet[13] == 1 && packet[14] == 1 && packet[15] == 1 {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatal("timed out waiting for live dns response through fabric vpn")
|
||||
}
|
||||
|
||||
func testDNSIPv4Packet() []byte {
|
||||
nonce := uint16(time.Now().UnixNano())
|
||||
dns := []byte{
|
||||
0x12, 0x34, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00,
|
||||
byte(nonce >> 8), byte(nonce), 0x01, 0x00, 0x00, 0x01, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x07, 'e', 'x', 'a',
|
||||
'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00,
|
||||
0x00, 0x01, 0x00, 0x01,
|
||||
@@ -102,8 +240,8 @@ func testDNSIPv4Packet() []byte {
|
||||
packet[9] = 17
|
||||
copy(packet[12:16], []byte{10, 77, 0, 2})
|
||||
copy(packet[16:20], []byte{1, 1, 1, 1})
|
||||
packet[20] = 0xcf
|
||||
packet[21] = 0x08
|
||||
packet[20] = byte(0xc0 | ((nonce >> 8) & 0x3f))
|
||||
packet[21] = byte(nonce)
|
||||
packet[22] = 0x00
|
||||
packet[23] = 0x35
|
||||
packet[24] = byte(udpLen >> 8)
|
||||
|
||||
Reference in New Issue
Block a user