Refactor RDP proxy handling and update related tests
This commit is contained in:
@@ -0,0 +1,199 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/fabricproto"
|
||||
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/mesh"
|
||||
)
|
||||
|
||||
type smokeOutput struct {
|
||||
OK bool `json:"ok"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
EntryNodeID string `json:"entry_node_id"`
|
||||
NextHopID string `json:"next_hop_node_id"`
|
||||
RouteID string `json:"route_id"`
|
||||
ElapsedMS int64 `json:"elapsed_ms"`
|
||||
Result mesh.ProductionForwardResult `json:"result"`
|
||||
Error string `json:"error,omitempty"`
|
||||
EnvelopePath []string `json:"envelope_path,omitempty"`
|
||||
}
|
||||
|
||||
type productionForwardResponse struct {
|
||||
Result mesh.ProductionForwardResult `json:"result,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
var (
|
||||
endpoint = flag.String("endpoint", "", "QUIC fabric endpoint for the entry node, for example quic://host:19131.")
|
||||
peerCert = flag.String("peer-cert-sha256", "", "Expected entry node QUIC TLS certificate SHA-256 fingerprint.")
|
||||
clusterID = flag.String("cluster-id", "", "Cluster ID.")
|
||||
routeID = flag.String("route-id", "", "Configured production route ID.")
|
||||
sourceNodeID = flag.String("source-node-id", "", "Route source node ID.")
|
||||
destNodeID = flag.String("destination-node-id", "", "Route destination node ID.")
|
||||
currentNodeID = flag.String("current-hop-node-id", "", "Current hop node ID expected by the entry node.")
|
||||
nextHopNodeID = flag.String("next-hop-node-id", "", "Next hop node ID from the entry node.")
|
||||
routePath = flag.String("route-path", "", "Comma-separated route path.")
|
||||
channel = flag.String("channel", mesh.ProductionChannelFabricControl, "Production channel class.")
|
||||
timeout = flag.Duration("timeout", 10*time.Second, "Smoke request timeout.")
|
||||
payloadText = flag.String("payload", `{"kind":"fabric-production-smoke"}`, "JSON payload string.")
|
||||
payloadB64 = flag.String("payload-b64", "", "Base64-encoded JSON payload string.")
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
if *endpoint == "" || *clusterID == "" || *routeID == "" || *sourceNodeID == "" || *destNodeID == "" || *currentNodeID == "" || *nextHopNodeID == "" {
|
||||
writeOutput(smokeOutput{OK: false, Error: "endpoint, cluster-id, route-id, source-node-id, destination-node-id, current-hop-node-id and next-hop-node-id are required"})
|
||||
os.Exit(2)
|
||||
}
|
||||
path := splitRoutePath(*routePath)
|
||||
payloadSource := strings.TrimSpace(*payloadText)
|
||||
if strings.TrimSpace(*payloadB64) != "" {
|
||||
decoded, err := base64.StdEncoding.DecodeString(strings.TrimSpace(*payloadB64))
|
||||
if err != nil {
|
||||
writeOutput(smokeOutput{OK: false, Error: "payload-b64 must be valid base64"})
|
||||
os.Exit(2)
|
||||
}
|
||||
payloadSource = string(decoded)
|
||||
}
|
||||
payload := json.RawMessage(strings.TrimSpace(payloadSource))
|
||||
if !json.Valid(payload) {
|
||||
writeOutput(smokeOutput{OK: false, Error: "payload must be valid JSON"})
|
||||
os.Exit(2)
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
messageType := mesh.ProductionMessageFabricControl
|
||||
if strings.TrimSpace(*channel) == mesh.ProductionChannelVPNPacket {
|
||||
messageType = mesh.ProductionMessageVPNPacketBatch
|
||||
}
|
||||
sum := sha256.Sum256(payload)
|
||||
envelope := mesh.ProductionEnvelope{
|
||||
FabricProtocolVersion: mesh.ProtocolVersion,
|
||||
MessageID: fmt.Sprintf("fabric-production-smoke-%d", now.UnixNano()),
|
||||
RouteID: strings.TrimSpace(*routeID),
|
||||
ClusterID: strings.TrimSpace(*clusterID),
|
||||
SourceNodeID: strings.TrimSpace(*sourceNodeID),
|
||||
DestinationNodeID: strings.TrimSpace(*destNodeID),
|
||||
CurrentHopNodeID: strings.TrimSpace(*currentNodeID),
|
||||
NextHopNodeID: strings.TrimSpace(*nextHopNodeID),
|
||||
RoutePath: path,
|
||||
ChannelClass: strings.TrimSpace(*channel),
|
||||
MessageType: messageType,
|
||||
TTL: 8,
|
||||
HopCount: 0,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(time.Minute),
|
||||
PayloadLength: len(payload),
|
||||
PayloadHash: hex.EncodeToString(sum[:]),
|
||||
Payload: payload,
|
||||
}
|
||||
|
||||
transport := mesh.NewQUICFabricTransport(nil)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), *timeout)
|
||||
defer cancel()
|
||||
started := time.Now()
|
||||
result, err := sendProductionEnvelope(ctx, transport, mesh.FabricTransportTarget{
|
||||
EndpointID: "fabric-production-smoke-entry",
|
||||
PeerID: envelope.CurrentHopNodeID,
|
||||
Endpoint: strings.TrimSpace(*endpoint),
|
||||
Transport: "quic",
|
||||
PeerCertSHA256: strings.TrimSpace(*peerCert),
|
||||
Timeout: *timeout,
|
||||
InboundBuffer: 8,
|
||||
ErrorBuffer: 4,
|
||||
}, envelope)
|
||||
output := smokeOutput{
|
||||
OK: err == nil && result.Accepted,
|
||||
Endpoint: *endpoint,
|
||||
EntryNodeID: envelope.CurrentHopNodeID,
|
||||
NextHopID: envelope.NextHopNodeID,
|
||||
RouteID: envelope.RouteID,
|
||||
ElapsedMS: time.Since(started).Milliseconds(),
|
||||
Result: result,
|
||||
EnvelopePath: path,
|
||||
}
|
||||
if err != nil {
|
||||
output.Error = err.Error()
|
||||
writeOutput(output)
|
||||
os.Exit(1)
|
||||
}
|
||||
writeOutput(output)
|
||||
}
|
||||
|
||||
func sendProductionEnvelope(ctx context.Context, transport *mesh.QUICFabricTransport, target mesh.FabricTransportTarget, envelope mesh.ProductionEnvelope) (mesh.ProductionForwardResult, error) {
|
||||
session, err := transport.Connect(ctx, target)
|
||||
if err != nil {
|
||||
return mesh.ProductionForwardResult{}, err
|
||||
}
|
||||
defer session.Close()
|
||||
payload, err := json.Marshal(envelope)
|
||||
if err != nil {
|
||||
return mesh.ProductionForwardResult{}, err
|
||||
}
|
||||
if err := session.Send(ctx, fabricproto.Frame{
|
||||
Type: fabricproto.FrameData,
|
||||
TrafficClass: fabricproto.TrafficClassReliable,
|
||||
StreamID: mesh.ProductionForwardQUICStreamID,
|
||||
Sequence: 1,
|
||||
Payload: payload,
|
||||
}); err != nil {
|
||||
return mesh.ProductionForwardResult{}, err
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return mesh.ProductionForwardResult{}, ctx.Err()
|
||||
case err := <-session.Errors():
|
||||
if err != nil {
|
||||
return mesh.ProductionForwardResult{}, err
|
||||
}
|
||||
case frame := <-session.Frames():
|
||||
if frame.Type != fabricproto.FrameData || frame.StreamID != mesh.ProductionForwardQUICStreamID || frame.Sequence != 1 {
|
||||
continue
|
||||
}
|
||||
var response productionForwardResponse
|
||||
if err := json.Unmarshal(frame.Payload, &response); err != nil {
|
||||
return mesh.ProductionForwardResult{}, err
|
||||
}
|
||||
if strings.TrimSpace(response.Error) != "" {
|
||||
return mesh.ProductionForwardResult{}, errors.New(response.Error)
|
||||
}
|
||||
return response.Result, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func splitRoutePath(value string) []string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(value, ",")
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part != "" {
|
||||
out = append(out, part)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func writeOutput(output smokeOutput) {
|
||||
payload, err := json.MarshalIndent(output, "", " ")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "marshal smoke output: %v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Println(string(payload))
|
||||
}
|
||||
Reference in New Issue
Block a user