Pin QUIC fabric endpoint certificates
This commit is contained in:
@@ -2,7 +2,10 @@ package mesh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -34,6 +37,34 @@ func NewQUICFabricTransport(config *quic.Config) *QUICFabricTransport {
|
||||
return &QUICFabricTransport{Config: config}
|
||||
}
|
||||
|
||||
func quicTLSConfigForTarget(target FabricTransportTarget) *tls.Config {
|
||||
expectedFingerprint := normalizeCertSHA256(target.PeerCertSHA256)
|
||||
config := &tls.Config{NextProtos: []string{fabricQUICNextProto}}
|
||||
if expectedFingerprint == "" {
|
||||
return config
|
||||
}
|
||||
config.InsecureSkipVerify = true
|
||||
config.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
|
||||
if len(rawCerts) == 0 {
|
||||
return fmt.Errorf("quic peer certificate missing")
|
||||
}
|
||||
sum := sha256.Sum256(rawCerts[0])
|
||||
actual := hex.EncodeToString(sum[:])
|
||||
if actual != expectedFingerprint {
|
||||
return fmt.Errorf("quic peer certificate fingerprint mismatch")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
func normalizeCertSHA256(value string) string {
|
||||
value = strings.ToLower(strings.TrimSpace(value))
|
||||
value = strings.ReplaceAll(value, "sha256:", "")
|
||||
value = strings.ReplaceAll(value, ":", "")
|
||||
return value
|
||||
}
|
||||
|
||||
func (t *QUICFabricTransport) Connect(ctx context.Context, target FabricTransportTarget) (FabricTransportSession, error) {
|
||||
if target.Endpoint == "" {
|
||||
return nil, fmt.Errorf("quic fabric endpoint is required")
|
||||
@@ -41,7 +72,7 @@ func (t *QUICFabricTransport) Connect(ctx context.Context, target FabricTranspor
|
||||
target.Endpoint = strings.TrimPrefix(strings.TrimSpace(target.Endpoint), "quic://")
|
||||
tlsConfig := target.TLSConfig
|
||||
if tlsConfig == nil {
|
||||
tlsConfig = &tls.Config{NextProtos: []string{fabricQUICNextProto}}
|
||||
tlsConfig = quicTLSConfigForTarget(target)
|
||||
} else {
|
||||
tlsConfig = tlsConfig.Clone()
|
||||
if len(tlsConfig.NextProtos) == 0 {
|
||||
|
||||
@@ -4,11 +4,14 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -102,6 +105,56 @@ func TestQUICFabricTransportDataAck(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestQUICFabricTransportVerifiesPinnedCertificate(t *testing.T) {
|
||||
tlsConfig := testQUICTLSConfig(t)
|
||||
listener := startQUICFabricEchoServerWithTLS(t, tlsConfig)
|
||||
defer listener.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
session, err := NewQUICFabricTransport(nil).Connect(ctx, FabricTransportTarget{
|
||||
Endpoint: listener.Addr().String(),
|
||||
PeerCertSHA256: testQUICCertSHA256(t, tlsConfig),
|
||||
Timeout: time.Second,
|
||||
InboundBuffer: 4,
|
||||
ErrorBuffer: 4,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("connect quic fabric with pinned certificate: %v", err)
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
if err := session.Send(ctx, fabricproto.Frame{Type: fabricproto.FramePing, Sequence: 43, Payload: []byte("pin")}); err != nil {
|
||||
t.Fatalf("send ping: %v", err)
|
||||
}
|
||||
select {
|
||||
case frame := <-session.Frames():
|
||||
if frame.Type != fabricproto.FramePong || frame.Sequence != 43 || string(frame.Payload) != "pin" {
|
||||
t.Fatalf("frame = %+v", frame)
|
||||
}
|
||||
case err := <-session.Errors():
|
||||
t.Fatalf("session error: %v", err)
|
||||
case <-ctx.Done():
|
||||
t.Fatal(ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestQUICFabricTransportRejectsPinnedCertificateMismatch(t *testing.T) {
|
||||
listener := startQUICFabricEchoServer(t)
|
||||
defer listener.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
_, err := NewQUICFabricTransport(nil).Connect(ctx, FabricTransportTarget{
|
||||
Endpoint: listener.Addr().String(),
|
||||
PeerCertSHA256: strings.Repeat("0", 64),
|
||||
Timeout: time.Second,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("connect succeeded with mismatched certificate pin")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQUICFabricServerHandlesFabricFrames(t *testing.T) {
|
||||
var events []FabricSessionEventLogEntry
|
||||
server, err := StartQUICFabricServer(context.Background(), QUICFabricServerConfig{
|
||||
@@ -152,7 +205,12 @@ func TestQUICFabricServerHandlesFabricFrames(t *testing.T) {
|
||||
|
||||
func startQUICFabricEchoServer(t *testing.T) *quic.Listener {
|
||||
t.Helper()
|
||||
listener, err := quic.ListenAddr("127.0.0.1:0", testQUICTLSConfig(t), &quic.Config{EnableDatagrams: true})
|
||||
return startQUICFabricEchoServerWithTLS(t, testQUICTLSConfig(t))
|
||||
}
|
||||
|
||||
func startQUICFabricEchoServerWithTLS(t *testing.T, tlsConfig *tls.Config) *quic.Listener {
|
||||
t.Helper()
|
||||
listener, err := quic.ListenAddr("127.0.0.1:0", tlsConfig, &quic.Config{EnableDatagrams: true})
|
||||
if err != nil {
|
||||
t.Fatalf("listen quic: %v", err)
|
||||
}
|
||||
@@ -189,6 +247,15 @@ func startQUICFabricEchoServer(t *testing.T) *quic.Listener {
|
||||
return listener
|
||||
}
|
||||
|
||||
func testQUICCertSHA256(t *testing.T, tlsConfig *tls.Config) string {
|
||||
t.Helper()
|
||||
if len(tlsConfig.Certificates) == 0 || len(tlsConfig.Certificates[0].Certificate) == 0 {
|
||||
t.Fatal("test tls config has no certificate")
|
||||
}
|
||||
sum := sha256.Sum256(tlsConfig.Certificates[0].Certificate[0])
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func testQUICTLSConfig(t *testing.T) *tls.Config {
|
||||
t.Helper()
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
|
||||
@@ -31,6 +31,7 @@ type FabricTransportTarget struct {
|
||||
Token string
|
||||
Header http.Header
|
||||
TLSConfig *tls.Config
|
||||
PeerCertSHA256 string
|
||||
Timeout time.Duration
|
||||
MaxPayload int
|
||||
OutboundBuffer int
|
||||
|
||||
Reference in New Issue
Block a user