Add tracked vpnruntime implementation for CI guard tests

This commit is contained in:
2026-05-12 10:02:49 +03:00
parent 60ef659084
commit 2eb4a769d0
8 changed files with 4086 additions and 28 deletions
@@ -3,9 +3,12 @@ package client
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/binary"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/url"
"time" "time"
) )
@@ -64,6 +67,95 @@ type HeartbeatRequest struct {
type HeartbeatResponse struct { type HeartbeatResponse struct {
Heartbeat json.RawMessage `json:"heartbeat"` Heartbeat json.RawMessage `json:"heartbeat"`
TestingFlags EffectiveTestingFlags `json:"testing_flags"` TestingFlags EffectiveTestingFlags `json:"testing_flags"`
UpdateHint *NodeUpdateHint `json:"update_hint,omitempty"`
}
type NodeUpdateHint struct {
SchemaVersion string `json:"schema_version"`
Generation string `json:"generation,omitempty"`
CheckNow bool `json:"check_now"`
Products []string `json:"products,omitempty"`
Reason string `json:"reason,omitempty"`
DeliveryMode string `json:"delivery_mode,omitempty"`
SubscriptionStatus string `json:"subscription_status,omitempty"`
UpdateService *NodeUpdateServiceAssignment `json:"update_service,omitempty"`
FallbackPollSeconds int `json:"fallback_poll_seconds,omitempty"`
}
type NodeUpdateServiceAssignment struct {
SchemaVersion string `json:"schema_version"`
NodeID string `json:"node_id,omitempty"`
NodeName string `json:"node_name,omitempty"`
Endpoint string `json:"endpoint,omitempty"`
Region string `json:"region,omitempty"`
Status string `json:"status"`
Reason string `json:"reason,omitempty"`
AssignedAt time.Time `json:"assigned_at"`
ExpiresAt time.Time `json:"expires_at"`
}
type NodeUpdatePlanRequest struct {
Product string
CurrentVersion string
OS string
Arch string
InstallType string
Channel string
}
type NodeUpdatePlanResponse struct {
Plan NodeUpdatePlan `json:"node_update_plan"`
}
type NodeUpdatePlan struct {
SchemaVersion string `json:"schema_version"`
ClusterID string `json:"cluster_id"`
NodeID string `json:"node_id"`
Product string `json:"product"`
CurrentVersion string `json:"current_version,omitempty"`
Action string `json:"action"`
Reason string `json:"reason"`
TargetVersion string `json:"target_version,omitempty"`
Channel string `json:"channel,omitempty"`
Strategy string `json:"strategy,omitempty"`
RollbackAllowed bool `json:"rollback_allowed"`
HealthWindowSec int `json:"health_window_seconds,omitempty"`
Artifact *ReleaseArtifact `json:"artifact,omitempty"`
AuthorityPayload json.RawMessage `json:"authority_payload,omitempty"`
AuthoritySignature *ClusterSignature `json:"authority_signature,omitempty"`
ProductionForwarding bool `json:"production_forwarding"`
}
type ReleaseArtifact struct {
ID string `json:"id"`
ReleaseID string `json:"release_id"`
ClusterID string `json:"cluster_id"`
Product string `json:"product"`
Version string `json:"version"`
OS string `json:"os"`
Arch string `json:"arch"`
InstallType string `json:"install_type"`
Kind string `json:"kind"`
URL string `json:"url"`
URLs []string `json:"urls,omitempty"`
SHA256 string `json:"sha256"`
SizeBytes int64 `json:"size_bytes"`
Signature *string `json:"signature,omitempty"`
Metadata json.RawMessage `json:"metadata"`
CreatedAt time.Time `json:"created_at"`
}
type NodeUpdateStatusRequest struct {
Product string `json:"product"`
CurrentVersion string `json:"current_version,omitempty"`
TargetVersion string `json:"target_version,omitempty"`
Phase string `json:"phase"`
Status string `json:"status"`
AttemptID string `json:"attempt_id,omitempty"`
ErrorMessage *string `json:"error_message,omitempty"`
RollbackVersion *string `json:"rollback_version,omitempty"`
Payload map[string]any `json:"payload,omitempty"`
ObservedAt time.Time `json:"observed_at,omitempty"`
} }
type EffectiveTestingFlags struct { type EffectiveTestingFlags struct {
@@ -91,6 +183,45 @@ type WorkloadStatusRequest struct {
StatusPayload map[string]any `json:"status_payload"` StatusPayload map[string]any `json:"status_payload"`
} }
type NodeVPNAssignmentLease struct {
LeaseID string `json:"lease_id"`
OwnerNodeID string `json:"owner_node_id"`
LeaseGeneration int64 `json:"lease_generation"`
Status string `json:"status"`
RenewedAt time.Time `json:"renewed_at"`
ExpiresAt time.Time `json:"expires_at"`
}
type NodeVPNAssignment struct {
VPNConnectionID string `json:"vpn_connection_id"`
ClusterID string `json:"cluster_id"`
OrganizationID string `json:"organization_id"`
Name string `json:"name"`
TargetEndpoint json.RawMessage `json:"target_endpoint"`
ProtocolFamily string `json:"protocol_family"`
Mode string `json:"mode"`
DesiredState string `json:"desired_state"`
RoutingUsage json.RawMessage `json:"routing_usage"`
RoutePolicy json.RawMessage `json:"route_policy"`
QoSPolicy json.RawMessage `json:"qos_policy"`
PlacementPolicy json.RawMessage `json:"placement_policy"`
Status string `json:"status"`
HasCredentialRef bool `json:"has_credential_ref"`
AssignmentReason string `json:"assignment_reason"`
ActiveLease *NodeVPNAssignmentLease `json:"active_lease,omitempty"`
UpdatedAt time.Time `json:"updated_at"`
}
type NodeVPNAssignmentStatusRequest struct {
ObservedStatus string `json:"observed_status"`
StatusPayload map[string]any `json:"status_payload"`
ObservedAt time.Time `json:"observed_at,omitempty"`
}
type NodeVPNAssignmentLeaseRenewRequest struct {
TTLSeconds int `json:"ttl_seconds"`
}
type MeshLinkObservationRequest struct { type MeshLinkObservationRequest struct {
SourceNodeID string `json:"source_node_id"` SourceNodeID string `json:"source_node_id"`
TargetNodeID string `json:"target_node_id"` TargetNodeID string `json:"target_node_id"`
@@ -147,10 +278,139 @@ type SyntheticMeshConfig struct {
RendezvousLeases []PeerRendezvousLease `json:"rendezvous_leases,omitempty"` RendezvousLeases []PeerRendezvousLease `json:"rendezvous_leases,omitempty"`
RendezvousRelayPolicy *RendezvousRelayPolicyReport `json:"rendezvous_relay_policy,omitempty"` RendezvousRelayPolicy *RendezvousRelayPolicyReport `json:"rendezvous_relay_policy,omitempty"`
RoutePathDecisions *RoutePathDecisionReport `json:"route_path_decisions,omitempty"` RoutePathDecisions *RoutePathDecisionReport `json:"route_path_decisions,omitempty"`
ServiceChannelFeedback *FabricServiceChannelFeedbackReport `json:"service_channel_route_feedback,omitempty"`
ServiceChannelAdaptivePolicy *FabricServiceChannelAdaptivePolicy `json:"service_channel_adaptive_policy,omitempty"`
ServiceChannelRemediationCommands []FabricServiceChannelRemediationCommand `json:"service_channel_remediation_commands,omitempty"`
MeshListener *MeshListenerConfig `json:"mesh_listener,omitempty"`
Routes []SyntheticMeshRouteConfig `json:"routes"` Routes []SyntheticMeshRouteConfig `json:"routes"`
ProductionForwarding bool `json:"production_forwarding"` ProductionForwarding bool `json:"production_forwarding"`
} }
type FabricServiceChannelRemediationCommand struct {
SchemaVersion string `json:"schema_version"`
CommandID string `json:"command_id"`
Action string `json:"action"`
ClusterID string `json:"cluster_id"`
ChannelID string `json:"channel_id"`
ResourceID string `json:"resource_id,omitempty"`
ServiceClass string `json:"service_class"`
EntryNodeID string `json:"entry_node_id,omitempty"`
ExitNodeID string `json:"exit_node_id,omitempty"`
PrimaryRouteID string `json:"primary_route_id,omitempty"`
ReplacementRouteID string `json:"replacement_route_id,omitempty"`
ReplacementRouteStatus string `json:"replacement_route_status,omitempty"`
PoolPolicyFingerprint string `json:"pool_policy_fingerprint,omitempty"`
GuardStatus string `json:"guard_status,omitempty"`
GuardReason string `json:"guard_reason,omitempty"`
ExecutionStatus string `json:"execution_status,omitempty"`
ExecutionReason string `json:"execution_reason,omitempty"`
ExecutionGeneration string `json:"execution_generation,omitempty"`
ExecutionObservedAt string `json:"execution_observed_at,omitempty"`
Reason string `json:"reason,omitempty"`
OperatorAction string `json:"operator_action,omitempty"`
IssuedAt time.Time `json:"issued_at"`
ExpiresAt time.Time `json:"expires_at"`
}
type FabricServiceChannelFeedbackReport struct {
SchemaVersion string `json:"schema_version"`
GeneratedAt time.Time `json:"generated_at"`
FeedbackMaxAgeSeconds int `json:"feedback_max_age_seconds"`
RecoveryPolicy *FabricServiceChannelRecoveryPolicy `json:"recovery_policy,omitempty"`
MissingProvenanceCount int `json:"missing_provenance_count,omitempty"`
StalePolicyCount int `json:"stale_policy_count,omitempty"`
StaleGenerationCount int `json:"stale_generation_count,omitempty"`
ObservationCount int `json:"observation_count"`
FencedRouteCount int `json:"fenced_route_count"`
DegradedRouteCount int `json:"degraded_route_count"`
HealthyRouteCount int `json:"healthy_route_count"`
RecoveredRouteCount int `json:"recovered_route_count,omitempty"`
RecoveryHysteresisCount int `json:"recovery_hysteresis_count,omitempty"`
RecoveryPromotedCount int `json:"recovery_promoted_count,omitempty"`
RecoveryDemotedCount int `json:"recovery_demoted_count,omitempty"`
Observations []FabricServiceChannelFeedbackObservation `json:"observations,omitempty"`
}
type FabricServiceChannelAdaptivePolicy struct {
SchemaVersion string `json:"schema_version"`
Fingerprint string `json:"fingerprint,omitempty"`
MaxParallelWindow int `json:"max_parallel_window"`
BulkPressureChannelThreshold int `json:"bulk_pressure_channel_threshold"`
QueuePressureHighWatermark int `json:"queue_pressure_high_watermark"`
QueuePressureMaxInFlight int `json:"queue_pressure_max_in_flight"`
ClassWindows map[string]int `json:"class_windows"`
Source string `json:"source"`
UpdatedByUserID *string `json:"updated_by_user_id,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
ControlPlaneOnly bool `json:"control_plane_only"`
ProductionForwarding bool `json:"production_forwarding"`
}
type FabricServiceChannelRecoveryPolicy struct {
SchemaVersion string `json:"schema_version"`
Fingerprint string `json:"fingerprint,omitempty"`
HysteresisPenalty int `json:"hysteresis_penalty"`
PromotionMinSamples int `json:"promotion_min_samples"`
DemotionFailureThreshold int `json:"demotion_failure_threshold"`
DemotionDropThreshold int `json:"demotion_drop_threshold"`
DemotionSlowThreshold int `json:"demotion_slow_threshold"`
DemotionRebuildEnabled bool `json:"demotion_rebuild_enabled"`
DemotionFencedEnabled bool `json:"demotion_fenced_enabled"`
Source string `json:"source"`
UpdatedByUserID *string `json:"updated_by_user_id,omitempty"`
UpdatedAt time.Time `json:"updated_at,omitempty"`
ControlPlaneOnly bool `json:"control_plane_only"`
ProductionForwarding bool `json:"production_forwarding"`
}
type FabricServiceChannelFeedbackObservation struct {
ID string `json:"id,omitempty"`
ClusterID string `json:"cluster_id"`
ReporterNodeID string `json:"reporter_node_id"`
RouteID string `json:"route_id"`
ServiceClass string `json:"service_class"`
FeedbackStatus string `json:"feedback_status"`
ScoreAdjustment int `json:"score_adjustment"`
EffectiveScoreAdjustment int `json:"effective_score_adjustment,omitempty"`
Reasons []string `json:"reasons,omitempty"`
LastError string `json:"last_error,omitempty"`
ConsecutiveFailures int `json:"consecutive_failures,omitempty"`
StallCount int `json:"stall_count,omitempty"`
LastSendDurationMs int64 `json:"last_send_duration_ms,omitempty"`
RecoveryState string `json:"recovery_state,omitempty"`
ObservedPolicyFingerprint string `json:"observed_policy_fingerprint,omitempty"`
EffectivePolicyFingerprint string `json:"effective_policy_fingerprint,omitempty"`
ObservedRouteGeneration string `json:"observed_route_generation,omitempty"`
EffectiveRouteGeneration string `json:"effective_route_generation,omitempty"`
ProvenanceMissing bool `json:"provenance_missing,omitempty"`
StalePolicy bool `json:"stale_policy,omitempty"`
StaleGeneration bool `json:"stale_generation,omitempty"`
StaleReason string `json:"stale_reason,omitempty"`
Payload json.RawMessage `json:"payload"`
ObservedAt time.Time `json:"observed_at"`
ExpiresAt time.Time `json:"expires_at"`
}
type MeshListenerConfig struct {
SchemaVersion string `json:"schema_version"`
Source string `json:"source"`
DesiredState string `json:"desired_state"`
ListenAddr string `json:"listen_addr"`
ListenPortMode string `json:"listen_port_mode"`
AutoPortStart int `json:"auto_port_start,omitempty"`
AutoPortEnd int `json:"auto_port_end,omitempty"`
AdvertiseEndpoint string `json:"advertise_endpoint,omitempty"`
AdvertiseTransport string `json:"advertise_transport,omitempty"`
ConnectivityMode string `json:"connectivity_mode,omitempty"`
NATType string `json:"nat_type,omitempty"`
Region string `json:"region,omitempty"`
ConfigVersion string `json:"config_version,omitempty"`
UpdatedByUserID string `json:"updated_by_user_id,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
ControlPlaneOnly bool `json:"control_plane_only"`
ProductionForwarding bool `json:"production_forwarding"`
}
type ClusterAuthorityDescriptor struct { type ClusterAuthorityDescriptor struct {
SchemaVersion string `json:"schema_version"` SchemaVersion string `json:"schema_version"`
ClusterID string `json:"cluster_id"` ClusterID string `json:"cluster_id"`
@@ -233,6 +493,11 @@ type RendezvousRelayPolicyReport struct {
type RoutePathDecision struct { type RoutePathDecision struct {
DecisionID string `json:"decision_id"` DecisionID string `json:"decision_id"`
RouteID string `json:"route_id"` RouteID string `json:"route_id"`
ReplacementRouteID string `json:"replacement_route_id,omitempty"`
RebuildRequestID string `json:"rebuild_request_id,omitempty"`
RebuildStatus string `json:"rebuild_status,omitempty"`
RebuildReason string `json:"rebuild_reason,omitempty"`
RebuildAttempt int `json:"rebuild_attempt,omitempty"`
ClusterID string `json:"cluster_id"` ClusterID string `json:"cluster_id"`
LocalNodeID string `json:"local_node_id"` LocalNodeID string `json:"local_node_id"`
SourceNodeID string `json:"source_node_id"` SourceNodeID string `json:"source_node_id"`
@@ -261,8 +526,15 @@ type RoutePathDecisionReport struct {
SchemaVersion string `json:"schema_version"` SchemaVersion string `json:"schema_version"`
DecisionMode string `json:"decision_mode"` DecisionMode string `json:"decision_mode"`
Generation string `json:"generation"` Generation string `json:"generation"`
RecoveryPolicy *FabricServiceChannelRecoveryPolicy `json:"recovery_policy,omitempty"`
DecisionCount int `json:"decision_count"` DecisionCount int `json:"decision_count"`
ReplacementDecisionCount int `json:"replacement_decision_count"` ReplacementDecisionCount int `json:"replacement_decision_count"`
DegradedDecisionCount int `json:"degraded_decision_count"`
RebuildRequestCount int `json:"rebuild_request_count,omitempty"`
RebuildAppliedCount int `json:"rebuild_applied_count,omitempty"`
RecoveryHysteresisCount int `json:"recovery_hysteresis_count,omitempty"`
RecoveryPromotedCount int `json:"recovery_promoted_count,omitempty"`
RecoveryDemotedCount int `json:"recovery_demoted_count,omitempty"`
ControlPlaneOnly bool `json:"control_plane_only"` ControlPlaneOnly bool `json:"control_plane_only"`
ProductionForwarding bool `json:"production_forwarding"` ProductionForwarding bool `json:"production_forwarding"`
Decisions []RoutePathDecision `json:"decisions,omitempty"` Decisions []RoutePathDecision `json:"decisions,omitempty"`
@@ -319,6 +591,29 @@ func (c *Client) Heartbeat(ctx context.Context, clusterID, nodeID string, reques
return response, nil return response, nil
} }
func (c *Client) NodeUpdatePlan(ctx context.Context, clusterID, nodeID string, request NodeUpdatePlanRequest) (NodeUpdatePlan, error) {
values := url.Values{}
values.Set("product", request.Product)
values.Set("current_version", request.CurrentVersion)
values.Set("os", request.OS)
values.Set("arch", request.Arch)
values.Set("install_type", request.InstallType)
if request.Channel != "" {
values.Set("channel", request.Channel)
}
var response NodeUpdatePlanResponse
path := fmt.Sprintf("/clusters/%s/nodes/%s/updates/plan?%s", clusterID, nodeID, values.Encode())
if err := c.getJSON(ctx, path, &response); err != nil {
return NodeUpdatePlan{}, err
}
return response.Plan, nil
}
func (c *Client) ReportNodeUpdateStatus(ctx context.Context, clusterID, nodeID string, request NodeUpdateStatusRequest) error {
path := fmt.Sprintf("/clusters/%s/nodes/%s/updates/status", clusterID, nodeID)
return c.postJSON(ctx, path, request, nil)
}
func (c *Client) DesiredWorkloads(ctx context.Context, clusterID, nodeID string) ([]DesiredWorkload, error) { func (c *Client) DesiredWorkloads(ctx context.Context, clusterID, nodeID string) ([]DesiredWorkload, error) {
var response struct { var response struct {
DesiredWorkloads []DesiredWorkload `json:"desired_workloads"` DesiredWorkloads []DesiredWorkload `json:"desired_workloads"`
@@ -335,6 +630,58 @@ func (c *Client) ReportWorkloadStatus(ctx context.Context, clusterID, nodeID, se
return c.postJSON(ctx, path, request, nil) return c.postJSON(ctx, path, request, nil)
} }
func (c *Client) NodeVPNAssignments(ctx context.Context, clusterID, nodeID string) ([]NodeVPNAssignment, error) {
var response struct {
Assignments []NodeVPNAssignment `json:"vpn_assignments"`
}
path := fmt.Sprintf("/clusters/%s/nodes/%s/vpn/assignments", clusterID, nodeID)
if err := c.getJSON(ctx, path, &response); err != nil {
return nil, err
}
return response.Assignments, nil
}
func (c *Client) ReportNodeVPNAssignmentStatus(ctx context.Context, clusterID, nodeID, vpnConnectionID string, request NodeVPNAssignmentStatusRequest) error {
path := fmt.Sprintf("/clusters/%s/nodes/%s/vpn/assignments/%s/status", clusterID, nodeID, vpnConnectionID)
return c.postJSON(ctx, path, request, nil)
}
func (c *Client) RenewNodeVPNAssignmentLease(ctx context.Context, clusterID, nodeID, vpnConnectionID, leaseID string, request NodeVPNAssignmentLeaseRenewRequest) error {
path := fmt.Sprintf("/clusters/%s/nodes/%s/vpn/assignments/%s/lease/%s/renew", clusterID, nodeID, vpnConnectionID, leaseID)
return c.postJSON(ctx, path, request, nil)
}
func (c *Client) SendVPNGatewayPacket(ctx context.Context, clusterID, vpnConnectionID string, packet []byte) error {
if len(packet) == 0 {
return nil
}
path := fmt.Sprintf("/clusters/%s/vpn-connections/%s/tunnel/gateway/packets", clusterID, vpnConnectionID)
return c.postBytes(ctx, path, packet)
}
func (c *Client) SendVPNGatewayPacketBatch(ctx context.Context, clusterID, vpnConnectionID string, packets [][]byte) error {
packets = cleanVPNPacketBatch(packets)
if len(packets) == 0 {
return nil
}
path := fmt.Sprintf("/clusters/%s/vpn-connections/%s/tunnel/gateway/packets?batch=true", clusterID, vpnConnectionID)
return c.postBytes(ctx, path, encodeVPNPacketBatch(packets))
}
func (c *Client) ReceiveVPNGatewayPacket(ctx context.Context, clusterID, vpnConnectionID string, timeout time.Duration) ([]byte, bool, error) {
path := fmt.Sprintf("/clusters/%s/vpn-connections/%s/tunnel/gateway/packets?timeout_ms=%d", clusterID, vpnConnectionID, timeout.Milliseconds())
return c.getBytes(ctx, path)
}
func (c *Client) ReceiveVPNGatewayPacketBatch(ctx context.Context, clusterID, vpnConnectionID string, timeout time.Duration) ([][]byte, error) {
path := fmt.Sprintf("/clusters/%s/vpn-connections/%s/tunnel/gateway/packets?batch=true&timeout_ms=%d", clusterID, vpnConnectionID, timeout.Milliseconds())
payload, ok, err := c.getBytes(ctx, path)
if err != nil || !ok {
return nil, err
}
return decodeVPNPacketBatch(payload)
}
func (c *Client) ReportMeshLink(ctx context.Context, clusterID string, request MeshLinkObservationRequest) error { func (c *Client) ReportMeshLink(ctx context.Context, clusterID string, request MeshLinkObservationRequest) error {
path := fmt.Sprintf("/clusters/%s/mesh/links", clusterID) path := fmt.Sprintf("/clusters/%s/mesh/links", clusterID)
return c.postJSON(ctx, path, request, nil) return c.postJSON(ctx, path, request, nil)
@@ -375,6 +722,49 @@ func (c *Client) getJSON(ctx context.Context, path string, response any) error {
return json.NewDecoder(httpResp.Body).Decode(response) return json.NewDecoder(httpResp.Body).Decode(response)
} }
func (c *Client) getBytes(ctx context.Context, path string) ([]byte, bool, error) {
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+path, nil)
if err != nil {
return nil, false, err
}
httpResp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, false, err
}
defer httpResp.Body.Close()
if httpResp.StatusCode == http.StatusNoContent {
return nil, false, nil
}
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
return nil, false, fmt.Errorf("backend returned status %d", httpResp.StatusCode)
}
payload, err := io.ReadAll(io.LimitReader(httpResp.Body, vpnPacketBatchMaxBytes))
if err != nil {
return nil, false, err
}
if len(payload) == 0 {
return nil, false, nil
}
return payload, true, nil
}
func (c *Client) postBytes(ctx context.Context, path string, payload []byte) error {
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+path, bytes.NewReader(payload))
if err != nil {
return err
}
httpReq.Header.Set("Content-Type", "application/octet-stream")
httpResp, err := c.httpClient.Do(httpReq)
if err != nil {
return err
}
defer httpResp.Body.Close()
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
return fmt.Errorf("backend returned status %d", httpResp.StatusCode)
}
return nil
}
func (c *Client) postJSON(ctx context.Context, path string, request any, response any) error { func (c *Client) postJSON(ctx context.Context, path string, request any, response any) error {
payload, err := json.Marshal(request) payload, err := json.Marshal(request)
if err != nil { if err != nil {
@@ -398,3 +788,59 @@ func (c *Client) postJSON(ctx context.Context, path string, request any, respons
} }
return json.NewDecoder(httpResp.Body).Decode(response) return json.NewDecoder(httpResp.Body).Decode(response)
} }
const (
vpnPacketMaxBytes = 65535
vpnPacketBatchMaxBytes = 4 * 1024 * 1024
)
func encodeVPNPacketBatch(packets [][]byte) []byte {
packets = cleanVPNPacketBatch(packets)
total := 0
for _, packet := range packets {
total += 4 + len(packet)
}
out := make([]byte, total)
offset := 0
for _, packet := range packets {
binary.BigEndian.PutUint32(out[offset:offset+4], uint32(len(packet)))
offset += 4
copy(out[offset:offset+len(packet)], packet)
offset += len(packet)
}
return out
}
func decodeVPNPacketBatch(payload []byte) ([][]byte, error) {
var packets [][]byte
for offset := 0; offset < len(payload); {
if offset+4 > len(payload) {
return nil, fmt.Errorf("truncated vpn packet batch header")
}
size := int(binary.BigEndian.Uint32(payload[offset : offset+4]))
offset += 4
if size <= 0 || size > vpnPacketMaxBytes {
return nil, fmt.Errorf("invalid vpn packet batch item size")
}
if offset+size > len(payload) {
return nil, fmt.Errorf("truncated vpn packet batch item")
}
packets = append(packets, append([]byte(nil), payload[offset:offset+size]...))
offset += size
}
return cleanVPNPacketBatch(packets), nil
}
func cleanVPNPacketBatch(packets [][]byte) [][]byte {
if len(packets) == 0 {
return nil
}
cleaned := make([][]byte, 0, len(packets))
for _, packet := range packets {
if len(packet) == 0 {
continue
}
cleaned = append(cleaned, append([]byte(nil), packet...))
}
return cleaned
}
@@ -0,0 +1,121 @@
package mesh
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"time"
)
type VPNPacketBatchPayload struct {
SchemaVersion string `json:"schema_version"`
VPNConnectionID string `json:"vpn_connection_id"`
Direction string `json:"direction"`
Packets [][]byte `json:"packets"`
SentAt time.Time `json:"sent_at"`
}
type ProductionVPNPacketEnvelopeInput struct {
MessageID string
RouteID string
ClusterID string
SourceNodeID string
DestinationNodeID string
CurrentHopNodeID string
NextHopNodeID string
RoutePath []string
TTL int
HopCount int
ExpiresAt time.Time
VPNConnectionID string
Direction string
Packets [][]byte
Now time.Time
}
func NewProductionVPNPacketBatchEnvelope(input ProductionVPNPacketEnvelopeInput) (ProductionEnvelope, error) {
now := input.Now.UTC()
if now.IsZero() {
now = time.Now().UTC()
}
packets := cleanProductionVPNPacketBatch(input.Packets)
if len(packets) == 0 {
return ProductionEnvelope{}, fmt.Errorf("%w: empty vpn packet batch", ErrForwardEnvelopeInvalid)
}
if input.MessageID == "" {
input.MessageID = fmt.Sprintf("vpnpkt-%d", now.UnixNano())
}
if input.TTL <= 0 {
input.TTL = 8
}
if input.ExpiresAt.IsZero() {
input.ExpiresAt = now.Add(15 * time.Second)
}
payload, err := json.Marshal(VPNPacketBatchPayload{
SchemaVersion: "rap.vpn_packet_batch.v1",
VPNConnectionID: input.VPNConnectionID,
Direction: input.Direction,
Packets: packets,
SentAt: now,
})
if err != nil {
return ProductionEnvelope{}, err
}
if len(payload) > MaxProductionVPNPacketPayloadBytes {
return ProductionEnvelope{}, fmt.Errorf("%w: vpn packet batch exceeds channel limit", ErrForwardEnvelopeInvalid)
}
sum := sha256.Sum256(payload)
return ProductionEnvelope{
FabricProtocolVersion: ProtocolVersion,
MessageID: input.MessageID,
RouteID: input.RouteID,
ClusterID: input.ClusterID,
SourceNodeID: input.SourceNodeID,
DestinationNodeID: input.DestinationNodeID,
CurrentHopNodeID: input.CurrentHopNodeID,
NextHopNodeID: input.NextHopNodeID,
RoutePath: append([]string{}, input.RoutePath...),
ChannelClass: ProductionChannelVPNPacket,
MessageType: ProductionMessageVPNPacketBatch,
TTL: input.TTL,
HopCount: input.HopCount,
CreatedAt: now,
ExpiresAt: input.ExpiresAt.UTC(),
PayloadLength: len(payload),
PayloadHash: hex.EncodeToString(sum[:]),
Payload: payload,
}, nil
}
func DecodeProductionVPNPacketBatch(envelope ProductionEnvelope) (VPNPacketBatchPayload, error) {
if envelope.ChannelClass != ProductionChannelVPNPacket || envelope.MessageType != ProductionMessageVPNPacketBatch {
return VPNPacketBatchPayload{}, ErrUnauthorizedChannel
}
var payload VPNPacketBatchPayload
if err := json.Unmarshal(envelope.Payload, &payload); err != nil {
return VPNPacketBatchPayload{}, err
}
if payload.SchemaVersion != "rap.vpn_packet_batch.v1" || payload.VPNConnectionID == "" {
return VPNPacketBatchPayload{}, fmt.Errorf("%w: invalid vpn packet batch payload", ErrForwardEnvelopeInvalid)
}
payload.Packets = cleanProductionVPNPacketBatch(payload.Packets)
if len(payload.Packets) == 0 {
return VPNPacketBatchPayload{}, fmt.Errorf("%w: empty vpn packet batch payload", ErrForwardEnvelopeInvalid)
}
return payload, nil
}
func cleanProductionVPNPacketBatch(packets [][]byte) [][]byte {
if len(packets) == 0 {
return nil
}
cleaned := make([][]byte, 0, len(packets))
for _, packet := range packets {
if len(packet) == 0 {
continue
}
cleaned = append(cleaned, append([]byte(nil), packet...))
}
return cleaned
}
@@ -0,0 +1,77 @@
package vpnruntime
import "encoding/binary"
func normalizeIPv4PacketChecksums(packet []byte) bool {
if len(packet) < 20 || packet[0]>>4 != 4 {
return false
}
ihl := int(packet[0]&0x0f) * 4
if ihl < 20 || len(packet) < ihl {
return false
}
totalLen := int(binary.BigEndian.Uint16(packet[2:4]))
if totalLen <= 0 || totalLen > len(packet) {
totalLen = len(packet)
}
if totalLen < ihl {
return false
}
packet[10], packet[11] = 0, 0
binary.BigEndian.PutUint16(packet[10:12], checksum(packet[:ihl]))
proto := packet[9]
payload := packet[ihl:totalLen]
switch proto {
case 6:
if len(payload) < 20 {
return true
}
payload[16], payload[17] = 0, 0
binary.BigEndian.PutUint16(payload[16:18], transportChecksum(packet, payload, proto))
case 17:
if len(payload) < 8 {
return true
}
payload[6], payload[7] = 0, 0
sum := transportChecksum(packet, payload, proto)
if sum == 0 {
sum = 0xffff
}
binary.BigEndian.PutUint16(payload[6:8], sum)
case 1:
if len(payload) < 4 {
return true
}
payload[2], payload[3] = 0, 0
binary.BigEndian.PutUint16(payload[2:4], checksum(payload))
}
return true
}
func transportChecksum(ipHeader []byte, payload []byte, proto byte) uint16 {
pseudo := make([]byte, 12+len(payload))
copy(pseudo[0:4], ipHeader[12:16])
copy(pseudo[4:8], ipHeader[16:20])
pseudo[8] = 0
pseudo[9] = proto
binary.BigEndian.PutUint16(pseudo[10:12], uint16(len(payload)))
copy(pseudo[12:], payload)
return checksum(pseudo)
}
func checksum(data []byte) uint16 {
var sum uint32
for len(data) >= 2 {
sum += uint32(binary.BigEndian.Uint16(data[:2]))
data = data[2:]
}
if len(data) == 1 {
sum += uint32(data[0]) << 8
}
for (sum >> 16) != 0 {
sum = (sum & 0xffff) + (sum >> 16)
}
return ^uint16(sum)
}
@@ -0,0 +1,52 @@
package vpnruntime
import (
"encoding/binary"
"testing"
)
func TestNormalizeIPv4PacketChecksumsRepairsTCP(t *testing.T) {
packet := []byte{
0x45, 0x00, 0x00, 0x28, 0x00, 0x00, 0x40, 0x00, 0x40, 0x06, 0x12, 0x34, 192, 168, 200, 61, 10, 77, 0, 2,
0x46, 0xa0, 0xdd, 0x78, 0, 0, 0, 1, 0, 0, 0, 0, 0x50, 0x12, 0x72, 0x10, 0xab, 0xcd, 0, 0,
}
if !normalizeIPv4PacketChecksums(packet) {
t.Fatal("normalize returned false")
}
if got := checksum(packet[:20]); got != 0 {
t.Fatalf("ip checksum verification = %#x, want 0", got)
}
tcp := packet[20:]
pseudo := make([]byte, 12+len(tcp))
copy(pseudo[0:4], packet[12:16])
copy(pseudo[4:8], packet[16:20])
pseudo[9] = 6
binary.BigEndian.PutUint16(pseudo[10:12], uint16(len(tcp)))
copy(pseudo[12:], tcp)
if got := checksum(pseudo); got != 0 {
t.Fatalf("tcp checksum verification = %#x, want 0", got)
}
}
func TestNormalizeIPv4PacketChecksumsRepairsUDP(t *testing.T) {
packet := []byte{
0x45, 0x00, 0x00, 0x20, 0, 0, 0x40, 0, 0x40, 0x11, 0x12, 0x34, 10, 77, 0, 2, 192, 168, 200, 210,
0x30, 0x39, 0x00, 0x35, 0x00, 0x0c, 0xab, 0xcd, 0xde, 0xad, 0xbe, 0xef,
}
if !normalizeIPv4PacketChecksums(packet) {
t.Fatal("normalize returned false")
}
if got := checksum(packet[:20]); got != 0 {
t.Fatalf("ip checksum verification = %#x, want 0", got)
}
udp := packet[20:]
pseudo := make([]byte, 12+len(udp))
copy(pseudo[0:4], packet[12:16])
copy(pseudo[4:8], packet[16:20])
pseudo[9] = 17
binary.BigEndian.PutUint16(pseudo[10:12], uint16(len(udp)))
copy(pseudo[12:], udp)
if got := checksum(pseudo); got != 0 {
t.Fatalf("udp checksum verification = %#x, want 0", got)
}
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,495 @@
package vpnruntime
import (
"context"
"fmt"
"io"
"log"
"net"
"sync"
"sync/atomic"
"time"
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/client"
)
type Gateway struct {
API *client.Client
Transport PacketTransport
ClusterID string
VPNConnectionID string
InterfaceName string
AddressCIDR string
RouteCIDR string
PollTimeout time.Duration
mu sync.Mutex
running bool
lastErr error
cancel context.CancelFunc
clientToGatewayBatches atomic.Uint64
clientToGatewayPackets atomic.Uint64
clientToGatewayBytes atomic.Uint64
gatewayToClientBatches atomic.Uint64
gatewayToClientPackets atomic.Uint64
gatewayToClientBytes atomic.Uint64
tunReadPackets atomic.Uint64
tunReadBytes atomic.Uint64
tunWritePackets atomic.Uint64
tunWriteBytes atomic.Uint64
uploadQueueDrops atomic.Uint64
downloadErrors atomic.Uint64
uploadErrors atomic.Uint64
lastClientToGatewayPacket string
lastGatewayToClientPacket string
lastRuntimeActivityAt time.Time
}
const (
vpnGatewayBatchMaxPackets = 2048
vpnGatewayBatchMaxBytes = 8 * 1024 * 1024
vpnGatewayBatchFlushTimeout = 3 * time.Millisecond
)
type readWriteCloser interface {
io.Reader
io.Writer
io.Closer
}
type PacketTransport interface {
SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error
ReceiveGatewayPacketBatch(ctx context.Context, timeout time.Duration) ([][]byte, error)
}
type BackendPacketTransport struct {
API *client.Client
ClusterID string
VPNConnectionID string
}
func (t BackendPacketTransport) SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error {
return t.API.SendVPNGatewayPacketBatch(ctx, t.ClusterID, t.VPNConnectionID, packets)
}
func (t BackendPacketTransport) ReceiveGatewayPacketBatch(ctx context.Context, timeout time.Duration) ([][]byte, error) {
return t.API.ReceiveVPNGatewayPacketBatch(ctx, t.ClusterID, t.VPNConnectionID, timeout)
}
func (g *Gateway) EnsureStarted(ctx context.Context) error {
g.mu.Lock()
if g.running {
g.mu.Unlock()
return nil
}
g.mu.Unlock()
if err := g.normalize(); err != nil {
g.setStopped(err)
return err
}
tun, err := openGatewayTun(g.InterfaceName, g.AddressCIDR, g.RouteCIDR)
if err != nil {
g.setStopped(err)
return err
}
runCtx, cancel := context.WithCancel(ctx)
g.mu.Lock()
if g.running {
g.mu.Unlock()
cancel()
_ = tun.Close()
return nil
}
g.running = true
g.lastErr = nil
g.cancel = cancel
g.mu.Unlock()
go func() {
if err := g.run(runCtx, tun); err != nil && runCtx.Err() == nil {
log.Printf("vpn gateway runtime stopped: vpn_connection_id=%s error=%v", g.VPNConnectionID, err)
g.setStopped(err)
return
}
g.setStopped(runCtx.Err())
}()
return nil
}
func (g *Gateway) Stop() {
g.mu.Lock()
cancel := g.cancel
g.cancel = nil
g.running = false
g.mu.Unlock()
if cancel != nil {
cancel()
}
}
func (g *Gateway) Status() (bool, string) {
g.mu.Lock()
defer g.mu.Unlock()
if g.lastErr != nil {
return g.running, g.lastErr.Error()
}
return g.running, ""
}
func (g *Gateway) IsReadyForConnection(vpnConnectionID string) bool {
g.mu.Lock()
defer g.mu.Unlock()
return g.running && g.VPNConnectionID == vpnConnectionID && vpnConnectionID != ""
}
func (g *Gateway) Snapshot() map[string]any {
g.mu.Lock()
running := g.running
lastErr := ""
if g.lastErr != nil {
lastErr = g.lastErr.Error()
}
lastClientToGatewayPacket := g.lastClientToGatewayPacket
lastGatewayToClientPacket := g.lastGatewayToClientPacket
lastRuntimeActivityAt := g.lastRuntimeActivityAt
g.mu.Unlock()
out := map[string]any{
"running": running,
"transport": g.transportName(),
"poll_timeout_ms": g.PollTimeout.Milliseconds(),
"client_to_gateway_batches": g.clientToGatewayBatches.Load(),
"client_to_gateway_packets": g.clientToGatewayPackets.Load(),
"client_to_gateway_bytes": g.clientToGatewayBytes.Load(),
"gateway_to_client_batches": g.gatewayToClientBatches.Load(),
"gateway_to_client_packets": g.gatewayToClientPackets.Load(),
"gateway_to_client_bytes": g.gatewayToClientBytes.Load(),
"tun_read_packets": g.tunReadPackets.Load(),
"tun_read_bytes": g.tunReadBytes.Load(),
"tun_write_packets": g.tunWritePackets.Load(),
"tun_write_bytes": g.tunWriteBytes.Load(),
"upload_queue_drops": g.uploadQueueDrops.Load(),
"download_errors": g.downloadErrors.Load(),
"upload_errors": g.uploadErrors.Load(),
"last_client_to_gateway": lastClientToGatewayPacket,
"last_gateway_to_client": lastGatewayToClientPacket,
}
if lastErr != "" {
out["last_error"] = lastErr
}
if !lastRuntimeActivityAt.IsZero() {
out["last_runtime_activity_at"] = lastRuntimeActivityAt.UTC().Format(time.RFC3339Nano)
}
return out
}
func (g *Gateway) transportName() string {
switch g.Transport.(type) {
case *FabricPacketTransport:
return "fabric_mesh"
case *LocalPacketTransport:
return "local_fabric_inbox"
case *AdaptivePacketTransport:
return "adaptive_fabric_backend"
case BackendPacketTransport:
return "backend_http_packet_relay"
default:
if g.Transport == nil {
return "none"
}
return fmt.Sprintf("%T", g.Transport)
}
}
func (g *Gateway) setStopped(err error) {
g.mu.Lock()
defer g.mu.Unlock()
g.running = false
g.lastErr = err
g.cancel = nil
}
func (g *Gateway) normalize() error {
if g.Transport == nil {
if g.API == nil {
return fmt.Errorf("api client or packet transport is required")
}
g.Transport = BackendPacketTransport{
API: g.API,
ClusterID: g.ClusterID,
VPNConnectionID: g.VPNConnectionID,
}
}
if g.ClusterID == "" || g.VPNConnectionID == "" {
return fmt.Errorf("cluster id and vpn connection id are required")
}
if g.InterfaceName == "" {
g.InterfaceName = "rapvpn0"
}
if g.AddressCIDR == "" {
g.AddressCIDR = "10.77.0.1/24"
}
if g.RouteCIDR == "" {
g.RouteCIDR = "10.77.0.0/24"
}
if g.PollTimeout <= 0 {
g.PollTimeout = 25 * time.Second
}
return nil
}
func (g *Gateway) run(ctx context.Context, tun readWriteCloser) error {
defer tun.Close()
errCh := make(chan error, 2)
go func() { errCh <- g.copyGatewayToClient(ctx, tun) }()
go func() { errCh <- g.copyClientToGateway(ctx, tun) }()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-errCh:
return err
}
}
func (g *Gateway) copyGatewayToClient(ctx context.Context, tun io.Reader) error {
packets := make(chan []byte, 32768)
errCh := make(chan error, 1)
go func() {
errCh <- g.uploadGatewayPackets(ctx, packets)
}()
buffer := make([]byte, 65535)
for {
select {
case <-ctx.Done():
return ctx.Err()
case err := <-errCh:
if err != nil {
return err
}
default:
}
n, err := tun.Read(buffer)
if err != nil {
return err
}
if n <= 0 {
continue
}
packet := append([]byte(nil), buffer[:n]...)
normalizeIPv4PacketChecksums(packet)
g.recordTunRead(packet)
select {
case packets <- packet:
default:
g.uploadQueueDrops.Add(1)
log.Printf("vpn gateway packet upload queue full; dropping packet: vpn_connection_id=%s", g.VPNConnectionID)
}
}
}
func (g *Gateway) uploadGatewayPackets(ctx context.Context, packets <-chan []byte) error {
batch := make([][]byte, 0, vpnGatewayBatchMaxPackets)
batchBytes := 0
timer := time.NewTimer(time.Hour)
if !timer.Stop() {
<-timer.C
}
timerActive := false
flush := func() {
if len(batch) == 0 {
return
}
packetCount := len(batch)
byteCount := packetBytesTotal(batch)
if err := g.Transport.SendGatewayPacketBatch(ctx, batch); err != nil {
g.uploadErrors.Add(1)
log.Printf("vpn gateway packet batch upload failed: vpn_connection_id=%s packets=%d error=%v", g.VPNConnectionID, len(batch), err)
} else {
g.recordGatewayToClientBatch(packetCount, byteCount, batch[0])
}
for i := range batch {
batch[i] = nil
}
batch = batch[:0]
batchBytes = 0
}
for {
if len(batch) == 0 && timerActive {
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timerActive = false
}
select {
case <-ctx.Done():
flush()
return ctx.Err()
case packet := <-packets:
packetBytes := len(packet)
if packetBytes <= 0 {
continue
}
packetFrameSize := 4 + packetBytes
if len(batch) > 0 {
if len(batch) >= vpnGatewayBatchMaxPackets || batchBytes+packetFrameSize > vpnGatewayBatchMaxBytes {
flush()
}
}
batch = append(batch, packet)
batchBytes += packetFrameSize
if len(batch) >= vpnGatewayBatchMaxPackets || batchBytes >= vpnGatewayBatchMaxBytes {
flush()
continue
}
if !timerActive {
timer.Reset(vpnGatewayBatchFlushTimeout)
timerActive = true
}
case <-timer.C:
timerActive = false
flush()
}
}
}
func (g *Gateway) copyClientToGateway(ctx context.Context, tun io.Writer) error {
for {
packets, err := g.Transport.ReceiveGatewayPacketBatch(ctx, g.PollTimeout)
if err != nil {
log.Printf("vpn gateway packet download failed: vpn_connection_id=%s error=%v", g.VPNConnectionID, err)
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(time.Second):
}
continue
}
if len(packets) == 0 {
continue
}
g.recordClientToGatewayBatch(len(packets), packetBytesTotal(packets), packets[0])
for _, packet := range packets {
normalizeIPv4PacketChecksums(packet)
if _, err := tun.Write(packet); err != nil {
g.downloadErrors.Add(1)
return err
}
g.recordTunWrite(packet)
}
}
}
func (g *Gateway) recordClientToGatewayBatch(packetCount int, byteCount int, first []byte) {
next := g.clientToGatewayBatches.Add(1)
g.clientToGatewayPackets.Add(uint64(packetCount))
g.clientToGatewayBytes.Add(uint64(byteCount))
summary := summarizePacket(first)
g.mu.Lock()
g.lastClientToGatewayPacket = summary
g.lastRuntimeActivityAt = time.Now().UTC()
g.mu.Unlock()
if next <= 5 {
log.Printf(
"vpn gateway client_to_gateway batch received: vpn_connection_id=%s batch=%d packets=%d bytes=%d first=%s",
g.VPNConnectionID,
next,
packetCount,
byteCount,
summary,
)
}
}
func (g *Gateway) recordGatewayToClientBatch(packetCount int, byteCount int, first []byte) {
next := g.gatewayToClientBatches.Add(1)
g.gatewayToClientPackets.Add(uint64(packetCount))
g.gatewayToClientBytes.Add(uint64(byteCount))
summary := summarizePacket(first)
g.mu.Lock()
g.lastGatewayToClientPacket = summary
g.lastRuntimeActivityAt = time.Now().UTC()
g.mu.Unlock()
if next <= 5 {
log.Printf(
"vpn gateway gateway_to_client batch uploaded: vpn_connection_id=%s batch=%d packets=%d bytes=%d first=%s",
g.VPNConnectionID,
next,
packetCount,
byteCount,
summary,
)
}
}
func (g *Gateway) recordTunWrite(packet []byte) {
next := g.tunWritePackets.Add(1)
g.tunWriteBytes.Add(uint64(len(packet)))
if next <= 5 {
log.Printf("vpn gateway packet written to tun: vpn_connection_id=%s packet=%d bytes=%d summary=%s", g.VPNConnectionID, next, len(packet), summarizePacket(packet))
}
}
func (g *Gateway) recordTunRead(packet []byte) {
next := g.tunReadPackets.Add(1)
g.tunReadBytes.Add(uint64(len(packet)))
if next <= 5 {
log.Printf("vpn gateway packet read from tun: vpn_connection_id=%s packet=%d bytes=%d summary=%s", g.VPNConnectionID, next, len(packet), summarizePacket(packet))
}
}
func packetBytesTotal(packets [][]byte) int {
total := 0
for _, packet := range packets {
total += len(packet)
}
return total
}
func summarizePacket(packet []byte) string {
if len(packet) < 1 {
return "empty"
}
version := packet[0] >> 4
switch version {
case 4:
return summarizeIPv4(packet)
case 6:
return summarizeIPv6(packet)
default:
return fmt.Sprintf("ip_version=%d bytes=%d", version, len(packet))
}
}
func summarizeIPv4(packet []byte) string {
if len(packet) < 20 {
return fmt.Sprintf("ipv4 truncated bytes=%d", len(packet))
}
ihl := int(packet[0]&0x0f) * 4
if ihl < 20 || len(packet) < ihl {
return fmt.Sprintf("ipv4 invalid_ihl=%d bytes=%d", ihl, len(packet))
}
proto := packet[9]
src := net.IP(packet[12:16]).String()
dst := net.IP(packet[16:20]).String()
return fmt.Sprintf("ipv4 %s -> %s proto=%d bytes=%d", src, dst, proto, len(packet))
}
func summarizeIPv6(packet []byte) string {
if len(packet) < 40 {
return fmt.Sprintf("ipv6 truncated bytes=%d", len(packet))
}
nextHeader := packet[6]
src := net.IP(packet[8:24]).String()
dst := net.IP(packet[24:40]).String()
return fmt.Sprintf("ipv6 %s -> %s next=%d bytes=%d", src, dst, nextHeader, len(packet))
}
@@ -0,0 +1,206 @@
//go:build linux
package vpnruntime
import (
"errors"
"fmt"
"net"
"os"
"os/exec"
"strings"
"syscall"
"unsafe"
)
const (
tunDevicePath = "/dev/net/tun"
iffTun = 0x0001
iffNoPI = 0x1000
tunSetIFF = 0x400454ca
ifNameSize = 16
)
type tunDevice struct {
file *os.File
fd int
name string
}
func openGatewayTun(name, addressCIDR, routeCIDR string) (*tunDevice, error) {
dev, err := openGatewayTunDevice(name)
if errors.Is(err, syscall.EBUSY) {
cleanupStaleGatewayInterface(name)
dev, err = openGatewayTunDevice(name)
}
if err != nil {
return nil, err
}
if err := configureGatewayInterface(name, addressCIDR, routeCIDR); err != nil {
_ = dev.Close()
return nil, err
}
return dev, nil
}
func openGatewayTunDevice(name string) (*tunDevice, error) {
file, err := os.OpenFile(tunDevicePath, os.O_RDWR, 0)
if err != nil {
return nil, fmt.Errorf("open %s: %w", tunDevicePath, err)
}
ifr := make([]byte, 40)
copy(ifr[:ifNameSize], []byte(name))
*(*uint16)(unsafe.Pointer(&ifr[ifNameSize])) = iffTun | iffNoPI
if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, file.Fd(), uintptr(tunSetIFF), uintptr(unsafe.Pointer(&ifr[0]))); errno != 0 {
file.Close()
return nil, fmt.Errorf("configure tun %s: %w", name, errno)
}
return &tunDevice{file: file, fd: int(file.Fd()), name: name}, nil
}
func cleanupStaleGatewayInterface(name string) {
if strings.TrimSpace(name) == "" {
return
}
_ = runCommand("ip", "link", "set", name, "down")
_ = runCommand("ip", "link", "delete", name)
}
func (d *tunDevice) Read(packet []byte) (int, error) {
return syscall.Read(d.fd, packet)
}
func (d *tunDevice) Write(packet []byte) (int, error) {
return syscall.Write(d.fd, packet)
}
func (d *tunDevice) Close() error {
_ = runCommand("ip", "link", "set", d.name, "down")
return d.file.Close()
}
func configureGatewayInterface(name, addressCIDR, routeCIDR string) error {
if _, _, err := net.ParseCIDR(addressCIDR); err != nil {
return fmt.Errorf("invalid vpn gateway address %q: %w", addressCIDR, err)
}
if err := runCommand("ip", "addr", "replace", addressCIDR, "dev", name); err != nil {
return err
}
if err := runCommand("ip", "link", "set", name, "up"); err != nil {
return err
}
if err := enableIPv4Forwarding(); err != nil {
return err
}
if err := disableReversePathFiltering(name); err != nil {
return err
}
if err := ensureForwardingRules(name); err != nil {
return err
}
if err := ensureMasqueradeRules(routeCIDR); err != nil {
return err
}
if err := ensureMSSClampRule(name); err != nil {
return err
}
return nil
}
func ensureMasqueradeRules(routeCIDR string) error {
egress, _ := defaultIPv4Interface()
if egress != "" {
if err := ensureIPTablesRule("nat", "POSTROUTING", "-s", routeCIDR, "-o", egress, "-j", "MASQUERADE"); err != nil {
return err
}
}
return ensureIPTablesRule("nat", "POSTROUTING", "-s", routeCIDR, "-j", "MASQUERADE")
}
func ensureMSSClampRule(interfaceName string) error {
err := ensureIPTablesRule("mangle", "FORWARD", "-i", interfaceName, "-p", "tcp", "--tcp-flags", "SYN,RST", "SYN", "-j", "TCPMSS", "--clamp-mss-to-pmtu")
if err == nil {
return nil
}
return nil
}
func defaultIPv4Interface() (string, error) {
out, err := exec.Command("ip", "-o", "-4", "route", "show", "default").CombinedOutput()
if err != nil {
return "", fmt.Errorf("ip default route failed: %w: %s", err, string(out))
}
fields := strings.Fields(string(out))
for i := 0; i+1 < len(fields); i++ {
if fields[i] == "dev" {
return fields[i+1], nil
}
}
return "", nil
}
func ensureForwardingRules(interfaceName string) error {
if err := ensureIPTablesRule("filter", "FORWARD", "-i", interfaceName, "-j", "ACCEPT"); err != nil {
return err
}
err := ensureIPTablesRule("filter", "FORWARD", "-o", interfaceName, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT")
if err == nil {
return nil
}
return ensureIPTablesRule("filter", "FORWARD", "-o", interfaceName, "-j", "ACCEPT")
}
func ensureIPTablesRule(table, chain string, rule ...string) error {
checkArgs := append([]string{"-t", table, "-C", chain}, rule...)
if err := runCommand("iptables", checkArgs...); err == nil {
return nil
}
addArgs := append([]string{"-t", table, "-I", chain, "1"}, rule...)
return runCommand("iptables", addArgs...)
}
func enableIPv4Forwarding() error {
if current, err := os.ReadFile("/proc/sys/net/ipv4/ip_forward"); err == nil && len(current) > 0 && current[0] == '1' {
return nil
}
if err := os.WriteFile("/proc/sys/net/ipv4/ip_forward", []byte("1\n"), 0o644); err == nil {
return nil
}
return runCommand("sysctl", "-w", "net.ipv4.ip_forward=1")
}
func disableReversePathFiltering(interfaceName string) error {
keys := []string{"all", "default", interfaceName}
if entries, err := os.ReadDir("/proc/sys/net/ipv4/conf"); err == nil {
for _, entry := range entries {
if entry.IsDir() {
keys = append(keys, entry.Name())
}
}
}
seen := make(map[string]bool)
for _, key := range keys {
if seen[key] {
continue
}
seen[key] = true
path := fmt.Sprintf("/proc/sys/net/ipv4/conf/%s/rp_filter", key)
if _, err := os.Stat(path); err != nil {
continue
}
if err := os.WriteFile(path, []byte("0\n"), 0o644); err != nil {
if sysctlErr := runCommand("sysctl", "-w", fmt.Sprintf("net.ipv4.conf.%s.rp_filter=0", key)); sysctlErr != nil {
return sysctlErr
}
}
}
return nil
}
func runCommand(name string, args ...string) error {
cmd := exec.Command(name, args...)
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("%s %v failed: %w: %s", name, args, err, string(out))
}
return nil
}
@@ -0,0 +1,23 @@
//go:build !linux && !windows
package vpnruntime
import "fmt"
type tunDevice struct{}
func openGatewayTun(name, addressCIDR, routeCIDR string) (*tunDevice, error) {
return nil, fmt.Errorf("vpn gateway runtime is currently supported only on linux")
}
func (d *tunDevice) Read(packet []byte) (int, error) {
return 0, fmt.Errorf("vpn gateway runtime is currently supported only on linux")
}
func (d *tunDevice) Write(packet []byte) (int, error) {
return 0, fmt.Errorf("vpn gateway runtime is currently supported only on linux")
}
func (d *tunDevice) Close() error {
return nil
}