Refactor RDP proxy handling and update related tests
This commit is contained in:
@@ -0,0 +1,487 @@
|
||||
package mesh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/fabricproto"
|
||||
)
|
||||
|
||||
type FabricChannelRuntimeConfig struct {
|
||||
RouterConfig FabricChannelRouterConfig
|
||||
StreamID uint64
|
||||
TrafficClass fabricproto.TrafficClass
|
||||
Timeout time.Duration
|
||||
MaxPayload int
|
||||
RouteHealthTTL time.Duration
|
||||
}
|
||||
|
||||
type FabricChannelRuntime struct {
|
||||
Transport FabricTransport
|
||||
Router FabricChannelRouter
|
||||
Pressure *FabricRoutePressureTracker
|
||||
Health *FabricRouteHealthTracker
|
||||
Config FabricChannelRuntimeConfig
|
||||
}
|
||||
|
||||
type FabricChannelRuntimeResult struct {
|
||||
Channel FabricChannel
|
||||
BytesSent uint64
|
||||
BytesRecv uint64
|
||||
FramesSent uint64
|
||||
FramesRecv uint64
|
||||
AcksReceived uint64
|
||||
RouteEvents []FabricChannelRouteEvent
|
||||
RouteAttempts []string
|
||||
MigrationEvents int
|
||||
RoutePressure FabricRoutePressureSnapshot
|
||||
RouteHealth FabricRouteHealthSnapshot
|
||||
}
|
||||
|
||||
type FabricChannelRequestResponseResult struct {
|
||||
FabricChannelRuntimeResult
|
||||
ResponsePayload []byte
|
||||
}
|
||||
|
||||
func NewFabricChannelRuntime(transport FabricTransport, cfg FabricChannelRuntimeConfig) *FabricChannelRuntime {
|
||||
if cfg.StreamID == 0 {
|
||||
cfg.StreamID = 2
|
||||
}
|
||||
if cfg.TrafficClass == 0 {
|
||||
cfg.TrafficClass = fabricproto.TrafficClassBulk
|
||||
}
|
||||
if cfg.Timeout <= 0 {
|
||||
cfg.Timeout = 30 * time.Second
|
||||
}
|
||||
if cfg.MaxPayload <= 0 {
|
||||
cfg.MaxPayload = fabricproto.DefaultMaxPayload
|
||||
}
|
||||
return &FabricChannelRuntime{
|
||||
Transport: transport,
|
||||
Router: NewFabricChannelRouter(cfg.RouterConfig),
|
||||
Pressure: NewFabricRoutePressureTracker(),
|
||||
Health: NewFabricRouteHealthTracker(cfg.RouteHealthTTL),
|
||||
Config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *FabricChannelRuntime) SendReliable(ctx context.Context, spec FabricChannelSpec, routeSet FabricRouteSet, payloads [][]byte) (FabricChannelRuntimeResult, error) {
|
||||
if r == nil || r.Transport == nil {
|
||||
return FabricChannelRuntimeResult{}, ErrForwardRuntimeUnavailable
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
routeSet = r.routeSetForScheduling(routeSet)
|
||||
channel, event, err := r.Router.OpenChannel(spec, routeSet, now)
|
||||
if err != nil {
|
||||
return FabricChannelRuntimeResult{}, err
|
||||
}
|
||||
result := FabricChannelRuntimeResult{Channel: channel, RouteEvents: []FabricChannelRouteEvent{event}}
|
||||
sequence := uint64(0)
|
||||
index := 0
|
||||
for index < len(payloads) {
|
||||
routeSet = r.routeSetForScheduling(routeSet)
|
||||
route, ok := findFabricRoute(routeSet, channel.RouteID)
|
||||
if !ok {
|
||||
return result, ErrFabricRouteNotFound
|
||||
}
|
||||
result.RouteAttempts = append(result.RouteAttempts, route.RouteID)
|
||||
target, err := FabricTransportTargetForRoute(route)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
releaseRoute := r.acquireRoute(route.RouteID)
|
||||
session, err := r.Transport.Connect(ctx, target)
|
||||
if err != nil {
|
||||
releaseRoute()
|
||||
r.markRouteFailure(route.RouteID, err)
|
||||
updated, event, rerouteErr := r.Router.ObserveChannel(channel, routeSet, FabricChannelObservation{
|
||||
ChannelID: spec.ChannelID,
|
||||
RouteID: route.RouteID,
|
||||
Failed: true,
|
||||
Reason: "connect_failed",
|
||||
ObservedAt: time.Now().UTC(),
|
||||
}, time.Now().UTC())
|
||||
channel = updated
|
||||
result.Channel = channel
|
||||
if event.Type == FabricChannelRouteEventReroute {
|
||||
result.RouteEvents = append(result.RouteEvents, event)
|
||||
result.MigrationEvents++
|
||||
continue
|
||||
}
|
||||
if rerouteErr != nil {
|
||||
return result, rerouteErr
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
migrated, sendErr := r.sendOnSession(ctx, session, &channel, routeSet, route, payloads, &index, &sequence, &result)
|
||||
_ = session.Close()
|
||||
releaseRoute()
|
||||
result.Channel = channel
|
||||
if sendErr != nil {
|
||||
return result, sendErr
|
||||
}
|
||||
if !migrated {
|
||||
break
|
||||
}
|
||||
}
|
||||
result.Channel = channel
|
||||
result.RoutePressure = r.snapshotRoutePressure()
|
||||
result.RouteHealth = r.snapshotRouteHealth()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *FabricChannelRuntime) SendRequestResponse(ctx context.Context, spec FabricChannelSpec, routeSet FabricRouteSet, payload []byte) (FabricChannelRequestResponseResult, error) {
|
||||
if r == nil || r.Transport == nil {
|
||||
return FabricChannelRequestResponseResult{}, ErrForwardRuntimeUnavailable
|
||||
}
|
||||
if len(payload) > r.Config.MaxPayload {
|
||||
return FabricChannelRequestResponseResult{}, fmt.Errorf("%w: %d > %d", fabricproto.ErrInvalidPayloadLen, len(payload), r.Config.MaxPayload)
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
routeSet = r.routeSetForScheduling(routeSet)
|
||||
channel, event, err := r.Router.OpenChannel(spec, routeSet, now)
|
||||
if err != nil {
|
||||
return FabricChannelRequestResponseResult{}, err
|
||||
}
|
||||
result := FabricChannelRequestResponseResult{
|
||||
FabricChannelRuntimeResult: FabricChannelRuntimeResult{Channel: channel, RouteEvents: []FabricChannelRouteEvent{event}},
|
||||
}
|
||||
sequence := uint64(1)
|
||||
for {
|
||||
routeSet = r.routeSetForScheduling(routeSet)
|
||||
route, ok := findFabricRoute(routeSet, channel.RouteID)
|
||||
if !ok {
|
||||
return result, ErrFabricRouteNotFound
|
||||
}
|
||||
result.RouteAttempts = append(result.RouteAttempts, route.RouteID)
|
||||
target, err := FabricTransportTargetForRoute(route)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
releaseRoute := r.acquireRoute(route.RouteID)
|
||||
session, err := r.Transport.Connect(ctx, target)
|
||||
if err != nil {
|
||||
releaseRoute()
|
||||
r.markRouteFailure(route.RouteID, err)
|
||||
updated, routeEvent, rerouteErr := r.Router.ObserveChannel(channel, routeSet, FabricChannelObservation{
|
||||
ChannelID: spec.ChannelID,
|
||||
RouteID: route.RouteID,
|
||||
Failed: true,
|
||||
Reason: "connect_failed",
|
||||
ObservedAt: time.Now().UTC(),
|
||||
}, time.Now().UTC())
|
||||
channel = updated
|
||||
result.Channel = channel
|
||||
if routeEvent.Type == FabricChannelRouteEventReroute {
|
||||
result.RouteEvents = append(result.RouteEvents, routeEvent)
|
||||
result.MigrationEvents++
|
||||
continue
|
||||
}
|
||||
if rerouteErr != nil {
|
||||
return result, rerouteErr
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
response, ackMs, sendErr := r.sendRequestResponseOnSession(ctx, session, route.RouteID, spec.ChannelID, payload, sequence)
|
||||
_ = session.Close()
|
||||
releaseRoute()
|
||||
result.Channel = channel
|
||||
if sendErr == nil {
|
||||
r.markRouteSuccess(route.RouteID)
|
||||
result.BytesSent += uint64(len(payload))
|
||||
result.FramesSent++
|
||||
result.BytesRecv += uint64(len(response))
|
||||
result.FramesRecv++
|
||||
result.AcksReceived++
|
||||
updated, routeEvent, observeErr := r.Router.ObserveChannel(channel, routeSet, FabricChannelObservation{
|
||||
ChannelID: spec.ChannelID,
|
||||
RouteID: route.RouteID,
|
||||
AckLatencyMs: ackMs,
|
||||
BytesSent: uint64(len(payload)),
|
||||
FramesSent: 1,
|
||||
BytesRecv: uint64(len(response)),
|
||||
FramesRecv: 1,
|
||||
ObservedAt: time.Now().UTC(),
|
||||
}, time.Now().UTC())
|
||||
channel = updated
|
||||
result.Channel = channel
|
||||
if observeErr != nil {
|
||||
return result, observeErr
|
||||
}
|
||||
if routeEvent.Type == FabricChannelRouteEventReroute {
|
||||
result.RouteEvents = append(result.RouteEvents, routeEvent)
|
||||
result.MigrationEvents++
|
||||
}
|
||||
result.ResponsePayload = response
|
||||
result.RoutePressure = r.snapshotRoutePressure()
|
||||
result.RouteHealth = r.snapshotRouteHealth()
|
||||
return result, nil
|
||||
}
|
||||
r.markRouteFailure(route.RouteID, sendErr)
|
||||
updated, routeEvent, rerouteErr := r.Router.ObserveChannel(channel, routeSet, FabricChannelObservation{
|
||||
ChannelID: spec.ChannelID,
|
||||
RouteID: route.RouteID,
|
||||
Failed: true,
|
||||
Reason: "response_failed",
|
||||
ObservedAt: time.Now().UTC(),
|
||||
}, time.Now().UTC())
|
||||
channel = updated
|
||||
result.Channel = channel
|
||||
if routeEvent.Type == FabricChannelRouteEventReroute {
|
||||
result.RouteEvents = append(result.RouteEvents, routeEvent)
|
||||
result.MigrationEvents++
|
||||
continue
|
||||
}
|
||||
if rerouteErr != nil {
|
||||
return result, rerouteErr
|
||||
}
|
||||
return result, sendErr
|
||||
}
|
||||
}
|
||||
|
||||
func (r *FabricChannelRuntime) routeSetForScheduling(routeSet FabricRouteSet) FabricRouteSet {
|
||||
if r != nil && r.Health != nil {
|
||||
routeSet = r.Health.Apply(routeSet, time.Now().UTC())
|
||||
}
|
||||
return r.routeSetWithActiveChannels(routeSet)
|
||||
}
|
||||
|
||||
func (r *FabricChannelRuntime) routeSetWithActiveChannels(routeSet FabricRouteSet) FabricRouteSet {
|
||||
if r == nil || r.Pressure == nil {
|
||||
return routeSet
|
||||
}
|
||||
return r.Pressure.Apply(routeSet)
|
||||
}
|
||||
|
||||
func (r *FabricChannelRuntime) acquireRoute(routeID string) func() {
|
||||
if r == nil || r.Pressure == nil {
|
||||
return func() {}
|
||||
}
|
||||
return r.Pressure.Acquire(routeID)
|
||||
}
|
||||
|
||||
func (r *FabricChannelRuntime) snapshotRoutePressure() FabricRoutePressureSnapshot {
|
||||
if r == nil || r.Pressure == nil {
|
||||
return FabricRoutePressureSnapshot{}
|
||||
}
|
||||
return r.Pressure.SnapshotPressure()
|
||||
}
|
||||
|
||||
func (r *FabricChannelRuntime) snapshotRouteHealth() FabricRouteHealthSnapshot {
|
||||
if r == nil || r.Health == nil {
|
||||
return FabricRouteHealthSnapshot{}
|
||||
}
|
||||
return r.Health.Snapshot(time.Now().UTC())
|
||||
}
|
||||
|
||||
func (r *FabricChannelRuntime) markRouteFailure(routeID string, err error) {
|
||||
if r == nil || r.Health == nil || err == nil {
|
||||
return
|
||||
}
|
||||
r.Health.MarkFailure(routeID, err.Error(), time.Now().UTC())
|
||||
}
|
||||
|
||||
func (r *FabricChannelRuntime) markRouteSuccess(routeID string) {
|
||||
if r == nil || r.Health == nil {
|
||||
return
|
||||
}
|
||||
r.Health.MarkSuccess(routeID)
|
||||
}
|
||||
|
||||
func (r *FabricChannelRuntime) sendOnSession(ctx context.Context, session FabricTransportSession, channel *FabricChannel, routeSet FabricRouteSet, route FabricRoute, payloads [][]byte, index *int, sequence *uint64, result *FabricChannelRuntimeResult) (bool, error) {
|
||||
cfg := r.Config
|
||||
if err := session.Send(ctx, fabricproto.Frame{
|
||||
Type: fabricproto.FrameOpenStream,
|
||||
TrafficClass: cfg.TrafficClass,
|
||||
StreamID: cfg.StreamID,
|
||||
}); err != nil {
|
||||
r.markRouteFailure(route.RouteID, err)
|
||||
return false, err
|
||||
}
|
||||
for *index < len(payloads) {
|
||||
payload := payloads[*index]
|
||||
if len(payload) > cfg.MaxPayload {
|
||||
return false, fmt.Errorf("%w: %d > %d", fabricproto.ErrInvalidPayloadLen, len(payload), cfg.MaxPayload)
|
||||
}
|
||||
(*sequence)++
|
||||
if err := session.Send(ctx, fabricproto.Frame{
|
||||
Type: fabricproto.FrameData,
|
||||
TrafficClass: cfg.TrafficClass,
|
||||
StreamID: cfg.StreamID,
|
||||
Sequence: *sequence,
|
||||
Payload: payload,
|
||||
}); err != nil {
|
||||
r.markRouteFailure(route.RouteID, err)
|
||||
return false, err
|
||||
}
|
||||
ackOK, ackMs := waitForFabricRuntimeAck(ctx, session, cfg.StreamID, *sequence, cfg.Timeout)
|
||||
if !ackOK {
|
||||
r.markRouteFailure(route.RouteID, fmt.Errorf("ack_failed"))
|
||||
updated, event, err := r.Router.ObserveChannel(*channel, routeSet, FabricChannelObservation{
|
||||
ChannelID: channel.Spec.ChannelID,
|
||||
RouteID: route.RouteID,
|
||||
Failed: true,
|
||||
Reason: "ack_failed",
|
||||
ObservedAt: time.Now().UTC(),
|
||||
}, time.Now().UTC())
|
||||
*channel = updated
|
||||
if event.Type == FabricChannelRouteEventReroute {
|
||||
result.RouteEvents = append(result.RouteEvents, event)
|
||||
result.MigrationEvents++
|
||||
return true, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
r.markRouteSuccess(route.RouteID)
|
||||
*index++
|
||||
result.BytesSent += uint64(len(payload))
|
||||
result.FramesSent++
|
||||
result.AcksReceived++
|
||||
updated, event, err := r.Router.ObserveChannel(*channel, routeSet, FabricChannelObservation{
|
||||
ChannelID: channel.Spec.ChannelID,
|
||||
RouteID: route.RouteID,
|
||||
AckLatencyMs: ackMs,
|
||||
BytesSent: uint64(len(payload)),
|
||||
FramesSent: 1,
|
||||
ObservedAt: time.Now().UTC(),
|
||||
}, time.Now().UTC())
|
||||
*channel = updated
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if event.Type == FabricChannelRouteEventReroute {
|
||||
result.RouteEvents = append(result.RouteEvents, event)
|
||||
result.MigrationEvents++
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
_ = session.Send(context.Background(), fabricproto.Frame{
|
||||
Type: fabricproto.FrameCloseStream,
|
||||
TrafficClass: cfg.TrafficClass,
|
||||
StreamID: cfg.StreamID,
|
||||
})
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *FabricChannelRuntime) sendRequestResponseOnSession(ctx context.Context, session FabricTransportSession, routeID string, channelID string, payload []byte, sequence uint64) ([]byte, int64, error) {
|
||||
cfg := r.Config
|
||||
if err := session.Send(ctx, fabricproto.Frame{
|
||||
Type: fabricproto.FrameOpenStream,
|
||||
TrafficClass: cfg.TrafficClass,
|
||||
StreamID: cfg.StreamID,
|
||||
}); err != nil {
|
||||
r.markRouteFailure(routeID, err)
|
||||
return nil, 0, err
|
||||
}
|
||||
started := time.Now()
|
||||
if err := session.Send(ctx, fabricproto.Frame{
|
||||
Type: fabricproto.FrameData,
|
||||
TrafficClass: cfg.TrafficClass,
|
||||
StreamID: cfg.StreamID,
|
||||
Sequence: sequence,
|
||||
Payload: payload,
|
||||
}); err != nil {
|
||||
r.markRouteFailure(routeID, err)
|
||||
return nil, 0, err
|
||||
}
|
||||
waitCtx := ctx
|
||||
if cfg.Timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
waitCtx, cancel = context.WithTimeout(ctx, cfg.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-waitCtx.Done():
|
||||
return nil, 0, waitCtx.Err()
|
||||
case err, ok := <-session.Errors():
|
||||
if !ok {
|
||||
return nil, 0, ErrForwardPeerUnavailable
|
||||
}
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
case frame, ok := <-session.Frames():
|
||||
if !ok {
|
||||
return nil, 0, ErrForwardPeerUnavailable
|
||||
}
|
||||
if frame.Type != fabricproto.FrameData || frame.StreamID != cfg.StreamID || frame.Sequence != sequence {
|
||||
continue
|
||||
}
|
||||
_ = session.Send(context.Background(), fabricproto.Frame{
|
||||
Type: fabricproto.FrameCloseStream,
|
||||
TrafficClass: cfg.TrafficClass,
|
||||
StreamID: cfg.StreamID,
|
||||
})
|
||||
return append([]byte(nil), frame.Payload...), time.Since(started).Milliseconds(), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func FabricTransportTargetForRoute(route FabricRoute) (FabricTransportTarget, error) {
|
||||
if strings.TrimSpace(route.RouteID) == "" {
|
||||
return FabricTransportTarget{}, ErrFabricRouteNotFound
|
||||
}
|
||||
if route.RelayCount > 0 {
|
||||
for _, hop := range route.Hops {
|
||||
if hop.Mode != FabricRouteRelay {
|
||||
continue
|
||||
}
|
||||
if target, ok := fabricTransportTargetForHop(hop); ok {
|
||||
return target, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
for i := len(route.Hops) - 1; i >= 0; i-- {
|
||||
if target, ok := fabricTransportTargetForHop(route.Hops[i]); ok {
|
||||
return target, nil
|
||||
}
|
||||
}
|
||||
return FabricTransportTarget{}, fmt.Errorf("%w: route %s has no transport endpoint", ErrFabricRouteNotFound, route.RouteID)
|
||||
}
|
||||
|
||||
func fabricTransportTargetForHop(hop FabricRouteHop) (FabricTransportTarget, bool) {
|
||||
endpoint := strings.TrimSpace(hop.Address)
|
||||
if endpoint == "" {
|
||||
return FabricTransportTarget{}, false
|
||||
}
|
||||
transport := string(hop.Mode)
|
||||
if transport == "" {
|
||||
transport = "quic"
|
||||
}
|
||||
return FabricTransportTarget{
|
||||
EndpointID: hop.EndpointID,
|
||||
PeerID: strings.TrimSpace(hop.NodeID),
|
||||
Endpoint: endpoint,
|
||||
Transport: transport,
|
||||
PeerCertSHA256: strings.TrimSpace(hop.PeerCertSHA256),
|
||||
}, true
|
||||
}
|
||||
|
||||
func waitForFabricRuntimeAck(ctx context.Context, session FabricTransportSession, streamID uint64, sequence uint64, timeout time.Duration) (bool, int64) {
|
||||
started := time.Now()
|
||||
if timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false, 0
|
||||
case err, ok := <-session.Errors():
|
||||
if !ok || err != nil {
|
||||
return false, 0
|
||||
}
|
||||
case frame, ok := <-session.Frames():
|
||||
if !ok {
|
||||
return false, 0
|
||||
}
|
||||
if frame.Type == fabricproto.FrameAck && frame.StreamID == streamID && frame.Sequence == sequence {
|
||||
return true, time.Since(started).Milliseconds()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user