400 lines
12 KiB
Go
400 lines
12 KiB
Go
package mesh
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/fabricproto"
|
|
"github.com/quic-go/quic-go"
|
|
)
|
|
|
|
type QUICFabricServer struct {
|
|
listener *quic.Listener
|
|
logger FabricSessionEventLogger
|
|
reverseMu sync.RWMutex
|
|
reverseTransport *QUICFabricTransport
|
|
fabricFrameHandler FabricFrameHandler
|
|
productionForwardHandler func(context.Context, ProductionEnvelope) (ProductionForwardResult, error)
|
|
webIngressForwardHandler func(context.Context, []byte) ([]byte, error)
|
|
fabricControlHandler func(context.Context, []byte) ([]byte, error)
|
|
syntheticForwardHandler func(context.Context, SyntheticEnvelope) (SyntheticEnvelope, error)
|
|
done chan struct{}
|
|
closeOnce sync.Once
|
|
}
|
|
|
|
type QUICFabricServerConfig struct {
|
|
ListenAddr string
|
|
TLSConfig *tls.Config
|
|
QUICConfig *quic.Config
|
|
Logger FabricSessionEventLogger
|
|
ReverseTransport *QUICFabricTransport
|
|
FabricFrameHandler FabricFrameHandler
|
|
ProductionForwardHandler func(context.Context, ProductionEnvelope) (ProductionForwardResult, error)
|
|
WebIngressForwardHandler func(context.Context, []byte) ([]byte, error)
|
|
FabricControlHandler func(context.Context, []byte) ([]byte, error)
|
|
SyntheticForwardHandler func(context.Context, SyntheticEnvelope) (SyntheticEnvelope, error)
|
|
}
|
|
|
|
type FabricFrameSender interface {
|
|
SendFrame(context.Context, fabricproto.Frame) error
|
|
}
|
|
|
|
type FabricFrameHandler func(context.Context, FabricFrameSender, fabricproto.Frame) (bool, error)
|
|
|
|
func StartQUICFabricServer(ctx context.Context, cfg QUICFabricServerConfig) (*QUICFabricServer, error) {
|
|
if cfg.ListenAddr == "" {
|
|
return nil, fmt.Errorf("quic fabric listen addr is required")
|
|
}
|
|
tlsConfig := cfg.TLSConfig
|
|
if tlsConfig == nil {
|
|
return nil, fmt.Errorf("quic fabric tls config is required")
|
|
}
|
|
tlsConfig = tlsConfig.Clone()
|
|
if len(tlsConfig.NextProtos) == 0 {
|
|
tlsConfig.NextProtos = []string{fabricQUICNextProto}
|
|
}
|
|
listener, err := quic.ListenAddr(cfg.ListenAddr, tlsConfig, defaultQUICFabricConfig(cfg.QUICConfig))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
server := &QUICFabricServer{
|
|
listener: listener,
|
|
logger: cfg.Logger,
|
|
reverseTransport: cfg.ReverseTransport,
|
|
fabricFrameHandler: cfg.FabricFrameHandler,
|
|
productionForwardHandler: cfg.ProductionForwardHandler,
|
|
webIngressForwardHandler: cfg.WebIngressForwardHandler,
|
|
fabricControlHandler: cfg.FabricControlHandler,
|
|
syntheticForwardHandler: cfg.SyntheticForwardHandler,
|
|
done: make(chan struct{}),
|
|
}
|
|
go server.acceptLoop(ctx)
|
|
return server, nil
|
|
}
|
|
|
|
func (s *QUICFabricServer) Addr() net.Addr {
|
|
if s == nil || s.listener == nil {
|
|
return nil
|
|
}
|
|
return s.listener.Addr()
|
|
}
|
|
|
|
func (s *QUICFabricServer) SetReverseTransport(transport *QUICFabricTransport) {
|
|
if s == nil {
|
|
return
|
|
}
|
|
s.reverseMu.Lock()
|
|
s.reverseTransport = transport
|
|
s.reverseMu.Unlock()
|
|
}
|
|
|
|
func (s *QUICFabricServer) Close() error {
|
|
if s == nil {
|
|
return nil
|
|
}
|
|
var err error
|
|
s.closeOnce.Do(func() {
|
|
close(s.done)
|
|
if s.listener != nil {
|
|
err = s.listener.Close()
|
|
}
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (s *QUICFabricServer) acceptLoop(ctx context.Context) {
|
|
defer s.Close()
|
|
for {
|
|
conn, err := s.listener.Accept(ctx)
|
|
if err != nil {
|
|
return
|
|
}
|
|
go s.handleConn(ctx, conn)
|
|
}
|
|
}
|
|
|
|
func (s *QUICFabricServer) handleConn(ctx context.Context, conn *quic.Conn) {
|
|
for {
|
|
stream, err := conn.AcceptStream(ctx)
|
|
if err != nil {
|
|
_ = conn.CloseWithError(0, "accept stream stopped")
|
|
return
|
|
}
|
|
go s.handleStream(ctx, conn, stream)
|
|
}
|
|
}
|
|
|
|
func (s *QUICFabricServer) handleStream(ctx context.Context, conn *quic.Conn, stream *quic.Stream) {
|
|
session := fabricproto.NewSession(fabricproto.SessionConfig{})
|
|
sender := &quicStreamFrameSender{stream: stream}
|
|
defer func() { _ = stream.Close() }()
|
|
s.logFabricSession(FabricSessionEventLogEntry{
|
|
Event: "fabric_session_quic_stream_opened",
|
|
AcceptedBy: "quic",
|
|
RemoteAddr: conn.RemoteAddr().String(),
|
|
})
|
|
defer s.logFabricSession(FabricSessionEventLogEntry{
|
|
Event: "fabric_session_quic_stream_closed",
|
|
AcceptedBy: "quic",
|
|
RemoteAddr: conn.RemoteAddr().String(),
|
|
})
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
_ = stream.Close()
|
|
return
|
|
default:
|
|
}
|
|
frame, err := fabricproto.ReadFrame(stream, fabricproto.DefaultMaxPayload)
|
|
if err != nil {
|
|
return
|
|
}
|
|
s.registerReverseHelloFrame(conn, frame)
|
|
if s.handleProductionForwardFrame(ctx, stream, frame) {
|
|
continue
|
|
}
|
|
if s.handleWebIngressForwardFrame(ctx, stream, frame) {
|
|
continue
|
|
}
|
|
if s.handleFabricControlForwardFrame(ctx, stream, frame) {
|
|
continue
|
|
}
|
|
if s.handleSyntheticForwardFrame(ctx, conn, stream, frame) {
|
|
continue
|
|
}
|
|
if s.fabricFrameHandler != nil {
|
|
handled, err := s.fabricFrameHandler(ctx, sender, frame)
|
|
if err != nil {
|
|
_ = conn.CloseWithError(2, err.Error())
|
|
return
|
|
}
|
|
if handled {
|
|
continue
|
|
}
|
|
}
|
|
event, responses, err := session.HandleFrame(frame)
|
|
if err != nil {
|
|
_ = conn.CloseWithError(2, err.Error())
|
|
return
|
|
}
|
|
if event.Type != fabricproto.SessionEventNone {
|
|
s.logFabricSession(FabricSessionEventLogEntry{
|
|
Event: "fabric_session_event",
|
|
SessionEvent: event.Type,
|
|
StreamID: event.StreamID,
|
|
Sequence: event.Sequence,
|
|
TrafficClass: event.TrafficClass,
|
|
AcceptedBy: "quic",
|
|
RemoteAddr: conn.RemoteAddr().String(),
|
|
})
|
|
}
|
|
for _, response := range responses {
|
|
if err := fabricproto.WriteFrame(stream, response); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
type quicStreamFrameSender struct {
|
|
stream *quic.Stream
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func (s *quicStreamFrameSender) SendFrame(ctx context.Context, frame fabricproto.Frame) error {
|
|
if s.stream == nil {
|
|
return fmt.Errorf("quic fabric stream is closed")
|
|
}
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if deadline, ok := ctx.Deadline(); ok {
|
|
_ = s.stream.SetWriteDeadline(deadline)
|
|
} else {
|
|
_ = s.stream.SetWriteDeadline(time.Now().Add(30 * time.Second))
|
|
}
|
|
return fabricproto.WriteFrame(s.stream, frame)
|
|
}
|
|
|
|
func (s *QUICFabricServer) registerReverseHelloFrame(conn *quic.Conn, frame fabricproto.Frame) {
|
|
reverseTransport := s.getReverseTransport()
|
|
if s == nil || reverseTransport == nil || conn == nil || frame.Type != fabricproto.FramePing {
|
|
return
|
|
}
|
|
payload := string(frame.Payload)
|
|
if !strings.HasPrefix(payload, fabricQUICReverseHelloPrefix) {
|
|
return
|
|
}
|
|
peerID := strings.TrimPrefix(payload, fabricQUICReverseHelloPrefix)
|
|
reverseTransport.RegisterReverseConn(peerID, conn)
|
|
s.logFabricSession(FabricSessionEventLogEntry{
|
|
Event: "fabric_session_quic_reverse_registered",
|
|
AcceptedBy: "quic_reverse_hello",
|
|
RemoteAddr: conn.RemoteAddr().String(),
|
|
PeerID: peerID,
|
|
})
|
|
}
|
|
|
|
type quicProductionForwardResponse struct {
|
|
Result ProductionForwardResult `json:"result,omitempty"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
type quicSyntheticForwardResponse struct {
|
|
Envelope SyntheticEnvelope `json:"envelope,omitempty"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
type quicWebIngressForwardResponse struct {
|
|
Payload json.RawMessage `json:"payload,omitempty"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
type quicFabricControlForwardResponse struct {
|
|
Payload json.RawMessage `json:"payload,omitempty"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
func (s *QUICFabricServer) handleProductionForwardFrame(ctx context.Context, stream *quic.Stream, frame fabricproto.Frame) bool {
|
|
if frame.Type != fabricproto.FrameData || frame.StreamID != ProductionForwardQUICStreamID {
|
|
return false
|
|
}
|
|
response := quicProductionForwardResponse{}
|
|
if s == nil || s.productionForwardHandler == nil {
|
|
response.Error = ErrForwardRuntimeUnavailable.Error()
|
|
} else {
|
|
var envelope ProductionEnvelope
|
|
if err := json.Unmarshal(frame.Payload, &envelope); err != nil {
|
|
response.Error = "invalid production mesh envelope"
|
|
} else if result, err := s.productionForwardHandler(ctx, envelope); err != nil {
|
|
response.Error = err.Error()
|
|
} else {
|
|
response.Result = result
|
|
}
|
|
}
|
|
payload, err := json.Marshal(response)
|
|
if err != nil {
|
|
return true
|
|
}
|
|
_ = fabricproto.WriteFrame(stream, fabricproto.Frame{
|
|
Type: fabricproto.FrameData,
|
|
TrafficClass: fabricproto.TrafficClassReliable,
|
|
StreamID: ProductionForwardQUICStreamID,
|
|
Sequence: frame.Sequence,
|
|
Payload: payload,
|
|
})
|
|
return true
|
|
}
|
|
|
|
func (s *QUICFabricServer) handleWebIngressForwardFrame(ctx context.Context, stream *quic.Stream, frame fabricproto.Frame) bool {
|
|
if frame.Type != fabricproto.FrameData || frame.StreamID != WebIngressForwardQUICStreamID {
|
|
return false
|
|
}
|
|
response := quicWebIngressForwardResponse{}
|
|
if s == nil || s.webIngressForwardHandler == nil {
|
|
response.Error = ErrForwardRuntimeUnavailable.Error()
|
|
} else if payload, err := s.webIngressForwardHandler(ctx, append([]byte(nil), frame.Payload...)); err != nil {
|
|
response.Error = err.Error()
|
|
} else {
|
|
response.Payload = append(json.RawMessage(nil), payload...)
|
|
}
|
|
payload, err := json.Marshal(response)
|
|
if err != nil {
|
|
return true
|
|
}
|
|
_ = fabricproto.WriteFrame(stream, fabricproto.Frame{
|
|
Type: fabricproto.FrameData,
|
|
TrafficClass: fabricproto.TrafficClassReliable,
|
|
StreamID: WebIngressForwardQUICStreamID,
|
|
Sequence: frame.Sequence,
|
|
Payload: payload,
|
|
})
|
|
return true
|
|
}
|
|
|
|
func (s *QUICFabricServer) handleFabricControlForwardFrame(ctx context.Context, stream *quic.Stream, frame fabricproto.Frame) bool {
|
|
if frame.Type != fabricproto.FrameData || frame.StreamID != FabricControlForwardQUICStreamID {
|
|
return false
|
|
}
|
|
response := quicFabricControlForwardResponse{}
|
|
if s == nil || s.fabricControlHandler == nil {
|
|
response.Error = ErrForwardRuntimeUnavailable.Error()
|
|
} else if payload, err := s.fabricControlHandler(ctx, append([]byte(nil), frame.Payload...)); err != nil {
|
|
response.Error = err.Error()
|
|
} else {
|
|
response.Payload = append(json.RawMessage(nil), payload...)
|
|
}
|
|
payload, err := json.Marshal(response)
|
|
if err != nil {
|
|
return true
|
|
}
|
|
_ = fabricproto.WriteFrame(stream, fabricproto.Frame{
|
|
Type: fabricproto.FrameData,
|
|
TrafficClass: fabricproto.TrafficClassReliable,
|
|
StreamID: FabricControlForwardQUICStreamID,
|
|
Sequence: frame.Sequence,
|
|
Payload: payload,
|
|
})
|
|
return true
|
|
}
|
|
|
|
func (s *QUICFabricServer) handleSyntheticForwardFrame(ctx context.Context, conn *quic.Conn, stream *quic.Stream, frame fabricproto.Frame) bool {
|
|
if frame.Type != fabricproto.FrameData || frame.StreamID != SyntheticForwardQUICStreamID {
|
|
return false
|
|
}
|
|
response := quicSyntheticForwardResponse{}
|
|
if s == nil || s.syntheticForwardHandler == nil {
|
|
response.Error = ErrMeshRuntimeDisabled.Error()
|
|
} else {
|
|
var envelope SyntheticEnvelope
|
|
if err := json.Unmarshal(frame.Payload, &envelope); err != nil {
|
|
response.Error = "invalid synthetic mesh envelope"
|
|
} else if ack, err := s.syntheticForwardHandler(ctx, envelope); err != nil {
|
|
response.Error = err.Error()
|
|
} else {
|
|
s.registerReversePeerConn(envelope.From.NodeID, conn)
|
|
response.Envelope = ack
|
|
}
|
|
}
|
|
payload, err := json.Marshal(response)
|
|
if err != nil {
|
|
return true
|
|
}
|
|
_ = fabricproto.WriteFrame(stream, fabricproto.Frame{
|
|
Type: fabricproto.FrameData,
|
|
TrafficClass: fabricproto.TrafficClassReliable,
|
|
StreamID: SyntheticForwardQUICStreamID,
|
|
Sequence: frame.Sequence,
|
|
Payload: payload,
|
|
})
|
|
return true
|
|
}
|
|
|
|
func (s *QUICFabricServer) registerReversePeerConn(peerID string, conn *quic.Conn) {
|
|
reverseTransport := s.getReverseTransport()
|
|
if s == nil || reverseTransport == nil || conn == nil {
|
|
return
|
|
}
|
|
reverseTransport.RegisterReverseConn(peerID, conn)
|
|
}
|
|
|
|
func (s *QUICFabricServer) getReverseTransport() *QUICFabricTransport {
|
|
if s == nil {
|
|
return nil
|
|
}
|
|
s.reverseMu.RLock()
|
|
defer s.reverseMu.RUnlock()
|
|
return s.reverseTransport
|
|
}
|
|
|
|
func (s *QUICFabricServer) logFabricSession(entry FabricSessionEventLogEntry) {
|
|
if s != nil && s.logger != nil {
|
|
s.logger(entry)
|
|
}
|
|
}
|