Files
rdp-proxy/agents/rap-node-agent/internal/mesh/fabric_quic_transport_test.go
T

291 lines
8.4 KiB
Go

package mesh
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/hex"
"encoding/pem"
"math/big"
"strings"
"testing"
"time"
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/fabricproto"
"github.com/quic-go/quic-go"
)
func TestQUICFabricTransportPingPong(t *testing.T) {
listener := startQUICFabricEchoServer(t)
defer listener.Close()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
transport := NewQUICFabricTransport(&quic.Config{EnableDatagrams: true})
session, err := transport.Connect(ctx, FabricTransportTarget{
Endpoint: listener.Addr().String(),
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{fabricQUICNextProto},
},
Timeout: time.Second,
InboundBuffer: 4,
ErrorBuffer: 4,
})
if err != nil {
t.Fatalf("connect quic fabric: %v", err)
}
defer session.Close()
if err := session.Send(ctx, fabricproto.Frame{Type: fabricproto.FramePing, Sequence: 42, Payload: []byte("quic")}); err != nil {
t.Fatalf("send ping: %v", err)
}
select {
case frame := <-session.Frames():
if frame.Type != fabricproto.FramePong || frame.Sequence != 42 || string(frame.Payload) != "quic" {
t.Fatalf("frame = %+v", frame)
}
case err := <-session.Errors():
t.Fatalf("session error: %v", err)
case <-ctx.Done():
t.Fatal(ctx.Err())
}
}
func TestQUICFabricTransportDataAck(t *testing.T) {
listener := startQUICFabricEchoServer(t)
defer listener.Close()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
session, err := NewQUICFabricTransport(nil).Connect(ctx, FabricTransportTarget{
Endpoint: listener.Addr().String(),
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{fabricQUICNextProto},
},
Timeout: time.Second,
InboundBuffer: 4,
ErrorBuffer: 4,
})
if err != nil {
t.Fatalf("connect quic fabric: %v", err)
}
defer session.Close()
if err := session.Send(ctx, fabricproto.Frame{
Type: fabricproto.FrameOpenStream,
StreamID: 9,
TrafficClass: fabricproto.TrafficClassInteractive,
}); err != nil {
t.Fatalf("open stream: %v", err)
}
if err := session.Send(ctx, fabricproto.Frame{
Type: fabricproto.FrameData,
StreamID: 9,
Sequence: 7,
TrafficClass: fabricproto.TrafficClassInteractive,
Payload: []byte("packet"),
}); err != nil {
t.Fatalf("send data: %v", err)
}
select {
case frame := <-session.Frames():
if frame.Type != fabricproto.FrameAck || frame.StreamID != 9 || frame.Sequence != 7 {
t.Fatalf("frame = %+v", frame)
}
case err := <-session.Errors():
t.Fatalf("session error: %v", err)
case <-ctx.Done():
t.Fatal(ctx.Err())
}
}
func TestQUICFabricTransportVerifiesPinnedCertificate(t *testing.T) {
tlsConfig := testQUICTLSConfig(t)
listener := startQUICFabricEchoServerWithTLS(t, tlsConfig)
defer listener.Close()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
session, err := NewQUICFabricTransport(nil).Connect(ctx, FabricTransportTarget{
Endpoint: listener.Addr().String(),
PeerCertSHA256: testQUICCertSHA256(t, tlsConfig),
Timeout: time.Second,
InboundBuffer: 4,
ErrorBuffer: 4,
})
if err != nil {
t.Fatalf("connect quic fabric with pinned certificate: %v", err)
}
defer session.Close()
if err := session.Send(ctx, fabricproto.Frame{Type: fabricproto.FramePing, Sequence: 43, Payload: []byte("pin")}); err != nil {
t.Fatalf("send ping: %v", err)
}
select {
case frame := <-session.Frames():
if frame.Type != fabricproto.FramePong || frame.Sequence != 43 || string(frame.Payload) != "pin" {
t.Fatalf("frame = %+v", frame)
}
case err := <-session.Errors():
t.Fatalf("session error: %v", err)
case <-ctx.Done():
t.Fatal(ctx.Err())
}
}
func TestQUICFabricTransportRejectsPinnedCertificateMismatch(t *testing.T) {
listener := startQUICFabricEchoServer(t)
defer listener.Close()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_, err := NewQUICFabricTransport(nil).Connect(ctx, FabricTransportTarget{
Endpoint: listener.Addr().String(),
PeerCertSHA256: strings.Repeat("0", 64),
Timeout: time.Second,
})
if err == nil {
t.Fatal("connect succeeded with mismatched certificate pin")
}
}
func TestQUICFabricServerHandlesFabricFrames(t *testing.T) {
var events []FabricSessionEventLogEntry
server, err := StartQUICFabricServer(context.Background(), QUICFabricServerConfig{
ListenAddr: "127.0.0.1:0",
TLSConfig: testQUICTLSConfig(t),
Logger: func(entry FabricSessionEventLogEntry) {
events = append(events, entry)
},
})
if err != nil {
t.Fatalf("start quic fabric server: %v", err)
}
defer server.Close()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
session, err := NewQUICFabricTransport(nil).Connect(ctx, FabricTransportTarget{
Endpoint: server.Addr().String(),
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{fabricQUICNextProto},
},
Timeout: time.Second,
InboundBuffer: 4,
ErrorBuffer: 4,
})
if err != nil {
t.Fatalf("connect quic fabric: %v", err)
}
defer session.Close()
if err := session.Send(ctx, fabricproto.Frame{Type: fabricproto.FramePing, Sequence: 77, Payload: []byte("server")}); err != nil {
t.Fatalf("send ping: %v", err)
}
select {
case frame := <-session.Frames():
if frame.Type != fabricproto.FramePong || frame.Sequence != 77 || string(frame.Payload) != "server" {
t.Fatalf("frame = %+v", frame)
}
case err := <-session.Errors():
t.Fatalf("session error: %v", err)
case <-ctx.Done():
t.Fatal(ctx.Err())
}
if len(events) < 2 || events[0].Event != "fabric_session_quic_stream_opened" {
t.Fatalf("events = %+v", events)
}
}
func startQUICFabricEchoServer(t *testing.T) *quic.Listener {
t.Helper()
return startQUICFabricEchoServerWithTLS(t, testQUICTLSConfig(t))
}
func startQUICFabricEchoServerWithTLS(t *testing.T, tlsConfig *tls.Config) *quic.Listener {
t.Helper()
listener, err := quic.ListenAddr("127.0.0.1:0", tlsConfig, &quic.Config{EnableDatagrams: true})
if err != nil {
t.Fatalf("listen quic: %v", err)
}
go func() {
conn, err := listener.Accept(context.Background())
if err != nil {
return
}
stream, err := conn.AcceptStream(context.Background())
if err != nil {
_ = conn.CloseWithError(1, "accept stream failed")
return
}
session := fabricproto.NewSession(fabricproto.SessionConfig{})
for {
frame, err := fabricproto.ReadFrame(stream, fabricproto.DefaultMaxPayload)
if err != nil {
_ = conn.CloseWithError(0, "closed")
return
}
_, responses, err := session.HandleFrame(frame)
if err != nil {
_ = conn.CloseWithError(2, err.Error())
return
}
for _, response := range responses {
if err := fabricproto.WriteFrame(stream, response); err != nil {
_ = conn.CloseWithError(3, err.Error())
return
}
}
}
}()
return listener
}
func testQUICCertSHA256(t *testing.T, tlsConfig *tls.Config) string {
t.Helper()
if len(tlsConfig.Certificates) == 0 || len(tlsConfig.Certificates[0].Certificate) == 0 {
t.Fatal("test tls config has no certificate")
}
sum := sha256.Sum256(tlsConfig.Certificates[0].Certificate[0])
return hex.EncodeToString(sum[:])
}
func testQUICTLSConfig(t *testing.T) *tls.Config {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("generate key: %v", err)
}
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "rap-fabric-test"},
NotBefore: time.Now().Add(-time.Minute),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
DNSNames: []string{"localhost"},
}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
t.Fatalf("create certificate: %v", err)
}
keyDER := x509.MarshalPKCS1PrivateKey(key)
cert, err := tls.X509KeyPair(
pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}),
pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: keyDER}),
)
if err != nil {
t.Fatalf("key pair: %v", err)
}
return &tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: []string{fabricQUICNextProto},
}
}