diff --git a/agents/rap-node-agent/internal/client/client.go b/agents/rap-node-agent/internal/client/client.go index 2bced1c..d708639 100644 --- a/agents/rap-node-agent/internal/client/client.go +++ b/agents/rap-node-agent/internal/client/client.go @@ -3,9 +3,12 @@ package client import ( "bytes" "context" + "encoding/binary" "encoding/json" "fmt" + "io" "net/http" + "net/url" "time" ) @@ -64,6 +67,95 @@ type HeartbeatRequest struct { type HeartbeatResponse struct { Heartbeat json.RawMessage `json:"heartbeat"` TestingFlags EffectiveTestingFlags `json:"testing_flags"` + UpdateHint *NodeUpdateHint `json:"update_hint,omitempty"` +} + +type NodeUpdateHint struct { + SchemaVersion string `json:"schema_version"` + Generation string `json:"generation,omitempty"` + CheckNow bool `json:"check_now"` + Products []string `json:"products,omitempty"` + Reason string `json:"reason,omitempty"` + DeliveryMode string `json:"delivery_mode,omitempty"` + SubscriptionStatus string `json:"subscription_status,omitempty"` + UpdateService *NodeUpdateServiceAssignment `json:"update_service,omitempty"` + FallbackPollSeconds int `json:"fallback_poll_seconds,omitempty"` +} + +type NodeUpdateServiceAssignment struct { + SchemaVersion string `json:"schema_version"` + NodeID string `json:"node_id,omitempty"` + NodeName string `json:"node_name,omitempty"` + Endpoint string `json:"endpoint,omitempty"` + Region string `json:"region,omitempty"` + Status string `json:"status"` + Reason string `json:"reason,omitempty"` + AssignedAt time.Time `json:"assigned_at"` + ExpiresAt time.Time `json:"expires_at"` +} + +type NodeUpdatePlanRequest struct { + Product string + CurrentVersion string + OS string + Arch string + InstallType string + Channel string +} + +type NodeUpdatePlanResponse struct { + Plan NodeUpdatePlan `json:"node_update_plan"` +} + +type NodeUpdatePlan struct { + SchemaVersion string `json:"schema_version"` + ClusterID string `json:"cluster_id"` + NodeID string `json:"node_id"` + Product string `json:"product"` + CurrentVersion string `json:"current_version,omitempty"` + Action string `json:"action"` + Reason string `json:"reason"` + TargetVersion string `json:"target_version,omitempty"` + Channel string `json:"channel,omitempty"` + Strategy string `json:"strategy,omitempty"` + RollbackAllowed bool `json:"rollback_allowed"` + HealthWindowSec int `json:"health_window_seconds,omitempty"` + Artifact *ReleaseArtifact `json:"artifact,omitempty"` + AuthorityPayload json.RawMessage `json:"authority_payload,omitempty"` + AuthoritySignature *ClusterSignature `json:"authority_signature,omitempty"` + ProductionForwarding bool `json:"production_forwarding"` +} + +type ReleaseArtifact struct { + ID string `json:"id"` + ReleaseID string `json:"release_id"` + ClusterID string `json:"cluster_id"` + Product string `json:"product"` + Version string `json:"version"` + OS string `json:"os"` + Arch string `json:"arch"` + InstallType string `json:"install_type"` + Kind string `json:"kind"` + URL string `json:"url"` + URLs []string `json:"urls,omitempty"` + SHA256 string `json:"sha256"` + SizeBytes int64 `json:"size_bytes"` + Signature *string `json:"signature,omitempty"` + Metadata json.RawMessage `json:"metadata"` + CreatedAt time.Time `json:"created_at"` +} + +type NodeUpdateStatusRequest struct { + Product string `json:"product"` + CurrentVersion string `json:"current_version,omitempty"` + TargetVersion string `json:"target_version,omitempty"` + Phase string `json:"phase"` + Status string `json:"status"` + AttemptID string `json:"attempt_id,omitempty"` + ErrorMessage *string `json:"error_message,omitempty"` + RollbackVersion *string `json:"rollback_version,omitempty"` + Payload map[string]any `json:"payload,omitempty"` + ObservedAt time.Time `json:"observed_at,omitempty"` } type EffectiveTestingFlags struct { @@ -91,6 +183,45 @@ type WorkloadStatusRequest struct { StatusPayload map[string]any `json:"status_payload"` } +type NodeVPNAssignmentLease struct { + LeaseID string `json:"lease_id"` + OwnerNodeID string `json:"owner_node_id"` + LeaseGeneration int64 `json:"lease_generation"` + Status string `json:"status"` + RenewedAt time.Time `json:"renewed_at"` + ExpiresAt time.Time `json:"expires_at"` +} + +type NodeVPNAssignment struct { + VPNConnectionID string `json:"vpn_connection_id"` + ClusterID string `json:"cluster_id"` + OrganizationID string `json:"organization_id"` + Name string `json:"name"` + TargetEndpoint json.RawMessage `json:"target_endpoint"` + ProtocolFamily string `json:"protocol_family"` + Mode string `json:"mode"` + DesiredState string `json:"desired_state"` + RoutingUsage json.RawMessage `json:"routing_usage"` + RoutePolicy json.RawMessage `json:"route_policy"` + QoSPolicy json.RawMessage `json:"qos_policy"` + PlacementPolicy json.RawMessage `json:"placement_policy"` + Status string `json:"status"` + HasCredentialRef bool `json:"has_credential_ref"` + AssignmentReason string `json:"assignment_reason"` + ActiveLease *NodeVPNAssignmentLease `json:"active_lease,omitempty"` + UpdatedAt time.Time `json:"updated_at"` +} + +type NodeVPNAssignmentStatusRequest struct { + ObservedStatus string `json:"observed_status"` + StatusPayload map[string]any `json:"status_payload"` + ObservedAt time.Time `json:"observed_at,omitempty"` +} + +type NodeVPNAssignmentLeaseRenewRequest struct { + TTLSeconds int `json:"ttl_seconds"` +} + type MeshLinkObservationRequest struct { SourceNodeID string `json:"source_node_id"` TargetNodeID string `json:"target_node_id"` @@ -129,26 +260,155 @@ type SyntheticMeshRouteConfig struct { } type SyntheticMeshConfig struct { - Enabled bool `json:"enabled"` - SchemaVersion string `json:"schema_version"` - ClusterID string `json:"cluster_id"` - LocalNodeID string `json:"local_node_id"` - AuthorityRequired bool `json:"authority_required"` - ClusterAuthority *ClusterAuthorityDescriptor `json:"cluster_authority,omitempty"` - AuthorityPayload json.RawMessage `json:"authority_payload,omitempty"` - AuthoritySignature *ClusterSignature `json:"authority_signature,omitempty"` - ConfigVersion string `json:"config_version,omitempty"` - PeerDirectoryVersion string `json:"peer_directory_version,omitempty"` - PolicyVersion string `json:"policy_version,omitempty"` - PeerEndpoints map[string]string `json:"peer_endpoints"` - PeerEndpointCandidates map[string][]PeerEndpointCandidate `json:"peer_endpoint_candidates,omitempty"` - PeerDirectory []PeerDirectoryEntry `json:"peer_directory,omitempty"` - RecoverySeeds []PeerRecoverySeed `json:"recovery_seeds,omitempty"` - RendezvousLeases []PeerRendezvousLease `json:"rendezvous_leases,omitempty"` - RendezvousRelayPolicy *RendezvousRelayPolicyReport `json:"rendezvous_relay_policy,omitempty"` - RoutePathDecisions *RoutePathDecisionReport `json:"route_path_decisions,omitempty"` - Routes []SyntheticMeshRouteConfig `json:"routes"` - ProductionForwarding bool `json:"production_forwarding"` + Enabled bool `json:"enabled"` + SchemaVersion string `json:"schema_version"` + ClusterID string `json:"cluster_id"` + LocalNodeID string `json:"local_node_id"` + AuthorityRequired bool `json:"authority_required"` + ClusterAuthority *ClusterAuthorityDescriptor `json:"cluster_authority,omitempty"` + AuthorityPayload json.RawMessage `json:"authority_payload,omitempty"` + AuthoritySignature *ClusterSignature `json:"authority_signature,omitempty"` + ConfigVersion string `json:"config_version,omitempty"` + PeerDirectoryVersion string `json:"peer_directory_version,omitempty"` + PolicyVersion string `json:"policy_version,omitempty"` + PeerEndpoints map[string]string `json:"peer_endpoints"` + PeerEndpointCandidates map[string][]PeerEndpointCandidate `json:"peer_endpoint_candidates,omitempty"` + PeerDirectory []PeerDirectoryEntry `json:"peer_directory,omitempty"` + RecoverySeeds []PeerRecoverySeed `json:"recovery_seeds,omitempty"` + RendezvousLeases []PeerRendezvousLease `json:"rendezvous_leases,omitempty"` + RendezvousRelayPolicy *RendezvousRelayPolicyReport `json:"rendezvous_relay_policy,omitempty"` + RoutePathDecisions *RoutePathDecisionReport `json:"route_path_decisions,omitempty"` + ServiceChannelFeedback *FabricServiceChannelFeedbackReport `json:"service_channel_route_feedback,omitempty"` + ServiceChannelAdaptivePolicy *FabricServiceChannelAdaptivePolicy `json:"service_channel_adaptive_policy,omitempty"` + ServiceChannelRemediationCommands []FabricServiceChannelRemediationCommand `json:"service_channel_remediation_commands,omitempty"` + MeshListener *MeshListenerConfig `json:"mesh_listener,omitempty"` + Routes []SyntheticMeshRouteConfig `json:"routes"` + ProductionForwarding bool `json:"production_forwarding"` +} + +type FabricServiceChannelRemediationCommand struct { + SchemaVersion string `json:"schema_version"` + CommandID string `json:"command_id"` + Action string `json:"action"` + ClusterID string `json:"cluster_id"` + ChannelID string `json:"channel_id"` + ResourceID string `json:"resource_id,omitempty"` + ServiceClass string `json:"service_class"` + EntryNodeID string `json:"entry_node_id,omitempty"` + ExitNodeID string `json:"exit_node_id,omitempty"` + PrimaryRouteID string `json:"primary_route_id,omitempty"` + ReplacementRouteID string `json:"replacement_route_id,omitempty"` + ReplacementRouteStatus string `json:"replacement_route_status,omitempty"` + PoolPolicyFingerprint string `json:"pool_policy_fingerprint,omitempty"` + GuardStatus string `json:"guard_status,omitempty"` + GuardReason string `json:"guard_reason,omitempty"` + ExecutionStatus string `json:"execution_status,omitempty"` + ExecutionReason string `json:"execution_reason,omitempty"` + ExecutionGeneration string `json:"execution_generation,omitempty"` + ExecutionObservedAt string `json:"execution_observed_at,omitempty"` + Reason string `json:"reason,omitempty"` + OperatorAction string `json:"operator_action,omitempty"` + IssuedAt time.Time `json:"issued_at"` + ExpiresAt time.Time `json:"expires_at"` +} + +type FabricServiceChannelFeedbackReport struct { + SchemaVersion string `json:"schema_version"` + GeneratedAt time.Time `json:"generated_at"` + FeedbackMaxAgeSeconds int `json:"feedback_max_age_seconds"` + RecoveryPolicy *FabricServiceChannelRecoveryPolicy `json:"recovery_policy,omitempty"` + MissingProvenanceCount int `json:"missing_provenance_count,omitempty"` + StalePolicyCount int `json:"stale_policy_count,omitempty"` + StaleGenerationCount int `json:"stale_generation_count,omitempty"` + ObservationCount int `json:"observation_count"` + FencedRouteCount int `json:"fenced_route_count"` + DegradedRouteCount int `json:"degraded_route_count"` + HealthyRouteCount int `json:"healthy_route_count"` + RecoveredRouteCount int `json:"recovered_route_count,omitempty"` + RecoveryHysteresisCount int `json:"recovery_hysteresis_count,omitempty"` + RecoveryPromotedCount int `json:"recovery_promoted_count,omitempty"` + RecoveryDemotedCount int `json:"recovery_demoted_count,omitempty"` + Observations []FabricServiceChannelFeedbackObservation `json:"observations,omitempty"` +} + +type FabricServiceChannelAdaptivePolicy struct { + SchemaVersion string `json:"schema_version"` + Fingerprint string `json:"fingerprint,omitempty"` + MaxParallelWindow int `json:"max_parallel_window"` + BulkPressureChannelThreshold int `json:"bulk_pressure_channel_threshold"` + QueuePressureHighWatermark int `json:"queue_pressure_high_watermark"` + QueuePressureMaxInFlight int `json:"queue_pressure_max_in_flight"` + ClassWindows map[string]int `json:"class_windows"` + Source string `json:"source"` + UpdatedByUserID *string `json:"updated_by_user_id,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` + ControlPlaneOnly bool `json:"control_plane_only"` + ProductionForwarding bool `json:"production_forwarding"` +} + +type FabricServiceChannelRecoveryPolicy struct { + SchemaVersion string `json:"schema_version"` + Fingerprint string `json:"fingerprint,omitempty"` + HysteresisPenalty int `json:"hysteresis_penalty"` + PromotionMinSamples int `json:"promotion_min_samples"` + DemotionFailureThreshold int `json:"demotion_failure_threshold"` + DemotionDropThreshold int `json:"demotion_drop_threshold"` + DemotionSlowThreshold int `json:"demotion_slow_threshold"` + DemotionRebuildEnabled bool `json:"demotion_rebuild_enabled"` + DemotionFencedEnabled bool `json:"demotion_fenced_enabled"` + Source string `json:"source"` + UpdatedByUserID *string `json:"updated_by_user_id,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` + ControlPlaneOnly bool `json:"control_plane_only"` + ProductionForwarding bool `json:"production_forwarding"` +} + +type FabricServiceChannelFeedbackObservation struct { + ID string `json:"id,omitempty"` + ClusterID string `json:"cluster_id"` + ReporterNodeID string `json:"reporter_node_id"` + RouteID string `json:"route_id"` + ServiceClass string `json:"service_class"` + FeedbackStatus string `json:"feedback_status"` + ScoreAdjustment int `json:"score_adjustment"` + EffectiveScoreAdjustment int `json:"effective_score_adjustment,omitempty"` + Reasons []string `json:"reasons,omitempty"` + LastError string `json:"last_error,omitempty"` + ConsecutiveFailures int `json:"consecutive_failures,omitempty"` + StallCount int `json:"stall_count,omitempty"` + LastSendDurationMs int64 `json:"last_send_duration_ms,omitempty"` + RecoveryState string `json:"recovery_state,omitempty"` + ObservedPolicyFingerprint string `json:"observed_policy_fingerprint,omitempty"` + EffectivePolicyFingerprint string `json:"effective_policy_fingerprint,omitempty"` + ObservedRouteGeneration string `json:"observed_route_generation,omitempty"` + EffectiveRouteGeneration string `json:"effective_route_generation,omitempty"` + ProvenanceMissing bool `json:"provenance_missing,omitempty"` + StalePolicy bool `json:"stale_policy,omitempty"` + StaleGeneration bool `json:"stale_generation,omitempty"` + StaleReason string `json:"stale_reason,omitempty"` + Payload json.RawMessage `json:"payload"` + ObservedAt time.Time `json:"observed_at"` + ExpiresAt time.Time `json:"expires_at"` +} + +type MeshListenerConfig struct { + SchemaVersion string `json:"schema_version"` + Source string `json:"source"` + DesiredState string `json:"desired_state"` + ListenAddr string `json:"listen_addr"` + ListenPortMode string `json:"listen_port_mode"` + AutoPortStart int `json:"auto_port_start,omitempty"` + AutoPortEnd int `json:"auto_port_end,omitempty"` + AdvertiseEndpoint string `json:"advertise_endpoint,omitempty"` + AdvertiseTransport string `json:"advertise_transport,omitempty"` + ConnectivityMode string `json:"connectivity_mode,omitempty"` + NATType string `json:"nat_type,omitempty"` + Region string `json:"region,omitempty"` + ConfigVersion string `json:"config_version,omitempty"` + UpdatedByUserID string `json:"updated_by_user_id,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` + ControlPlaneOnly bool `json:"control_plane_only"` + ProductionForwarding bool `json:"production_forwarding"` } type ClusterAuthorityDescriptor struct { @@ -233,6 +493,11 @@ type RendezvousRelayPolicyReport struct { type RoutePathDecision struct { DecisionID string `json:"decision_id"` RouteID string `json:"route_id"` + ReplacementRouteID string `json:"replacement_route_id,omitempty"` + RebuildRequestID string `json:"rebuild_request_id,omitempty"` + RebuildStatus string `json:"rebuild_status,omitempty"` + RebuildReason string `json:"rebuild_reason,omitempty"` + RebuildAttempt int `json:"rebuild_attempt,omitempty"` ClusterID string `json:"cluster_id"` LocalNodeID string `json:"local_node_id"` SourceNodeID string `json:"source_node_id"` @@ -258,14 +523,21 @@ type RoutePathDecision struct { } type RoutePathDecisionReport struct { - SchemaVersion string `json:"schema_version"` - DecisionMode string `json:"decision_mode"` - Generation string `json:"generation"` - DecisionCount int `json:"decision_count"` - ReplacementDecisionCount int `json:"replacement_decision_count"` - ControlPlaneOnly bool `json:"control_plane_only"` - ProductionForwarding bool `json:"production_forwarding"` - Decisions []RoutePathDecision `json:"decisions,omitempty"` + SchemaVersion string `json:"schema_version"` + DecisionMode string `json:"decision_mode"` + Generation string `json:"generation"` + RecoveryPolicy *FabricServiceChannelRecoveryPolicy `json:"recovery_policy,omitempty"` + DecisionCount int `json:"decision_count"` + ReplacementDecisionCount int `json:"replacement_decision_count"` + DegradedDecisionCount int `json:"degraded_decision_count"` + RebuildRequestCount int `json:"rebuild_request_count,omitempty"` + RebuildAppliedCount int `json:"rebuild_applied_count,omitempty"` + RecoveryHysteresisCount int `json:"recovery_hysteresis_count,omitempty"` + RecoveryPromotedCount int `json:"recovery_promoted_count,omitempty"` + RecoveryDemotedCount int `json:"recovery_demoted_count,omitempty"` + ControlPlaneOnly bool `json:"control_plane_only"` + ProductionForwarding bool `json:"production_forwarding"` + Decisions []RoutePathDecision `json:"decisions,omitempty"` } type PeerEndpointCandidate struct { @@ -319,6 +591,29 @@ func (c *Client) Heartbeat(ctx context.Context, clusterID, nodeID string, reques return response, nil } +func (c *Client) NodeUpdatePlan(ctx context.Context, clusterID, nodeID string, request NodeUpdatePlanRequest) (NodeUpdatePlan, error) { + values := url.Values{} + values.Set("product", request.Product) + values.Set("current_version", request.CurrentVersion) + values.Set("os", request.OS) + values.Set("arch", request.Arch) + values.Set("install_type", request.InstallType) + if request.Channel != "" { + values.Set("channel", request.Channel) + } + var response NodeUpdatePlanResponse + path := fmt.Sprintf("/clusters/%s/nodes/%s/updates/plan?%s", clusterID, nodeID, values.Encode()) + if err := c.getJSON(ctx, path, &response); err != nil { + return NodeUpdatePlan{}, err + } + return response.Plan, nil +} + +func (c *Client) ReportNodeUpdateStatus(ctx context.Context, clusterID, nodeID string, request NodeUpdateStatusRequest) error { + path := fmt.Sprintf("/clusters/%s/nodes/%s/updates/status", clusterID, nodeID) + return c.postJSON(ctx, path, request, nil) +} + func (c *Client) DesiredWorkloads(ctx context.Context, clusterID, nodeID string) ([]DesiredWorkload, error) { var response struct { DesiredWorkloads []DesiredWorkload `json:"desired_workloads"` @@ -335,6 +630,58 @@ func (c *Client) ReportWorkloadStatus(ctx context.Context, clusterID, nodeID, se return c.postJSON(ctx, path, request, nil) } +func (c *Client) NodeVPNAssignments(ctx context.Context, clusterID, nodeID string) ([]NodeVPNAssignment, error) { + var response struct { + Assignments []NodeVPNAssignment `json:"vpn_assignments"` + } + path := fmt.Sprintf("/clusters/%s/nodes/%s/vpn/assignments", clusterID, nodeID) + if err := c.getJSON(ctx, path, &response); err != nil { + return nil, err + } + return response.Assignments, nil +} + +func (c *Client) ReportNodeVPNAssignmentStatus(ctx context.Context, clusterID, nodeID, vpnConnectionID string, request NodeVPNAssignmentStatusRequest) error { + path := fmt.Sprintf("/clusters/%s/nodes/%s/vpn/assignments/%s/status", clusterID, nodeID, vpnConnectionID) + return c.postJSON(ctx, path, request, nil) +} + +func (c *Client) RenewNodeVPNAssignmentLease(ctx context.Context, clusterID, nodeID, vpnConnectionID, leaseID string, request NodeVPNAssignmentLeaseRenewRequest) error { + path := fmt.Sprintf("/clusters/%s/nodes/%s/vpn/assignments/%s/lease/%s/renew", clusterID, nodeID, vpnConnectionID, leaseID) + return c.postJSON(ctx, path, request, nil) +} + +func (c *Client) SendVPNGatewayPacket(ctx context.Context, clusterID, vpnConnectionID string, packet []byte) error { + if len(packet) == 0 { + return nil + } + path := fmt.Sprintf("/clusters/%s/vpn-connections/%s/tunnel/gateway/packets", clusterID, vpnConnectionID) + return c.postBytes(ctx, path, packet) +} + +func (c *Client) SendVPNGatewayPacketBatch(ctx context.Context, clusterID, vpnConnectionID string, packets [][]byte) error { + packets = cleanVPNPacketBatch(packets) + if len(packets) == 0 { + return nil + } + path := fmt.Sprintf("/clusters/%s/vpn-connections/%s/tunnel/gateway/packets?batch=true", clusterID, vpnConnectionID) + return c.postBytes(ctx, path, encodeVPNPacketBatch(packets)) +} + +func (c *Client) ReceiveVPNGatewayPacket(ctx context.Context, clusterID, vpnConnectionID string, timeout time.Duration) ([]byte, bool, error) { + path := fmt.Sprintf("/clusters/%s/vpn-connections/%s/tunnel/gateway/packets?timeout_ms=%d", clusterID, vpnConnectionID, timeout.Milliseconds()) + return c.getBytes(ctx, path) +} + +func (c *Client) ReceiveVPNGatewayPacketBatch(ctx context.Context, clusterID, vpnConnectionID string, timeout time.Duration) ([][]byte, error) { + path := fmt.Sprintf("/clusters/%s/vpn-connections/%s/tunnel/gateway/packets?batch=true&timeout_ms=%d", clusterID, vpnConnectionID, timeout.Milliseconds()) + payload, ok, err := c.getBytes(ctx, path) + if err != nil || !ok { + return nil, err + } + return decodeVPNPacketBatch(payload) +} + func (c *Client) ReportMeshLink(ctx context.Context, clusterID string, request MeshLinkObservationRequest) error { path := fmt.Sprintf("/clusters/%s/mesh/links", clusterID) return c.postJSON(ctx, path, request, nil) @@ -375,6 +722,49 @@ func (c *Client) getJSON(ctx context.Context, path string, response any) error { return json.NewDecoder(httpResp.Body).Decode(response) } +func (c *Client) getBytes(ctx context.Context, path string) ([]byte, bool, error) { + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+path, nil) + if err != nil { + return nil, false, err + } + httpResp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, false, err + } + defer httpResp.Body.Close() + if httpResp.StatusCode == http.StatusNoContent { + return nil, false, nil + } + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + return nil, false, fmt.Errorf("backend returned status %d", httpResp.StatusCode) + } + payload, err := io.ReadAll(io.LimitReader(httpResp.Body, vpnPacketBatchMaxBytes)) + if err != nil { + return nil, false, err + } + if len(payload) == 0 { + return nil, false, nil + } + return payload, true, nil +} + +func (c *Client) postBytes(ctx context.Context, path string, payload []byte) error { + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+path, bytes.NewReader(payload)) + if err != nil { + return err + } + httpReq.Header.Set("Content-Type", "application/octet-stream") + httpResp, err := c.httpClient.Do(httpReq) + if err != nil { + return err + } + defer httpResp.Body.Close() + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + return fmt.Errorf("backend returned status %d", httpResp.StatusCode) + } + return nil +} + func (c *Client) postJSON(ctx context.Context, path string, request any, response any) error { payload, err := json.Marshal(request) if err != nil { @@ -398,3 +788,59 @@ func (c *Client) postJSON(ctx context.Context, path string, request any, respons } return json.NewDecoder(httpResp.Body).Decode(response) } + +const ( + vpnPacketMaxBytes = 65535 + vpnPacketBatchMaxBytes = 4 * 1024 * 1024 +) + +func encodeVPNPacketBatch(packets [][]byte) []byte { + packets = cleanVPNPacketBatch(packets) + total := 0 + for _, packet := range packets { + total += 4 + len(packet) + } + out := make([]byte, total) + offset := 0 + for _, packet := range packets { + binary.BigEndian.PutUint32(out[offset:offset+4], uint32(len(packet))) + offset += 4 + copy(out[offset:offset+len(packet)], packet) + offset += len(packet) + } + return out +} + +func decodeVPNPacketBatch(payload []byte) ([][]byte, error) { + var packets [][]byte + for offset := 0; offset < len(payload); { + if offset+4 > len(payload) { + return nil, fmt.Errorf("truncated vpn packet batch header") + } + size := int(binary.BigEndian.Uint32(payload[offset : offset+4])) + offset += 4 + if size <= 0 || size > vpnPacketMaxBytes { + return nil, fmt.Errorf("invalid vpn packet batch item size") + } + if offset+size > len(payload) { + return nil, fmt.Errorf("truncated vpn packet batch item") + } + packets = append(packets, append([]byte(nil), payload[offset:offset+size]...)) + offset += size + } + return cleanVPNPacketBatch(packets), nil +} + +func cleanVPNPacketBatch(packets [][]byte) [][]byte { + if len(packets) == 0 { + return nil + } + cleaned := make([][]byte, 0, len(packets)) + for _, packet := range packets { + if len(packet) == 0 { + continue + } + cleaned = append(cleaned, append([]byte(nil), packet...)) + } + return cleaned +} diff --git a/agents/rap-node-agent/internal/mesh/vpn_packet.go b/agents/rap-node-agent/internal/mesh/vpn_packet.go new file mode 100644 index 0000000..d19bf78 --- /dev/null +++ b/agents/rap-node-agent/internal/mesh/vpn_packet.go @@ -0,0 +1,121 @@ +package mesh + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "time" +) + +type VPNPacketBatchPayload struct { + SchemaVersion string `json:"schema_version"` + VPNConnectionID string `json:"vpn_connection_id"` + Direction string `json:"direction"` + Packets [][]byte `json:"packets"` + SentAt time.Time `json:"sent_at"` +} + +type ProductionVPNPacketEnvelopeInput struct { + MessageID string + RouteID string + ClusterID string + SourceNodeID string + DestinationNodeID string + CurrentHopNodeID string + NextHopNodeID string + RoutePath []string + TTL int + HopCount int + ExpiresAt time.Time + VPNConnectionID string + Direction string + Packets [][]byte + Now time.Time +} + +func NewProductionVPNPacketBatchEnvelope(input ProductionVPNPacketEnvelopeInput) (ProductionEnvelope, error) { + now := input.Now.UTC() + if now.IsZero() { + now = time.Now().UTC() + } + packets := cleanProductionVPNPacketBatch(input.Packets) + if len(packets) == 0 { + return ProductionEnvelope{}, fmt.Errorf("%w: empty vpn packet batch", ErrForwardEnvelopeInvalid) + } + if input.MessageID == "" { + input.MessageID = fmt.Sprintf("vpnpkt-%d", now.UnixNano()) + } + if input.TTL <= 0 { + input.TTL = 8 + } + if input.ExpiresAt.IsZero() { + input.ExpiresAt = now.Add(15 * time.Second) + } + payload, err := json.Marshal(VPNPacketBatchPayload{ + SchemaVersion: "rap.vpn_packet_batch.v1", + VPNConnectionID: input.VPNConnectionID, + Direction: input.Direction, + Packets: packets, + SentAt: now, + }) + if err != nil { + return ProductionEnvelope{}, err + } + if len(payload) > MaxProductionVPNPacketPayloadBytes { + return ProductionEnvelope{}, fmt.Errorf("%w: vpn packet batch exceeds channel limit", ErrForwardEnvelopeInvalid) + } + sum := sha256.Sum256(payload) + return ProductionEnvelope{ + FabricProtocolVersion: ProtocolVersion, + MessageID: input.MessageID, + RouteID: input.RouteID, + ClusterID: input.ClusterID, + SourceNodeID: input.SourceNodeID, + DestinationNodeID: input.DestinationNodeID, + CurrentHopNodeID: input.CurrentHopNodeID, + NextHopNodeID: input.NextHopNodeID, + RoutePath: append([]string{}, input.RoutePath...), + ChannelClass: ProductionChannelVPNPacket, + MessageType: ProductionMessageVPNPacketBatch, + TTL: input.TTL, + HopCount: input.HopCount, + CreatedAt: now, + ExpiresAt: input.ExpiresAt.UTC(), + PayloadLength: len(payload), + PayloadHash: hex.EncodeToString(sum[:]), + Payload: payload, + }, nil +} + +func DecodeProductionVPNPacketBatch(envelope ProductionEnvelope) (VPNPacketBatchPayload, error) { + if envelope.ChannelClass != ProductionChannelVPNPacket || envelope.MessageType != ProductionMessageVPNPacketBatch { + return VPNPacketBatchPayload{}, ErrUnauthorizedChannel + } + var payload VPNPacketBatchPayload + if err := json.Unmarshal(envelope.Payload, &payload); err != nil { + return VPNPacketBatchPayload{}, err + } + if payload.SchemaVersion != "rap.vpn_packet_batch.v1" || payload.VPNConnectionID == "" { + return VPNPacketBatchPayload{}, fmt.Errorf("%w: invalid vpn packet batch payload", ErrForwardEnvelopeInvalid) + } + payload.Packets = cleanProductionVPNPacketBatch(payload.Packets) + if len(payload.Packets) == 0 { + return VPNPacketBatchPayload{}, fmt.Errorf("%w: empty vpn packet batch payload", ErrForwardEnvelopeInvalid) + } + return payload, nil +} + +func cleanProductionVPNPacketBatch(packets [][]byte) [][]byte { + if len(packets) == 0 { + return nil + } + cleaned := make([][]byte, 0, len(packets)) + for _, packet := range packets { + if len(packet) == 0 { + continue + } + cleaned = append(cleaned, append([]byte(nil), packet...)) + } + return cleaned +} diff --git a/agents/rap-node-agent/internal/vpnruntime/checksum.go b/agents/rap-node-agent/internal/vpnruntime/checksum.go new file mode 100644 index 0000000..bf42208 --- /dev/null +++ b/agents/rap-node-agent/internal/vpnruntime/checksum.go @@ -0,0 +1,77 @@ +package vpnruntime + +import "encoding/binary" + +func normalizeIPv4PacketChecksums(packet []byte) bool { + if len(packet) < 20 || packet[0]>>4 != 4 { + return false + } + ihl := int(packet[0]&0x0f) * 4 + if ihl < 20 || len(packet) < ihl { + return false + } + totalLen := int(binary.BigEndian.Uint16(packet[2:4])) + if totalLen <= 0 || totalLen > len(packet) { + totalLen = len(packet) + } + if totalLen < ihl { + return false + } + + packet[10], packet[11] = 0, 0 + binary.BigEndian.PutUint16(packet[10:12], checksum(packet[:ihl])) + + proto := packet[9] + payload := packet[ihl:totalLen] + switch proto { + case 6: + if len(payload) < 20 { + return true + } + payload[16], payload[17] = 0, 0 + binary.BigEndian.PutUint16(payload[16:18], transportChecksum(packet, payload, proto)) + case 17: + if len(payload) < 8 { + return true + } + payload[6], payload[7] = 0, 0 + sum := transportChecksum(packet, payload, proto) + if sum == 0 { + sum = 0xffff + } + binary.BigEndian.PutUint16(payload[6:8], sum) + case 1: + if len(payload) < 4 { + return true + } + payload[2], payload[3] = 0, 0 + binary.BigEndian.PutUint16(payload[2:4], checksum(payload)) + } + return true +} + +func transportChecksum(ipHeader []byte, payload []byte, proto byte) uint16 { + pseudo := make([]byte, 12+len(payload)) + copy(pseudo[0:4], ipHeader[12:16]) + copy(pseudo[4:8], ipHeader[16:20]) + pseudo[8] = 0 + pseudo[9] = proto + binary.BigEndian.PutUint16(pseudo[10:12], uint16(len(payload))) + copy(pseudo[12:], payload) + return checksum(pseudo) +} + +func checksum(data []byte) uint16 { + var sum uint32 + for len(data) >= 2 { + sum += uint32(binary.BigEndian.Uint16(data[:2])) + data = data[2:] + } + if len(data) == 1 { + sum += uint32(data[0]) << 8 + } + for (sum >> 16) != 0 { + sum = (sum & 0xffff) + (sum >> 16) + } + return ^uint16(sum) +} diff --git a/agents/rap-node-agent/internal/vpnruntime/checksum_test.go b/agents/rap-node-agent/internal/vpnruntime/checksum_test.go new file mode 100644 index 0000000..c4b51f6 --- /dev/null +++ b/agents/rap-node-agent/internal/vpnruntime/checksum_test.go @@ -0,0 +1,52 @@ +package vpnruntime + +import ( + "encoding/binary" + "testing" +) + +func TestNormalizeIPv4PacketChecksumsRepairsTCP(t *testing.T) { + packet := []byte{ + 0x45, 0x00, 0x00, 0x28, 0x00, 0x00, 0x40, 0x00, 0x40, 0x06, 0x12, 0x34, 192, 168, 200, 61, 10, 77, 0, 2, + 0x46, 0xa0, 0xdd, 0x78, 0, 0, 0, 1, 0, 0, 0, 0, 0x50, 0x12, 0x72, 0x10, 0xab, 0xcd, 0, 0, + } + if !normalizeIPv4PacketChecksums(packet) { + t.Fatal("normalize returned false") + } + if got := checksum(packet[:20]); got != 0 { + t.Fatalf("ip checksum verification = %#x, want 0", got) + } + tcp := packet[20:] + pseudo := make([]byte, 12+len(tcp)) + copy(pseudo[0:4], packet[12:16]) + copy(pseudo[4:8], packet[16:20]) + pseudo[9] = 6 + binary.BigEndian.PutUint16(pseudo[10:12], uint16(len(tcp))) + copy(pseudo[12:], tcp) + if got := checksum(pseudo); got != 0 { + t.Fatalf("tcp checksum verification = %#x, want 0", got) + } +} + +func TestNormalizeIPv4PacketChecksumsRepairsUDP(t *testing.T) { + packet := []byte{ + 0x45, 0x00, 0x00, 0x20, 0, 0, 0x40, 0, 0x40, 0x11, 0x12, 0x34, 10, 77, 0, 2, 192, 168, 200, 210, + 0x30, 0x39, 0x00, 0x35, 0x00, 0x0c, 0xab, 0xcd, 0xde, 0xad, 0xbe, 0xef, + } + if !normalizeIPv4PacketChecksums(packet) { + t.Fatal("normalize returned false") + } + if got := checksum(packet[:20]); got != 0 { + t.Fatalf("ip checksum verification = %#x, want 0", got) + } + udp := packet[20:] + pseudo := make([]byte, 12+len(udp)) + copy(pseudo[0:4], packet[12:16]) + copy(pseudo[4:8], packet[16:20]) + pseudo[9] = 17 + binary.BigEndian.PutUint16(pseudo[10:12], uint16(len(udp))) + copy(pseudo[12:], udp) + if got := checksum(pseudo); got != 0 { + t.Fatalf("udp checksum verification = %#x, want 0", got) + } +} diff --git a/agents/rap-node-agent/internal/vpnruntime/fabric_transport.go b/agents/rap-node-agent/internal/vpnruntime/fabric_transport.go new file mode 100644 index 0000000..fc1ee70 --- /dev/null +++ b/agents/rap-node-agent/internal/vpnruntime/fabric_transport.go @@ -0,0 +1,2638 @@ +package vpnruntime + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "hash/fnv" + "sort" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/example/remote-access-platform/agents/rap-node-agent/internal/mesh" +) + +const ( + FabricDirectionClientToGateway = "client_to_gateway" + FabricDirectionGatewayToClient = "gateway_to_client" + + defaultFabricFlowShardCount = 32 + defaultFabricFlowQueueCapacity = 1024 + defaultFabricFlowParallelSendWindow = 4 + defaultFabricFlowQualityWindowCapacity = 32 + defaultFabricFlowFailureThreshold = 2 + defaultFabricFlowSlowSendThreshold = 2 * time.Second + defaultFabricRouteQualitySwitchThreshold = 30 +) + +type FabricPacketTransport struct { + ForwardTransport mesh.ProductionForwardTransport + Inbox *FabricPacketInbox + + ClusterID string + VPNConnectionID string + RouteID string + LocalNodeID string + RemoteNodeID string + NextHopNodeID string + RoutePath []string + SendDirection string + ReceiveDirection string +} + +type FabricClientPacketIngress struct { + ForwardTransport mesh.ProductionForwardTransport + Inbox *FabricPacketInbox + Routes func() []mesh.SyntheticRoute + LocalGateway func(vpnConnectionID string) bool + FlowScheduler *FabricFlowScheduler + MaxParallelFlowSends int + RecoveryPolicyFingerprint string + AdaptivePolicyFingerprint string + + ClusterID string + LocalNodeID string + RouteManager FabricServiceChannelRouteManager + RouteManagerTransition FabricServiceChannelRouteManagerTransition + RouteQualityPreferences map[string]FabricServiceChannelRouteQualityPreference + + mu sync.Mutex + lastSelectedRouteID string + lastSelectedNextHop string + lastError string + sendBatches uint64 + sendPackets uint64 + sendRouteAttempts uint64 + sendRouteFailures uint64 + sendFallbackLocal uint64 + sendFlowBatches uint64 + sendFlowPackets uint64 + sendFlowDropped uint64 + sendFlowParallel uint64 + receiveBatches uint64 + receivePackets uint64 + receiveEmpty uint64 +} + +type FabricServiceChannelRouteManager struct { + SchemaVersion string `json:"schema_version"` + Generation string `json:"generation,omitempty"` + RebuildRequestCount int `json:"rebuild_request_count"` + RebuildAppliedCount int `json:"rebuild_applied_count"` + WithdrawnRouteCount int `json:"withdrawn_route_count"` + PendingFallbackCount int `json:"pending_degraded_fallback_count"` + LastAppliedAt string `json:"last_applied_at,omitempty"` + Decisions []FabricServiceChannelRouteManagerDecision `json:"decisions,omitempty"` + withdrawnRoutes map[string]FabricServiceChannelRouteManagerDecision + replacements map[string]string +} + +type FabricServiceChannelRouteManagerTransition struct { + SchemaVersion string `json:"schema_version"` + PreviousGeneration string `json:"previous_generation,omitempty"` + Generation string `json:"generation,omitempty"` + Status string `json:"status,omitempty"` + ObservedAt string `json:"observed_at,omitempty"` + DecisionCount int `json:"decision_count"` + WithdrawnRouteCount int `json:"withdrawn_route_count"` + RestoredRouteCount int `json:"restored_route_count"` + ClearedSelectedRouteID string `json:"cleared_selected_route_id,omitempty"` + PendingFallbackCount int `json:"pending_degraded_fallback_count"` + RebuildAppliedCount int `json:"rebuild_applied_count"` +} + +type FabricServiceChannelRouteManagerDecision struct { + RouteID string `json:"route_id"` + ReplacementRouteID string `json:"replacement_route_id,omitempty"` + RebuildRequestID string `json:"rebuild_request_id,omitempty"` + RebuildStatus string `json:"rebuild_status,omitempty"` + RebuildReason string `json:"rebuild_reason,omitempty"` + RebuildAttempt int `json:"rebuild_attempt,omitempty"` + DecisionSource string `json:"decision_source,omitempty"` + Generation string `json:"generation,omitempty"` + EffectiveHops []string `json:"effective_hops,omitempty"` +} + +type FabricServiceChannelRouteQualityPreference struct { + RouteID string `json:"route_id"` + FeedbackStatus string `json:"feedback_status,omitempty"` + ScoreAdjustment int `json:"score_adjustment"` + RawScoreAdjustment int `json:"raw_score_adjustment,omitempty"` + Reasons []string `json:"reasons,omitempty"` + LastSendDurationMs int64 `json:"last_send_duration_ms,omitempty"` + ObservedAt string `json:"observed_at,omitempty"` + ExpiresAt string `json:"expires_at,omitempty"` +} + +type FabricFlowScheduler struct { + mu sync.Mutex + shardCount int + queueCapacity int + adaptivePolicy FabricServiceChannelAdaptivePolicy + queues map[string]*fabricFlowQueue + enqueued uint64 + dequeued uint64 + dropped uint64 + highWatermark int + inFlight int + maxInFlight int +} + +type FabricServiceChannelAdaptivePolicy struct { + SchemaVersion string + Fingerprint string + MaxParallelWindow int + BulkPressureChannelThreshold int + QueuePressureHighWatermark int + QueuePressureMaxInFlight int + ClassWindows map[string]int +} + +const ( + FabricTrafficClassControl = "control" + FabricTrafficClassInteractive = "interactive" + FabricTrafficClassReliable = "reliable" + FabricTrafficClassBulk = "bulk" + FabricTrafficClassDroppable = "droppable" +) + +type fabricFlowQueue struct { + TrafficClass string + Depth int + Enqueued uint64 + Dequeued uint64 + Dropped uint64 + HighWatermark int + Served uint64 + InFlight int + MaxInFlight int + SendAttempts uint64 + SendSuccesses uint64 + SendFailures uint64 + LastServedAt time.Time + LastRouteID string + RoutePolicyVersion string + RouteGeneration string + RecoveryPolicyFingerprint string + LastNextHop string + LastFailedRouteID string + LastFailedRoutePolicyVersion string + LastFailedRouteGeneration string + LastError string + ConsecutiveFailures uint64 + StallCount uint64 + LastSendDurationMillis int64 + RouteRebuildRecommended bool + DegradedFallbackRecommended bool + QualityPreferenceRouteID string + QualityPreferenceScore int + QualityPreferenceRawScore int + QualityPreferenceReasons []string + LatencyLe10Millis uint64 + LatencyLe100Millis uint64 + LatencyLe1000Millis uint64 + LatencyGt1000Millis uint64 + QualityWindow []fabricFlowQualitySample +} + +type fabricFlowQualitySample struct { + At time.Time + Success bool + Failure bool + Dropped bool + Slow bool + DurationMillis int64 +} + +type fabricFlowQualityWindowStats struct { + SampleCount int + SuccessCount int + FailureCount int + SlowCount int + DropCount int + AvgLatencyMs int64 + LastUpdatedAt time.Time +} + +type FabricScheduledPacketBatch struct { + ChannelID string + FlowID string + Shard int + TrafficClass string + Packets [][]byte + QueueDepth int + Dropped uint64 + Classifier string + ServiceMode string +} + +type FabricFlowSchedulerSnapshot struct { + SchemaVersion string `json:"schema_version"` + Enabled bool `json:"enabled"` + ServiceNeutral bool `json:"service_neutral"` + Classifier string `json:"classifier"` + ServiceMode string `json:"service_mode"` + ShardCount int `json:"shard_count"` + QueueCapacity int `json:"queue_capacity"` + ChannelCount int `json:"channel_count"` + Enqueued uint64 `json:"enqueued"` + Dequeued uint64 `json:"dequeued"` + Dropped uint64 `json:"dropped"` + HighWatermark int `json:"high_watermark"` + BackpressureActive bool `json:"backpressure_active"` + InFlight int `json:"in_flight"` + MaxInFlight int `json:"max_in_flight"` + AdaptiveBackpressureActive bool `json:"adaptive_backpressure_active,omitempty"` + AdaptiveBackpressureReason string `json:"adaptive_backpressure_reason,omitempty"` + RecommendedParallelWindows map[string]int `json:"recommended_parallel_windows,omitempty"` + AdaptivePolicyFingerprint string `json:"adaptive_policy_fingerprint,omitempty"` + SlowChannelCount int `json:"slow_channel_count"` + FailingChannelCount int `json:"failing_channel_count"` + QualityWindowSampleCount int `json:"quality_window_sample_count"` + QualityWindowFailureCount int `json:"quality_window_failure_count"` + QualityWindowSlowCount int `json:"quality_window_slow_count"` + QualityWindowDropCount int `json:"quality_window_drop_count"` + QueueDepths map[string]int `json:"queue_depths"` + TrafficClassCounts map[string]int `json:"traffic_class_counts,omitempty"` + ChannelStats map[string]FabricFlowStat `json:"channel_stats"` +} + +type FabricFlowStat struct { + TrafficClass string `json:"traffic_class,omitempty"` + Depth int `json:"depth"` + Enqueued uint64 `json:"enqueued"` + Dequeued uint64 `json:"dequeued"` + Dropped uint64 `json:"dropped"` + HighWatermark int `json:"high_watermark"` + Served uint64 `json:"served"` + InFlight int `json:"in_flight"` + MaxInFlight int `json:"max_in_flight"` + SendAttempts uint64 `json:"send_attempts"` + SendSuccesses uint64 `json:"send_successes"` + SendFailures uint64 `json:"send_failures"` + LastServedAt string `json:"last_served_at,omitempty"` + LastRouteID string `json:"last_route_id,omitempty"` + LastNextHop string `json:"last_next_hop,omitempty"` + RoutePolicyVersion string `json:"route_policy_version,omitempty"` + RouteGeneration string `json:"route_generation,omitempty"` + RecoveryPolicyFingerprint string `json:"recovery_policy_fingerprint,omitempty"` + LastFailedRouteID string `json:"last_failed_route_id,omitempty"` + LastFailedRoutePolicyVersion string `json:"last_failed_route_policy_version,omitempty"` + LastFailedRouteGeneration string `json:"last_failed_route_generation,omitempty"` + LastError string `json:"last_error,omitempty"` + ConsecutiveFailures uint64 `json:"consecutive_failures"` + StallCount uint64 `json:"stall_count"` + LastSendDurationMillis int64 `json:"last_send_duration_ms,omitempty"` + RouteRebuildRecommended bool `json:"route_rebuild_recommended"` + DegradedFallbackRecommended bool `json:"degraded_fallback_recommended"` + QualityPreferenceRouteID string `json:"quality_preference_route_id,omitempty"` + QualityPreferenceScore int `json:"quality_preference_score,omitempty"` + QualityPreferenceRawScore int `json:"quality_preference_raw_score,omitempty"` + QualityPreferenceReasons []string `json:"quality_preference_reasons,omitempty"` + LatencyLe10Millis uint64 `json:"latency_le_10ms"` + LatencyLe100Millis uint64 `json:"latency_le_100ms"` + LatencyLe1000Millis uint64 `json:"latency_le_1000ms"` + LatencyGt1000Millis uint64 `json:"latency_gt_1000ms"` + QualityWindowSampleCount int `json:"quality_window_sample_count"` + QualityWindowSuccessCount int `json:"quality_window_success_count"` + QualityWindowFailureCount int `json:"quality_window_failure_count"` + QualityWindowSlowCount int `json:"quality_window_slow_count"` + QualityWindowDropCount int `json:"quality_window_drop_count"` + QualityWindowAvgLatencyMs int64 `json:"quality_window_avg_latency_ms,omitempty"` + QualityWindowLastUpdatedAt string `json:"quality_window_last_updated_at,omitempty"` +} + +func NewFabricFlowScheduler(shardCount int, queueCapacity int) *FabricFlowScheduler { + if shardCount <= 0 { + shardCount = defaultFabricFlowShardCount + } + if queueCapacity <= 0 { + queueCapacity = defaultFabricFlowQueueCapacity + } + return &FabricFlowScheduler{ + shardCount: shardCount, + queueCapacity: queueCapacity, + adaptivePolicy: defaultFabricServiceChannelAdaptivePolicy(), + queues: map[string]*fabricFlowQueue{}, + } +} + +func defaultFabricServiceChannelAdaptivePolicy() FabricServiceChannelAdaptivePolicy { + return normalizeFabricServiceChannelAdaptivePolicy(FabricServiceChannelAdaptivePolicy{ + SchemaVersion: "rap.fabric_service_channel_adaptive_policy.v1", + MaxParallelWindow: defaultFabricFlowParallelSendWindow, + BulkPressureChannelThreshold: 16, + QueuePressureHighWatermark: 16, + QueuePressureMaxInFlight: defaultFabricFlowParallelSendWindow * 4, + ClassWindows: map[string]int{ + FabricTrafficClassControl: defaultFabricFlowParallelSendWindow, + FabricTrafficClassInteractive: defaultFabricFlowParallelSendWindow, + FabricTrafficClassReliable: 3, + FabricTrafficClassBulk: 1, + FabricTrafficClassDroppable: 1, + }, + }) +} + +func normalizeFabricServiceChannelAdaptivePolicy(policy FabricServiceChannelAdaptivePolicy) FabricServiceChannelAdaptivePolicy { + if policy.SchemaVersion == "" { + policy.SchemaVersion = "rap.fabric_service_channel_adaptive_policy.v1" + } + if policy.MaxParallelWindow <= 0 { + policy.MaxParallelWindow = defaultFabricFlowParallelSendWindow + } + if policy.BulkPressureChannelThreshold <= 0 { + policy.BulkPressureChannelThreshold = 16 + } + if policy.QueuePressureHighWatermark <= 0 { + policy.QueuePressureHighWatermark = 16 + } + if policy.QueuePressureMaxInFlight <= 0 { + policy.QueuePressureMaxInFlight = defaultFabricFlowParallelSendWindow * 4 + } + if policy.ClassWindows == nil { + policy.ClassWindows = map[string]int{} + } + defaults := map[string]int{ + FabricTrafficClassControl: policy.MaxParallelWindow, + FabricTrafficClassInteractive: policy.MaxParallelWindow, + FabricTrafficClassReliable: minPositive(policy.MaxParallelWindow, 3), + FabricTrafficClassBulk: 1, + FabricTrafficClassDroppable: 1, + } + next := map[string]int{} + for className, fallback := range defaults { + value := policy.ClassWindows[className] + if value <= 0 { + value = fallback + } + next[className] = clampFabricWindow(value, 1, policy.MaxParallelWindow) + } + policy.ClassWindows = next + return policy +} + +func (s *FabricFlowScheduler) ConfigureAdaptivePolicy(policy FabricServiceChannelAdaptivePolicy) { + if s == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.adaptivePolicy = normalizeFabricServiceChannelAdaptivePolicy(policy) +} + +func (s *FabricFlowScheduler) ScheduleClientPackets(packets [][]byte) []FabricScheduledPacketBatch { + return s.scheduleClientPackets("", "", packets) +} + +func (s *FabricFlowScheduler) ScheduleClientPacketsForConnection(vpnConnectionID string, packets [][]byte) []FabricScheduledPacketBatch { + return s.scheduleClientPackets(vpnConnectionID, "", packets) +} + +func (s *FabricFlowScheduler) ScheduleClientPacketsForConnectionClass(vpnConnectionID string, trafficClass string, packets [][]byte) []FabricScheduledPacketBatch { + return s.scheduleClientPackets(vpnConnectionID, trafficClass, packets) +} + +func (s *FabricFlowScheduler) scheduleClientPackets(vpnConnectionID string, trafficClass string, packets [][]byte) []FabricScheduledPacketBatch { + packets = cleanPacketBatch(packets) + if len(packets) == 0 { + return nil + } + if s == nil { + s = NewFabricFlowScheduler(0, 0) + } + trafficClass = normalizeFabricTrafficClass(trafficClass) + grouped := map[string]*FabricScheduledPacketBatch{} + for _, packet := range packets { + flowID, shard := classifyPacketFlow(packet, s.shardCountValue()) + channelID := fabricFlowChannelIDForClass(vpnConnectionID, trafficClass, shard) + queueDepth, dropped := s.enqueue(channelID, trafficClass) + if dropped { + continue + } + batch := grouped[channelID] + if batch == nil { + batch = &FabricScheduledPacketBatch{ + ChannelID: channelID, + FlowID: flowID, + Shard: shard, + TrafficClass: trafficClass, + Classifier: "ip_5tuple_or_packet_hash", + ServiceMode: "application_protocol_agnostic", + } + grouped[channelID] = batch + } + batch.Packets = append(batch.Packets, append([]byte(nil), packet...)) + batch.QueueDepth = queueDepth + } + out := make([]FabricScheduledPacketBatch, 0, len(grouped)) + for _, batch := range grouped { + out = append(out, *batch) + } + s.sortScheduledBatches(out) + return out +} + +func fabricFlowChannelID(vpnConnectionID string, shard int) string { + return fabricFlowChannelIDForClass(vpnConnectionID, "", shard) +} + +func fabricFlowChannelIDForClass(vpnConnectionID string, trafficClass string, shard int) string { + base := fmt.Sprintf("flow-%02d", shard) + vpnConnectionID = strings.TrimSpace(vpnConnectionID) + if vpnConnectionID == "" { + return base + } + trafficClass = normalizeFabricTrafficClass(trafficClass) + if trafficClass != "" && trafficClass != FabricTrafficClassBulk { + return "vpn:" + vpnConnectionID + ":" + trafficClass + ":" + base + } + return "vpn:" + vpnConnectionID + ":" + base +} + +func (s *FabricFlowScheduler) Complete(batch FabricScheduledPacketBatch) { + if s == nil || len(batch.Packets) == 0 { + return + } + s.dequeue(batch.ChannelID, len(batch.Packets)) +} + +func (s *FabricFlowScheduler) BeginSend(channelID string) { + if s == nil || channelID == "" { + return + } + s.mu.Lock() + defer s.mu.Unlock() + queue := s.ensureQueueLocked(channelID) + queue.InFlight++ + queue.SendAttempts++ + if queue.InFlight > queue.MaxInFlight { + queue.MaxInFlight = queue.InFlight + } + s.inFlight++ + if s.inFlight > s.maxInFlight { + s.maxInFlight = s.inFlight + } +} + +func (s *FabricFlowScheduler) EndSend(channelID string) { + if s == nil || channelID == "" { + return + } + s.mu.Lock() + defer s.mu.Unlock() + queue := s.queues[channelID] + if queue != nil && queue.InFlight > 0 { + queue.InFlight-- + } + if s.inFlight > 0 { + s.inFlight-- + } +} + +func (s *FabricFlowScheduler) RecommendedParallelSendWindow(maxWindow int) int { + return s.RecommendedParallelSendWindowForTrafficClass("", maxWindow) +} + +func (s *FabricFlowScheduler) RecommendedParallelSendWindowForTrafficClass(trafficClass string, maxWindow int) int { + if maxWindow <= 1 { + return 1 + } + if s == nil { + return maxWindow + } + trafficClass = normalizeFabricTrafficClass(trafficClass) + s.mu.Lock() + defer s.mu.Unlock() + if maxWindow > s.adaptivePolicy.MaxParallelWindow && s.adaptivePolicy.MaxParallelWindow > 0 { + maxWindow = s.adaptivePolicy.MaxParallelWindow + } + global := s.parallelPressureLocked("") + classPressure := s.parallelPressureLocked(trafficClass) + if fabricTrafficClassPriority(trafficClass) <= fabricTrafficClassPriority(FabricTrafficClassInteractive) { + if classPressure.hasDrops { + return boundedParallelWindow(maxWindow - 1) + } + if classPressure.failing > 0 || classPressure.slow > 0 { + return boundedParallelWindow(maxWindow - 1) + } + return maxWindow + } + if trafficClass == FabricTrafficClassReliable { + if classPressure.hasDrops || classPressure.failing > 0 { + return classWindowLimit(s.adaptivePolicy, trafficClass, boundedParallelWindow(maxWindow/2)) + } + if global.hasDrops || global.failing+global.slow > 0 || global.highPressure { + return classWindowLimit(s.adaptivePolicy, trafficClass, boundedParallelWindow(maxWindow-1)) + } + return maxWindow + } + if classPressure.hasDrops { + return classWindowLimit(s.adaptivePolicy, trafficClass, boundedParallelWindow(maxWindow/2)) + } + if global.hasDrops { + return classWindowLimit(s.adaptivePolicy, trafficClass, boundedParallelWindow(maxWindow/2)) + } + if global.highPressure && global.interactiveOrControlQueues > 0 { + return classWindowLimit(s.adaptivePolicy, trafficClass, boundedParallelWindow(maxWindow/2)) + } + if global.highPressure { + return classWindowLimit(s.adaptivePolicy, trafficClass, boundedParallelWindow(maxWindow/2)) + } + if classPressure.failing >= maxWindow || classPressure.slow >= maxWindow { + return classWindowLimit(s.adaptivePolicy, trafficClass, 1) + } + if classPressure.failing+classPressure.slow > 0 { + return classWindowLimit(s.adaptivePolicy, trafficClass, boundedParallelWindow(maxWindow-1)) + } + return maxWindow +} + +type fabricParallelPressure struct { + hasDrops bool + failing int + slow int + highPressure bool + interactiveOrControlQueues int + bulkQueues int +} + +func (s *FabricFlowScheduler) parallelPressureLocked(trafficClass string) fabricParallelPressure { + out := fabricParallelPressure{} + if s == nil { + return out + } + trafficClass = strings.TrimSpace(trafficClass) + failing := 0 + slow := 0 + for _, queue := range s.queues { + if queue == nil { + continue + } + queueClass := normalizeFabricTrafficClass(queue.TrafficClass) + if queueClass == FabricTrafficClassControl || queueClass == FabricTrafficClassInteractive { + out.interactiveOrControlQueues++ + } + if queueClass == FabricTrafficClassBulk { + out.bulkQueues++ + } + if trafficClass != "" && queueClass != trafficClass { + continue + } + stats := queue.qualityWindowStats() + if stats.DropCount > 0 { + out.hasDrops = true + } + if stats.FailureCount > stats.SuccessCount || (stats.FailureCount > 0 && queue.DegradedFallbackRecommended) { + failing++ + } + if stats.SlowCount > 0 { + slow++ + } + policy := s.adaptivePolicy + if policy.QueuePressureHighWatermark <= 0 { + policy = defaultFabricServiceChannelAdaptivePolicy() + } + if queue.HighWatermark >= policy.QueuePressureHighWatermark || queue.MaxInFlight >= policy.QueuePressureMaxInFlight { + out.highPressure = true + } + } + policy := s.adaptivePolicy + if policy.QueuePressureHighWatermark <= 0 { + policy = defaultFabricServiceChannelAdaptivePolicy() + } + if s.highWatermark >= policy.QueuePressureHighWatermark || s.maxInFlight >= policy.QueuePressureMaxInFlight { + out.highPressure = true + } + if out.bulkQueues >= policy.BulkPressureChannelThreshold && out.interactiveOrControlQueues > 0 { + out.highPressure = true + } + out.failing = failing + out.slow = slow + return out +} + +func boundedParallelWindow(value int) int { + if value < 1 { + return 1 + } + return value +} + +func (s *FabricFlowScheduler) Snapshot() FabricFlowSchedulerSnapshot { + snapshot := FabricFlowSchedulerSnapshot{ + SchemaVersion: "rap.fabric_flow_scheduler.v1", + Enabled: s != nil, + ServiceNeutral: true, + Classifier: "ip_5tuple_or_packet_hash", + ServiceMode: "application_protocol_agnostic", + QueueDepths: map[string]int{}, + TrafficClassCounts: map[string]int{}, + RecommendedParallelWindows: map[string]int{}, + ChannelStats: map[string]FabricFlowStat{}, + } + if s == nil { + snapshot.ShardCount = defaultFabricFlowShardCount + snapshot.QueueCapacity = defaultFabricFlowQueueCapacity + return snapshot + } + s.mu.Lock() + defer s.mu.Unlock() + snapshot.ShardCount = s.shardCount + snapshot.QueueCapacity = s.queueCapacity + snapshot.AdaptivePolicyFingerprint = s.adaptivePolicy.Fingerprint + snapshot.ChannelCount = len(s.queues) + snapshot.Enqueued = s.enqueued + snapshot.Dequeued = s.dequeued + snapshot.Dropped = s.dropped + snapshot.HighWatermark = s.highWatermark + snapshot.InFlight = s.inFlight + snapshot.MaxInFlight = s.maxInFlight + for channelID, queue := range s.queues { + qualityStats := queue.qualityWindowStats() + snapshot.QueueDepths[channelID] = queue.Depth + trafficClass := normalizeFabricTrafficClass(queue.TrafficClass) + snapshot.TrafficClassCounts[trafficClass]++ + stat := FabricFlowStat{ + Depth: queue.Depth, + TrafficClass: trafficClass, + Enqueued: queue.Enqueued, + Dequeued: queue.Dequeued, + Dropped: queue.Dropped, + HighWatermark: queue.HighWatermark, + Served: queue.Served, + InFlight: queue.InFlight, + MaxInFlight: queue.MaxInFlight, + SendAttempts: queue.SendAttempts, + SendSuccesses: queue.SendSuccesses, + SendFailures: queue.SendFailures, + LastRouteID: queue.LastRouteID, + RoutePolicyVersion: queue.RoutePolicyVersion, + RouteGeneration: queue.RouteGeneration, + RecoveryPolicyFingerprint: queue.RecoveryPolicyFingerprint, + LastNextHop: queue.LastNextHop, + LastFailedRouteID: queue.LastFailedRouteID, + LastFailedRoutePolicyVersion: queue.LastFailedRoutePolicyVersion, + LastFailedRouteGeneration: queue.LastFailedRouteGeneration, + LastError: queue.LastError, + ConsecutiveFailures: queue.ConsecutiveFailures, + StallCount: queue.StallCount, + LastSendDurationMillis: queue.LastSendDurationMillis, + RouteRebuildRecommended: queue.RouteRebuildRecommended, + DegradedFallbackRecommended: queue.DegradedFallbackRecommended, + QualityPreferenceRouteID: queue.QualityPreferenceRouteID, + QualityPreferenceScore: queue.QualityPreferenceScore, + QualityPreferenceRawScore: queue.QualityPreferenceRawScore, + QualityPreferenceReasons: append([]string{}, queue.QualityPreferenceReasons...), + LatencyLe10Millis: queue.LatencyLe10Millis, + LatencyLe100Millis: queue.LatencyLe100Millis, + LatencyLe1000Millis: queue.LatencyLe1000Millis, + LatencyGt1000Millis: queue.LatencyGt1000Millis, + QualityWindowSampleCount: qualityStats.SampleCount, + QualityWindowSuccessCount: qualityStats.SuccessCount, + QualityWindowFailureCount: qualityStats.FailureCount, + QualityWindowSlowCount: qualityStats.SlowCount, + QualityWindowDropCount: qualityStats.DropCount, + QualityWindowAvgLatencyMs: qualityStats.AvgLatencyMs, + } + if !queue.LastServedAt.IsZero() { + stat.LastServedAt = queue.LastServedAt.UTC().Format(time.RFC3339Nano) + } + if !qualityStats.LastUpdatedAt.IsZero() { + stat.QualityWindowLastUpdatedAt = qualityStats.LastUpdatedAt.UTC().Format(time.RFC3339Nano) + } + snapshot.ChannelStats[channelID] = FabricFlowStat{ + Depth: stat.Depth, + TrafficClass: stat.TrafficClass, + Enqueued: stat.Enqueued, + Dequeued: stat.Dequeued, + Dropped: stat.Dropped, + HighWatermark: stat.HighWatermark, + Served: stat.Served, + InFlight: stat.InFlight, + MaxInFlight: stat.MaxInFlight, + SendAttempts: stat.SendAttempts, + SendSuccesses: stat.SendSuccesses, + SendFailures: stat.SendFailures, + LastServedAt: stat.LastServedAt, + LastRouteID: stat.LastRouteID, + LastNextHop: stat.LastNextHop, + RoutePolicyVersion: stat.RoutePolicyVersion, + RouteGeneration: stat.RouteGeneration, + RecoveryPolicyFingerprint: stat.RecoveryPolicyFingerprint, + LastFailedRouteID: stat.LastFailedRouteID, + LastFailedRoutePolicyVersion: stat.LastFailedRoutePolicyVersion, + LastFailedRouteGeneration: stat.LastFailedRouteGeneration, + LastError: stat.LastError, + ConsecutiveFailures: stat.ConsecutiveFailures, + StallCount: stat.StallCount, + LastSendDurationMillis: stat.LastSendDurationMillis, + RouteRebuildRecommended: stat.RouteRebuildRecommended, + DegradedFallbackRecommended: stat.DegradedFallbackRecommended, + QualityPreferenceRouteID: stat.QualityPreferenceRouteID, + QualityPreferenceScore: stat.QualityPreferenceScore, + QualityPreferenceRawScore: stat.QualityPreferenceRawScore, + QualityPreferenceReasons: append([]string{}, stat.QualityPreferenceReasons...), + LatencyLe10Millis: stat.LatencyLe10Millis, + LatencyLe100Millis: stat.LatencyLe100Millis, + LatencyLe1000Millis: stat.LatencyLe1000Millis, + LatencyGt1000Millis: stat.LatencyGt1000Millis, + QualityWindowSampleCount: stat.QualityWindowSampleCount, + QualityWindowSuccessCount: stat.QualityWindowSuccessCount, + QualityWindowFailureCount: stat.QualityWindowFailureCount, + QualityWindowSlowCount: stat.QualityWindowSlowCount, + QualityWindowDropCount: stat.QualityWindowDropCount, + QualityWindowAvgLatencyMs: stat.QualityWindowAvgLatencyMs, + QualityWindowLastUpdatedAt: stat.QualityWindowLastUpdatedAt, + } + snapshot.QualityWindowSampleCount += qualityStats.SampleCount + snapshot.QualityWindowFailureCount += qualityStats.FailureCount + snapshot.QualityWindowSlowCount += qualityStats.SlowCount + snapshot.QualityWindowDropCount += qualityStats.DropCount + if queue.Depth >= s.queueCapacity || qualityStats.DropCount > 0 { + snapshot.BackpressureActive = true + } + if (queue.RouteRebuildRecommended || queue.DegradedFallbackRecommended) && qualityStats.FailureCount > 0 { + snapshot.BackpressureActive = true + } + if qualityStats.SlowCount > 0 { + snapshot.SlowChannelCount++ + } + if qualityStats.FailureCount > qualityStats.SuccessCount || (qualityStats.FailureCount > 0 && queue.DegradedFallbackRecommended) { + snapshot.FailingChannelCount++ + } + } + if snapshot.QualityWindowDropCount > 0 { + snapshot.BackpressureActive = true + } + for _, trafficClass := range []string{FabricTrafficClassControl, FabricTrafficClassInteractive, FabricTrafficClassReliable, FabricTrafficClassBulk, FabricTrafficClassDroppable} { + snapshot.RecommendedParallelWindows[trafficClass] = s.recommendedParallelSendWindowForTrafficClassLocked(trafficClass, s.adaptivePolicy.MaxParallelWindow) + } + if len(snapshot.RecommendedParallelWindows) > 0 { + bulkWindow := snapshot.RecommendedParallelWindows[FabricTrafficClassBulk] + interactiveWindow := snapshot.RecommendedParallelWindows[FabricTrafficClassInteractive] + if bulkWindow > 0 && interactiveWindow > 0 && bulkWindow < interactiveWindow { + snapshot.AdaptiveBackpressureActive = true + snapshot.AdaptiveBackpressureReason = "bulk_window_reduced_to_protect_interactive" + } + } + return snapshot +} + +func (s *FabricFlowScheduler) recommendedParallelSendWindowForTrafficClassLocked(trafficClass string, maxWindow int) int { + if maxWindow <= 1 { + return 1 + } + trafficClass = normalizeFabricTrafficClass(trafficClass) + if s == nil { + return maxWindow + } + if maxWindow > s.adaptivePolicy.MaxParallelWindow && s.adaptivePolicy.MaxParallelWindow > 0 { + maxWindow = s.adaptivePolicy.MaxParallelWindow + } + // The public method cannot be called here because Snapshot already holds the + // scheduler mutex. Keep this wrapper intentionally small by mirroring the + // public policy on already-locked state. + globalPressure := s.parallelPressureLocked("") + classPressure := s.parallelPressureLocked(trafficClass) + if fabricTrafficClassPriority(trafficClass) <= fabricTrafficClassPriority(FabricTrafficClassInteractive) { + if classPressure.hasDrops || classPressure.failing > 0 || classPressure.slow > 0 { + return boundedParallelWindow(maxWindow - 1) + } + return maxWindow + } + if trafficClass == FabricTrafficClassReliable { + if classPressure.hasDrops || classPressure.failing > 0 { + return classWindowLimit(s.adaptivePolicy, trafficClass, boundedParallelWindow(maxWindow/2)) + } + if globalPressure.hasDrops || globalPressure.failing+globalPressure.slow > 0 || globalPressure.highPressure { + return classWindowLimit(s.adaptivePolicy, trafficClass, boundedParallelWindow(maxWindow-1)) + } + return maxWindow + } + if classPressure.hasDrops || globalPressure.hasDrops { + return classWindowLimit(s.adaptivePolicy, trafficClass, boundedParallelWindow(maxWindow/2)) + } + if globalPressure.highPressure && globalPressure.interactiveOrControlQueues > 0 { + return classWindowLimit(s.adaptivePolicy, trafficClass, boundedParallelWindow(maxWindow/2)) + } + if globalPressure.highPressure { + return classWindowLimit(s.adaptivePolicy, trafficClass, boundedParallelWindow(maxWindow/2)) + } + if classPressure.failing >= maxWindow || classPressure.slow >= maxWindow { + return classWindowLimit(s.adaptivePolicy, trafficClass, 1) + } + if classPressure.failing+classPressure.slow > 0 { + return classWindowLimit(s.adaptivePolicy, trafficClass, boundedParallelWindow(maxWindow-1)) + } + return maxWindow +} + +func classWindowLimit(policy FabricServiceChannelAdaptivePolicy, trafficClass string, maxWindow int) int { + if policy.ClassWindows != nil { + if value := policy.ClassWindows[normalizeFabricTrafficClass(trafficClass)]; value > 0 && value < maxWindow { + return value + } + } + return maxWindow +} + +func clampFabricWindow(value, minValue, maxValue int) int { + if value < minValue { + return minValue + } + if value > maxValue { + return maxValue + } + return value +} + +func minPositive(a, b int) int { + if a <= 0 { + return b + } + if b <= 0 || a < b { + return a + } + return b +} + +func (s *FabricFlowScheduler) Dropped() uint64 { + if s == nil { + return 0 + } + s.mu.Lock() + defer s.mu.Unlock() + return s.dropped +} + +func (s *FabricFlowScheduler) shardCountValue() int { + if s == nil || s.shardCount <= 0 { + return defaultFabricFlowShardCount + } + return s.shardCount +} + +func (s *FabricFlowScheduler) enqueue(channelID string, trafficClass string) (int, bool) { + if s == nil { + return 0, false + } + s.mu.Lock() + defer s.mu.Unlock() + queue := s.ensureQueueLocked(channelID) + if queue.TrafficClass == "" { + queue.TrafficClass = normalizeFabricTrafficClass(trafficClass) + } + if queue.Depth >= s.queueCapacity { + queue.Dropped++ + s.dropped++ + queue.recordQualitySample(fabricFlowQualitySample{ + At: time.Now().UTC(), + Dropped: true, + }) + return queue.Depth, true + } + queue.Depth++ + queue.Enqueued++ + s.enqueued++ + if queue.Depth > queue.HighWatermark { + queue.HighWatermark = queue.Depth + } + if queue.Depth > s.highWatermark { + s.highWatermark = queue.Depth + } + return queue.Depth, false +} + +func (s *FabricFlowScheduler) dequeue(channelID string, count int) { + if s == nil || count <= 0 { + return + } + s.mu.Lock() + defer s.mu.Unlock() + queue := s.queues[channelID] + if queue == nil { + return + } + if count > queue.Depth { + count = queue.Depth + } + queue.Depth -= count + queue.Dequeued += uint64(count) + queue.Served++ + queue.LastServedAt = time.Now().UTC() + s.dequeued += uint64(count) +} + +func (s *FabricFlowScheduler) sortScheduledBatches(batches []FabricScheduledPacketBatch) { + if len(batches) < 2 { + return + } + s.mu.Lock() + defer s.mu.Unlock() + sort.Slice(batches, func(a, b int) bool { + leftPriority := fabricTrafficClassPriority(batches[a].TrafficClass) + rightPriority := fabricTrafficClassPriority(batches[b].TrafficClass) + if leftPriority != rightPriority { + return leftPriority < rightPriority + } + left := s.queues[batches[a].ChannelID] + right := s.queues[batches[b].ChannelID] + leftStalled := left != nil && (left.RouteRebuildRecommended || left.DegradedFallbackRecommended) + rightStalled := right != nil && (right.RouteRebuildRecommended || right.DegradedFallbackRecommended) + if leftStalled != rightStalled { + return !leftStalled + } + leftServed := uint64(0) + rightServed := uint64(0) + if left != nil { + leftServed = left.Served + } + if right != nil { + rightServed = right.Served + } + if leftServed != rightServed { + return leftServed < rightServed + } + leftServedAt := time.Time{} + rightServedAt := time.Time{} + if left != nil { + leftServedAt = left.LastServedAt + } + if right != nil { + rightServedAt = right.LastServedAt + } + if !leftServedAt.Equal(rightServedAt) { + if leftServedAt.IsZero() { + return true + } + if rightServedAt.IsZero() { + return false + } + return leftServedAt.Before(rightServedAt) + } + return batches[a].ChannelID < batches[b].ChannelID + }) +} + +func normalizeFabricTrafficClass(value string) string { + switch strings.TrimSpace(strings.ToLower(value)) { + case FabricTrafficClassControl: + return FabricTrafficClassControl + case FabricTrafficClassInteractive: + return FabricTrafficClassInteractive + case FabricTrafficClassReliable: + return FabricTrafficClassReliable + case FabricTrafficClassDroppable: + return FabricTrafficClassDroppable + case FabricTrafficClassBulk: + return FabricTrafficClassBulk + default: + return FabricTrafficClassBulk + } +} + +func fabricTrafficClassPriority(value string) int { + switch normalizeFabricTrafficClass(value) { + case FabricTrafficClassControl: + return 0 + case FabricTrafficClassInteractive: + return 1 + case FabricTrafficClassReliable: + return 2 + case FabricTrafficClassBulk: + return 3 + case FabricTrafficClassDroppable: + return 4 + default: + return 3 + } +} + +func (s *FabricFlowScheduler) RoutePreference(channelID string) (preferredRouteID string, avoidRouteID string) { + if s == nil || channelID == "" { + return "", "" + } + s.mu.Lock() + defer s.mu.Unlock() + queue := s.queues[channelID] + if queue == nil { + return "", "" + } + if queue.ConsecutiveFailures == 0 { + preferredRouteID = queue.LastRouteID + } + return preferredRouteID, queue.LastFailedRouteID +} + +func (s *FabricFlowScheduler) RecordRouteSuccess(channelID string, routeID string, nextHop string, duration time.Duration, preferences ...FabricServiceChannelRouteQualityPreference) { + s.RecordRouteSuccessWithProvenance(channelID, routeID, nextHop, duration, fabricFlowRouteProvenance{}, preferences...) +} + +func (s *FabricFlowScheduler) RecordRouteSuccessWithProvenance(channelID string, routeID string, nextHop string, duration time.Duration, provenance fabricFlowRouteProvenance, preferences ...FabricServiceChannelRouteQualityPreference) { + if s == nil || channelID == "" { + return + } + s.mu.Lock() + defer s.mu.Unlock() + queue := s.ensureQueueLocked(channelID) + queue.LastRouteID = routeID + queue.RoutePolicyVersion = strings.TrimSpace(provenance.PolicyVersion) + queue.RouteGeneration = strings.TrimSpace(provenance.Generation) + queue.RecoveryPolicyFingerprint = strings.TrimSpace(provenance.RecoveryPolicyFingerprint) + queue.LastNextHop = nextHop + queue.LastFailedRouteID = "" + queue.LastFailedRoutePolicyVersion = "" + queue.LastFailedRouteGeneration = "" + queue.LastError = "" + queue.ConsecutiveFailures = 0 + queue.LastSendDurationMillis = fabricSendDurationMillis(duration) + queue.SendSuccesses++ + queue.recordLatency(duration) + queue.recordQualitySample(fabricFlowQualitySample{ + At: time.Now().UTC(), + Success: true, + Slow: duration > defaultFabricFlowSlowSendThreshold, + DurationMillis: queue.LastSendDurationMillis, + }) + queue.recordQualityPreference(preferences...) + if duration > defaultFabricFlowSlowSendThreshold { + queue.StallCount++ + queue.RouteRebuildRecommended = true + } else { + queue.RouteRebuildRecommended = false + queue.DegradedFallbackRecommended = false + } +} + +func (s *FabricFlowScheduler) RecordRouteFailure(channelID string, routeID string, nextHop string, err error, duration time.Duration) { + s.RecordRouteFailureWithProvenance(channelID, routeID, nextHop, err, duration, fabricFlowRouteProvenance{}) +} + +func (s *FabricFlowScheduler) RecordRouteFailureWithProvenance(channelID string, routeID string, nextHop string, err error, duration time.Duration, provenance fabricFlowRouteProvenance) { + if s == nil || channelID == "" { + return + } + s.mu.Lock() + defer s.mu.Unlock() + queue := s.ensureQueueLocked(channelID) + queue.LastFailedRouteID = routeID + queue.LastFailedRoutePolicyVersion = strings.TrimSpace(provenance.PolicyVersion) + queue.LastFailedRouteGeneration = strings.TrimSpace(provenance.Generation) + if fp := strings.TrimSpace(provenance.RecoveryPolicyFingerprint); fp != "" { + queue.RecoveryPolicyFingerprint = fp + } + queue.LastNextHop = nextHop + queue.ConsecutiveFailures++ + queue.StallCount++ + queue.LastSendDurationMillis = fabricSendDurationMillis(duration) + queue.SendFailures++ + queue.recordLatency(duration) + queue.recordQualitySample(fabricFlowQualitySample{ + At: time.Now().UTC(), + Failure: true, + Slow: duration > defaultFabricFlowSlowSendThreshold, + DurationMillis: queue.LastSendDurationMillis, + }) + if err != nil { + queue.LastError = err.Error() + } + queue.RouteRebuildRecommended = true + if queue.ConsecutiveFailures >= defaultFabricFlowFailureThreshold { + queue.DegradedFallbackRecommended = true + } + queue.clearQualityPreference() +} + +func (s *FabricFlowScheduler) RecordLocalFallback(channelID string) { + if s == nil || channelID == "" { + return + } + s.mu.Lock() + defer s.mu.Unlock() + queue := s.ensureQueueLocked(channelID) + queue.LastRouteID = "local_gateway" + queue.RoutePolicyVersion = "" + queue.RouteGeneration = "" + queue.RecoveryPolicyFingerprint = "" + queue.LastNextHop = "local_gateway" + queue.LastFailedRouteID = "" + queue.LastFailedRoutePolicyVersion = "" + queue.LastFailedRouteGeneration = "" + queue.LastError = "" + queue.ConsecutiveFailures = 0 + queue.RouteRebuildRecommended = false + queue.DegradedFallbackRecommended = false + queue.clearQualityPreference() +} + +func (s *FabricFlowScheduler) ClearQualityPreferencesNotIn(validRouteIDs map[string]struct{}) { + if s == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + for _, queue := range s.queues { + if queue == nil || queue.QualityPreferenceRouteID == "" { + continue + } + if _, ok := validRouteIDs[queue.QualityPreferenceRouteID]; !ok { + queue.clearQualityPreference() + } + } +} + +func (s *FabricFlowScheduler) ClearQualityPreferencesForRoutes(routeIDs map[string]struct{}) { + if s == nil || len(routeIDs) == 0 { + return + } + s.mu.Lock() + defer s.mu.Unlock() + for _, queue := range s.queues { + if queue == nil || queue.QualityPreferenceRouteID == "" { + continue + } + if _, ok := routeIDs[queue.QualityPreferenceRouteID]; ok { + queue.clearQualityPreference() + } + } +} + +func (q *fabricFlowQueue) recordLatency(duration time.Duration) { + if q == nil { + return + } + millis := fabricSendDurationMillis(duration) + switch { + case millis <= 10: + q.LatencyLe10Millis++ + case millis <= 100: + q.LatencyLe100Millis++ + case millis <= 1000: + q.LatencyLe1000Millis++ + default: + q.LatencyGt1000Millis++ + } +} + +func (q *fabricFlowQueue) recordQualitySample(sample fabricFlowQualitySample) { + if q == nil { + return + } + if sample.At.IsZero() { + sample.At = time.Now().UTC() + } + q.QualityWindow = append(q.QualityWindow, sample) + if len(q.QualityWindow) > defaultFabricFlowQualityWindowCapacity { + keepFrom := len(q.QualityWindow) - defaultFabricFlowQualityWindowCapacity + copy(q.QualityWindow, q.QualityWindow[keepFrom:]) + q.QualityWindow = q.QualityWindow[:defaultFabricFlowQualityWindowCapacity] + } +} + +func (q *fabricFlowQueue) qualityWindowStats() fabricFlowQualityWindowStats { + stats := fabricFlowQualityWindowStats{} + if q == nil || len(q.QualityWindow) == 0 { + return stats + } + var latencyTotal int64 + var latencySamples int64 + for _, sample := range q.QualityWindow { + stats.SampleCount++ + if sample.Success { + stats.SuccessCount++ + } + if sample.Failure { + stats.FailureCount++ + } + if sample.Slow { + stats.SlowCount++ + } + if sample.Dropped { + stats.DropCount++ + } + if sample.DurationMillis > 0 { + latencyTotal += sample.DurationMillis + latencySamples++ + } + if sample.At.After(stats.LastUpdatedAt) { + stats.LastUpdatedAt = sample.At + } + } + if latencySamples > 0 { + stats.AvgLatencyMs = latencyTotal / latencySamples + } + return stats +} + +func (q *fabricFlowQueue) recordQualityPreference(preferences ...FabricServiceChannelRouteQualityPreference) { + if q == nil { + return + } + if len(preferences) == 0 || strings.TrimSpace(preferences[0].RouteID) == "" || preferences[0].ScoreAdjustment <= 0 { + q.clearQualityPreference() + return + } + preference := preferences[0] + q.QualityPreferenceRouteID = preference.RouteID + q.QualityPreferenceScore = preference.ScoreAdjustment + q.QualityPreferenceRawScore = preference.RawScoreAdjustment + if q.QualityPreferenceRawScore <= 0 { + q.QualityPreferenceRawScore = preference.ScoreAdjustment + } + q.QualityPreferenceReasons = dedupeStrings(preference.Reasons) +} + +func (q *fabricFlowQueue) clearQualityPreference() { + if q == nil { + return + } + q.QualityPreferenceRouteID = "" + q.QualityPreferenceScore = 0 + q.QualityPreferenceRawScore = 0 + q.QualityPreferenceReasons = nil +} + +func fabricSendDurationMillis(duration time.Duration) int64 { + if duration <= 0 { + return 0 + } + millis := duration.Milliseconds() + if millis == 0 { + return 1 + } + return millis +} + +func (s *FabricFlowScheduler) ensureQueueLocked(channelID string) *fabricFlowQueue { + if s.queues == nil { + s.queues = map[string]*fabricFlowQueue{} + } + queue := s.queues[channelID] + if queue == nil { + queue = &fabricFlowQueue{} + s.queues[channelID] = queue + } + return queue +} + +type LocalPacketTransport struct { + Inbox *FabricPacketInbox + VPNConnectionID string +} + +type AdaptivePacketTransport struct { + Primary PacketTransport + Fallback PacketTransport + PrimaryTimeout time.Duration + lastReceive atomic.Int32 +} + +const ( + adaptiveTransportPrimary int32 = iota + adaptiveTransportFallback +) + +func (t *LocalPacketTransport) SendGatewayPacketBatch(_ context.Context, packets [][]byte) error { + packets = cleanPacketBatch(packets) + if len(packets) == 0 { + return nil + } + if t == nil || t.Inbox == nil { + return mesh.ErrForwardRuntimeUnavailable + } + return t.Inbox.DeliverLocalPacketBatch(t.VPNConnectionID, FabricDirectionGatewayToClient, packets) +} + +func (t *LocalPacketTransport) ReceiveGatewayPacketBatch(ctx context.Context, timeout time.Duration) ([][]byte, error) { + if t == nil || t.Inbox == nil { + return nil, mesh.ErrForwardRuntimeUnavailable + } + return t.Inbox.Receive(ctx, t.VPNConnectionID, FabricDirectionClientToGateway, timeout) +} + +func (t *AdaptivePacketTransport) SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error { + packets = cleanPacketBatch(packets) + if len(packets) == 0 { + return nil + } + if t == nil || (t.Primary == nil && t.Fallback == nil) { + return mesh.ErrForwardRuntimeUnavailable + } + preferred, alternate := t.preferredSendOrder() + if preferred != nil { + if err := preferred.SendGatewayPacketBatch(ctx, packets); err == nil { + return nil + } else if alternate == nil { + return err + } + } + if alternate != nil { + return alternate.SendGatewayPacketBatch(ctx, packets) + } + return mesh.ErrForwardRuntimeUnavailable +} + +func (t *AdaptivePacketTransport) ReceiveGatewayPacketBatch(ctx context.Context, timeout time.Duration) ([][]byte, error) { + if t == nil || (t.Primary == nil && t.Fallback == nil) { + return nil, mesh.ErrForwardRuntimeUnavailable + } + if t.Primary == nil { + t.lastReceive.Store(adaptiveTransportFallback) + return t.Fallback.ReceiveGatewayPacketBatch(ctx, timeout) + } + if t.Fallback == nil { + t.lastReceive.Store(adaptiveTransportPrimary) + return t.Primary.ReceiveGatewayPacketBatch(ctx, timeout) + } + primaryTimeout := t.PrimaryTimeout + if primaryTimeout <= 0 { + primaryTimeout = 50 * time.Millisecond + } + if timeout > 0 && primaryTimeout > timeout { + primaryTimeout = timeout + } + packets, err := t.Primary.ReceiveGatewayPacketBatch(ctx, primaryTimeout) + if err == nil && len(packets) > 0 { + t.lastReceive.Store(adaptiveTransportPrimary) + return packets, nil + } + if err != nil && timeout <= 0 { + return nil, err + } + fallbackTimeout := timeout + if timeout > primaryTimeout { + fallbackTimeout = timeout - primaryTimeout + } + packets, fallbackErr := t.Fallback.ReceiveGatewayPacketBatch(ctx, fallbackTimeout) + if fallbackErr == nil && len(packets) > 0 { + t.lastReceive.Store(adaptiveTransportFallback) + } + if fallbackErr != nil { + return nil, fallbackErr + } + return packets, nil +} + +func (t *AdaptivePacketTransport) preferredSendOrder() (PacketTransport, PacketTransport) { + if t == nil { + return nil, nil + } + if t.lastReceive.Load() == adaptiveTransportFallback { + return t.Fallback, t.Primary + } + return t.Primary, t.Fallback +} + +func (t *FabricPacketTransport) SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error { + packets = cleanPacketBatch(packets) + if len(packets) == 0 { + return nil + } + if t == nil || t.ForwardTransport == nil { + return mesh.ErrForwardRuntimeUnavailable + } + if t.ClusterID == "" || t.VPNConnectionID == "" || t.RouteID == "" || t.LocalNodeID == "" || t.RemoteNodeID == "" { + return errors.New("fabric packet transport route identity is incomplete") + } + nextHop := t.NextHopNodeID + if nextHop == "" { + nextHop = t.RemoteNodeID + } + envelopeCurrentHop := nextHop + envelopeNextHop := nextHopAfter(t.RoutePath, envelopeCurrentHop, t.RemoteNodeID) + direction := t.SendDirection + if direction == "" { + direction = FabricDirectionGatewayToClient + } + envelope, err := mesh.NewProductionVPNPacketBatchEnvelope(mesh.ProductionVPNPacketEnvelopeInput{ + RouteID: t.RouteID, + ClusterID: t.ClusterID, + SourceNodeID: t.LocalNodeID, + DestinationNodeID: t.RemoteNodeID, + CurrentHopNodeID: envelopeCurrentHop, + NextHopNodeID: envelopeNextHop, + RoutePath: t.RoutePath, + VPNConnectionID: t.VPNConnectionID, + Direction: direction, + Packets: packets, + }) + if err != nil { + return err + } + _, err = t.ForwardTransport.SendProduction(ctx, nextHop, envelope) + return err +} + +func (i *FabricClientPacketIngress) SendClientPacketBatch(ctx context.Context, clusterID string, vpnConnectionID string, packets [][]byte) error { + return i.SendClientPacketBatchWithTrafficClass(ctx, clusterID, vpnConnectionID, "", packets) +} + +func (i *FabricClientPacketIngress) SendClientPacketBatchWithTrafficClass(ctx context.Context, clusterID string, vpnConnectionID string, trafficClass string, packets [][]byte) error { + packets = cleanPacketBatch(packets) + if len(packets) == 0 { + return nil + } + if i == nil { + return mesh.ErrForwardRuntimeUnavailable + } + i.recordSendBatch(len(packets)) + scheduler := i.flowScheduler() + droppedBefore := scheduler.Dropped() + scheduled := scheduler.ScheduleClientPacketsForConnectionClass(vpnConnectionID, trafficClass, packets) + droppedAfter := scheduler.Dropped() + if droppedAfter > droppedBefore { + i.recordFlowDropped(droppedAfter - droppedBefore) + } + if len(scheduled) == 0 { + i.recordError(mesh.ErrSyntheticRelayQueueFull) + return mesh.ErrSyntheticRelayQueueFull + } + maxParallel := scheduler.RecommendedParallelSendWindowForTrafficClass(trafficClass, i.maxParallelFlowSends()) + if maxParallel > 1 && len(scheduled) > 1 { + return i.sendScheduledClientPacketBatchesParallel(ctx, clusterID, vpnConnectionID, scheduled, maxParallel) + } + for _, batch := range scheduled { + i.recordFlowBatch(len(batch.Packets)) + if err := i.sendScheduledClientPacketBatch(ctx, clusterID, vpnConnectionID, batch); err != nil { + return err + } + } + return nil +} + +func (i *FabricClientPacketIngress) sendScheduledClientPacketBatchesParallel(ctx context.Context, clusterID string, vpnConnectionID string, scheduled []FabricScheduledPacketBatch, maxParallel int) error { + if maxParallel <= 1 || len(scheduled) <= 1 { + for _, batch := range scheduled { + i.recordFlowBatch(len(batch.Packets)) + if err := i.sendScheduledClientPacketBatch(ctx, clusterID, vpnConnectionID, batch); err != nil { + return err + } + } + return nil + } + if maxParallel > len(scheduled) { + maxParallel = len(scheduled) + } + i.recordFlowParallel() + sem := make(chan struct{}, maxParallel) + var wg sync.WaitGroup + var errMu sync.Mutex + var firstErr error + for _, scheduledBatch := range scheduled { + batch := scheduledBatch + i.recordFlowBatch(len(batch.Packets)) + sem <- struct{}{} + wg.Add(1) + go func() { + defer wg.Done() + defer func() { <-sem }() + if err := i.sendScheduledClientPacketBatch(ctx, clusterID, vpnConnectionID, batch); err != nil { + errMu.Lock() + if firstErr == nil { + firstErr = err + } + errMu.Unlock() + } + }() + } + wg.Wait() + return firstErr +} + +func (i *FabricClientPacketIngress) PreferClientRoute(routeID string) { + routeID = strings.TrimSpace(routeID) + if i == nil || routeID == "" { + return + } + i.mu.Lock() + defer i.mu.Unlock() + i.lastSelectedRouteID = routeID + i.lastSelectedNextHop = "" +} + +func (i *FabricClientPacketIngress) sendScheduledClientPacketBatch(ctx context.Context, clusterID string, vpnConnectionID string, batch FabricScheduledPacketBatch) error { + scheduler := i.flowScheduler() + scheduler.BeginSend(batch.ChannelID) + defer scheduler.EndSend(batch.ChannelID) + defer scheduler.Complete(batch) + packets := cleanPacketBatch(batch.Packets) + if len(packets) == 0 { + return nil + } + candidates := i.routeCandidatesForChannel(clusterID, batch.ChannelID) + if len(candidates) == 0 && i.localGatewayReady(vpnConnectionID) { + if err := i.inbox().DeliverLocalPacketBatch(vpnConnectionID, FabricDirectionClientToGateway, packets); err != nil { + i.recordError(err) + return err + } + scheduler.RecordLocalFallback(batch.ChannelID) + i.recordLocalFallback() + return nil + } + if len(candidates) == 0 { + i.recordError(mesh.ErrRouteNotFound) + return mesh.ErrRouteNotFound + } + transport := i.forwardTransport() + if transport == nil { + i.recordError(mesh.ErrForwardRuntimeUnavailable) + return mesh.ErrForwardRuntimeUnavailable + } + var lastErr error + for _, candidate := range candidates { + startedAt := time.Now() + i.recordRouteAttempt() + envelopeCurrentHop := candidate.NextHop + envelopeNextHop := nextHopAfter(candidate.Route.Hops, envelopeCurrentHop, candidate.Route.DestinationNodeID) + envelope, err := mesh.NewProductionVPNPacketBatchEnvelope(mesh.ProductionVPNPacketEnvelopeInput{ + RouteID: candidate.Route.RouteID, + ClusterID: candidate.Route.ClusterID, + SourceNodeID: candidate.Route.SourceNodeID, + DestinationNodeID: candidate.Route.DestinationNodeID, + CurrentHopNodeID: envelopeCurrentHop, + NextHopNodeID: envelopeNextHop, + RoutePath: candidate.Route.Hops, + TTL: candidate.Route.MaxTTL, + ExpiresAt: candidate.Route.ExpiresAt, + VPNConnectionID: vpnConnectionID, + Direction: FabricDirectionClientToGateway, + Packets: packets, + }) + if err != nil { + lastErr = err + scheduler.RecordRouteFailureWithProvenance(batch.ChannelID, candidate.Route.RouteID, candidate.NextHop, err, time.Since(startedAt), i.routeProvenanceFor(candidate.Route)) + i.recordRouteFailure(err) + continue + } + if _, err = transport.SendProduction(ctx, candidate.NextHop, envelope); err != nil { + lastErr = err + scheduler.RecordRouteFailureWithProvenance(batch.ChannelID, candidate.Route.RouteID, candidate.NextHop, err, time.Since(startedAt), i.routeProvenanceFor(candidate.Route)) + i.recordRouteFailure(err) + continue + } + preference, _ := i.routeQualityPreference(candidate.Route.RouteID) + scheduler.RecordRouteSuccessWithProvenance(batch.ChannelID, candidate.Route.RouteID, candidate.NextHop, time.Since(startedAt), i.routeProvenanceFor(candidate.Route), preference) + i.recordRouteSuccess(candidate.Route.RouteID, candidate.NextHop) + return nil + } + if i.localGatewayReady(vpnConnectionID) { + if err := i.inbox().DeliverLocalPacketBatch(vpnConnectionID, FabricDirectionClientToGateway, packets); err != nil { + i.recordError(err) + return err + } + scheduler.RecordLocalFallback(batch.ChannelID) + i.recordLocalFallback() + return nil + } + if lastErr == nil { + lastErr = mesh.ErrRouteNotFound + } + i.recordError(lastErr) + return lastErr +} + +func (i *FabricClientPacketIngress) ReceiveClientPacketBatch(ctx context.Context, clusterID string, vpnConnectionID string, timeout time.Duration) ([][]byte, error) { + inbox := i.inbox() + if i == nil || inbox == nil { + return nil, mesh.ErrForwardRuntimeUnavailable + } + if _, _, ok := i.selectRoute(clusterID); !ok { + if !i.localGatewayReady(vpnConnectionID) { + i.recordReceiveEmpty() + return nil, mesh.ErrRouteNotFound + } + } + packets, err := inbox.Receive(ctx, vpnConnectionID, FabricDirectionGatewayToClient, timeout) + if err != nil { + i.recordError(err) + return nil, err + } + if len(packets) == 0 { + i.recordReceiveEmpty() + return nil, nil + } + i.recordReceiveBatch(len(packets)) + return packets, nil +} + +func (i *FabricClientPacketIngress) localGatewayReady(vpnConnectionID string) bool { + if i == nil || i.inbox() == nil || vpnConnectionID == "" { + return false + } + localGateway := i.localGateway() + return localGateway != nil && localGateway(vpnConnectionID) +} + +func (i *FabricClientPacketIngress) selectRoute(clusterID string) (mesh.SyntheticRoute, string, bool) { + candidates := i.routeCandidates(clusterID) + if len(candidates) == 0 { + return mesh.SyntheticRoute{}, "", false + } + return candidates[0].Route, candidates[0].NextHop, true +} + +type fabricClientRouteCandidate struct { + Route mesh.SyntheticRoute + NextHop string +} + +func (i *FabricClientPacketIngress) routeCandidates(clusterID string) []fabricClientRouteCandidate { + return i.routeCandidatesWithPreference(clusterID, i.lastRouteID(), "") +} + +func (i *FabricClientPacketIngress) routeCandidatesForChannel(clusterID string, channelID string) []fabricClientRouteCandidate { + preferredRouteID, avoidRouteID := i.flowScheduler().RoutePreference(channelID) + if preferredRouteID == "" && avoidRouteID == "" { + preferredRouteID = i.lastRouteID() + } + return i.routeCandidatesWithPreference(clusterID, preferredRouteID, avoidRouteID) +} + +func (i *FabricClientPacketIngress) routeCandidatesWithPreference(clusterID string, preferredRouteID string, avoidRouteID string) []fabricClientRouteCandidate { + routesFunc := i.routesFunc() + if i == nil || routesFunc == nil { + return nil + } + if clusterID == "" { + clusterID = i.ClusterID + } + now := time.Now().UTC() + var preferred []fabricClientRouteCandidate + var alternates []fabricClientRouteCandidate + var deferred []fabricClientRouteCandidate + manager := i.routeManager() + if preferredRouteID != "" && manager.isWithdrawn(preferredRouteID) { + if replacementRouteID := manager.replacementRouteID(preferredRouteID); replacementRouteID != "" { + preferredRouteID = replacementRouteID + } else { + if avoidRouteID == "" { + avoidRouteID = preferredRouteID + } + preferredRouteID = "" + } + } + for _, route := range routesFunc() { + if route.ClusterID != clusterID || route.SourceNodeID != i.LocalNodeID || !containsString(route.AllowedChannels, mesh.ProductionChannelVPNPacket) { + continue + } + if manager.isWithdrawn(route.RouteID) { + continue + } + if !route.ExpiresAt.IsZero() && !route.ExpiresAt.After(now) { + continue + } + nextHop := nextHopAfter(route.Hops, i.LocalNodeID, route.DestinationNodeID) + if nextHop == "" || nextHop == i.LocalNodeID { + continue + } + candidate := fabricClientRouteCandidate{Route: route, NextHop: nextHop} + if preferredRouteID != "" && route.RouteID == preferredRouteID { + preferred = append(preferred, candidate) + } else if avoidRouteID != "" && route.RouteID == avoidRouteID { + deferred = append(deferred, candidate) + } else { + alternates = append(alternates, candidate) + } + } + out := append(preferred, alternates...) + out = i.applyRouteQualityPreferences(out, preferredRouteID) + return append(out, deferred...) +} + +func (i *FabricClientPacketIngress) applyRouteQualityPreferences(candidates []fabricClientRouteCandidate, preferredRouteID string) []fabricClientRouteCandidate { + if len(candidates) < 2 { + return candidates + } + preferences := i.routeQualityPreferences() + if len(preferences) == 0 { + return candidates + } + preferredScore := 0 + if preferredRouteID != "" { + if preference, ok := preferences[preferredRouteID]; ok { + preferredScore = preference.ScoreAdjustment + } + } + bestIndex := -1 + bestScore := 0 + for index, candidate := range candidates { + preference, ok := preferences[candidate.Route.RouteID] + if !ok || preference.ScoreAdjustment <= 0 { + continue + } + if bestIndex == -1 || preference.ScoreAdjustment > bestScore { + bestIndex = index + bestScore = preference.ScoreAdjustment + } + } + if bestIndex <= 0 || bestScore < preferredScore+defaultFabricRouteQualitySwitchThreshold { + return candidates + } + out := make([]fabricClientRouteCandidate, 0, len(candidates)) + out = append(out, candidates[bestIndex]) + out = append(out, candidates[:bestIndex]...) + out = append(out, candidates[bestIndex+1:]...) + return out +} + +func (t *FabricPacketTransport) ReceiveGatewayPacketBatch(ctx context.Context, timeout time.Duration) ([][]byte, error) { + if t == nil || t.Inbox == nil { + return nil, mesh.ErrForwardRuntimeUnavailable + } + direction := t.ReceiveDirection + if direction == "" { + direction = FabricDirectionClientToGateway + } + return t.Inbox.Receive(ctx, t.VPNConnectionID, direction, timeout) +} + +type FabricPacketInbox struct { + capacity int + mu sync.Mutex + queues map[string]chan mesh.VPNPacketBatchPayload + dropped uint64 +} + +func NewFabricPacketInbox(capacity int) *FabricPacketInbox { + if capacity <= 0 { + capacity = 4096 + } + return &FabricPacketInbox{ + capacity: capacity, + queues: map[string]chan mesh.VPNPacketBatchPayload{}, + } +} + +func (i *FabricPacketInbox) DeliverProductionEnvelope(_ context.Context, envelope mesh.ProductionEnvelope) error { + if i == nil { + return mesh.ErrForwardRuntimeUnavailable + } + if envelope.ChannelClass != mesh.ProductionChannelVPNPacket || envelope.MessageType != mesh.ProductionMessageVPNPacketBatch { + return nil + } + payload, err := mesh.DecodeProductionVPNPacketBatch(envelope) + if err != nil { + return err + } + payload.Packets = cleanPacketBatch(payload.Packets) + if len(payload.Packets) == 0 { + return nil + } + return i.enqueue(payload) +} + +func (i *FabricPacketInbox) DeliverLocalPacketBatch(vpnConnectionID, direction string, packets [][]byte) error { + if i == nil { + return mesh.ErrForwardRuntimeUnavailable + } + if vpnConnectionID == "" || direction == "" { + return mesh.ErrForwardEnvelopeInvalid + } + packets = cleanPacketBatch(packets) + if len(packets) == 0 { + return nil + } + return i.enqueue(mesh.VPNPacketBatchPayload{ + SchemaVersion: "rap.vpn_packet_batch.v1", + VPNConnectionID: vpnConnectionID, + Direction: direction, + Packets: packets, + SentAt: time.Now().UTC(), + }) +} + +func (i *FabricPacketInbox) Receive(ctx context.Context, vpnConnectionID, direction string, timeout time.Duration) ([][]byte, error) { + if i == nil { + return nil, mesh.ErrForwardRuntimeUnavailable + } + if vpnConnectionID == "" || direction == "" { + return nil, mesh.ErrForwardEnvelopeInvalid + } + if timeout <= 0 { + timeout = 25 * time.Second + } + timer := time.NewTimer(timeout) + defer timer.Stop() + queue := i.queue(vpnConnectionID, direction) + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + return nil, nil + case payload := <-queue: + packets := cleanPacketBatch(payload.Packets) + if len(packets) == 0 { + continue + } + return packets, nil + } + } +} + +func (i *FabricPacketInbox) enqueue(payload mesh.VPNPacketBatchPayload) error { + queue := i.queue(payload.VPNConnectionID, payload.Direction) + select { + case queue <- payload: + default: + i.mu.Lock() + i.dropped++ + i.mu.Unlock() + } + return nil +} + +func (i *FabricPacketInbox) queue(vpnConnectionID, direction string) chan mesh.VPNPacketBatchPayload { + key := vpnConnectionID + "\x00" + direction + i.mu.Lock() + defer i.mu.Unlock() + if i.queues == nil { + i.queues = map[string]chan mesh.VPNPacketBatchPayload{} + } + queue, ok := i.queues[key] + if !ok { + queue = make(chan mesh.VPNPacketBatchPayload, i.capacity) + i.queues[key] = queue + } + return queue +} + +func (i *FabricPacketInbox) Dropped() uint64 { + if i == nil { + return 0 + } + i.mu.Lock() + defer i.mu.Unlock() + return i.dropped +} + +type FabricPacketInboxSnapshot struct { + SchemaVersion string `json:"schema_version"` + Capacity int `json:"capacity"` + Dropped uint64 `json:"dropped"` + QueueDepths map[string]int `json:"queue_depths"` + QueueCount int `json:"queue_count"` +} + +func (i *FabricPacketInbox) Snapshot() FabricPacketInboxSnapshot { + snapshot := FabricPacketInboxSnapshot{ + SchemaVersion: "rap.fabric_packet_inbox.v1", + QueueDepths: map[string]int{}, + } + if i == nil { + return snapshot + } + i.mu.Lock() + defer i.mu.Unlock() + snapshot.Capacity = i.capacity + snapshot.Dropped = i.dropped + snapshot.QueueCount = len(i.queues) + for key, queue := range i.queues { + snapshot.QueueDepths[strings.ReplaceAll(key, "\x00", ":")] = len(queue) + } + return snapshot +} + +type FabricClientPacketIngressSnapshot struct { + SchemaVersion string `json:"schema_version"` + ClusterID string `json:"cluster_id,omitempty"` + LocalNodeID string `json:"local_node_id,omitempty"` + RouteCandidateCount int `json:"route_candidate_count"` + LastSelectedRouteID string `json:"last_selected_route_id,omitempty"` + LastSelectedNextHop string `json:"last_selected_next_hop,omitempty"` + LastError string `json:"last_error,omitempty"` + SendBatches uint64 `json:"send_batches"` + SendPackets uint64 `json:"send_packets"` + SendRouteAttempts uint64 `json:"send_route_attempts"` + SendRouteFailures uint64 `json:"send_route_failures"` + SendFallbackLocal uint64 `json:"send_fallback_local"` + SendFlowBatches uint64 `json:"send_flow_batches"` + SendFlowPackets uint64 `json:"send_flow_packets"` + SendFlowDropped uint64 `json:"send_flow_dropped"` + SendFlowParallel uint64 `json:"send_flow_parallel_batches"` + MaxParallelFlowSends int `json:"max_parallel_flow_sends"` + RecommendedParallelFlowSends int `json:"recommended_parallel_flow_sends"` + ReceiveBatches uint64 `json:"receive_batches"` + ReceivePackets uint64 `json:"receive_packets"` + ReceiveEmpty uint64 `json:"receive_empty"` + Inbox FabricPacketInboxSnapshot `json:"inbox"` + FlowScheduler FabricFlowSchedulerSnapshot `json:"flow_scheduler"` + RouteManager FabricServiceChannelRouteManager `json:"route_manager"` + RouteManagerTransition FabricServiceChannelRouteManagerTransition `json:"route_manager_transition"` + RouteQualityPreferenceCount int `json:"route_quality_preference_count"` + RouteQualityPreferences []FabricServiceChannelRouteQualityPreference `json:"route_quality_preferences,omitempty"` +} + +func (i *FabricClientPacketIngress) Snapshot(clusterID string) FabricClientPacketIngressSnapshot { + snapshot := FabricClientPacketIngressSnapshot{ + SchemaVersion: "rap.fabric_service_channel_route_manager.v1", + } + if i == nil { + return snapshot + } + i.mu.Lock() + snapshot.ClusterID = firstNonEmpty(clusterID, i.ClusterID) + snapshot.LocalNodeID = i.LocalNodeID + snapshot.LastSelectedRouteID = i.lastSelectedRouteID + snapshot.LastSelectedNextHop = i.lastSelectedNextHop + snapshot.LastError = i.lastError + snapshot.SendBatches = i.sendBatches + snapshot.SendPackets = i.sendPackets + snapshot.SendRouteAttempts = i.sendRouteAttempts + snapshot.SendRouteFailures = i.sendRouteFailures + snapshot.SendFallbackLocal = i.sendFallbackLocal + snapshot.SendFlowBatches = i.sendFlowBatches + snapshot.SendFlowPackets = i.sendFlowPackets + snapshot.SendFlowDropped = i.sendFlowDropped + snapshot.SendFlowParallel = i.sendFlowParallel + snapshot.MaxParallelFlowSends = i.maxParallelFlowSendsLocked() + snapshot.ReceiveBatches = i.receiveBatches + snapshot.ReceivePackets = i.receivePackets + snapshot.ReceiveEmpty = i.receiveEmpty + snapshot.RouteManager = i.RouteManager.snapshot() + snapshot.RouteManagerTransition = i.RouteManagerTransition.snapshot() + snapshot.RouteQualityPreferenceCount = len(i.RouteQualityPreferences) + snapshot.RouteQualityPreferences = routeQualityPreferenceSlice(i.RouteQualityPreferences) + recoveryPolicyFingerprint := strings.TrimSpace(i.RecoveryPolicyFingerprint) + i.mu.Unlock() + snapshot.RouteCandidateCount = len(i.routeCandidates(snapshot.ClusterID)) + snapshot.Inbox = i.inbox().Snapshot() + snapshot.FlowScheduler = i.flowScheduler().Snapshot() + annotateFabricFlowSchedulerProvenance(&snapshot.FlowScheduler, i.routeProvenance(snapshot.ClusterID), recoveryPolicyFingerprint) + snapshot.RecommendedParallelFlowSends = i.flowScheduler().RecommendedParallelSendWindowForTrafficClass(FabricTrafficClassBulk, snapshot.MaxParallelFlowSends) + return snapshot +} + +func (i *FabricClientPacketIngress) UpdateRuntime(forwardTransport mesh.ProductionForwardTransport, inbox *FabricPacketInbox, clusterID string, localNodeID string, localGateway func(string) bool, routes func() []mesh.SyntheticRoute, recoveryPolicyFingerprint string, adaptivePolicies ...FabricServiceChannelAdaptivePolicy) { + if i == nil { + return + } + i.mu.Lock() + defer i.mu.Unlock() + i.ForwardTransport = forwardTransport + i.Inbox = inbox + i.ClusterID = clusterID + i.LocalNodeID = localNodeID + i.LocalGateway = localGateway + i.Routes = routes + i.RecoveryPolicyFingerprint = strings.TrimSpace(recoveryPolicyFingerprint) + adaptivePolicy := defaultFabricServiceChannelAdaptivePolicy() + if len(adaptivePolicies) > 0 { + adaptivePolicy = adaptivePolicies[0] + } + i.AdaptivePolicyFingerprint = strings.TrimSpace(adaptivePolicy.Fingerprint) + if i.FlowScheduler == nil { + i.FlowScheduler = NewFabricFlowScheduler(0, 0) + } + i.FlowScheduler.ConfigureAdaptivePolicy(adaptivePolicy) + if i.MaxParallelFlowSends <= 0 { + i.MaxParallelFlowSends = defaultFabricFlowParallelSendWindow + } +} + +type fabricRouteProvenance struct { + PolicyVersion string + Generation string +} + +type fabricFlowRouteProvenance struct { + PolicyVersion string + Generation string + RecoveryPolicyFingerprint string +} + +func (i *FabricClientPacketIngress) routeProvenanceFor(route mesh.SyntheticRoute) fabricFlowRouteProvenance { + policyVersion := strings.TrimSpace(route.PolicyVersion) + if policyVersion == "" { + policyVersion = strings.TrimSpace(route.RouteVersion) + } + generation := policyVersion + return fabricFlowRouteProvenance{ + PolicyVersion: policyVersion, + Generation: generation, + RecoveryPolicyFingerprint: strings.TrimSpace(i.RecoveryPolicyFingerprint), + } +} + +func (i *FabricClientPacketIngress) routeProvenance(clusterID string) map[string]fabricRouteProvenance { + out := map[string]fabricRouteProvenance{} + routesFunc := i.routesFunc() + if i == nil || routesFunc == nil { + return out + } + localNodeID := strings.TrimSpace(i.LocalNodeID) + for _, route := range routesFunc() { + if strings.TrimSpace(route.RouteID) == "" { + continue + } + if clusterID != "" && route.ClusterID != clusterID { + continue + } + if localNodeID != "" && route.SourceNodeID != localNodeID { + continue + } + policyVersion := strings.TrimSpace(route.PolicyVersion) + if policyVersion == "" { + policyVersion = strings.TrimSpace(route.RouteVersion) + } + out[route.RouteID] = fabricRouteProvenance{ + PolicyVersion: policyVersion, + Generation: policyVersion, + } + } + return out +} + +func annotateFabricFlowSchedulerProvenance(snapshot *FabricFlowSchedulerSnapshot, routes map[string]fabricRouteProvenance, recoveryPolicyFingerprint string) { + if snapshot == nil || len(snapshot.ChannelStats) == 0 { + return + } + for channelID, stat := range snapshot.ChannelStats { + if recoveryPolicyFingerprint != "" { + stat.RecoveryPolicyFingerprint = recoveryPolicyFingerprint + } + if route, ok := routes[stat.LastRouteID]; ok { + stat.RoutePolicyVersion = route.PolicyVersion + stat.RouteGeneration = route.Generation + } + if route, ok := routes[stat.LastFailedRouteID]; ok { + stat.LastFailedRoutePolicyVersion = route.PolicyVersion + stat.LastFailedRouteGeneration = route.Generation + if stat.RoutePolicyVersion == "" { + stat.RoutePolicyVersion = route.PolicyVersion + } + if stat.RouteGeneration == "" { + stat.RouteGeneration = route.Generation + } + } + snapshot.ChannelStats[channelID] = stat + } +} + +func (i *FabricClientPacketIngress) UpdateRouteManager(decisions []FabricServiceChannelRouteManagerDecision, generation string, observedAt time.Time) { + if i == nil { + return + } + manager := NewFabricServiceChannelRouteManager(decisions, generation, observedAt) + i.mu.Lock() + transition := newFabricServiceChannelRouteManagerTransition(i.RouteManager, manager, observedAt) + i.RouteManager = manager + if i.lastSelectedRouteID != "" && manager.isWithdrawn(i.lastSelectedRouteID) { + clearedRouteID := i.lastSelectedRouteID + transition.ClearedSelectedRouteID = clearedRouteID + i.lastSelectedRouteID = manager.replacementRouteID(clearedRouteID) + i.lastSelectedNextHop = "" + } + i.RouteManagerTransition = transition + withdrawnRoutes := manager.withdrawnRouteIDs() + scheduler := i.FlowScheduler + i.mu.Unlock() + if scheduler != nil { + scheduler.ClearQualityPreferencesForRoutes(withdrawnRoutes) + } +} + +func (i *FabricClientPacketIngress) UpdateRouteQualityPreferences(preferences []FabricServiceChannelRouteQualityPreference, observedAt time.Time) { + if i == nil { + return + } + now := observedAt.UTC() + if now.IsZero() { + now = time.Now().UTC() + } + next := map[string]FabricServiceChannelRouteQualityPreference{} + for _, preference := range preferences { + preference.RouteID = strings.TrimSpace(preference.RouteID) + preference.FeedbackStatus = strings.TrimSpace(preference.FeedbackStatus) + if preference.RouteID == "" || preference.ScoreAdjustment <= 0 { + continue + } + if preference.FeedbackStatus != "" && preference.FeedbackStatus != "healthy" { + continue + } + if preference.ExpiresAt != "" { + expiresAt, err := time.Parse(time.RFC3339Nano, preference.ExpiresAt) + if err == nil && !expiresAt.After(now) { + continue + } + } + if preference.RawScoreAdjustment <= 0 { + preference.RawScoreAdjustment = preference.ScoreAdjustment + } + preference.Reasons = dedupeStrings(preference.Reasons) + next[preference.RouteID] = preference + } + i.mu.Lock() + i.RouteQualityPreferences = next + scheduler := i.FlowScheduler + i.mu.Unlock() + validRouteIDs := make(map[string]struct{}, len(next)) + for routeID := range next { + validRouteIDs[routeID] = struct{}{} + } + if scheduler != nil { + scheduler.ClearQualityPreferencesNotIn(validRouteIDs) + } +} + +func NewFabricServiceChannelRouteManager(decisions []FabricServiceChannelRouteManagerDecision, generation string, observedAt time.Time) FabricServiceChannelRouteManager { + manager := FabricServiceChannelRouteManager{ + SchemaVersion: "rap.fabric_service_channel_route_manager_rebuild.v1", + Generation: strings.TrimSpace(generation), + Decisions: []FabricServiceChannelRouteManagerDecision{}, + withdrawnRoutes: map[string]FabricServiceChannelRouteManagerDecision{}, + replacements: map[string]string{}, + } + if !observedAt.IsZero() { + manager.LastAppliedAt = observedAt.UTC().Format(time.RFC3339Nano) + } + for _, decision := range decisions { + decision.RouteID = strings.TrimSpace(decision.RouteID) + decision.ReplacementRouteID = strings.TrimSpace(decision.ReplacementRouteID) + decision.RebuildStatus = strings.TrimSpace(decision.RebuildStatus) + decision.RebuildRequestID = strings.TrimSpace(decision.RebuildRequestID) + if decision.RouteID == "" || decision.RebuildStatus == "" { + continue + } + decision.EffectiveHops = append([]string{}, decision.EffectiveHops...) + manager.Decisions = append(manager.Decisions, decision) + manager.RebuildRequestCount++ + switch decision.RebuildStatus { + case "applied": + manager.RebuildAppliedCount++ + manager.WithdrawnRouteCount++ + manager.withdrawnRoutes[decision.RouteID] = decision + if decision.ReplacementRouteID != "" { + manager.replacements[decision.RouteID] = decision.ReplacementRouteID + } + case "pending_degraded_fallback": + manager.PendingFallbackCount++ + manager.WithdrawnRouteCount++ + manager.withdrawnRoutes[decision.RouteID] = decision + } + } + return manager +} + +func (m FabricServiceChannelRouteManager) snapshot() FabricServiceChannelRouteManager { + out := m + out.Decisions = append([]FabricServiceChannelRouteManagerDecision{}, m.Decisions...) + out.withdrawnRoutes = nil + out.replacements = nil + if out.SchemaVersion == "" { + out.SchemaVersion = "rap.fabric_service_channel_route_manager_rebuild.v1" + } + return out +} + +func newFabricServiceChannelRouteManagerTransition(previous FabricServiceChannelRouteManager, next FabricServiceChannelRouteManager, observedAt time.Time) FabricServiceChannelRouteManagerTransition { + transition := FabricServiceChannelRouteManagerTransition{ + SchemaVersion: "rap.fabric_service_channel_route_manager_transition.v1", + PreviousGeneration: strings.TrimSpace(previous.Generation), + Generation: strings.TrimSpace(next.Generation), + DecisionCount: len(next.Decisions), + WithdrawnRouteCount: next.WithdrawnRouteCount, + PendingFallbackCount: next.PendingFallbackCount, + RebuildAppliedCount: next.RebuildAppliedCount, + } + if !observedAt.IsZero() { + transition.ObservedAt = observedAt.UTC().Format(time.RFC3339Nano) + } + previousWithdrawn := previous.withdrawnRouteIDs() + nextWithdrawn := next.withdrawnRouteIDs() + for routeID := range previousWithdrawn { + if _, ok := nextWithdrawn[routeID]; !ok { + transition.RestoredRouteCount++ + } + } + switch { + case transition.RestoredRouteCount > 0 && transition.WithdrawnRouteCount == 0: + transition.Status = "restored_by_new_config" + case transition.RebuildAppliedCount > 0: + transition.Status = "applied_rebuild" + case transition.PendingFallbackCount > 0: + transition.Status = "pending_degraded_fallback" + case transition.DecisionCount > 0: + transition.Status = "decisions_observed" + default: + transition.Status = "empty" + } + return transition +} + +func (t FabricServiceChannelRouteManagerTransition) snapshot() FabricServiceChannelRouteManagerTransition { + if t.SchemaVersion == "" { + t.SchemaVersion = "rap.fabric_service_channel_route_manager_transition.v1" + } + return t +} + +func (m FabricServiceChannelRouteManager) withdrawnRouteIDs() map[string]struct{} { + out := map[string]struct{}{} + if m.withdrawnRoutes != nil { + for routeID := range m.withdrawnRoutes { + routeID = strings.TrimSpace(routeID) + if routeID != "" { + out[routeID] = struct{}{} + } + } + return out + } + for _, decision := range m.Decisions { + if decision.RouteID == "" { + continue + } + if decision.RebuildStatus == "applied" || decision.RebuildStatus == "pending_degraded_fallback" { + out[strings.TrimSpace(decision.RouteID)] = struct{}{} + } + } + return out +} + +func (m FabricServiceChannelRouteManager) isWithdrawn(routeID string) bool { + routeID = strings.TrimSpace(routeID) + if routeID == "" { + return false + } + if m.withdrawnRoutes != nil { + _, ok := m.withdrawnRoutes[routeID] + return ok + } + for _, decision := range m.Decisions { + if decision.RouteID == routeID && (decision.RebuildStatus == "applied" || decision.RebuildStatus == "pending_degraded_fallback") { + return true + } + } + return false +} + +func (m FabricServiceChannelRouteManager) replacementRouteID(routeID string) string { + routeID = strings.TrimSpace(routeID) + if routeID == "" { + return "" + } + if m.replacements != nil { + return strings.TrimSpace(m.replacements[routeID]) + } + for _, decision := range m.Decisions { + if strings.TrimSpace(decision.RouteID) == routeID && decision.RebuildStatus == "applied" { + return strings.TrimSpace(decision.ReplacementRouteID) + } + } + return "" +} + +func (i *FabricClientPacketIngress) forwardTransport() mesh.ProductionForwardTransport { + if i == nil { + return nil + } + i.mu.Lock() + defer i.mu.Unlock() + return i.ForwardTransport +} + +func (i *FabricClientPacketIngress) inbox() *FabricPacketInbox { + if i == nil { + return nil + } + i.mu.Lock() + defer i.mu.Unlock() + return i.Inbox +} + +func (i *FabricClientPacketIngress) localGateway() func(string) bool { + if i == nil { + return nil + } + i.mu.Lock() + defer i.mu.Unlock() + return i.LocalGateway +} + +func (i *FabricClientPacketIngress) routesFunc() func() []mesh.SyntheticRoute { + if i == nil { + return nil + } + i.mu.Lock() + defer i.mu.Unlock() + return i.Routes +} + +func (i *FabricClientPacketIngress) flowScheduler() *FabricFlowScheduler { + if i == nil { + return NewFabricFlowScheduler(0, 0) + } + i.mu.Lock() + defer i.mu.Unlock() + if i.FlowScheduler == nil { + i.FlowScheduler = NewFabricFlowScheduler(0, 0) + } + return i.FlowScheduler +} + +func (i *FabricClientPacketIngress) maxParallelFlowSends() int { + if i == nil { + return 1 + } + i.mu.Lock() + defer i.mu.Unlock() + return i.maxParallelFlowSendsLocked() +} + +func (i *FabricClientPacketIngress) maxParallelFlowSendsLocked() int { + if i == nil || i.MaxParallelFlowSends <= 0 { + return 1 + } + return i.MaxParallelFlowSends +} + +func (i *FabricClientPacketIngress) routeManager() FabricServiceChannelRouteManager { + if i == nil { + return FabricServiceChannelRouteManager{} + } + i.mu.Lock() + defer i.mu.Unlock() + return i.RouteManager +} + +func (i *FabricClientPacketIngress) routeQualityPreferences() map[string]FabricServiceChannelRouteQualityPreference { + if i == nil { + return nil + } + i.mu.Lock() + defer i.mu.Unlock() + out := make(map[string]FabricServiceChannelRouteQualityPreference, len(i.RouteQualityPreferences)) + for routeID, preference := range i.RouteQualityPreferences { + out[routeID] = preference + } + return out +} + +func (i *FabricClientPacketIngress) routeQualityPreference(routeID string) (FabricServiceChannelRouteQualityPreference, bool) { + routeID = strings.TrimSpace(routeID) + if i == nil || routeID == "" { + return FabricServiceChannelRouteQualityPreference{}, false + } + i.mu.Lock() + defer i.mu.Unlock() + preference, ok := i.RouteQualityPreferences[routeID] + return preference, ok +} + +func routeQualityPreferenceSlice(preferences map[string]FabricServiceChannelRouteQualityPreference) []FabricServiceChannelRouteQualityPreference { + if len(preferences) == 0 { + return nil + } + out := make([]FabricServiceChannelRouteQualityPreference, 0, len(preferences)) + for _, preference := range preferences { + preference.Reasons = dedupeStrings(preference.Reasons) + out = append(out, preference) + } + sort.SliceStable(out, func(a, b int) bool { + if out[a].ScoreAdjustment != out[b].ScoreAdjustment { + return out[a].ScoreAdjustment > out[b].ScoreAdjustment + } + return out[a].RouteID < out[b].RouteID + }) + return out +} + +func (i *FabricClientPacketIngress) lastRouteID() string { + if i == nil { + return "" + } + i.mu.Lock() + defer i.mu.Unlock() + return i.lastSelectedRouteID +} + +func (i *FabricClientPacketIngress) recordSendBatch(packetCount int) { + i.mu.Lock() + defer i.mu.Unlock() + i.sendBatches++ + i.sendPackets += uint64(packetCount) +} + +func (i *FabricClientPacketIngress) recordRouteAttempt() { + i.mu.Lock() + defer i.mu.Unlock() + i.sendRouteAttempts++ +} + +func (i *FabricClientPacketIngress) recordRouteFailure(err error) { + i.mu.Lock() + defer i.mu.Unlock() + i.sendRouteFailures++ + if err != nil { + i.lastError = err.Error() + } +} + +func (i *FabricClientPacketIngress) recordRouteSuccess(routeID, nextHop string) { + i.mu.Lock() + defer i.mu.Unlock() + i.lastSelectedRouteID = routeID + i.lastSelectedNextHop = nextHop + i.lastError = "" +} + +func (i *FabricClientPacketIngress) recordLocalFallback() { + i.mu.Lock() + defer i.mu.Unlock() + i.sendFallbackLocal++ + i.lastSelectedRouteID = "local_gateway" + i.lastSelectedNextHop = i.LocalNodeID + i.lastError = "" +} + +func (i *FabricClientPacketIngress) recordFlowBatch(packetCount int) { + i.mu.Lock() + defer i.mu.Unlock() + i.sendFlowBatches++ + i.sendFlowPackets += uint64(packetCount) +} + +func (i *FabricClientPacketIngress) recordFlowDropped(packetCount uint64) { + i.mu.Lock() + defer i.mu.Unlock() + i.sendFlowDropped += packetCount +} + +func (i *FabricClientPacketIngress) recordFlowParallel() { + i.mu.Lock() + defer i.mu.Unlock() + i.sendFlowParallel++ +} + +func (i *FabricClientPacketIngress) recordReceiveBatch(packetCount int) { + i.mu.Lock() + defer i.mu.Unlock() + i.receiveBatches++ + i.receivePackets += uint64(packetCount) +} + +func (i *FabricClientPacketIngress) recordReceiveEmpty() { + i.mu.Lock() + defer i.mu.Unlock() + i.receiveEmpty++ +} + +func (i *FabricClientPacketIngress) recordError(err error) { + if err == nil { + return + } + i.mu.Lock() + defer i.mu.Unlock() + i.lastError = err.Error() +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return strings.TrimSpace(value) + } + } + return "" +} + +func nextHopAfter(path []string, localNodeID string, destinationNodeID string) string { + if len(path) == 0 { + return destinationNodeID + } + for index, nodeID := range path { + if nodeID == localNodeID { + if index+1 < len(path) { + return path[index+1] + } + return localNodeID + } + } + return destinationNodeID +} + +func containsString(values []string, needle string) bool { + for _, value := range values { + if value == needle { + return true + } + } + return false +} + +func dedupeStrings(values []string) []string { + if len(values) == 0 { + return nil + } + seen := map[string]struct{}{} + out := make([]string, 0, len(values)) + for _, value := range values { + value = strings.TrimSpace(value) + if value == "" { + continue + } + if _, ok := seen[value]; ok { + continue + } + seen[value] = struct{}{} + out = append(out, value) + } + return out +} + +func classifyPacketFlow(packet []byte, shardCount int) (string, int) { + if shardCount <= 0 { + shardCount = defaultFabricFlowShardCount + } + key := packetFlowKey(packet) + hash := fnv.New32a() + _, _ = hash.Write([]byte(key)) + shard := int(hash.Sum32() % uint32(shardCount)) + return key, shard +} + +func packetFlowKey(packet []byte) string { + if len(packet) == 0 { + return "empty" + } + version := packet[0] >> 4 + switch version { + case 4: + return ipv4PacketFlowKey(packet) + case 6: + return ipv6PacketFlowKey(packet) + default: + sum := fnv.New64a() + _, _ = sum.Write(packet) + return fmt.Sprintf("opaque:%x", sum.Sum64()) + } +} + +func ipv4PacketFlowKey(packet []byte) string { + if len(packet) < 20 { + return packetHashFlowKey("ipv4-short", packet) + } + ihl := int(packet[0]&0x0f) * 4 + if ihl < 20 || len(packet) < ihl { + return packetHashFlowKey("ipv4-invalid", packet) + } + proto := packet[9] + src := binary.BigEndian.Uint32(packet[12:16]) + dst := binary.BigEndian.Uint32(packet[16:20]) + srcPort, dstPort := transportPorts(proto, packet[ihl:]) + if src > dst || (src == dst && srcPort > dstPort) { + src, dst = dst, src + srcPort, dstPort = dstPort, srcPort + } + return fmt.Sprintf("ipv4:%d:%08x:%d:%08x:%d", proto, src, srcPort, dst, dstPort) +} + +func ipv6PacketFlowKey(packet []byte) string { + if len(packet) < 40 { + return packetHashFlowKey("ipv6-short", packet) + } + nextHeader := packet[6] + src := packet[8:24] + dst := packet[24:40] + srcPort, dstPort := transportPorts(nextHeader, packet[40:]) + srcKey := fmt.Sprintf("%x", src) + dstKey := fmt.Sprintf("%x", dst) + if srcKey > dstKey || (srcKey == dstKey && srcPort > dstPort) { + srcKey, dstKey = dstKey, srcKey + srcPort, dstPort = dstPort, srcPort + } + return fmt.Sprintf("ipv6:%d:%s:%d:%s:%d", nextHeader, srcKey, srcPort, dstKey, dstPort) +} + +func transportPorts(proto byte, payload []byte) (uint16, uint16) { + switch proto { + case 6, 17: + if len(payload) >= 4 { + return binary.BigEndian.Uint16(payload[0:2]), binary.BigEndian.Uint16(payload[2:4]) + } + } + return 0, 0 +} + +func packetHashFlowKey(prefix string, packet []byte) string { + sum := fnv.New64a() + _, _ = sum.Write(packet) + return fmt.Sprintf("%s:%x", prefix, sum.Sum64()) +} + +func cleanPacketBatch(packets [][]byte) [][]byte { + if len(packets) == 0 { + return nil + } + cleaned := make([][]byte, 0, len(packets)) + for _, packet := range packets { + if len(packet) == 0 { + continue + } + cleaned = append(cleaned, append([]byte(nil), packet...)) + } + return cleaned +} diff --git a/agents/rap-node-agent/internal/vpnruntime/gateway.go b/agents/rap-node-agent/internal/vpnruntime/gateway.go new file mode 100644 index 0000000..25c9138 --- /dev/null +++ b/agents/rap-node-agent/internal/vpnruntime/gateway.go @@ -0,0 +1,495 @@ +package vpnruntime + +import ( + "context" + "fmt" + "io" + "log" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/example/remote-access-platform/agents/rap-node-agent/internal/client" +) + +type Gateway struct { + API *client.Client + Transport PacketTransport + ClusterID string + VPNConnectionID string + InterfaceName string + AddressCIDR string + RouteCIDR string + PollTimeout time.Duration + + mu sync.Mutex + running bool + lastErr error + cancel context.CancelFunc + + clientToGatewayBatches atomic.Uint64 + clientToGatewayPackets atomic.Uint64 + clientToGatewayBytes atomic.Uint64 + gatewayToClientBatches atomic.Uint64 + gatewayToClientPackets atomic.Uint64 + gatewayToClientBytes atomic.Uint64 + tunReadPackets atomic.Uint64 + tunReadBytes atomic.Uint64 + tunWritePackets atomic.Uint64 + tunWriteBytes atomic.Uint64 + uploadQueueDrops atomic.Uint64 + downloadErrors atomic.Uint64 + uploadErrors atomic.Uint64 + + lastClientToGatewayPacket string + lastGatewayToClientPacket string + lastRuntimeActivityAt time.Time +} + +const ( + vpnGatewayBatchMaxPackets = 2048 + vpnGatewayBatchMaxBytes = 8 * 1024 * 1024 + vpnGatewayBatchFlushTimeout = 3 * time.Millisecond +) + +type readWriteCloser interface { + io.Reader + io.Writer + io.Closer +} + +type PacketTransport interface { + SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error + ReceiveGatewayPacketBatch(ctx context.Context, timeout time.Duration) ([][]byte, error) +} + +type BackendPacketTransport struct { + API *client.Client + ClusterID string + VPNConnectionID string +} + +func (t BackendPacketTransport) SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error { + return t.API.SendVPNGatewayPacketBatch(ctx, t.ClusterID, t.VPNConnectionID, packets) +} + +func (t BackendPacketTransport) ReceiveGatewayPacketBatch(ctx context.Context, timeout time.Duration) ([][]byte, error) { + return t.API.ReceiveVPNGatewayPacketBatch(ctx, t.ClusterID, t.VPNConnectionID, timeout) +} + +func (g *Gateway) EnsureStarted(ctx context.Context) error { + g.mu.Lock() + if g.running { + g.mu.Unlock() + return nil + } + g.mu.Unlock() + + if err := g.normalize(); err != nil { + g.setStopped(err) + return err + } + tun, err := openGatewayTun(g.InterfaceName, g.AddressCIDR, g.RouteCIDR) + if err != nil { + g.setStopped(err) + return err + } + runCtx, cancel := context.WithCancel(ctx) + + g.mu.Lock() + if g.running { + g.mu.Unlock() + cancel() + _ = tun.Close() + return nil + } + g.running = true + g.lastErr = nil + g.cancel = cancel + g.mu.Unlock() + + go func() { + if err := g.run(runCtx, tun); err != nil && runCtx.Err() == nil { + log.Printf("vpn gateway runtime stopped: vpn_connection_id=%s error=%v", g.VPNConnectionID, err) + g.setStopped(err) + return + } + g.setStopped(runCtx.Err()) + }() + return nil +} + +func (g *Gateway) Stop() { + g.mu.Lock() + cancel := g.cancel + g.cancel = nil + g.running = false + g.mu.Unlock() + if cancel != nil { + cancel() + } +} + +func (g *Gateway) Status() (bool, string) { + g.mu.Lock() + defer g.mu.Unlock() + if g.lastErr != nil { + return g.running, g.lastErr.Error() + } + return g.running, "" +} + +func (g *Gateway) IsReadyForConnection(vpnConnectionID string) bool { + g.mu.Lock() + defer g.mu.Unlock() + return g.running && g.VPNConnectionID == vpnConnectionID && vpnConnectionID != "" +} + +func (g *Gateway) Snapshot() map[string]any { + g.mu.Lock() + running := g.running + lastErr := "" + if g.lastErr != nil { + lastErr = g.lastErr.Error() + } + lastClientToGatewayPacket := g.lastClientToGatewayPacket + lastGatewayToClientPacket := g.lastGatewayToClientPacket + lastRuntimeActivityAt := g.lastRuntimeActivityAt + g.mu.Unlock() + + out := map[string]any{ + "running": running, + "transport": g.transportName(), + "poll_timeout_ms": g.PollTimeout.Milliseconds(), + "client_to_gateway_batches": g.clientToGatewayBatches.Load(), + "client_to_gateway_packets": g.clientToGatewayPackets.Load(), + "client_to_gateway_bytes": g.clientToGatewayBytes.Load(), + "gateway_to_client_batches": g.gatewayToClientBatches.Load(), + "gateway_to_client_packets": g.gatewayToClientPackets.Load(), + "gateway_to_client_bytes": g.gatewayToClientBytes.Load(), + "tun_read_packets": g.tunReadPackets.Load(), + "tun_read_bytes": g.tunReadBytes.Load(), + "tun_write_packets": g.tunWritePackets.Load(), + "tun_write_bytes": g.tunWriteBytes.Load(), + "upload_queue_drops": g.uploadQueueDrops.Load(), + "download_errors": g.downloadErrors.Load(), + "upload_errors": g.uploadErrors.Load(), + "last_client_to_gateway": lastClientToGatewayPacket, + "last_gateway_to_client": lastGatewayToClientPacket, + } + if lastErr != "" { + out["last_error"] = lastErr + } + if !lastRuntimeActivityAt.IsZero() { + out["last_runtime_activity_at"] = lastRuntimeActivityAt.UTC().Format(time.RFC3339Nano) + } + return out +} + +func (g *Gateway) transportName() string { + switch g.Transport.(type) { + case *FabricPacketTransport: + return "fabric_mesh" + case *LocalPacketTransport: + return "local_fabric_inbox" + case *AdaptivePacketTransport: + return "adaptive_fabric_backend" + case BackendPacketTransport: + return "backend_http_packet_relay" + default: + if g.Transport == nil { + return "none" + } + return fmt.Sprintf("%T", g.Transport) + } +} + +func (g *Gateway) setStopped(err error) { + g.mu.Lock() + defer g.mu.Unlock() + g.running = false + g.lastErr = err + g.cancel = nil +} + +func (g *Gateway) normalize() error { + if g.Transport == nil { + if g.API == nil { + return fmt.Errorf("api client or packet transport is required") + } + g.Transport = BackendPacketTransport{ + API: g.API, + ClusterID: g.ClusterID, + VPNConnectionID: g.VPNConnectionID, + } + } + if g.ClusterID == "" || g.VPNConnectionID == "" { + return fmt.Errorf("cluster id and vpn connection id are required") + } + if g.InterfaceName == "" { + g.InterfaceName = "rapvpn0" + } + if g.AddressCIDR == "" { + g.AddressCIDR = "10.77.0.1/24" + } + if g.RouteCIDR == "" { + g.RouteCIDR = "10.77.0.0/24" + } + if g.PollTimeout <= 0 { + g.PollTimeout = 25 * time.Second + } + return nil +} + +func (g *Gateway) run(ctx context.Context, tun readWriteCloser) error { + defer tun.Close() + + errCh := make(chan error, 2) + go func() { errCh <- g.copyGatewayToClient(ctx, tun) }() + go func() { errCh <- g.copyClientToGateway(ctx, tun) }() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: + return err + } +} + +func (g *Gateway) copyGatewayToClient(ctx context.Context, tun io.Reader) error { + packets := make(chan []byte, 32768) + errCh := make(chan error, 1) + go func() { + errCh <- g.uploadGatewayPackets(ctx, packets) + }() + + buffer := make([]byte, 65535) + for { + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: + if err != nil { + return err + } + default: + } + n, err := tun.Read(buffer) + if err != nil { + return err + } + if n <= 0 { + continue + } + packet := append([]byte(nil), buffer[:n]...) + normalizeIPv4PacketChecksums(packet) + g.recordTunRead(packet) + select { + case packets <- packet: + default: + g.uploadQueueDrops.Add(1) + log.Printf("vpn gateway packet upload queue full; dropping packet: vpn_connection_id=%s", g.VPNConnectionID) + } + } +} + +func (g *Gateway) uploadGatewayPackets(ctx context.Context, packets <-chan []byte) error { + batch := make([][]byte, 0, vpnGatewayBatchMaxPackets) + batchBytes := 0 + timer := time.NewTimer(time.Hour) + if !timer.Stop() { + <-timer.C + } + timerActive := false + flush := func() { + if len(batch) == 0 { + return + } + packetCount := len(batch) + byteCount := packetBytesTotal(batch) + if err := g.Transport.SendGatewayPacketBatch(ctx, batch); err != nil { + g.uploadErrors.Add(1) + log.Printf("vpn gateway packet batch upload failed: vpn_connection_id=%s packets=%d error=%v", g.VPNConnectionID, len(batch), err) + } else { + g.recordGatewayToClientBatch(packetCount, byteCount, batch[0]) + } + for i := range batch { + batch[i] = nil + } + batch = batch[:0] + batchBytes = 0 + } + for { + if len(batch) == 0 && timerActive { + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timerActive = false + } + select { + case <-ctx.Done(): + flush() + return ctx.Err() + case packet := <-packets: + packetBytes := len(packet) + if packetBytes <= 0 { + continue + } + packetFrameSize := 4 + packetBytes + if len(batch) > 0 { + if len(batch) >= vpnGatewayBatchMaxPackets || batchBytes+packetFrameSize > vpnGatewayBatchMaxBytes { + flush() + } + } + batch = append(batch, packet) + batchBytes += packetFrameSize + if len(batch) >= vpnGatewayBatchMaxPackets || batchBytes >= vpnGatewayBatchMaxBytes { + flush() + continue + } + if !timerActive { + timer.Reset(vpnGatewayBatchFlushTimeout) + timerActive = true + } + case <-timer.C: + timerActive = false + flush() + } + } +} + +func (g *Gateway) copyClientToGateway(ctx context.Context, tun io.Writer) error { + for { + packets, err := g.Transport.ReceiveGatewayPacketBatch(ctx, g.PollTimeout) + if err != nil { + log.Printf("vpn gateway packet download failed: vpn_connection_id=%s error=%v", g.VPNConnectionID, err) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(time.Second): + } + continue + } + if len(packets) == 0 { + continue + } + g.recordClientToGatewayBatch(len(packets), packetBytesTotal(packets), packets[0]) + for _, packet := range packets { + normalizeIPv4PacketChecksums(packet) + if _, err := tun.Write(packet); err != nil { + g.downloadErrors.Add(1) + return err + } + g.recordTunWrite(packet) + } + } +} + +func (g *Gateway) recordClientToGatewayBatch(packetCount int, byteCount int, first []byte) { + next := g.clientToGatewayBatches.Add(1) + g.clientToGatewayPackets.Add(uint64(packetCount)) + g.clientToGatewayBytes.Add(uint64(byteCount)) + summary := summarizePacket(first) + g.mu.Lock() + g.lastClientToGatewayPacket = summary + g.lastRuntimeActivityAt = time.Now().UTC() + g.mu.Unlock() + if next <= 5 { + log.Printf( + "vpn gateway client_to_gateway batch received: vpn_connection_id=%s batch=%d packets=%d bytes=%d first=%s", + g.VPNConnectionID, + next, + packetCount, + byteCount, + summary, + ) + } +} + +func (g *Gateway) recordGatewayToClientBatch(packetCount int, byteCount int, first []byte) { + next := g.gatewayToClientBatches.Add(1) + g.gatewayToClientPackets.Add(uint64(packetCount)) + g.gatewayToClientBytes.Add(uint64(byteCount)) + summary := summarizePacket(first) + g.mu.Lock() + g.lastGatewayToClientPacket = summary + g.lastRuntimeActivityAt = time.Now().UTC() + g.mu.Unlock() + if next <= 5 { + log.Printf( + "vpn gateway gateway_to_client batch uploaded: vpn_connection_id=%s batch=%d packets=%d bytes=%d first=%s", + g.VPNConnectionID, + next, + packetCount, + byteCount, + summary, + ) + } +} + +func (g *Gateway) recordTunWrite(packet []byte) { + next := g.tunWritePackets.Add(1) + g.tunWriteBytes.Add(uint64(len(packet))) + if next <= 5 { + log.Printf("vpn gateway packet written to tun: vpn_connection_id=%s packet=%d bytes=%d summary=%s", g.VPNConnectionID, next, len(packet), summarizePacket(packet)) + } +} + +func (g *Gateway) recordTunRead(packet []byte) { + next := g.tunReadPackets.Add(1) + g.tunReadBytes.Add(uint64(len(packet))) + if next <= 5 { + log.Printf("vpn gateway packet read from tun: vpn_connection_id=%s packet=%d bytes=%d summary=%s", g.VPNConnectionID, next, len(packet), summarizePacket(packet)) + } +} + +func packetBytesTotal(packets [][]byte) int { + total := 0 + for _, packet := range packets { + total += len(packet) + } + return total +} + +func summarizePacket(packet []byte) string { + if len(packet) < 1 { + return "empty" + } + version := packet[0] >> 4 + switch version { + case 4: + return summarizeIPv4(packet) + case 6: + return summarizeIPv6(packet) + default: + return fmt.Sprintf("ip_version=%d bytes=%d", version, len(packet)) + } +} + +func summarizeIPv4(packet []byte) string { + if len(packet) < 20 { + return fmt.Sprintf("ipv4 truncated bytes=%d", len(packet)) + } + ihl := int(packet[0]&0x0f) * 4 + if ihl < 20 || len(packet) < ihl { + return fmt.Sprintf("ipv4 invalid_ihl=%d bytes=%d", ihl, len(packet)) + } + proto := packet[9] + src := net.IP(packet[12:16]).String() + dst := net.IP(packet[16:20]).String() + return fmt.Sprintf("ipv4 %s -> %s proto=%d bytes=%d", src, dst, proto, len(packet)) +} + +func summarizeIPv6(packet []byte) string { + if len(packet) < 40 { + return fmt.Sprintf("ipv6 truncated bytes=%d", len(packet)) + } + nextHeader := packet[6] + src := net.IP(packet[8:24]).String() + dst := net.IP(packet[24:40]).String() + return fmt.Sprintf("ipv6 %s -> %s next=%d bytes=%d", src, dst, nextHeader, len(packet)) +} diff --git a/agents/rap-node-agent/internal/vpnruntime/tun_linux.go b/agents/rap-node-agent/internal/vpnruntime/tun_linux.go new file mode 100644 index 0000000..36bb960 --- /dev/null +++ b/agents/rap-node-agent/internal/vpnruntime/tun_linux.go @@ -0,0 +1,206 @@ +//go:build linux + +package vpnruntime + +import ( + "errors" + "fmt" + "net" + "os" + "os/exec" + "strings" + "syscall" + "unsafe" +) + +const ( + tunDevicePath = "/dev/net/tun" + iffTun = 0x0001 + iffNoPI = 0x1000 + tunSetIFF = 0x400454ca + ifNameSize = 16 +) + +type tunDevice struct { + file *os.File + fd int + name string +} + +func openGatewayTun(name, addressCIDR, routeCIDR string) (*tunDevice, error) { + dev, err := openGatewayTunDevice(name) + if errors.Is(err, syscall.EBUSY) { + cleanupStaleGatewayInterface(name) + dev, err = openGatewayTunDevice(name) + } + if err != nil { + return nil, err + } + if err := configureGatewayInterface(name, addressCIDR, routeCIDR); err != nil { + _ = dev.Close() + return nil, err + } + return dev, nil +} + +func openGatewayTunDevice(name string) (*tunDevice, error) { + file, err := os.OpenFile(tunDevicePath, os.O_RDWR, 0) + if err != nil { + return nil, fmt.Errorf("open %s: %w", tunDevicePath, err) + } + ifr := make([]byte, 40) + copy(ifr[:ifNameSize], []byte(name)) + *(*uint16)(unsafe.Pointer(&ifr[ifNameSize])) = iffTun | iffNoPI + if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, file.Fd(), uintptr(tunSetIFF), uintptr(unsafe.Pointer(&ifr[0]))); errno != 0 { + file.Close() + return nil, fmt.Errorf("configure tun %s: %w", name, errno) + } + return &tunDevice{file: file, fd: int(file.Fd()), name: name}, nil +} + +func cleanupStaleGatewayInterface(name string) { + if strings.TrimSpace(name) == "" { + return + } + _ = runCommand("ip", "link", "set", name, "down") + _ = runCommand("ip", "link", "delete", name) +} + +func (d *tunDevice) Read(packet []byte) (int, error) { + return syscall.Read(d.fd, packet) +} + +func (d *tunDevice) Write(packet []byte) (int, error) { + return syscall.Write(d.fd, packet) +} + +func (d *tunDevice) Close() error { + _ = runCommand("ip", "link", "set", d.name, "down") + return d.file.Close() +} + +func configureGatewayInterface(name, addressCIDR, routeCIDR string) error { + if _, _, err := net.ParseCIDR(addressCIDR); err != nil { + return fmt.Errorf("invalid vpn gateway address %q: %w", addressCIDR, err) + } + if err := runCommand("ip", "addr", "replace", addressCIDR, "dev", name); err != nil { + return err + } + if err := runCommand("ip", "link", "set", name, "up"); err != nil { + return err + } + if err := enableIPv4Forwarding(); err != nil { + return err + } + if err := disableReversePathFiltering(name); err != nil { + return err + } + if err := ensureForwardingRules(name); err != nil { + return err + } + if err := ensureMasqueradeRules(routeCIDR); err != nil { + return err + } + if err := ensureMSSClampRule(name); err != nil { + return err + } + return nil +} + +func ensureMasqueradeRules(routeCIDR string) error { + egress, _ := defaultIPv4Interface() + if egress != "" { + if err := ensureIPTablesRule("nat", "POSTROUTING", "-s", routeCIDR, "-o", egress, "-j", "MASQUERADE"); err != nil { + return err + } + } + return ensureIPTablesRule("nat", "POSTROUTING", "-s", routeCIDR, "-j", "MASQUERADE") +} + +func ensureMSSClampRule(interfaceName string) error { + err := ensureIPTablesRule("mangle", "FORWARD", "-i", interfaceName, "-p", "tcp", "--tcp-flags", "SYN,RST", "SYN", "-j", "TCPMSS", "--clamp-mss-to-pmtu") + if err == nil { + return nil + } + return nil +} + +func defaultIPv4Interface() (string, error) { + out, err := exec.Command("ip", "-o", "-4", "route", "show", "default").CombinedOutput() + if err != nil { + return "", fmt.Errorf("ip default route failed: %w: %s", err, string(out)) + } + fields := strings.Fields(string(out)) + for i := 0; i+1 < len(fields); i++ { + if fields[i] == "dev" { + return fields[i+1], nil + } + } + return "", nil +} + +func ensureForwardingRules(interfaceName string) error { + if err := ensureIPTablesRule("filter", "FORWARD", "-i", interfaceName, "-j", "ACCEPT"); err != nil { + return err + } + err := ensureIPTablesRule("filter", "FORWARD", "-o", interfaceName, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT") + if err == nil { + return nil + } + return ensureIPTablesRule("filter", "FORWARD", "-o", interfaceName, "-j", "ACCEPT") +} + +func ensureIPTablesRule(table, chain string, rule ...string) error { + checkArgs := append([]string{"-t", table, "-C", chain}, rule...) + if err := runCommand("iptables", checkArgs...); err == nil { + return nil + } + addArgs := append([]string{"-t", table, "-I", chain, "1"}, rule...) + return runCommand("iptables", addArgs...) +} + +func enableIPv4Forwarding() error { + if current, err := os.ReadFile("/proc/sys/net/ipv4/ip_forward"); err == nil && len(current) > 0 && current[0] == '1' { + return nil + } + if err := os.WriteFile("/proc/sys/net/ipv4/ip_forward", []byte("1\n"), 0o644); err == nil { + return nil + } + return runCommand("sysctl", "-w", "net.ipv4.ip_forward=1") +} + +func disableReversePathFiltering(interfaceName string) error { + keys := []string{"all", "default", interfaceName} + if entries, err := os.ReadDir("/proc/sys/net/ipv4/conf"); err == nil { + for _, entry := range entries { + if entry.IsDir() { + keys = append(keys, entry.Name()) + } + } + } + seen := make(map[string]bool) + for _, key := range keys { + if seen[key] { + continue + } + seen[key] = true + path := fmt.Sprintf("/proc/sys/net/ipv4/conf/%s/rp_filter", key) + if _, err := os.Stat(path); err != nil { + continue + } + if err := os.WriteFile(path, []byte("0\n"), 0o644); err != nil { + if sysctlErr := runCommand("sysctl", "-w", fmt.Sprintf("net.ipv4.conf.%s.rp_filter=0", key)); sysctlErr != nil { + return sysctlErr + } + } + } + return nil +} + +func runCommand(name string, args ...string) error { + cmd := exec.Command(name, args...) + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("%s %v failed: %w: %s", name, args, err, string(out)) + } + return nil +} diff --git a/agents/rap-node-agent/internal/vpnruntime/tun_unsupported.go b/agents/rap-node-agent/internal/vpnruntime/tun_unsupported.go new file mode 100644 index 0000000..b763ac3 --- /dev/null +++ b/agents/rap-node-agent/internal/vpnruntime/tun_unsupported.go @@ -0,0 +1,23 @@ +//go:build !linux && !windows + +package vpnruntime + +import "fmt" + +type tunDevice struct{} + +func openGatewayTun(name, addressCIDR, routeCIDR string) (*tunDevice, error) { + return nil, fmt.Errorf("vpn gateway runtime is currently supported only on linux") +} + +func (d *tunDevice) Read(packet []byte) (int, error) { + return 0, fmt.Errorf("vpn gateway runtime is currently supported only on linux") +} + +func (d *tunDevice) Write(packet []byte) (int, error) { + return 0, fmt.Errorf("vpn gateway runtime is currently supported only on linux") +} + +func (d *tunDevice) Close() error { + return nil +}