package vpnruntime import ( "context" "encoding/binary" "errors" "fmt" "time" "github.com/example/remote-access-platform/agents/rap-node-agent/internal/fabricproto" "github.com/example/remote-access-platform/agents/rap-node-agent/internal/mesh" ) const ( fabricVPNPacketPayloadMagic uint32 = 0x52565042 // RVPB fabricVPNPacketPayloadVersion uint8 = 1 fabricVPNPacketPayloadHeader = 24 fabricVPNPacketMaxPacketCount = 2048 fabricVPNPacketDirectionClientToGateway uint8 = 1 fabricVPNPacketDirectionGatewayToClient uint8 = 2 ) var ( ErrFabricVPNPacketFrameInvalid = errors.New("invalid fabric vpn packet frame") ErrFabricVPNPacketPayload = errors.New("invalid fabric vpn packet payload") ) type FabricVPNPacketFrameInput struct { StreamID uint64 Sequence uint64 VPNConnectionID string Direction string TrafficClass string Packets [][]byte Now time.Time } func NewFabricVPNPacketDataFrame(input FabricVPNPacketFrameInput) (fabricproto.Frame, error) { if input.StreamID == 0 { return fabricproto.Frame{}, fmt.Errorf("%w: missing stream id", ErrFabricVPNPacketFrameInvalid) } if input.VPNConnectionID == "" || input.Direction == "" { return fabricproto.Frame{}, fmt.Errorf("%w: missing vpn identity", ErrFabricVPNPacketFrameInvalid) } packets := cleanPacketBatch(input.Packets) if len(packets) == 0 { return fabricproto.Frame{}, fmt.Errorf("%w: empty packet batch", ErrFabricVPNPacketFrameInvalid) } payload, err := encodeFabricVPNPacketPayload(input, packets) if err != nil { return fabricproto.Frame{}, err } return fabricproto.Frame{ Type: fabricproto.FrameData, TrafficClass: fabricFrameTrafficClass(input.TrafficClass, packets), StreamID: input.StreamID, Sequence: input.Sequence, Payload: payload, }, nil } func DecodeFabricVPNPacketDataFrame(frame fabricproto.Frame) (mesh.VPNPacketBatchPayload, error) { if frame.Type != fabricproto.FrameData || frame.StreamID == 0 { return mesh.VPNPacketBatchPayload{}, fmt.Errorf("%w: expected DATA stream frame", ErrFabricVPNPacketFrameInvalid) } return decodeFabricVPNPacketPayload(frame.Payload) } func (i *FabricPacketInbox) DeliverFabricSessionFrame(_ context.Context, frame fabricproto.Frame) error { if i == nil { return mesh.ErrForwardRuntimeUnavailable } payload, err := DecodeFabricVPNPacketDataFrame(frame) if err != nil { return err } payload.Packets = cleanPacketBatch(payload.Packets) if len(payload.Packets) == 0 { return nil } return i.enqueue(payload) } func encodeFabricVPNPacketPayload(input FabricVPNPacketFrameInput, packets [][]byte) ([]byte, error) { if len(packets) > fabricVPNPacketMaxPacketCount { return nil, fmt.Errorf("%w: packet count %d > %d", ErrFabricVPNPacketPayload, len(packets), fabricVPNPacketMaxPacketCount) } directionCode, err := fabricVPNPacketDirectionCode(input.Direction) if err != nil { return nil, err } vpnID := []byte(input.VPNConnectionID) if len(vpnID) > 0xffff { return nil, fmt.Errorf("%w: vpn connection id too long", ErrFabricVPNPacketPayload) } now := input.Now.UTC() if now.IsZero() { now = time.Now().UTC() } total := fabricVPNPacketPayloadHeader + len(vpnID) for _, packet := range packets { total += 4 + len(packet) } out := make([]byte, total) binary.BigEndian.PutUint32(out[0:4], fabricVPNPacketPayloadMagic) out[4] = fabricVPNPacketPayloadVersion out[5] = directionCode binary.BigEndian.PutUint16(out[6:8], uint16(len(packets))) binary.BigEndian.PutUint16(out[8:10], uint16(len(vpnID))) binary.BigEndian.PutUint64(out[12:20], uint64(now.UnixNano())) offset := fabricVPNPacketPayloadHeader copy(out[offset:], vpnID) offset += len(vpnID) for _, packet := range packets { binary.BigEndian.PutUint32(out[offset:offset+4], uint32(len(packet))) offset += 4 copy(out[offset:], packet) offset += len(packet) } return out, nil } func decodeFabricVPNPacketPayload(payload []byte) (mesh.VPNPacketBatchPayload, error) { if len(payload) < fabricVPNPacketPayloadHeader { return mesh.VPNPacketBatchPayload{}, fmt.Errorf("%w: short payload", ErrFabricVPNPacketPayload) } if binary.BigEndian.Uint32(payload[0:4]) != fabricVPNPacketPayloadMagic { return mesh.VPNPacketBatchPayload{}, fmt.Errorf("%w: bad magic", ErrFabricVPNPacketPayload) } if payload[4] != fabricVPNPacketPayloadVersion { return mesh.VPNPacketBatchPayload{}, fmt.Errorf("%w: unsupported version %d", ErrFabricVPNPacketPayload, payload[4]) } direction, err := fabricVPNPacketDirectionName(payload[5]) if err != nil { return mesh.VPNPacketBatchPayload{}, err } packetCount := int(binary.BigEndian.Uint16(payload[6:8])) vpnIDLength := int(binary.BigEndian.Uint16(payload[8:10])) if packetCount <= 0 || packetCount > fabricVPNPacketMaxPacketCount { return mesh.VPNPacketBatchPayload{}, fmt.Errorf("%w: invalid packet count %d", ErrFabricVPNPacketPayload, packetCount) } offset := fabricVPNPacketPayloadHeader if len(payload) < offset+vpnIDLength { return mesh.VPNPacketBatchPayload{}, fmt.Errorf("%w: truncated vpn id", ErrFabricVPNPacketPayload) } vpnID := string(payload[offset : offset+vpnIDLength]) offset += vpnIDLength if vpnID == "" { return mesh.VPNPacketBatchPayload{}, fmt.Errorf("%w: empty vpn id", ErrFabricVPNPacketPayload) } packets := make([][]byte, 0, packetCount) for index := 0; index < packetCount; index++ { if len(payload) < offset+4 { return mesh.VPNPacketBatchPayload{}, fmt.Errorf("%w: truncated packet length", ErrFabricVPNPacketPayload) } packetLength := int(binary.BigEndian.Uint32(payload[offset : offset+4])) offset += 4 if packetLength <= 0 || len(payload) < offset+packetLength { return mesh.VPNPacketBatchPayload{}, fmt.Errorf("%w: truncated packet", ErrFabricVPNPacketPayload) } packets = append(packets, append([]byte(nil), payload[offset:offset+packetLength]...)) offset += packetLength } if offset != len(payload) { return mesh.VPNPacketBatchPayload{}, fmt.Errorf("%w: trailing bytes", ErrFabricVPNPacketPayload) } sentAt := time.Unix(0, int64(binary.BigEndian.Uint64(payload[12:20]))).UTC() return mesh.VPNPacketBatchPayload{ SchemaVersion: "rap.vpn_packet_batch.fabric.v1", VPNConnectionID: vpnID, Direction: direction, Packets: packets, SentAt: sentAt, }, nil } func fabricVPNPacketDirectionCode(direction string) (uint8, error) { switch direction { case FabricDirectionClientToGateway: return fabricVPNPacketDirectionClientToGateway, nil case FabricDirectionGatewayToClient: return fabricVPNPacketDirectionGatewayToClient, nil default: return 0, fmt.Errorf("%w: unknown direction %q", ErrFabricVPNPacketPayload, direction) } } func fabricVPNPacketDirectionName(direction uint8) (string, error) { switch direction { case fabricVPNPacketDirectionClientToGateway: return FabricDirectionClientToGateway, nil case fabricVPNPacketDirectionGatewayToClient: return FabricDirectionGatewayToClient, nil default: return "", fmt.Errorf("%w: unknown direction %d", ErrFabricVPNPacketPayload, direction) } } func fabricFrameTrafficClass(trafficClass string, packets [][]byte) fabricproto.TrafficClass { switch normalizeFabricTrafficClass(trafficClass) { case FabricTrafficClassControl: return fabricproto.TrafficClassControl case FabricTrafficClassInteractive: return fabricproto.TrafficClassInteractive case FabricTrafficClassReliable: return fabricproto.TrafficClassReliable case FabricTrafficClassDroppable: return fabricproto.TrafficClassDroppable default: if batchHasTCPControlPacket(packets) { return fabricproto.TrafficClassInteractive } return fabricproto.TrafficClassBulk } }