158 lines
5.1 KiB
Go
158 lines
5.1 KiB
Go
package mesh
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"time"
|
|
)
|
|
|
|
func ValidateProductionEnvelope(local PeerIdentity, envelope ProductionEnvelope, now time.Time) error {
|
|
if envelope.FabricProtocolVersion != ProtocolVersion {
|
|
return fmt.Errorf("%w: unsupported fabric_protocol_version", ErrForwardEnvelopeInvalid)
|
|
}
|
|
if envelope.MessageID == "" {
|
|
return fmt.Errorf("%w: message_id is required", ErrForwardEnvelopeInvalid)
|
|
}
|
|
if envelope.RouteID == "" {
|
|
return fmt.Errorf("%w: route_id is required", ErrForwardEnvelopeInvalid)
|
|
}
|
|
if envelope.ClusterID == "" || envelope.ClusterID != local.ClusterID {
|
|
return ErrClusterMismatch
|
|
}
|
|
if envelope.SourceNodeID == "" || envelope.DestinationNodeID == "" {
|
|
return fmt.Errorf("%w: source_node_id and destination_node_id are required", ErrForwardEnvelopeInvalid)
|
|
}
|
|
if envelope.CurrentHopNodeID != local.NodeID {
|
|
return ErrNodeMismatch
|
|
}
|
|
if envelope.NextHopNodeID == "" {
|
|
return fmt.Errorf("%w: next_hop_node_id is required", ErrForwardEnvelopeInvalid)
|
|
}
|
|
if len(envelope.RoutePath) > 0 {
|
|
if err := validateProductionRoutePath(local, envelope); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
maxPayloadBytes := MaxProductionEnvelopePayloadBytes
|
|
switch envelope.ChannelClass {
|
|
case ProductionChannelFabricControl:
|
|
if envelope.MessageType != ProductionMessageFabricControl {
|
|
return fmt.Errorf("%w: unsupported message_type", ErrForwardEnvelopeInvalid)
|
|
}
|
|
case ProductionChannelVPNPacket:
|
|
if envelope.MessageType != ProductionMessageVPNPacketBatch {
|
|
return fmt.Errorf("%w: unsupported message_type", ErrForwardEnvelopeInvalid)
|
|
}
|
|
maxPayloadBytes = MaxProductionVPNPacketPayloadBytes
|
|
default:
|
|
return ErrUnauthorizedChannel
|
|
}
|
|
if envelope.TTL <= 0 {
|
|
return ErrTTLExhausted
|
|
}
|
|
if envelope.HopCount < 0 {
|
|
return fmt.Errorf("%w: hop_count must not be negative", ErrForwardEnvelopeInvalid)
|
|
}
|
|
if envelope.CreatedAt.IsZero() || envelope.ExpiresAt.IsZero() {
|
|
return fmt.Errorf("%w: created_at and expires_at are required", ErrForwardEnvelopeInvalid)
|
|
}
|
|
if envelope.CreatedAt.After(now.UTC().Add(MaxProductionEnvelopeFutureSkew)) {
|
|
return fmt.Errorf("%w: created_at exceeds allowed future skew", ErrForwardEnvelopeInvalid)
|
|
}
|
|
if !envelope.ExpiresAt.After(now.UTC()) {
|
|
return ErrRouteExpired
|
|
}
|
|
if envelope.PayloadLength != len(envelope.Payload) {
|
|
return fmt.Errorf("%w: payload_length mismatch", ErrForwardEnvelopeInvalid)
|
|
}
|
|
if envelope.PayloadLength > maxPayloadBytes {
|
|
return fmt.Errorf("%w: payload exceeds channel limit", ErrForwardEnvelopeInvalid)
|
|
}
|
|
if envelope.PayloadHash == "" {
|
|
return fmt.Errorf("%w: payload_hash is required", ErrForwardEnvelopeInvalid)
|
|
}
|
|
sum := sha256.Sum256(envelope.Payload)
|
|
if envelope.PayloadHash != hex.EncodeToString(sum[:]) {
|
|
return fmt.Errorf("%w: payload_hash mismatch", ErrForwardEnvelopeInvalid)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func validateProductionRoutePath(local PeerIdentity, envelope ProductionEnvelope) error {
|
|
if len(envelope.RoutePath) < 2 {
|
|
return ErrInvalidRoutePath
|
|
}
|
|
if envelope.RoutePath[0] != envelope.SourceNodeID || envelope.RoutePath[len(envelope.RoutePath)-1] != envelope.DestinationNodeID {
|
|
return ErrInvalidRoutePath
|
|
}
|
|
currentIndex := -1
|
|
seen := map[string]struct{}{}
|
|
for index, nodeID := range envelope.RoutePath {
|
|
if nodeID == "" {
|
|
return ErrInvalidRoutePath
|
|
}
|
|
if _, duplicate := seen[nodeID]; duplicate {
|
|
return ErrLoopDetected
|
|
}
|
|
seen[nodeID] = struct{}{}
|
|
if nodeID == local.NodeID {
|
|
currentIndex = index
|
|
}
|
|
}
|
|
if currentIndex < 0 || envelope.CurrentHopNodeID != local.NodeID {
|
|
return ErrNodeMismatch
|
|
}
|
|
if containsProductionNodeID(envelope.VisitedNodeIDs, local.NodeID) {
|
|
return ErrLoopDetected
|
|
}
|
|
for _, visitedNodeID := range envelope.VisitedNodeIDs {
|
|
if visitedNodeID == "" || !containsProductionNodeID(envelope.RoutePath, visitedNodeID) {
|
|
return ErrInvalidRoutePath
|
|
}
|
|
}
|
|
if envelope.DestinationNodeID == local.NodeID {
|
|
if envelope.NextHopNodeID != local.NodeID {
|
|
return ErrInvalidRoutePath
|
|
}
|
|
return nil
|
|
}
|
|
if currentIndex >= len(envelope.RoutePath)-1 {
|
|
return ErrInvalidRoutePath
|
|
}
|
|
if envelope.NextHopNodeID != envelope.RoutePath[currentIndex+1] {
|
|
return ErrInvalidRoutePath
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func containsProductionNodeID(values []string, needle string) bool {
|
|
for _, value := range values {
|
|
if value == needle {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func NewProductionEnvelopeObservation(envelope ProductionEnvelope, observedAt time.Time) ProductionEnvelopeObservation {
|
|
return ProductionEnvelopeObservation{
|
|
MessageID: envelope.MessageID,
|
|
RouteID: envelope.RouteID,
|
|
ClusterID: envelope.ClusterID,
|
|
SourceNodeID: envelope.SourceNodeID,
|
|
DestinationNodeID: envelope.DestinationNodeID,
|
|
CurrentHopNodeID: envelope.CurrentHopNodeID,
|
|
NextHopNodeID: envelope.NextHopNodeID,
|
|
RoutePath: append([]string{}, envelope.RoutePath...),
|
|
VisitedNodeIDs: append([]string{}, envelope.VisitedNodeIDs...),
|
|
ChannelClass: envelope.ChannelClass,
|
|
MessageType: envelope.MessageType,
|
|
TTL: envelope.TTL,
|
|
HopCount: envelope.HopCount,
|
|
PayloadLength: envelope.PayloadLength,
|
|
PayloadHash: envelope.PayloadHash,
|
|
ObservedAt: observedAt.UTC(),
|
|
}
|
|
}
|