95 lines
2.3 KiB
Go
95 lines
2.3 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"net/url"
|
|
"os"
|
|
"os/signal"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
func main() {
|
|
var (
|
|
gatewayURL = flag.String("gateway-url", "ws://127.0.0.1:8080/api/v1/gateway/ws", "websocket gateway url")
|
|
attachToken = flag.String("attach-token", "", "short-lived attach token")
|
|
duration = flag.Duration("duration", 90*time.Second, "maximum time to keep the client attached")
|
|
heartbeatInterval = flag.Duration("heartbeat-interval", 10*time.Second, "client heartbeat interval")
|
|
exitOnTakenOver = flag.Bool("exit-on-taken-over", true, "exit successfully after receiving session.taken_over")
|
|
)
|
|
flag.Parse()
|
|
|
|
if *attachToken == "" {
|
|
log.Fatal("attach-token is required")
|
|
}
|
|
|
|
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
|
defer cancel()
|
|
if *duration > 0 {
|
|
var timeoutCancel context.CancelFunc
|
|
ctx, timeoutCancel = context.WithTimeout(ctx, *duration)
|
|
defer timeoutCancel()
|
|
}
|
|
|
|
targetURL, err := url.Parse(*gatewayURL)
|
|
if err != nil {
|
|
log.Fatalf("parse gateway url: %v", err)
|
|
}
|
|
query := targetURL.Query()
|
|
query.Set("attach_token", *attachToken)
|
|
targetURL.RawQuery = query.Encode()
|
|
|
|
conn, _, err := websocket.DefaultDialer.DialContext(ctx, targetURL.String(), nil)
|
|
if err != nil {
|
|
log.Fatalf("dial websocket: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
for {
|
|
_, payload, err := conn.ReadMessage()
|
|
if err != nil {
|
|
done <- err
|
|
return
|
|
}
|
|
fmt.Println(string(payload))
|
|
if *exitOnTakenOver && containsTakenOver(payload) {
|
|
done <- nil
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
heartbeatTicker := time.NewTicker(*heartbeatInterval)
|
|
defer heartbeatTicker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case err := <-done:
|
|
if err != nil {
|
|
log.Fatalf("websocket read: %v", err)
|
|
}
|
|
return
|
|
case <-heartbeatTicker.C:
|
|
if err := conn.WriteJSON(map[string]any{"type": "heartbeat"}); err != nil {
|
|
log.Fatalf("heartbeat write: %v", err)
|
|
}
|
|
case <-ctx.Done():
|
|
if err := ctx.Err(); err != nil && err != context.Canceled && err != context.DeadlineExceeded {
|
|
log.Fatalf("context ended: %v", err)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func containsTakenOver(payload []byte) bool {
|
|
return strings.Contains(string(payload), `"type":"session.taken_over"`)
|
|
}
|