package vpnruntime import ( "context" "encoding/binary" "encoding/json" "errors" "fmt" "time" "github.com/example/remote-access-platform/agents/rap-node-agent/internal/fabricproto" "github.com/example/remote-access-platform/agents/rap-node-agent/internal/mesh" ) const ( fabricVPNPacketPayloadMagic uint32 = 0x52565042 // RVPB fabricVPNPacketPayloadVersion uint8 = 2 fabricVPNPacketPayloadHeader = 24 fabricVPNPacketMaxPacketCount = 2048 fabricVPNPacketMaxMetadataBytes = 64 * 1024 fabricVPNPacketDirectionClientToGateway uint8 = 1 fabricVPNPacketDirectionGatewayToClient uint8 = 2 ) var ( ErrFabricVPNPacketFrameInvalid = errors.New("invalid fabric vpn packet frame") ErrFabricVPNPacketPayload = errors.New("invalid fabric vpn packet payload") ) type FabricVPNPacketFrameInput struct { StreamID uint64 Sequence uint64 VPNConnectionID string Direction string TrafficClass string ServiceTunnel FabricServiceTunnel Packets [][]byte Now time.Time } func NewFabricVPNPacketDataFrame(input FabricVPNPacketFrameInput) (fabricproto.Frame, error) { if input.StreamID == 0 { return fabricproto.Frame{}, fmt.Errorf("%w: missing stream id", ErrFabricVPNPacketFrameInvalid) } if input.VPNConnectionID == "" || input.Direction == "" { return fabricproto.Frame{}, fmt.Errorf("%w: missing vpn identity", ErrFabricVPNPacketFrameInvalid) } packets := cleanPacketBatch(input.Packets) if len(packets) == 0 { return fabricproto.Frame{}, fmt.Errorf("%w: empty packet batch", ErrFabricVPNPacketFrameInvalid) } payload, err := encodeFabricVPNPacketPayload(input, packets) if err != nil { return fabricproto.Frame{}, err } return fabricproto.Frame{ Type: fabricproto.FrameData, TrafficClass: fabricFrameTrafficClass(input.TrafficClass, packets), StreamID: input.StreamID, Sequence: input.Sequence, Payload: payload, }, nil } func NewFabricVPNSessionHelloFrame(input FabricVPNPacketFrameInput) (fabricproto.Frame, error) { if input.StreamID == 0 { return fabricproto.Frame{}, fmt.Errorf("%w: missing stream id", ErrFabricVPNPacketFrameInvalid) } if input.VPNConnectionID == "" || input.Direction == "" { return fabricproto.Frame{}, fmt.Errorf("%w: missing vpn identity", ErrFabricVPNPacketFrameInvalid) } payload, err := encodeFabricVPNPacketPayload(input, nil) if err != nil { return fabricproto.Frame{}, err } return fabricproto.Frame{ Type: fabricproto.FrameData, TrafficClass: fabricFrameTrafficClass(input.TrafficClass, nil), StreamID: input.StreamID, Sequence: input.Sequence, Payload: payload, }, nil } func DecodeFabricVPNPacketDataFrame(frame fabricproto.Frame) (mesh.VPNPacketBatchPayload, error) { if frame.Type != fabricproto.FrameData || frame.StreamID == 0 { return mesh.VPNPacketBatchPayload{}, fmt.Errorf("%w: expected DATA stream frame", ErrFabricVPNPacketFrameInvalid) } return decodeFabricVPNPacketPayload(frame.Payload) } func (i *FabricPacketInbox) DeliverFabricSessionFrame(_ context.Context, frame fabricproto.Frame) error { if i == nil { return mesh.ErrForwardRuntimeUnavailable } payload, err := DecodeFabricVPNPacketDataFrame(frame) if err != nil { return err } payload.Packets = cleanPacketBatch(payload.Packets) if len(payload.Packets) == 0 { return nil } return i.enqueue(payload) } func encodeFabricVPNPacketPayload(input FabricVPNPacketFrameInput, packets [][]byte) ([]byte, error) { if len(packets) > fabricVPNPacketMaxPacketCount { return nil, fmt.Errorf("%w: packet count %d > %d", ErrFabricVPNPacketPayload, len(packets), fabricVPNPacketMaxPacketCount) } directionCode, err := fabricVPNPacketDirectionCode(input.Direction) if err != nil { return nil, err } vpnID := []byte(input.VPNConnectionID) if len(vpnID) > 0xffff { return nil, fmt.Errorf("%w: vpn connection id too long", ErrFabricVPNPacketPayload) } var metadata []byte if len(packets) == 0 { var err error metadata, err = encodeFabricVPNPacketServiceMetadata(input) if err != nil { return nil, err } } now := input.Now.UTC() if now.IsZero() { now = time.Now().UTC() } total := fabricVPNPacketPayloadHeader + len(vpnID) + len(metadata) for _, packet := range packets { total += 4 + len(packet) } out := make([]byte, total) binary.BigEndian.PutUint32(out[0:4], fabricVPNPacketPayloadMagic) out[4] = fabricVPNPacketPayloadVersion out[5] = directionCode binary.BigEndian.PutUint16(out[6:8], uint16(len(packets))) binary.BigEndian.PutUint16(out[8:10], uint16(len(vpnID))) binary.BigEndian.PutUint16(out[10:12], uint16(len(metadata))) binary.BigEndian.PutUint64(out[12:20], uint64(now.UnixNano())) offset := fabricVPNPacketPayloadHeader copy(out[offset:], vpnID) offset += len(vpnID) copy(out[offset:], metadata) offset += len(metadata) for _, packet := range packets { binary.BigEndian.PutUint32(out[offset:offset+4], uint32(len(packet))) offset += 4 copy(out[offset:], packet) offset += len(packet) } return out, nil } func decodeFabricVPNPacketPayload(payload []byte) (mesh.VPNPacketBatchPayload, error) { if len(payload) < fabricVPNPacketPayloadHeader { return mesh.VPNPacketBatchPayload{}, fmt.Errorf("%w: short payload", ErrFabricVPNPacketPayload) } if binary.BigEndian.Uint32(payload[0:4]) != fabricVPNPacketPayloadMagic { return mesh.VPNPacketBatchPayload{}, fmt.Errorf("%w: bad magic", ErrFabricVPNPacketPayload) } version := payload[4] if version != 1 && version != fabricVPNPacketPayloadVersion { return mesh.VPNPacketBatchPayload{}, fmt.Errorf("%w: unsupported version %d", ErrFabricVPNPacketPayload, payload[4]) } direction, err := fabricVPNPacketDirectionName(payload[5]) if err != nil { return mesh.VPNPacketBatchPayload{}, err } packetCount := int(binary.BigEndian.Uint16(payload[6:8])) vpnIDLength := int(binary.BigEndian.Uint16(payload[8:10])) metadataLength := 0 if version >= 2 { metadataLength = int(binary.BigEndian.Uint16(payload[10:12])) } if packetCount < 0 || packetCount > fabricVPNPacketMaxPacketCount { return mesh.VPNPacketBatchPayload{}, fmt.Errorf("%w: invalid packet count %d", ErrFabricVPNPacketPayload, packetCount) } offset := fabricVPNPacketPayloadHeader if len(payload) < offset+vpnIDLength { return mesh.VPNPacketBatchPayload{}, fmt.Errorf("%w: truncated vpn id", ErrFabricVPNPacketPayload) } vpnID := string(payload[offset : offset+vpnIDLength]) offset += vpnIDLength if vpnID == "" { return mesh.VPNPacketBatchPayload{}, fmt.Errorf("%w: empty vpn id", ErrFabricVPNPacketPayload) } metadata := fabricVPNPacketServiceMetadata{} if metadataLength > 0 { if metadataLength > fabricVPNPacketMaxMetadataBytes || len(payload) < offset+metadataLength { return mesh.VPNPacketBatchPayload{}, fmt.Errorf("%w: truncated service metadata", ErrFabricVPNPacketPayload) } if err := json.Unmarshal(payload[offset:offset+metadataLength], &metadata); err != nil { return mesh.VPNPacketBatchPayload{}, fmt.Errorf("%w: invalid service metadata: %v", ErrFabricVPNPacketPayload, err) } offset += metadataLength } packets := make([][]byte, 0, packetCount) for index := 0; index < packetCount; index++ { if len(payload) < offset+4 { return mesh.VPNPacketBatchPayload{}, fmt.Errorf("%w: truncated packet length", ErrFabricVPNPacketPayload) } packetLength := int(binary.BigEndian.Uint32(payload[offset : offset+4])) offset += 4 if packetLength <= 0 || len(payload) < offset+packetLength { return mesh.VPNPacketBatchPayload{}, fmt.Errorf("%w: truncated packet", ErrFabricVPNPacketPayload) } packets = append(packets, append([]byte(nil), payload[offset:offset+packetLength]...)) offset += packetLength } if offset != len(payload) { return mesh.VPNPacketBatchPayload{}, fmt.Errorf("%w: trailing bytes", ErrFabricVPNPacketPayload) } sentAt := time.Unix(0, int64(binary.BigEndian.Uint64(payload[12:20]))).UTC() return mesh.VPNPacketBatchPayload{ SchemaVersion: "rap.vpn_packet_batch.fabric.v1", VPNConnectionID: vpnID, TunnelID: firstNonEmptyTunnelString(metadata.TunnelID, vpnID), PoolID: metadata.PoolID, ServiceID: metadata.ServiceID, LocalServiceID: metadata.LocalServiceID, RemoteServiceID: metadata.RemoteServiceID, ServiceKind: metadata.ServiceKind, ServiceClass: metadata.ServiceClass, ServiceRole: metadata.ServiceRole, RouteLeaseID: metadata.RouteLeaseID, RouteGeneration: metadata.RouteGeneration, DataPlane: metadata.DataPlane, TransportOwner: metadata.TransportOwner, RouteVisibility: metadata.RouteVisibility, TrafficClasses: metadata.TrafficClasses, StreamShards: metadata.StreamShards, Direction: direction, Packets: packets, SentAt: sentAt, }, nil } type fabricVPNPacketServiceMetadata struct { TunnelID string `json:"tunnel_id,omitempty"` PoolID string `json:"pool_id,omitempty"` ServiceID string `json:"service_id,omitempty"` LocalServiceID string `json:"local_service_id,omitempty"` RemoteServiceID string `json:"remote_service_id,omitempty"` ServiceKind string `json:"service_kind,omitempty"` ServiceClass string `json:"service_class,omitempty"` ServiceRole string `json:"service_role,omitempty"` RouteLeaseID string `json:"route_lease_id,omitempty"` RouteGeneration string `json:"route_generation,omitempty"` DataPlane string `json:"data_plane,omitempty"` TransportOwner string `json:"transport_owner,omitempty"` RouteVisibility string `json:"route_visibility,omitempty"` TrafficClasses []string `json:"traffic_classes,omitempty"` StreamShards int `json:"stream_shards,omitempty"` } func encodeFabricVPNPacketServiceMetadata(input FabricVPNPacketFrameInput) ([]byte, error) { tunnel := NormalizeServiceTunnel(input.ServiceTunnel, input.VPNConnectionID) metadata := fabricVPNPacketServiceMetadata{ TunnelID: firstNonEmptyTunnelString(tunnel.TunnelID, input.VPNConnectionID), PoolID: tunnel.PoolID, ServiceID: tunnel.ServiceID, LocalServiceID: tunnel.LocalServiceID, RemoteServiceID: tunnel.RemoteServiceID, ServiceKind: tunnel.ServiceKind, ServiceClass: tunnel.ServiceClass, ServiceRole: tunnel.ServiceRole, RouteLeaseID: tunnel.RouteLeaseID, RouteGeneration: tunnel.RouteGeneration, DataPlane: tunnel.DataPlane, TransportOwner: tunnel.TransportOwner, RouteVisibility: tunnel.RouteVisibility, TrafficClasses: append([]string(nil), tunnel.TrafficClasses...), StreamShards: tunnel.StreamShards, } payload, err := json.Marshal(metadata) if err != nil { return nil, err } if len(payload) > fabricVPNPacketMaxMetadataBytes || len(payload) > 0xffff { return nil, fmt.Errorf("%w: service metadata too large", ErrFabricVPNPacketPayload) } return payload, nil } func fabricVPNPacketDirectionCode(direction string) (uint8, error) { switch direction { case FabricDirectionClientToGateway: return fabricVPNPacketDirectionClientToGateway, nil case FabricDirectionGatewayToClient: return fabricVPNPacketDirectionGatewayToClient, nil default: return 0, fmt.Errorf("%w: unknown direction %q", ErrFabricVPNPacketPayload, direction) } } func fabricVPNPacketDirectionName(direction uint8) (string, error) { switch direction { case fabricVPNPacketDirectionClientToGateway: return FabricDirectionClientToGateway, nil case fabricVPNPacketDirectionGatewayToClient: return FabricDirectionGatewayToClient, nil default: return "", fmt.Errorf("%w: unknown direction %d", ErrFabricVPNPacketPayload, direction) } } func fabricFrameTrafficClass(trafficClass string, packets [][]byte) fabricproto.TrafficClass { switch normalizeFabricTrafficClass(trafficClass) { case FabricTrafficClassControl: return fabricproto.TrafficClassControl case FabricTrafficClassDNS: return fabricproto.TrafficClassReliable case FabricTrafficClassInteractive: return fabricproto.TrafficClassInteractive case FabricTrafficClassReliable: return fabricproto.TrafficClassReliable case FabricTrafficClassDroppable: return fabricproto.TrafficClassDroppable default: return fabricproto.TrafficClassBulk } }