Refactor RDP proxy handling and update related tests
This commit is contained in:
@@ -3,28 +3,50 @@ package mesh
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/fabricproto"
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
type QUICFabricServer struct {
|
||||
listener *quic.Listener
|
||||
logger FabricSessionEventLogger
|
||||
done chan struct{}
|
||||
closeOnce sync.Once
|
||||
listener *quic.Listener
|
||||
logger FabricSessionEventLogger
|
||||
reverseMu sync.RWMutex
|
||||
reverseTransport *QUICFabricTransport
|
||||
fabricFrameHandler FabricFrameHandler
|
||||
productionForwardHandler func(context.Context, ProductionEnvelope) (ProductionForwardResult, error)
|
||||
webIngressForwardHandler func(context.Context, []byte) ([]byte, error)
|
||||
fabricControlHandler func(context.Context, []byte) ([]byte, error)
|
||||
syntheticForwardHandler func(context.Context, SyntheticEnvelope) (SyntheticEnvelope, error)
|
||||
done chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
type QUICFabricServerConfig struct {
|
||||
ListenAddr string
|
||||
TLSConfig *tls.Config
|
||||
QUICConfig *quic.Config
|
||||
Logger FabricSessionEventLogger
|
||||
ListenAddr string
|
||||
TLSConfig *tls.Config
|
||||
QUICConfig *quic.Config
|
||||
Logger FabricSessionEventLogger
|
||||
ReverseTransport *QUICFabricTransport
|
||||
FabricFrameHandler FabricFrameHandler
|
||||
ProductionForwardHandler func(context.Context, ProductionEnvelope) (ProductionForwardResult, error)
|
||||
WebIngressForwardHandler func(context.Context, []byte) ([]byte, error)
|
||||
FabricControlHandler func(context.Context, []byte) ([]byte, error)
|
||||
SyntheticForwardHandler func(context.Context, SyntheticEnvelope) (SyntheticEnvelope, error)
|
||||
}
|
||||
|
||||
type FabricFrameSender interface {
|
||||
SendFrame(context.Context, fabricproto.Frame) error
|
||||
}
|
||||
|
||||
type FabricFrameHandler func(context.Context, FabricFrameSender, fabricproto.Frame) (bool, error)
|
||||
|
||||
func StartQUICFabricServer(ctx context.Context, cfg QUICFabricServerConfig) (*QUICFabricServer, error) {
|
||||
if cfg.ListenAddr == "" {
|
||||
return nil, fmt.Errorf("quic fabric listen addr is required")
|
||||
@@ -42,9 +64,15 @@ func StartQUICFabricServer(ctx context.Context, cfg QUICFabricServerConfig) (*QU
|
||||
return nil, err
|
||||
}
|
||||
server := &QUICFabricServer{
|
||||
listener: listener,
|
||||
logger: cfg.Logger,
|
||||
done: make(chan struct{}),
|
||||
listener: listener,
|
||||
logger: cfg.Logger,
|
||||
reverseTransport: cfg.ReverseTransport,
|
||||
fabricFrameHandler: cfg.FabricFrameHandler,
|
||||
productionForwardHandler: cfg.ProductionForwardHandler,
|
||||
webIngressForwardHandler: cfg.WebIngressForwardHandler,
|
||||
fabricControlHandler: cfg.FabricControlHandler,
|
||||
syntheticForwardHandler: cfg.SyntheticForwardHandler,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go server.acceptLoop(ctx)
|
||||
return server, nil
|
||||
@@ -57,6 +85,15 @@ func (s *QUICFabricServer) Addr() net.Addr {
|
||||
return s.listener.Addr()
|
||||
}
|
||||
|
||||
func (s *QUICFabricServer) SetReverseTransport(transport *QUICFabricTransport) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.reverseMu.Lock()
|
||||
s.reverseTransport = transport
|
||||
s.reverseMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *QUICFabricServer) Close() error {
|
||||
if s == nil {
|
||||
return nil
|
||||
@@ -95,6 +132,8 @@ func (s *QUICFabricServer) handleConn(ctx context.Context, conn *quic.Conn) {
|
||||
|
||||
func (s *QUICFabricServer) handleStream(ctx context.Context, conn *quic.Conn, stream *quic.Stream) {
|
||||
session := fabricproto.NewSession(fabricproto.SessionConfig{})
|
||||
sender := quicStreamFrameSender{stream: stream}
|
||||
defer func() { _ = stream.Close() }()
|
||||
s.logFabricSession(FabricSessionEventLogEntry{
|
||||
Event: "fabric_session_quic_stream_opened",
|
||||
AcceptedBy: "quic",
|
||||
@@ -116,6 +155,29 @@ func (s *QUICFabricServer) handleStream(ctx context.Context, conn *quic.Conn, st
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.registerReverseHelloFrame(conn, frame)
|
||||
if s.handleProductionForwardFrame(ctx, stream, frame) {
|
||||
continue
|
||||
}
|
||||
if s.handleWebIngressForwardFrame(ctx, stream, frame) {
|
||||
continue
|
||||
}
|
||||
if s.handleFabricControlForwardFrame(ctx, stream, frame) {
|
||||
continue
|
||||
}
|
||||
if s.handleSyntheticForwardFrame(ctx, conn, stream, frame) {
|
||||
continue
|
||||
}
|
||||
if s.fabricFrameHandler != nil {
|
||||
handled, err := s.fabricFrameHandler(ctx, sender, frame)
|
||||
if err != nil {
|
||||
_ = conn.CloseWithError(2, err.Error())
|
||||
return
|
||||
}
|
||||
if handled {
|
||||
continue
|
||||
}
|
||||
}
|
||||
event, responses, err := session.HandleFrame(frame)
|
||||
if err != nil {
|
||||
_ = conn.CloseWithError(2, err.Error())
|
||||
@@ -140,6 +202,196 @@ func (s *QUICFabricServer) handleStream(ctx context.Context, conn *quic.Conn, st
|
||||
}
|
||||
}
|
||||
|
||||
type quicStreamFrameSender struct {
|
||||
stream *quic.Stream
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (s quicStreamFrameSender) SendFrame(ctx context.Context, frame fabricproto.Frame) error {
|
||||
if s.stream == nil {
|
||||
return fmt.Errorf("quic fabric stream is closed")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
_ = s.stream.SetWriteDeadline(deadline)
|
||||
} else {
|
||||
_ = s.stream.SetWriteDeadline(time.Now().Add(30 * time.Second))
|
||||
}
|
||||
return fabricproto.WriteFrame(s.stream, frame)
|
||||
}
|
||||
|
||||
func (s *QUICFabricServer) registerReverseHelloFrame(conn *quic.Conn, frame fabricproto.Frame) {
|
||||
reverseTransport := s.getReverseTransport()
|
||||
if s == nil || reverseTransport == nil || conn == nil || frame.Type != fabricproto.FramePing {
|
||||
return
|
||||
}
|
||||
payload := string(frame.Payload)
|
||||
if !strings.HasPrefix(payload, fabricQUICReverseHelloPrefix) {
|
||||
return
|
||||
}
|
||||
peerID := strings.TrimPrefix(payload, fabricQUICReverseHelloPrefix)
|
||||
reverseTransport.RegisterReverseConn(peerID, conn)
|
||||
s.logFabricSession(FabricSessionEventLogEntry{
|
||||
Event: "fabric_session_quic_reverse_registered",
|
||||
AcceptedBy: "quic_reverse_hello",
|
||||
RemoteAddr: conn.RemoteAddr().String(),
|
||||
PeerID: peerID,
|
||||
})
|
||||
}
|
||||
|
||||
type quicProductionForwardResponse struct {
|
||||
Result ProductionForwardResult `json:"result,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type quicSyntheticForwardResponse struct {
|
||||
Envelope SyntheticEnvelope `json:"envelope,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type quicWebIngressForwardResponse struct {
|
||||
Payload json.RawMessage `json:"payload,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type quicFabricControlForwardResponse struct {
|
||||
Payload json.RawMessage `json:"payload,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (s *QUICFabricServer) handleProductionForwardFrame(ctx context.Context, stream *quic.Stream, frame fabricproto.Frame) bool {
|
||||
if frame.Type != fabricproto.FrameData || frame.StreamID != ProductionForwardQUICStreamID {
|
||||
return false
|
||||
}
|
||||
response := quicProductionForwardResponse{}
|
||||
if s == nil || s.productionForwardHandler == nil {
|
||||
response.Error = ErrForwardRuntimeUnavailable.Error()
|
||||
} else {
|
||||
var envelope ProductionEnvelope
|
||||
if err := json.Unmarshal(frame.Payload, &envelope); err != nil {
|
||||
response.Error = "invalid production mesh envelope"
|
||||
} else if result, err := s.productionForwardHandler(ctx, envelope); err != nil {
|
||||
response.Error = err.Error()
|
||||
} else {
|
||||
response.Result = result
|
||||
}
|
||||
}
|
||||
payload, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
_ = fabricproto.WriteFrame(stream, fabricproto.Frame{
|
||||
Type: fabricproto.FrameData,
|
||||
TrafficClass: fabricproto.TrafficClassReliable,
|
||||
StreamID: ProductionForwardQUICStreamID,
|
||||
Sequence: frame.Sequence,
|
||||
Payload: payload,
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *QUICFabricServer) handleWebIngressForwardFrame(ctx context.Context, stream *quic.Stream, frame fabricproto.Frame) bool {
|
||||
if frame.Type != fabricproto.FrameData || frame.StreamID != WebIngressForwardQUICStreamID {
|
||||
return false
|
||||
}
|
||||
response := quicWebIngressForwardResponse{}
|
||||
if s == nil || s.webIngressForwardHandler == nil {
|
||||
response.Error = ErrForwardRuntimeUnavailable.Error()
|
||||
} else if payload, err := s.webIngressForwardHandler(ctx, append([]byte(nil), frame.Payload...)); err != nil {
|
||||
response.Error = err.Error()
|
||||
} else {
|
||||
response.Payload = append(json.RawMessage(nil), payload...)
|
||||
}
|
||||
payload, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
_ = fabricproto.WriteFrame(stream, fabricproto.Frame{
|
||||
Type: fabricproto.FrameData,
|
||||
TrafficClass: fabricproto.TrafficClassReliable,
|
||||
StreamID: WebIngressForwardQUICStreamID,
|
||||
Sequence: frame.Sequence,
|
||||
Payload: payload,
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *QUICFabricServer) handleFabricControlForwardFrame(ctx context.Context, stream *quic.Stream, frame fabricproto.Frame) bool {
|
||||
if frame.Type != fabricproto.FrameData || frame.StreamID != FabricControlForwardQUICStreamID {
|
||||
return false
|
||||
}
|
||||
response := quicFabricControlForwardResponse{}
|
||||
if s == nil || s.fabricControlHandler == nil {
|
||||
response.Error = ErrForwardRuntimeUnavailable.Error()
|
||||
} else if payload, err := s.fabricControlHandler(ctx, append([]byte(nil), frame.Payload...)); err != nil {
|
||||
response.Error = err.Error()
|
||||
} else {
|
||||
response.Payload = append(json.RawMessage(nil), payload...)
|
||||
}
|
||||
payload, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
_ = fabricproto.WriteFrame(stream, fabricproto.Frame{
|
||||
Type: fabricproto.FrameData,
|
||||
TrafficClass: fabricproto.TrafficClassReliable,
|
||||
StreamID: FabricControlForwardQUICStreamID,
|
||||
Sequence: frame.Sequence,
|
||||
Payload: payload,
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *QUICFabricServer) handleSyntheticForwardFrame(ctx context.Context, conn *quic.Conn, stream *quic.Stream, frame fabricproto.Frame) bool {
|
||||
if frame.Type != fabricproto.FrameData || frame.StreamID != SyntheticForwardQUICStreamID {
|
||||
return false
|
||||
}
|
||||
response := quicSyntheticForwardResponse{}
|
||||
if s == nil || s.syntheticForwardHandler == nil {
|
||||
response.Error = ErrMeshRuntimeDisabled.Error()
|
||||
} else {
|
||||
var envelope SyntheticEnvelope
|
||||
if err := json.Unmarshal(frame.Payload, &envelope); err != nil {
|
||||
response.Error = "invalid synthetic mesh envelope"
|
||||
} else if ack, err := s.syntheticForwardHandler(ctx, envelope); err != nil {
|
||||
response.Error = err.Error()
|
||||
} else {
|
||||
s.registerReversePeerConn(envelope.From.NodeID, conn)
|
||||
response.Envelope = ack
|
||||
}
|
||||
}
|
||||
payload, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
_ = fabricproto.WriteFrame(stream, fabricproto.Frame{
|
||||
Type: fabricproto.FrameData,
|
||||
TrafficClass: fabricproto.TrafficClassReliable,
|
||||
StreamID: SyntheticForwardQUICStreamID,
|
||||
Sequence: frame.Sequence,
|
||||
Payload: payload,
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *QUICFabricServer) registerReversePeerConn(peerID string, conn *quic.Conn) {
|
||||
reverseTransport := s.getReverseTransport()
|
||||
if s == nil || reverseTransport == nil || conn == nil {
|
||||
return
|
||||
}
|
||||
reverseTransport.RegisterReverseConn(peerID, conn)
|
||||
}
|
||||
|
||||
func (s *QUICFabricServer) getReverseTransport() *QUICFabricTransport {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
s.reverseMu.RLock()
|
||||
defer s.reverseMu.RUnlock()
|
||||
return s.reverseTransport
|
||||
}
|
||||
|
||||
func (s *QUICFabricServer) logFabricSession(entry FabricSessionEventLogEntry) {
|
||||
if s != nil && s.logger != nil {
|
||||
s.logger(entry)
|
||||
|
||||
Reference in New Issue
Block a user