Files

200 lines
7.1 KiB
Go

package main
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"flag"
"fmt"
"os"
"strings"
"time"
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/fabricproto"
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/mesh"
)
type smokeOutput struct {
OK bool `json:"ok"`
Endpoint string `json:"endpoint"`
EntryNodeID string `json:"entry_node_id"`
NextHopID string `json:"next_hop_node_id"`
RouteID string `json:"route_id"`
ElapsedMS int64 `json:"elapsed_ms"`
Result mesh.ProductionForwardResult `json:"result"`
Error string `json:"error,omitempty"`
EnvelopePath []string `json:"envelope_path,omitempty"`
}
type productionForwardResponse struct {
Result mesh.ProductionForwardResult `json:"result,omitempty"`
Error string `json:"error,omitempty"`
}
func main() {
var (
endpoint = flag.String("endpoint", "", "QUIC fabric endpoint for the entry node, for example quic://host:19131.")
peerCert = flag.String("peer-cert-sha256", "", "Expected entry node QUIC TLS certificate SHA-256 fingerprint.")
clusterID = flag.String("cluster-id", "", "Cluster ID.")
routeID = flag.String("route-id", "", "Configured production route ID.")
sourceNodeID = flag.String("source-node-id", "", "Route source node ID.")
destNodeID = flag.String("destination-node-id", "", "Route destination node ID.")
currentNodeID = flag.String("current-hop-node-id", "", "Current hop node ID expected by the entry node.")
nextHopNodeID = flag.String("next-hop-node-id", "", "Next hop node ID from the entry node.")
routePath = flag.String("route-path", "", "Comma-separated route path.")
channel = flag.String("channel", mesh.ProductionChannelFabricControl, "Production channel class.")
timeout = flag.Duration("timeout", 10*time.Second, "Smoke request timeout.")
payloadText = flag.String("payload", `{"kind":"fabric-production-smoke"}`, "JSON payload string.")
payloadB64 = flag.String("payload-b64", "", "Base64-encoded JSON payload string.")
)
flag.Parse()
if *endpoint == "" || *clusterID == "" || *routeID == "" || *sourceNodeID == "" || *destNodeID == "" || *currentNodeID == "" || *nextHopNodeID == "" {
writeOutput(smokeOutput{OK: false, Error: "endpoint, cluster-id, route-id, source-node-id, destination-node-id, current-hop-node-id and next-hop-node-id are required"})
os.Exit(2)
}
path := splitRoutePath(*routePath)
payloadSource := strings.TrimSpace(*payloadText)
if strings.TrimSpace(*payloadB64) != "" {
decoded, err := base64.StdEncoding.DecodeString(strings.TrimSpace(*payloadB64))
if err != nil {
writeOutput(smokeOutput{OK: false, Error: "payload-b64 must be valid base64"})
os.Exit(2)
}
payloadSource = string(decoded)
}
payload := json.RawMessage(strings.TrimSpace(payloadSource))
if !json.Valid(payload) {
writeOutput(smokeOutput{OK: false, Error: "payload must be valid JSON"})
os.Exit(2)
}
now := time.Now().UTC()
messageType := mesh.ProductionMessageFabricControl
if strings.TrimSpace(*channel) == mesh.ProductionChannelVPNPacket {
messageType = mesh.ProductionMessageVPNPacketBatch
}
sum := sha256.Sum256(payload)
envelope := mesh.ProductionEnvelope{
FabricProtocolVersion: mesh.ProtocolVersion,
MessageID: fmt.Sprintf("fabric-production-smoke-%d", now.UnixNano()),
RouteID: strings.TrimSpace(*routeID),
ClusterID: strings.TrimSpace(*clusterID),
SourceNodeID: strings.TrimSpace(*sourceNodeID),
DestinationNodeID: strings.TrimSpace(*destNodeID),
CurrentHopNodeID: strings.TrimSpace(*currentNodeID),
NextHopNodeID: strings.TrimSpace(*nextHopNodeID),
RoutePath: path,
ChannelClass: strings.TrimSpace(*channel),
MessageType: messageType,
TTL: 8,
HopCount: 0,
CreatedAt: now,
ExpiresAt: now.Add(time.Minute),
PayloadLength: len(payload),
PayloadHash: hex.EncodeToString(sum[:]),
Payload: payload,
}
transport := mesh.NewQUICFabricTransport(nil)
ctx, cancel := context.WithTimeout(context.Background(), *timeout)
defer cancel()
started := time.Now()
result, err := sendProductionEnvelope(ctx, transport, mesh.FabricTransportTarget{
EndpointID: "fabric-production-smoke-entry",
PeerID: envelope.CurrentHopNodeID,
Endpoint: strings.TrimSpace(*endpoint),
Transport: "quic",
PeerCertSHA256: strings.TrimSpace(*peerCert),
Timeout: *timeout,
InboundBuffer: 8,
ErrorBuffer: 4,
}, envelope)
output := smokeOutput{
OK: err == nil && result.Accepted,
Endpoint: *endpoint,
EntryNodeID: envelope.CurrentHopNodeID,
NextHopID: envelope.NextHopNodeID,
RouteID: envelope.RouteID,
ElapsedMS: time.Since(started).Milliseconds(),
Result: result,
EnvelopePath: path,
}
if err != nil {
output.Error = err.Error()
writeOutput(output)
os.Exit(1)
}
writeOutput(output)
}
func sendProductionEnvelope(ctx context.Context, transport *mesh.QUICFabricTransport, target mesh.FabricTransportTarget, envelope mesh.ProductionEnvelope) (mesh.ProductionForwardResult, error) {
session, err := transport.Connect(ctx, target)
if err != nil {
return mesh.ProductionForwardResult{}, err
}
defer session.Close()
payload, err := json.Marshal(envelope)
if err != nil {
return mesh.ProductionForwardResult{}, err
}
if err := session.Send(ctx, fabricproto.Frame{
Type: fabricproto.FrameData,
TrafficClass: fabricproto.TrafficClassReliable,
StreamID: mesh.ProductionForwardQUICStreamID,
Sequence: 1,
Payload: payload,
}); err != nil {
return mesh.ProductionForwardResult{}, err
}
for {
select {
case <-ctx.Done():
return mesh.ProductionForwardResult{}, ctx.Err()
case err := <-session.Errors():
if err != nil {
return mesh.ProductionForwardResult{}, err
}
case frame := <-session.Frames():
if frame.Type != fabricproto.FrameData || frame.StreamID != mesh.ProductionForwardQUICStreamID || frame.Sequence != 1 {
continue
}
var response productionForwardResponse
if err := json.Unmarshal(frame.Payload, &response); err != nil {
return mesh.ProductionForwardResult{}, err
}
if strings.TrimSpace(response.Error) != "" {
return mesh.ProductionForwardResult{}, errors.New(response.Error)
}
return response.Result, nil
}
}
}
func splitRoutePath(value string) []string {
value = strings.TrimSpace(value)
if value == "" {
return nil
}
parts := strings.Split(value, ",")
out := make([]string, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part != "" {
out = append(out, part)
}
}
return out
}
func writeOutput(output smokeOutput) {
payload, err := json.MarshalIndent(output, "", " ")
if err != nil {
fmt.Fprintf(os.Stderr, "marshal smoke output: %v\n", err)
return
}
fmt.Println(string(payload))
}