package fabricproto import ( "context" "net/http" "net/http/httptest" "testing" "time" "github.com/gorilla/websocket" ) func TestRunWebSocketHandlesBinaryFrames(t *testing.T) { events := make(chan SessionEventType, 4) server := newFabricWebSocketTestServer(t, TransportLoop{ Session: NewSession(SessionConfig{}), OnEvent: func(event SessionEvent) ([]Frame, error) { events <- event.Type return nil, nil }, }) defer server.Close() conn := dialFabricWebSocket(t, server.URL) defer conn.Close() writeWebSocketFrame(t, conn, Frame{Type: FrameOpenStream, TrafficClass: TrafficClassInteractive, StreamID: 7}) writeWebSocketFrame(t, conn, Frame{Type: FrameData, TrafficClass: TrafficClassInteractive, StreamID: 7, Sequence: 9, Payload: []byte("hello")}) if got := <-events; got != SessionEventStreamOpened { t.Fatalf("first event = %s, want stream_opened", got) } if got := <-events; got != SessionEventData { t.Fatalf("second event = %s, want data", got) } ack := readWebSocketFrame(t, conn) if ack.Type != FrameAck || ack.StreamID != 7 || ack.Sequence != 9 { t.Fatalf("ack = %+v", ack) } } func TestRunWebSocketPongAndHandlerResponse(t *testing.T) { server := newFabricWebSocketTestServer(t, TransportLoop{ Session: NewSession(SessionConfig{}), OnEvent: func(event SessionEvent) ([]Frame, error) { if event.Type != SessionEventPing { return nil, nil } return []Frame{{Type: FrameSessionReady, Payload: []byte("ready")}}, nil }, }) defer server.Close() conn := dialFabricWebSocket(t, server.URL) defer conn.Close() writeWebSocketFrame(t, conn, Frame{Type: FramePing, Sequence: 4, Payload: []byte("probe")}) pong := readWebSocketFrame(t, conn) if pong.Type != FramePong || pong.Sequence != 4 || string(pong.Payload) != "probe" { t.Fatalf("pong = %+v", pong) } ready := readWebSocketFrame(t, conn) if ready.Type != FrameSessionReady || string(ready.Payload) != "ready" { t.Fatalf("ready = %+v", ready) } } func TestRunWebSocketIgnoresTextMessages(t *testing.T) { server := newFabricWebSocketTestServer(t, TransportLoop{Session: NewSession(SessionConfig{})}) defer server.Close() conn := dialFabricWebSocket(t, server.URL) defer conn.Close() if err := conn.WriteMessage(websocket.TextMessage, []byte("ignore me")); err != nil { t.Fatalf("write text: %v", err) } writeWebSocketFrame(t, conn, Frame{Type: FramePing, Sequence: 1}) pong := readWebSocketFrame(t, conn) if pong.Type != FramePong || pong.Sequence != 1 { t.Fatalf("pong = %+v", pong) } } func newFabricWebSocketTestServer(t *testing.T, loop TransportLoop) *httptest.Server { t.Helper() upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { return } defer conn.Close() _ = loop.RunWebSocket(r.Context(), conn, WebSocketTransportConfig{WriteTimeout: time.Second}) })) return server } func dialFabricWebSocket(t *testing.T, httpURL string) *websocket.Conn { t.Helper() wsURL := "ws" + httpURL[len("http"):] conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) if err != nil { t.Fatalf("dial websocket: %v", err) } return conn } func writeWebSocketFrame(t *testing.T, conn *websocket.Conn, frame Frame) { t.Helper() encoded, err := MarshalFrame(frame) if err != nil { t.Fatalf("marshal frame: %v", err) } if err := conn.WriteMessage(websocket.BinaryMessage, encoded); err != nil { t.Fatalf("write websocket frame: %v", err) } } func readWebSocketFrame(t *testing.T, conn *websocket.Conn) Frame { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() type result struct { messageType int payload []byte err error } ch := make(chan result, 1) go func() { messageType, payload, err := conn.ReadMessage() ch <- result{messageType: messageType, payload: payload, err: err} }() select { case <-ctx.Done(): t.Fatal("timed out waiting for websocket frame") case got := <-ch: if got.err != nil { t.Fatalf("read websocket frame: %v", got.err) } if got.messageType != websocket.BinaryMessage { t.Fatalf("message type = %d, want binary", got.messageType) } frame, err := UnmarshalFrame(got.payload, DefaultMaxPayload) if err != nil { t.Fatalf("unmarshal websocket frame: %v", err) } return frame } return Frame{} }