Refactor RDP proxy handling and update related tests

This commit is contained in:
2026-05-17 20:38:35 +03:00
parent 8e9402580f
commit d551e57fd5
172 changed files with 22117 additions and 2509 deletions
@@ -5,6 +5,8 @@ import (
"encoding/json"
"errors"
"fmt"
"net"
"net/url"
"strings"
"time"
@@ -190,7 +192,7 @@ func (s *PostgresStore) UpdateCluster(ctx context.Context, input UpdateClusterIn
func (s *PostgresStore) GetClusterAuthority(ctx context.Context, clusterID string) (ClusterAuthorityKey, error) {
row := s.db.QueryRow(ctx, `
SELECT cluster_id::text, authority_state, key_algorithm, public_key,
public_key_fingerprint, private_key, created_at, updated_at
public_key_fingerprint, private_key, created_at, updated_at, metadata
FROM cluster_authorities
WHERE cluster_id = $1::uuid
`, clusterID)
@@ -3497,7 +3499,7 @@ func (s *PostgresStore) CheckVPNLeaseOwnerEligibility(ctx context.Context, clust
WHERE nra.cluster_id = vc.cluster_id
AND nra.node_id = $3::uuid
AND nra.status = 'active'
AND nra.role IN ('vpn-exit', 'vpn-connector')
AND nra.role IN ('vpn-exit', 'vpn-connector', 'ipv4-egress')
AND (nra.organization_id IS NULL OR nra.organization_id = vc.organization_id)
) AS has_authorized_role
FROM vpn_connections vc
@@ -3582,7 +3584,7 @@ func (s *PostgresStore) ListNodeVPNAssignments(ctx context.Context, clusterID, n
WHERE nra.cluster_id = vc.cluster_id
AND nra.node_id = $2::uuid
AND nra.status = 'active'
AND nra.role IN ('vpn-exit', 'vpn-connector')
AND nra.role IN ('vpn-exit', 'vpn-connector', 'ipv4-egress')
AND (nra.organization_id IS NULL OR nra.organization_id = vc.organization_id)
) AS has_authorized_role,
EXISTS (
@@ -3769,13 +3771,33 @@ func scanClusterAuthority(row scanner) (ClusterAuthorityKey, error) {
&item.PrivateKey,
&item.CreatedAt,
&item.UpdatedAt,
&item.Metadata,
); err != nil {
return ClusterAuthorityKey{}, err
}
item.SchemaVersion = clusterauth.AuthoritySchemaVersion
ensureRaw(&item.Metadata, `{}`)
item.QuorumDescriptor = clusterAuthorityQuorumDescriptorFromMetadata(item.Metadata)
return item, nil
}
func clusterAuthorityQuorumDescriptorFromMetadata(metadata json.RawMessage) *QuorumDescriptor {
if len(metadata) == 0 || !json.Valid(metadata) {
return nil
}
var envelope struct {
QuorumDescriptor *QuorumDescriptor `json:"quorum_descriptor"`
Quorum *QuorumDescriptor `json:"quorum"`
}
if err := json.Unmarshal(metadata, &envelope); err != nil {
return nil
}
if envelope.QuorumDescriptor != nil {
return envelope.QuorumDescriptor
}
return envelope.Quorum
}
func scanNodeGroup(row scanner) (ClusterNodeGroup, error) {
var item ClusterNodeGroup
if err := row.Scan(
@@ -4517,6 +4539,8 @@ func (s *PostgresStore) GetVPNClientProfile(
), '[]'::jsonb) AS allowed_node_ids,
COALESCE(vc.placement_policy->'entry_node_ids', '[]'::jsonb) AS entry_node_ids,
COALESCE(vc.placement_policy->>'exit_node_id', '') AS exit_node_id,
COALESCE(pool.id::text, '') AS exit_pool_id,
COALESCE(pool.name, vc.name) AS exit_pool_name,
CASE WHEN l.id IS NULL THEN NULL ELSE jsonb_build_object(
'lease_id', l.id::text,
'owner_node_id', l.owner_node_id::text,
@@ -4576,6 +4600,34 @@ func (s *PostgresStore) GetVPNClientProfile(
'runtime_observed_at', gateway_status.observed_at
)) END AS client_config
FROM vpn_connections vc
LEFT JOIN LATERAL (
SELECT ep.id, ep.name
FROM fabric_egress_pools ep
WHERE ep.cluster_id = vc.cluster_id
AND ep.status = 'active'
AND (
ep.id::text = COALESCE(vc.placement_policy->>'exit_pool_id', '')
OR ep.name = COALESCE(vc.placement_policy->>'exit_pool_name', '')
OR EXISTS (
SELECT 1
FROM fabric_egress_pool_nodes epn
WHERE epn.egress_pool_id = ep.id
AND epn.cluster_id = vc.cluster_id
AND epn.status = 'active'
AND epn.node_id::text = ANY (
SELECT jsonb_array_elements_text(COALESCE(vc.placement_policy->'exit_node_ids', '[]'::jsonb))
)
)
)
ORDER BY
CASE
WHEN ep.id::text = COALESCE(vc.placement_policy->>'exit_pool_id', '') THEN 0
WHEN ep.name = COALESCE(vc.placement_policy->>'exit_pool_name', '') THEN 1
ELSE 2
END,
ep.name
LIMIT 1
) pool ON TRUE
LEFT JOIN vpn_connection_leases l
ON l.cluster_id = vc.cluster_id
AND l.vpn_connection_id = vc.id
@@ -4620,6 +4672,8 @@ func (s *PostgresStore) GetVPNClientProfile(
&allowedRaw,
&entryRaw,
&item.ExitNodeID,
&item.ExitPoolID,
&item.ExitPoolName,
&activeLeaseRaw,
&item.RoutePolicies,
&item.ClientConfig,
@@ -4641,6 +4695,15 @@ func (s *PostgresStore) GetVPNClientProfile(
ensureRaw(&item.PlacementPolicy, `{}`)
ensureRaw(&item.RoutePolicies, `[]`)
ensureRaw(&item.ClientConfig, `{}`)
if item.ExitPoolName != "" || item.ExitPoolID != "" {
item.ClientConfig = mergeJSONObjects(item.ClientConfig, map[string]any{
"exit_pool": map[string]any{
"id": item.ExitPoolID,
"name": firstNonEmptyMetadataString(item.ExitPoolName, item.Name),
"kind": "virtual_pool",
},
})
}
item.ClientConfig = enrichVPNClientFabricRoute(item, preferredEntryNodeID, preferredExitNodeID)
profile.Connections = append(profile.Connections, item)
}
@@ -4651,8 +4714,13 @@ func (s *PostgresStore) GetVPNClientProfile(
if err != nil {
return VPNClientProfile{}, err
}
exitEndpoints, err := s.vpnEntryEndpointCandidates(ctx, clusterID, vpnProfileExitNodeIDs(profile))
if err != nil {
return VPNClientProfile{}, err
}
for i := range profile.Connections {
profile.Connections[i].ClientConfig = enrichVPNClientEntryEndpointCandidates(profile.Connections[i], entryEndpoints)
profile.Connections[i].ClientConfig = enrichVPNClientExitEndpointCandidates(profile.Connections[i], exitEndpoints)
}
return profile, nil
}
@@ -4733,6 +4801,18 @@ func vpnProfileEntryNodeIDs(profile VPNClientProfile) []string {
return dedupeStrings(out)
}
func vpnProfileExitNodeIDs(profile VPNClientProfile) []string {
var out []string
for _, connection := range profile.Connections {
route := vpnFabricRouteFromClientConfig(connection.ClientConfig)
out = append(out, route.SelectedExitNodeID)
out = append(out, route.ExitPoolNodeIDs...)
out = append(out, connection.ExitNodeID)
out = append(out, connection.AllowedNodeIDs...)
}
return dedupeStrings(out)
}
func (s *PostgresStore) vpnEntryEndpointCandidates(ctx context.Context, clusterID string, entryNodeIDs []string) (map[string][]map[string]any, error) {
entryNodeIDs = dedupeStrings(entryNodeIDs)
out := make(map[string][]map[string]any, len(entryNodeIDs))
@@ -4778,13 +4858,12 @@ func vpnEntryEndpointCandidatesFromHeartbeat(nodeID string, capabilities json.Ra
if len(metadata) == 0 || json.Unmarshal(metadata, &payload) != nil {
return nil
}
certByCandidate := endpointCandidateCertsFromHeartbeatMetadata(metadata)
report := payload.MeshEndpointReport
var out []map[string]any
seen := map[string]struct{}{}
for _, candidate := range report.EndpointCandidates {
address := strings.TrimSpace(candidate.Address)
if address == "" {
continue
}
candidateNodeID := strings.TrimSpace(candidate.NodeID)
if candidateNodeID == "" {
candidateNodeID = nodeID
@@ -4793,6 +4872,9 @@ func vpnEntryEndpointCandidatesFromHeartbeat(nodeID string, capabilities json.Ra
if transport == "" {
transport = strings.TrimSpace(report.Transport)
}
if !usableVPNFabricPeerEndpoint(address, transport) {
continue
}
connectivityMode := strings.TrimSpace(candidate.ConnectivityMode)
if connectivityMode == "" {
connectivityMode = strings.TrimSpace(report.ConnectivityMode)
@@ -4813,6 +4895,11 @@ func vpnEntryEndpointCandidatesFromHeartbeat(nodeID string, capabilities json.Ra
if endpointID == "" {
endpointID = "mesh-" + candidateNodeID
}
key := candidateNodeID + "\x00" + strings.ToLower(transport) + "\x00" + strings.ToLower(address)
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
item := map[string]any{
"node_id": candidateNodeID,
"endpoint_id": endpointID,
@@ -4826,6 +4913,15 @@ func vpnEntryEndpointCandidatesFromHeartbeat(nodeID string, capabilities json.Ra
"status": "reported",
"source": "node_latest_heartbeat.mesh_endpoint_report.endpoint_candidates",
}
if certSHA256 := firstNonEmptyMetadataString(
endpointCandidateMetadataString(candidate.Metadata, "tls_cert_sha256", "peer_cert_sha256"),
certByCandidate[endpointID],
certByCandidate[address],
certByCandidate[candidateNodeID+"\x00"+address],
); certSHA256 != "" {
item["tls_cert_sha256"] = certSHA256
item["peer_cert_sha256"] = certSHA256
}
if apiBaseURL := vpnEntryAPIBaseURL(address); apiBaseURL != "" {
item["api_base_url"] = apiBaseURL
}
@@ -4833,7 +4929,7 @@ func vpnEntryEndpointCandidatesFromHeartbeat(nodeID string, capabilities json.Ra
}
if len(out) == 0 {
address := strings.TrimSpace(report.PeerEndpoint)
if address != "" {
if usableVPNFabricPeerEndpoint(address, strings.TrimSpace(report.Transport)) {
item := map[string]any{
"node_id": nodeID,
"endpoint_id": "mesh-peer-endpoint-" + nodeID,
@@ -4856,6 +4952,107 @@ func vpnEntryEndpointCandidatesFromHeartbeat(nodeID string, capabilities json.Ra
return out
}
func endpointCandidateCertsFromHeartbeatMetadata(metadata json.RawMessage) map[string]string {
out := map[string]string{}
var payload map[string]any
if len(metadata) == 0 || json.Unmarshal(metadata, &payload) != nil {
return out
}
report, _ := payload["mesh_endpoint_report"].(map[string]any)
candidates, _ := report["endpoint_candidates"].([]any)
for _, raw := range candidates {
candidate, _ := raw.(map[string]any)
if candidate == nil {
continue
}
meta, _ := candidate["metadata"].(map[string]any)
cert := strings.TrimSpace(metadataAnyString(meta["tls_cert_sha256"]))
if cert == "" {
cert = strings.TrimSpace(metadataAnyString(meta["peer_cert_sha256"]))
}
if cert == "" {
continue
}
endpointID := strings.TrimSpace(metadataAnyString(candidate["endpoint_id"]))
address := strings.TrimSpace(metadataAnyString(candidate["address"]))
nodeID := strings.TrimSpace(metadataAnyString(candidate["node_id"]))
if endpointID != "" {
out[endpointID] = cert
}
if address != "" {
out[address] = cert
}
if nodeID != "" && address != "" {
out[nodeID+"\x00"+address] = cert
}
}
return out
}
func metadataAnyString(value any) string {
switch typed := value.(type) {
case string:
return typed
default:
return ""
}
}
func firstNonEmptyMetadataString(values ...string) string {
for _, value := range values {
if strings.TrimSpace(value) != "" {
return strings.TrimSpace(value)
}
}
return ""
}
func usableVPNFabricPeerEndpoint(address string, transport string) bool {
address = strings.TrimSpace(address)
if address == "" {
return false
}
transport = strings.ToLower(strings.TrimSpace(transport))
if !strings.Contains(transport, "quic") {
return false
}
parsed, err := url.Parse(address)
if err != nil {
return false
}
if strings.ToLower(parsed.Scheme) != "quic" {
return false
}
host := parsed.Hostname()
if host == "" {
return false
}
ip := net.ParseIP(host)
if ip == nil {
return true
}
if ip.IsUnspecified() || ip.IsLoopback() {
return false
}
return true
}
func endpointCandidateMetadataString(metadata json.RawMessage, keys ...string) string {
if len(metadata) == 0 {
return ""
}
var values map[string]any
if json.Unmarshal(metadata, &values) != nil {
return ""
}
for _, key := range keys {
if value, ok := values[key].(string); ok && strings.TrimSpace(value) != "" {
return strings.TrimSpace(value)
}
}
return ""
}
func heartbeatCapabilityEnabled(capabilities json.RawMessage, name string) bool {
var cfg map[string]any
if len(capabilities) == 0 || json.Unmarshal(capabilities, &cfg) != nil {
@@ -4921,6 +5118,44 @@ func enrichVPNClientEntryEndpointCandidates(connection VPNClientConnection, endp
return out
}
func enrichVPNClientExitEndpointCandidates(connection VPNClientConnection, endpoints map[string][]map[string]any) json.RawMessage {
var cfg map[string]any
if err := json.Unmarshal(connection.ClientConfig, &cfg); err != nil || cfg == nil {
cfg = map[string]any{}
}
route := vpnFabricRouteFromClientConfig(connection.ClientConfig)
exitIDs := dedupeStrings(append([]string{route.SelectedExitNodeID}, route.ExitPoolNodeIDs...))
exitIDs = dedupeStrings(append(exitIDs, connection.ExitNodeID))
exitIDs = dedupeStrings(append(exitIDs, connection.AllowedNodeIDs...))
var candidates []map[string]any
seen := map[string]struct{}{}
for _, nodeID := range exitIDs {
for _, candidate := range endpoints[nodeID] {
address, _ := candidate["address"].(string)
endpointID, _ := candidate["endpoint_id"].(string)
key := nodeID + "\x00" + endpointID + "\x00" + address
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
enriched := make(map[string]any, len(candidate)+2)
for k, v := range candidate {
enriched[k] = v
}
enriched["selected_exit"] = nodeID != "" && nodeID == route.SelectedExitNodeID
enriched["exit_pool_member"] = true
candidates = append(candidates, enriched)
}
}
cfg["vpn_exit_endpoint_candidates"] = candidates
cfg["vpn_exit_endpoint_candidate_count"] = len(candidates)
out, err := json.Marshal(cfg)
if err != nil {
return connection.ClientConfig
}
return out
}
func listVPNConnectionAllowedNodes(ctx context.Context, q rowQuerier, clusterID, vpnConnectionID string) ([]VPNConnectionAllowedNode, error) {
rows, err := q.Query(ctx, `
SELECT vpn_connection_id::text, cluster_id::text, node_id::text, role_preference,
@@ -5087,13 +5322,32 @@ func ensureRaw(raw *json.RawMessage, fallback string) {
}
}
func mergeJSONObjects(raw json.RawMessage, values map[string]any) json.RawMessage {
out := map[string]any{}
_ = json.Unmarshal(raw, &out)
if out == nil {
out = map[string]any{}
}
for key, value := range values {
out[key] = value
}
payload, err := json.Marshal(out)
if err != nil {
return raw
}
return payload
}
func enrichVPNClientFabricRoute(item VPNClientConnection, preferredEntryNodeID, preferredExitNodeID string) json.RawMessage {
var cfg map[string]any
if err := json.Unmarshal(item.ClientConfig, &cfg); err != nil || cfg == nil {
cfg = map[string]any{}
}
entryPool := dedupeStrings(append([]string{}, item.EntryNodeIDs...))
if len(entryPool) == 0 {
placementPolicy := jsonObjectFromRaw(item.PlacementPolicy)
entrySelector, _ := placementPolicy["entry_selector"].(string)
clientNodeEntry := strings.EqualFold(strings.TrimSpace(entrySelector), "client_node") || placementPolicy["android_node_agent_target"] == true
if len(entryPool) == 0 && !clientNodeEntry {
entryPool = dedupeStrings(append([]string{}, item.AllowedNodeIDs...))
}
exitPool := []string{}
@@ -5107,7 +5361,10 @@ func enrichVPNClientFabricRoute(item VPNClientConnection, preferredEntryNodeID,
exitPool = dedupeStrings(exitPool)
preferredEntryNodeID = strings.TrimSpace(preferredEntryNodeID)
selectedEntry := selectPreferredNode(entryPool, preferredEntryNodeID)
selectedEntry := ""
if !clientNodeEntry {
selectedEntry = selectPreferredNode(entryPool, preferredEntryNodeID)
}
selectedExit := selectPreferredNode(exitPool, preferredExitNodeID)
if selectedExit == "" && item.ActiveLease != nil && item.ActiveLease.OwnerNodeID != "" {
selectedExit = item.ActiveLease.OwnerNodeID
@@ -5116,6 +5373,8 @@ func enrichVPNClientFabricRoute(item VPNClientConnection, preferredEntryNodeID,
switch {
case selectedEntry != "" && selectedExit != "":
status = "planned"
case clientNodeEntry && selectedExit != "":
status = "planned"
case selectedEntry == "":
status = "waiting_for_entry"
case selectedExit == "":
@@ -5129,8 +5388,10 @@ func enrichVPNClientFabricRoute(item VPNClientConnection, preferredEntryNodeID,
"preferred_data_plane": "fabric_service_channel",
"fallback_data_plane": "none",
"backend_relay_fallback": false,
"selection_mode": "farm_authoritative_entry_to_exit",
"selection_mode": "farm_authoritative_client_node_to_exit_pool",
"route_authority": "fabric_farm",
"entry_selector": firstNonEmptyString(entrySelector, "entry-node"),
"client_node_entry": clientNodeEntry,
"vpn_builds_routes": false,
"vpn_builds_tunnels": false,
"farm_builds_routes": true,
@@ -5163,7 +5424,9 @@ func enrichVPNClientFabricRoute(item VPNClientConnection, preferredEntryNodeID,
"diagnostics_only_protocol_summaries": true,
},
"route_selection": map[string]any{
"mode": "farm_authoritative_lowest_latency_healthy_route",
"mode": "farm_authoritative_lowest_latency_healthy_route_to_exit_pool",
"entry_selector": firstNonEmptyString(entrySelector, "entry-node"),
"client_node_entry": clientNodeEntry,
"selected_entry_node_id": selectedEntry,
"selected_exit_node_id": selectedExit,
"route_candidates": routeCandidates,
@@ -5175,7 +5438,7 @@ func enrichVPNClientFabricRoute(item VPNClientConnection, preferredEntryNodeID,
"preserve_vpn_connection_id": true,
"alternate_route_count": alternateVPNRouteCount(routeCandidates, selectedEntry, selectedExit),
"reroute_triggers": []string{
"entry_unhealthy",
"client_node_mesh_path_unhealthy",
"exit_unhealthy",
"mesh_route_latency_regression",
"mesh_route_loss_regression",
@@ -5199,12 +5462,30 @@ func enrichVPNClientFabricRoute(item VPNClientConnection, preferredEntryNodeID,
return out
}
func jsonObjectFromRaw(raw json.RawMessage) map[string]any {
var out map[string]any
if len(raw) == 0 || json.Unmarshal(raw, &out) != nil || out == nil {
return map[string]any{}
}
return out
}
func vpnFabricRouteCandidates(entryPool, exitPool []string, selectedEntry, selectedExit string) []map[string]any {
type pair struct {
entry string
exit string
}
pairs := make([]pair, 0, len(entryPool)*len(exitPool)+1)
if len(entryPool) == 0 && selectedExit != "" {
pairs = append(pairs, pair{exit: selectedExit})
}
if len(entryPool) == 0 {
for _, exit := range exitPool {
if exit != "" {
pairs = append(pairs, pair{exit: exit})
}
}
}
if selectedEntry != "" && selectedExit != "" {
pairs = append(pairs, pair{entry: selectedEntry, exit: selectedExit})
}
@@ -5219,6 +5500,9 @@ func vpnFabricRouteCandidates(entryPool, exitPool []string, selectedEntry, selec
seen := map[string]struct{}{}
out := make([]map[string]any, 0, len(pairs))
for _, pair := range pairs {
if pair.exit == "" {
continue
}
key := pair.entry + "\x00" + pair.exit
if _, ok := seen[key]; ok {
continue
@@ -5226,17 +5510,22 @@ func vpnFabricRouteCandidates(entryPool, exitPool []string, selectedEntry, selec
seen[key] = struct{}{}
priority := len(out) + 1
role := "alternate"
if pair.entry == selectedEntry && pair.exit == selectedExit {
if pair.exit == selectedExit && (pair.entry == selectedEntry || selectedEntry == "") {
role = "preferred"
priority = 0
}
out = append(out, map[string]any{
"entry_node_id": pair.entry,
"exit_node_id": pair.exit,
"role": role,
"priority": priority,
"status": "candidate",
})
candidate := map[string]any{
"exit_node_id": pair.exit,
"role": role,
"priority": priority,
"status": "candidate",
"source_role": "vpn-client",
"route_scope": "client_node_to_exit_pool",
}
if pair.entry != "" {
candidate["entry_node_id"] = pair.entry
}
out = append(out, candidate)
}
return out
}