Use gated fabric sessions for VPN transport
This commit is contained in:
@@ -70,7 +70,52 @@ func (t *FabricSessionPacketTransport) ReceiveGatewayPacketBatch(ctx context.Con
|
||||
if direction == "" {
|
||||
direction = FabricDirectionClientToGateway
|
||||
}
|
||||
return t.Inbox.Receive(ctx, t.VPNConnectionID, direction, timeout)
|
||||
if packets, err := t.Inbox.Receive(ctx, t.VPNConnectionID, direction, 5*time.Millisecond); err != nil || len(packets) > 0 {
|
||||
return packets, err
|
||||
}
|
||||
if t.Receiver == nil {
|
||||
return t.Inbox.Receive(ctx, t.VPNConnectionID, direction, timeout)
|
||||
}
|
||||
if timeout <= 0 {
|
||||
timeout = 25 * time.Second
|
||||
}
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
frames := t.Receiver.Frames()
|
||||
errorsCh := t.Receiver.Errors()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil, nil
|
||||
case err, ok := <-errorsCh:
|
||||
if !ok {
|
||||
errorsCh = nil
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case frame, ok := <-frames:
|
||||
if !ok {
|
||||
return t.Inbox.Receive(ctx, t.VPNConnectionID, direction, 5*time.Millisecond)
|
||||
}
|
||||
if frame.Type != fabricproto.FrameData || (t.StreamID != 0 && frame.StreamID != t.StreamID) {
|
||||
continue
|
||||
}
|
||||
payload, err := DecodeFabricVPNPacketDataFrame(frame)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if payload.VPNConnectionID == t.VPNConnectionID && payload.Direction == direction {
|
||||
return cleanPacketBatch(payload.Packets), nil
|
||||
}
|
||||
if err := t.Inbox.DeliverFabricSessionFrame(ctx, frame); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *FabricSessionPacketTransport) RunFrameIngress(ctx context.Context) error {
|
||||
|
||||
@@ -265,6 +265,40 @@ func TestFabricSessionPacketTransportRunFrameIngressDeliversInbox(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFabricSessionPacketTransportReceiveReadsPumpFrames(t *testing.T) {
|
||||
inbox := NewFabricPacketInbox(4)
|
||||
receiver := memoryFabricSessionReceiver{
|
||||
frames: make(chan fabricproto.Frame, 1),
|
||||
errors: make(chan error, 1),
|
||||
}
|
||||
transport := &FabricSessionPacketTransport{
|
||||
Receiver: receiver,
|
||||
Inbox: inbox,
|
||||
StreamID: 711,
|
||||
VPNConnectionID: "vpn-1",
|
||||
ReceiveDirection: FabricDirectionClientToGateway,
|
||||
}
|
||||
frame, err := NewFabricVPNPacketDataFrame(FabricVPNPacketFrameInput{
|
||||
StreamID: 711,
|
||||
Sequence: 1,
|
||||
VPNConnectionID: "vpn-1",
|
||||
Direction: FabricDirectionClientToGateway,
|
||||
Packets: [][]byte{[]byte("request")},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("new fabric vpn frame: %v", err)
|
||||
}
|
||||
receiver.frames <- frame
|
||||
|
||||
packets, err := transport.ReceiveGatewayPacketBatch(context.Background(), time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("receive gateway packet: %v", err)
|
||||
}
|
||||
if len(packets) != 1 || string(packets[0]) != "request" {
|
||||
t.Fatalf("packets = %#v", packets)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFabricSessionPacketTransportIngressIgnoresOtherStreams(t *testing.T) {
|
||||
inbox := NewFabricPacketInbox(4)
|
||||
receiver := memoryFabricSessionReceiver{
|
||||
|
||||
@@ -193,6 +193,8 @@ func (g *Gateway) Snapshot() map[string]any {
|
||||
|
||||
func (g *Gateway) transportName() string {
|
||||
switch g.Transport.(type) {
|
||||
case *FabricSessionPacketTransport:
|
||||
return "fabric_session"
|
||||
case *FabricPacketTransport:
|
||||
return "fabric_mesh"
|
||||
case *LocalPacketTransport:
|
||||
|
||||
Reference in New Issue
Block a user