148 lines
4.4 KiB
Go
148 lines
4.4 KiB
Go
package mesh
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type VPNPacketBatchPayload struct {
|
|
SchemaVersion string `json:"schema_version"`
|
|
VPNConnectionID string `json:"vpn_connection_id"`
|
|
Direction string `json:"direction"`
|
|
Packets [][]byte `json:"packets"`
|
|
SentAt time.Time `json:"sent_at"`
|
|
}
|
|
|
|
type ProductionVPNPacketEnvelopeInput struct {
|
|
MessageID string
|
|
RouteID string
|
|
ClusterID string
|
|
SourceNodeID string
|
|
DestinationNodeID string
|
|
CurrentHopNodeID string
|
|
NextHopNodeID string
|
|
RoutePath []string
|
|
TTL int
|
|
HopCount int
|
|
ExpiresAt time.Time
|
|
VPNConnectionID string
|
|
Direction string
|
|
Packets [][]byte
|
|
Now time.Time
|
|
}
|
|
|
|
func NewProductionVPNPacketBatchEnvelope(input ProductionVPNPacketEnvelopeInput) (ProductionEnvelope, error) {
|
|
now := input.Now.UTC()
|
|
if now.IsZero() {
|
|
now = time.Now().UTC()
|
|
}
|
|
packets := cleanProductionVPNPacketBatch(input.Packets)
|
|
if len(packets) == 0 {
|
|
return ProductionEnvelope{}, fmt.Errorf("%w: empty vpn packet batch", ErrForwardEnvelopeInvalid)
|
|
}
|
|
if input.MessageID == "" {
|
|
input.MessageID = fmt.Sprintf("vpnpkt-%d", now.UnixNano())
|
|
}
|
|
if input.TTL <= 0 {
|
|
input.TTL = 8
|
|
}
|
|
if input.ExpiresAt.IsZero() {
|
|
input.ExpiresAt = now.Add(15 * time.Second)
|
|
}
|
|
payload, err := json.Marshal(VPNPacketBatchPayload{
|
|
SchemaVersion: "rap.vpn_packet_batch.v1",
|
|
VPNConnectionID: input.VPNConnectionID,
|
|
Direction: input.Direction,
|
|
Packets: packets,
|
|
SentAt: now,
|
|
})
|
|
if err != nil {
|
|
return ProductionEnvelope{}, err
|
|
}
|
|
if len(payload) > MaxProductionVPNPacketPayloadBytes {
|
|
return ProductionEnvelope{}, fmt.Errorf("%w: vpn packet batch exceeds channel limit", ErrForwardEnvelopeInvalid)
|
|
}
|
|
sum := sha256.Sum256(payload)
|
|
return ProductionEnvelope{
|
|
FabricProtocolVersion: ProtocolVersion,
|
|
MessageID: input.MessageID,
|
|
RouteID: input.RouteID,
|
|
ClusterID: input.ClusterID,
|
|
SourceNodeID: input.SourceNodeID,
|
|
DestinationNodeID: input.DestinationNodeID,
|
|
CurrentHopNodeID: input.CurrentHopNodeID,
|
|
NextHopNodeID: input.NextHopNodeID,
|
|
RoutePath: append([]string{}, input.RoutePath...),
|
|
ChannelClass: ProductionChannelVPNPacket,
|
|
MessageType: ProductionMessageVPNPacketBatch,
|
|
TTL: input.TTL,
|
|
HopCount: input.HopCount,
|
|
CreatedAt: now,
|
|
ExpiresAt: input.ExpiresAt.UTC(),
|
|
PayloadLength: len(payload),
|
|
PayloadHash: hex.EncodeToString(sum[:]),
|
|
Payload: payload,
|
|
}, nil
|
|
}
|
|
|
|
func DecodeProductionVPNPacketBatch(envelope ProductionEnvelope) (VPNPacketBatchPayload, error) {
|
|
if envelope.ChannelClass != ProductionChannelVPNPacket || envelope.MessageType != ProductionMessageVPNPacketBatch {
|
|
return VPNPacketBatchPayload{}, ErrUnauthorizedChannel
|
|
}
|
|
var payload VPNPacketBatchPayload
|
|
if err := json.Unmarshal(envelope.Payload, &payload); err != nil {
|
|
return VPNPacketBatchPayload{}, err
|
|
}
|
|
if payload.SchemaVersion != "rap.vpn_packet_batch.v1" || payload.VPNConnectionID == "" {
|
|
return VPNPacketBatchPayload{}, fmt.Errorf("%w: invalid vpn packet batch payload", ErrForwardEnvelopeInvalid)
|
|
}
|
|
payload.Packets = cleanProductionVPNPacketBatch(payload.Packets)
|
|
if len(payload.Packets) == 0 {
|
|
return VPNPacketBatchPayload{}, fmt.Errorf("%w: empty vpn packet batch payload", ErrForwardEnvelopeInvalid)
|
|
}
|
|
return payload, nil
|
|
}
|
|
|
|
func cleanProductionVPNPacketBatch(packets [][]byte) [][]byte {
|
|
if len(packets) == 0 {
|
|
return nil
|
|
}
|
|
cleaned := make([][]byte, 0, len(packets))
|
|
for _, packet := range packets {
|
|
if len(packet) == 0 {
|
|
continue
|
|
}
|
|
cleaned = append(cleaned, append([]byte(nil), packet...))
|
|
}
|
|
return cleaned
|
|
}
|
|
|
|
func inferVPNPacketTrafficClass(explicit string, packets [][]byte) string {
|
|
explicit = strings.TrimSpace(strings.ToLower(explicit))
|
|
if explicit != "" && explicit != "bulk" {
|
|
return explicit
|
|
}
|
|
for _, packet := range packets {
|
|
if isVPNPacketTCPControl(packet) {
|
|
return "interactive"
|
|
}
|
|
}
|
|
return explicit
|
|
}
|
|
|
|
func isVPNPacketTCPControl(packet []byte) bool {
|
|
if len(packet) < 20 || packet[0]>>4 != 4 {
|
|
return false
|
|
}
|
|
ihl := int(packet[0]&0x0f) * 4
|
|
if ihl < 20 || len(packet) < ihl+20 || packet[9] != 6 {
|
|
return false
|
|
}
|
|
flags := packet[ihl+13]
|
|
return flags&0x17 != 0
|
|
}
|