292 lines
9.9 KiB
Go
292 lines
9.9 KiB
Go
package mesh
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"time"
|
|
)
|
|
|
|
type ProductionEnvelopeObserver func(context.Context, ProductionEnvelopeObservation) error
|
|
type ProductionForwardLogger func(ProductionForwardLogEntry)
|
|
|
|
type Server struct {
|
|
Local PeerIdentity
|
|
SyntheticRuntime *SyntheticRuntime
|
|
ProductionForwardingEnabled bool
|
|
ProductionEnvelopeObserver ProductionEnvelopeObserver
|
|
ProductionForwardTransport ProductionForwardTransport
|
|
ProductionForwardLogger ProductionForwardLogger
|
|
ProductionRoutes []SyntheticRoute
|
|
}
|
|
|
|
func (s Server) Handler() http.Handler {
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/mesh/v1/health", s.handleHealth)
|
|
mux.HandleFunc("/mesh/v1/forward", s.handleForward)
|
|
mux.HandleFunc("/mesh/v1/synthetic/probe", s.handleSyntheticProbe)
|
|
return mux
|
|
}
|
|
|
|
func (s Server) handleHealth(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
var message HealthMessage
|
|
if err := json.NewDecoder(r.Body).Decode(&message); err != nil {
|
|
http.Error(w, "invalid health message", http.StatusBadRequest)
|
|
return
|
|
}
|
|
if message.ProtocolVersion != ProtocolVersion {
|
|
http.Error(w, "unsupported mesh protocol version", http.StatusBadRequest)
|
|
return
|
|
}
|
|
if err := ValidatePeer(s.Local, message.From); err != nil {
|
|
http.Error(w, err.Error(), http.StatusForbidden)
|
|
return
|
|
}
|
|
if message.To.NodeID != "" && message.To.NodeID != s.Local.NodeID {
|
|
http.Error(w, ErrNodeMismatch.Error(), http.StatusForbidden)
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_ = json.NewEncoder(w).Encode(HealthAck{
|
|
ProtocolVersion: ProtocolVersion,
|
|
Accepted: true,
|
|
By: s.Local,
|
|
})
|
|
}
|
|
|
|
func (s Server) handleForward(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
if !s.ProductionForwardingEnabled {
|
|
s.logProductionForward(ProductionForwardLogEntry{
|
|
Event: "production_forward_rejected",
|
|
ClusterID: s.Local.ClusterID,
|
|
LocalNodeID: s.Local.NodeID,
|
|
Reason: ErrForwardDisabled.Error(),
|
|
StatusCode: http.StatusNotImplemented,
|
|
OccurredAt: time.Now().UTC(),
|
|
})
|
|
http.Error(w, ErrForwardDisabled.Error(), http.StatusNotImplemented)
|
|
return
|
|
}
|
|
var envelope ProductionEnvelope
|
|
if err := json.NewDecoder(r.Body).Decode(&envelope); err != nil {
|
|
s.logProductionForward(ProductionForwardLogEntry{
|
|
Event: "production_forward_rejected",
|
|
ClusterID: s.Local.ClusterID,
|
|
LocalNodeID: s.Local.NodeID,
|
|
Reason: "invalid production mesh envelope",
|
|
StatusCode: http.StatusBadRequest,
|
|
OccurredAt: time.Now().UTC(),
|
|
})
|
|
http.Error(w, "invalid production mesh envelope", http.StatusBadRequest)
|
|
return
|
|
}
|
|
if err := ValidateProductionEnvelope(s.Local, envelope, time.Now().UTC()); err != nil {
|
|
s.rejectProductionForward(w, envelope, err, forwardStatusCode(err))
|
|
return
|
|
}
|
|
if err := ValidateProductionEnvelopeRouteConfig(s.Local, envelope, s.ProductionRoutes, time.Now().UTC()); err != nil {
|
|
s.rejectProductionForward(w, envelope, err, forwardStatusCode(err))
|
|
return
|
|
}
|
|
s.logProductionForward(productionForwardLogEntry("production_forward_accepted", s.Local, envelope, "", 0))
|
|
if s.ProductionEnvelopeObserver != nil {
|
|
observation := NewProductionEnvelopeObservation(envelope, time.Now().UTC())
|
|
if err := observeProductionEnvelope(r.Context(), s.ProductionEnvelopeObserver, observation); err != nil {
|
|
s.logProductionForward(productionForwardLogEntry("production_forward_rejected", s.Local, envelope, ErrForwardObservationFailed.Error(), http.StatusInternalServerError))
|
|
http.Error(w, ErrForwardObservationFailed.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
if envelope.DestinationNodeID == s.Local.NodeID {
|
|
s.logProductionForward(productionForwardLogEntry("production_forward_delivered", s.Local, envelope, "", http.StatusOK))
|
|
writeProductionForwardResult(w, ProductionForwardResult{
|
|
Accepted: true,
|
|
Delivered: true,
|
|
By: s.Local,
|
|
MessageID: envelope.MessageID,
|
|
RouteID: envelope.RouteID,
|
|
})
|
|
return
|
|
}
|
|
if envelope.NextHopNodeID == s.Local.NodeID {
|
|
s.rejectProductionForward(w, envelope, ErrLoopDetected, forwardStatusCode(ErrLoopDetected))
|
|
return
|
|
}
|
|
if len(envelope.RoutePath) == 0 && envelope.NextHopNodeID != envelope.DestinationNodeID {
|
|
s.rejectProductionForward(w, envelope, ErrForwardRuntimeUnavailable, http.StatusNotImplemented)
|
|
return
|
|
}
|
|
if s.ProductionForwardTransport == nil {
|
|
s.rejectProductionForward(w, envelope, ErrForwardRuntimeUnavailable, http.StatusNotImplemented)
|
|
return
|
|
}
|
|
if envelope.TTL <= 1 {
|
|
s.rejectProductionForward(w, envelope, ErrTTLExhausted, forwardStatusCode(ErrTTLExhausted))
|
|
return
|
|
}
|
|
forwarded := envelope
|
|
forwarded.CurrentHopNodeID = envelope.NextHopNodeID
|
|
forwarded.NextHopNodeID = nextProductionHopAfter(envelope.RoutePath, envelope.NextHopNodeID, envelope.DestinationNodeID)
|
|
forwarded.TTL = envelope.TTL - 1
|
|
forwarded.HopCount = envelope.HopCount + 1
|
|
forwarded.VisitedNodeIDs = append(append([]string{}, envelope.VisitedNodeIDs...), s.Local.NodeID)
|
|
result, err := s.ProductionForwardTransport.SendProduction(r.Context(), envelope.NextHopNodeID, forwarded)
|
|
if err != nil {
|
|
s.rejectProductionForward(w, envelope, err, forwardStatusCode(err))
|
|
return
|
|
}
|
|
s.logProductionForward(productionForwardLogEntry("production_forward_forwarded", s.Local, envelope, "", http.StatusOK))
|
|
result.Accepted = true
|
|
result.Forwarded = true
|
|
result.By = s.Local
|
|
result.MessageID = envelope.MessageID
|
|
result.RouteID = envelope.RouteID
|
|
result.NextNodeID = envelope.NextHopNodeID
|
|
writeProductionForwardResult(w, result)
|
|
}
|
|
|
|
func (s Server) rejectProductionForward(w http.ResponseWriter, envelope ProductionEnvelope, err error, statusCode int) {
|
|
s.logProductionForward(productionForwardLogEntry("production_forward_rejected", s.Local, envelope, err.Error(), statusCode))
|
|
http.Error(w, err.Error(), statusCode)
|
|
}
|
|
|
|
func (s Server) logProductionForward(entry ProductionForwardLogEntry) {
|
|
if s.ProductionForwardLogger == nil {
|
|
return
|
|
}
|
|
if entry.OccurredAt.IsZero() {
|
|
entry.OccurredAt = time.Now().UTC()
|
|
}
|
|
s.ProductionForwardLogger(entry)
|
|
}
|
|
|
|
func productionForwardLogEntry(event string, local PeerIdentity, envelope ProductionEnvelope, reason string, statusCode int) ProductionForwardLogEntry {
|
|
return ProductionForwardLogEntry{
|
|
Event: event,
|
|
RouteID: envelope.RouteID,
|
|
MessageID: envelope.MessageID,
|
|
ClusterID: envelope.ClusterID,
|
|
LocalNodeID: local.NodeID,
|
|
SourceNodeID: envelope.SourceNodeID,
|
|
DestinationNodeID: envelope.DestinationNodeID,
|
|
CurrentHopNodeID: envelope.CurrentHopNodeID,
|
|
NextHopNodeID: envelope.NextHopNodeID,
|
|
ChannelClass: envelope.ChannelClass,
|
|
MessageType: envelope.MessageType,
|
|
Reason: reason,
|
|
StatusCode: statusCode,
|
|
TTL: envelope.TTL,
|
|
HopCount: envelope.HopCount,
|
|
RoutePathLength: len(envelope.RoutePath),
|
|
VisitedCount: len(envelope.VisitedNodeIDs),
|
|
PayloadLength: envelope.PayloadLength,
|
|
OccurredAt: time.Now().UTC(),
|
|
}
|
|
}
|
|
|
|
func nextProductionHopAfter(routePath []string, currentNodeID string, destinationNodeID string) string {
|
|
if len(routePath) == 0 {
|
|
return destinationNodeID
|
|
}
|
|
for index, nodeID := range routePath {
|
|
if nodeID == currentNodeID {
|
|
if index >= len(routePath)-1 {
|
|
return currentNodeID
|
|
}
|
|
return routePath[index+1]
|
|
}
|
|
}
|
|
return destinationNodeID
|
|
}
|
|
|
|
func writeProductionForwardResult(w http.ResponseWriter, result ProductionForwardResult) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_ = json.NewEncoder(w).Encode(result)
|
|
}
|
|
|
|
func observeProductionEnvelope(ctx context.Context, observer ProductionEnvelopeObserver, observation ProductionEnvelopeObservation) (err error) {
|
|
if observer == nil {
|
|
return nil
|
|
}
|
|
defer func() {
|
|
if recover() != nil {
|
|
err = ErrForwardObservationFailed
|
|
}
|
|
}()
|
|
return observer(ctx, observation)
|
|
}
|
|
|
|
func (s Server) handleSyntheticProbe(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
if s.SyntheticRuntime == nil {
|
|
http.Error(w, ErrMeshRuntimeDisabled.Error(), http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
var envelope SyntheticEnvelope
|
|
if err := json.NewDecoder(r.Body).Decode(&envelope); err != nil {
|
|
http.Error(w, "invalid synthetic mesh envelope", http.StatusBadRequest)
|
|
return
|
|
}
|
|
ack, err := s.SyntheticRuntime.Receive(r.Context(), envelope)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), syntheticStatusCode(err))
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_ = json.NewEncoder(w).Encode(ack)
|
|
}
|
|
|
|
func NewHealthMessage(from, to PeerIdentity) HealthMessage {
|
|
status := "reachable"
|
|
return HealthMessage{
|
|
ProtocolVersion: ProtocolVersion,
|
|
From: from,
|
|
To: to,
|
|
ObservedAt: time.Now().UTC(),
|
|
LinkStatus: status,
|
|
}
|
|
}
|
|
|
|
func syntheticStatusCode(err error) int {
|
|
switch err {
|
|
case ErrClusterMismatch, ErrNodeMismatch, ErrUnauthorizedChannel, ErrLoopDetected:
|
|
return http.StatusForbidden
|
|
case ErrMeshRuntimeDisabled:
|
|
return http.StatusServiceUnavailable
|
|
case ErrRouteExpired, ErrTTLExhausted, ErrInvalidRoutePath, ErrUnsupportedSyntheticMessage, ErrRouteIDRequired:
|
|
return http.StatusBadRequest
|
|
case ErrRouteNotFound, ErrSyntheticPeerUnavailable:
|
|
return http.StatusNotFound
|
|
default:
|
|
return http.StatusBadRequest
|
|
}
|
|
}
|
|
|
|
func forwardStatusCode(err error) int {
|
|
switch err {
|
|
case ErrClusterMismatch, ErrNodeMismatch, ErrUnauthorizedChannel, ErrLoopDetected:
|
|
return http.StatusForbidden
|
|
case ErrRouteExpired, ErrTTLExhausted, ErrInvalidRoutePath, ErrRouteIDRequired:
|
|
return http.StatusBadRequest
|
|
case ErrForwardRuntimeUnavailable:
|
|
return http.StatusNotImplemented
|
|
case ErrRouteNotFound:
|
|
return http.StatusNotFound
|
|
case ErrForwardPeerUnavailable:
|
|
return http.StatusBadGateway
|
|
default:
|
|
return http.StatusBadRequest
|
|
}
|
|
}
|