Files
rdp-proxy/agents/rap-node-agent/internal/vpnruntime/fabric_session_transport.go
T

155 lines
3.8 KiB
Go

package vpnruntime
import (
"context"
"errors"
"sync/atomic"
"time"
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/fabricproto"
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/mesh"
)
type FabricSessionFrameSender interface {
Send(context.Context, fabricproto.Frame) error
}
type FabricSessionFrameReceiver interface {
Frames() <-chan fabricproto.Frame
Errors() <-chan error
}
type FabricSessionPacketTransport struct {
Sender FabricSessionFrameSender
Receiver FabricSessionFrameReceiver
Inbox *FabricPacketInbox
StreamID uint64
VPNConnectionID string
SendDirection string
ReceiveDirection string
TrafficClass string
sequence uint64
}
func (t *FabricSessionPacketTransport) SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error {
packets = cleanPacketBatch(packets)
if len(packets) == 0 {
return nil
}
if t == nil || t.Sender == nil {
return mesh.ErrForwardRuntimeUnavailable
}
if t.StreamID == 0 || t.VPNConnectionID == "" {
return errors.New("fabric session packet transport identity is incomplete")
}
direction := t.SendDirection
if direction == "" {
direction = FabricDirectionGatewayToClient
}
frame, err := NewFabricVPNPacketDataFrame(FabricVPNPacketFrameInput{
StreamID: t.StreamID,
Sequence: atomic.AddUint64(&t.sequence, 1),
VPNConnectionID: t.VPNConnectionID,
Direction: direction,
TrafficClass: t.TrafficClass,
Packets: packets,
})
if err != nil {
return err
}
return t.Sender.Send(ctx, frame)
}
func (t *FabricSessionPacketTransport) ReceiveGatewayPacketBatch(ctx context.Context, timeout time.Duration) ([][]byte, error) {
if t == nil || t.Inbox == nil {
return nil, mesh.ErrForwardRuntimeUnavailable
}
direction := t.ReceiveDirection
if direction == "" {
direction = FabricDirectionClientToGateway
}
if packets, err := t.Inbox.Receive(ctx, t.VPNConnectionID, direction, 5*time.Millisecond); err != nil || len(packets) > 0 {
return packets, err
}
if t.Receiver == nil {
return t.Inbox.Receive(ctx, t.VPNConnectionID, direction, timeout)
}
if timeout <= 0 {
timeout = 25 * time.Second
}
timer := time.NewTimer(timeout)
defer timer.Stop()
frames := t.Receiver.Frames()
errorsCh := t.Receiver.Errors()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-timer.C:
return nil, nil
case err, ok := <-errorsCh:
if !ok {
errorsCh = nil
continue
}
if err != nil {
return nil, err
}
case frame, ok := <-frames:
if !ok {
return t.Inbox.Receive(ctx, t.VPNConnectionID, direction, 5*time.Millisecond)
}
if frame.Type != fabricproto.FrameData || (t.StreamID != 0 && frame.StreamID != t.StreamID) {
continue
}
payload, err := DecodeFabricVPNPacketDataFrame(frame)
if err != nil {
return nil, err
}
if payload.VPNConnectionID == t.VPNConnectionID && payload.Direction == direction {
return cleanPacketBatch(payload.Packets), nil
}
if err := t.Inbox.DeliverFabricSessionFrame(ctx, frame); err != nil {
return nil, err
}
}
}
}
func (t *FabricSessionPacketTransport) RunFrameIngress(ctx context.Context) error {
if t == nil || t.Receiver == nil || t.Inbox == nil {
return mesh.ErrForwardRuntimeUnavailable
}
frames := t.Receiver.Frames()
errorsCh := t.Receiver.Errors()
for {
select {
case <-ctx.Done():
return ctx.Err()
case err, ok := <-errorsCh:
if !ok {
errorsCh = nil
continue
}
if err != nil {
return err
}
case frame, ok := <-frames:
if !ok {
return nil
}
if frame.Type != fabricproto.FrameData {
continue
}
if t.StreamID != 0 && frame.StreamID != t.StreamID {
continue
}
if err := t.Inbox.DeliverFabricSessionFrame(ctx, frame); err != nil {
return err
}
}
}
}