Files
rdp-proxy/agents/rap-node-agent/internal/fabricproto/websocket_test.go
T

152 lines
4.4 KiB
Go

package fabricproto
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gorilla/websocket"
)
func TestRunWebSocketHandlesBinaryFrames(t *testing.T) {
events := make(chan SessionEventType, 4)
server := newFabricWebSocketTestServer(t, TransportLoop{
Session: NewSession(SessionConfig{}),
OnEvent: func(event SessionEvent) ([]Frame, error) {
events <- event.Type
return nil, nil
},
})
defer server.Close()
conn := dialFabricWebSocket(t, server.URL)
defer conn.Close()
writeWebSocketFrame(t, conn, Frame{Type: FrameOpenStream, TrafficClass: TrafficClassInteractive, StreamID: 7})
writeWebSocketFrame(t, conn, Frame{Type: FrameData, TrafficClass: TrafficClassInteractive, StreamID: 7, Sequence: 9, Payload: []byte("hello")})
if got := <-events; got != SessionEventStreamOpened {
t.Fatalf("first event = %s, want stream_opened", got)
}
if got := <-events; got != SessionEventData {
t.Fatalf("second event = %s, want data", got)
}
ack := readWebSocketFrame(t, conn)
if ack.Type != FrameAck || ack.StreamID != 7 || ack.Sequence != 9 {
t.Fatalf("ack = %+v", ack)
}
}
func TestRunWebSocketPongAndHandlerResponse(t *testing.T) {
server := newFabricWebSocketTestServer(t, TransportLoop{
Session: NewSession(SessionConfig{}),
OnEvent: func(event SessionEvent) ([]Frame, error) {
if event.Type != SessionEventPing {
return nil, nil
}
return []Frame{{Type: FrameSessionReady, Payload: []byte("ready")}}, nil
},
})
defer server.Close()
conn := dialFabricWebSocket(t, server.URL)
defer conn.Close()
writeWebSocketFrame(t, conn, Frame{Type: FramePing, Sequence: 4, Payload: []byte("probe")})
pong := readWebSocketFrame(t, conn)
if pong.Type != FramePong || pong.Sequence != 4 || string(pong.Payload) != "probe" {
t.Fatalf("pong = %+v", pong)
}
ready := readWebSocketFrame(t, conn)
if ready.Type != FrameSessionReady || string(ready.Payload) != "ready" {
t.Fatalf("ready = %+v", ready)
}
}
func TestRunWebSocketIgnoresTextMessages(t *testing.T) {
server := newFabricWebSocketTestServer(t, TransportLoop{Session: NewSession(SessionConfig{})})
defer server.Close()
conn := dialFabricWebSocket(t, server.URL)
defer conn.Close()
if err := conn.WriteMessage(websocket.TextMessage, []byte("ignore me")); err != nil {
t.Fatalf("write text: %v", err)
}
writeWebSocketFrame(t, conn, Frame{Type: FramePing, Sequence: 1})
pong := readWebSocketFrame(t, conn)
if pong.Type != FramePong || pong.Sequence != 1 {
t.Fatalf("pong = %+v", pong)
}
}
func newFabricWebSocketTestServer(t *testing.T, loop TransportLoop) *httptest.Server {
t.Helper()
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
_ = loop.RunWebSocket(r.Context(), conn, WebSocketTransportConfig{WriteTimeout: time.Second})
}))
return server
}
func dialFabricWebSocket(t *testing.T, httpURL string) *websocket.Conn {
t.Helper()
wsURL := "ws" + httpURL[len("http"):]
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
return conn
}
func writeWebSocketFrame(t *testing.T, conn *websocket.Conn, frame Frame) {
t.Helper()
encoded, err := MarshalFrame(frame)
if err != nil {
t.Fatalf("marshal frame: %v", err)
}
if err := conn.WriteMessage(websocket.BinaryMessage, encoded); err != nil {
t.Fatalf("write websocket frame: %v", err)
}
}
func readWebSocketFrame(t *testing.T, conn *websocket.Conn) Frame {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
type result struct {
messageType int
payload []byte
err error
}
ch := make(chan result, 1)
go func() {
messageType, payload, err := conn.ReadMessage()
ch <- result{messageType: messageType, payload: payload, err: err}
}()
select {
case <-ctx.Done():
t.Fatal("timed out waiting for websocket frame")
case got := <-ch:
if got.err != nil {
t.Fatalf("read websocket frame: %v", got.err)
}
if got.messageType != websocket.BinaryMessage {
t.Fatalf("message type = %d, want binary", got.messageType)
}
frame, err := UnmarshalFrame(got.payload, DefaultMaxPayload)
if err != nil {
t.Fatalf("unmarshal websocket frame: %v", err)
}
return frame
}
return Frame{}
}