Files
rdp-proxy/agents/rap-node-agent/internal/mesh/client.go
T
2026-05-16 00:40:59 +03:00

405 lines
10 KiB
Go

package mesh
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/example/remote-access-platform/agents/rap-node-agent/internal/fabricproto"
"github.com/gorilla/websocket"
)
type Client struct {
BaseURL string
HTTPClient *http.Client
}
type FabricSessionDialOptions struct {
Token string
Header http.Header
Dialer *websocket.Dialer
Timeout time.Duration
MaxPayload int
}
type FabricSessionClient struct {
conn *websocket.Conn
timeout time.Duration
maxPayload int
readMu sync.Mutex
writeMu sync.Mutex
}
type FabricSessionPumpOptions struct {
OutboundBuffer int
InboundBuffer int
ErrorBuffer int
}
type FabricSessionPump struct {
session *FabricSessionClient
outbound chan fabricproto.Frame
inbound chan fabricproto.Frame
errors chan error
done chan struct{}
cancel context.CancelFunc
closeMu sync.Once
}
func NewClient(baseURL string) Client {
return Client{
BaseURL: baseURL,
HTTPClient: &http.Client{
Timeout: 5 * time.Second,
},
}
}
func (c Client) SendHealth(ctx context.Context, message HealthMessage) (HealthAck, error) {
payload, err := json.Marshal(message)
if err != nil {
return HealthAck{}, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.BaseURL+"/mesh/v1/health", bytes.NewReader(payload))
if err != nil {
return HealthAck{}, err
}
req.Header.Set("Content-Type", "application/json")
httpClient := c.HTTPClient
if httpClient == nil {
httpClient = http.DefaultClient
}
resp, err := httpClient.Do(req)
if err != nil {
return HealthAck{}, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return HealthAck{}, fmt.Errorf("mesh health rejected with status %d", resp.StatusCode)
}
var ack HealthAck
if err := json.NewDecoder(resp.Body).Decode(&ack); err != nil {
return HealthAck{}, err
}
return ack, nil
}
func (c Client) SendSynthetic(ctx context.Context, envelope SyntheticEnvelope) (SyntheticEnvelope, error) {
payload, err := json.Marshal(envelope)
if err != nil {
return SyntheticEnvelope{}, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.BaseURL+"/mesh/v1/synthetic/probe", bytes.NewReader(payload))
if err != nil {
return SyntheticEnvelope{}, err
}
req.Header.Set("Content-Type", "application/json")
httpClient := c.HTTPClient
if httpClient == nil {
httpClient = http.DefaultClient
}
resp, err := httpClient.Do(req)
if err != nil {
return SyntheticEnvelope{}, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return SyntheticEnvelope{}, fmt.Errorf("mesh synthetic probe rejected with status %d", resp.StatusCode)
}
var ack SyntheticEnvelope
if err := json.NewDecoder(resp.Body).Decode(&ack); err != nil {
return SyntheticEnvelope{}, err
}
return ack, nil
}
func (c Client) SendProduction(ctx context.Context, envelope ProductionEnvelope) (ProductionForwardResult, error) {
payload, err := json.Marshal(envelope)
if err != nil {
return ProductionForwardResult{}, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.BaseURL+"/mesh/v1/forward", bytes.NewReader(payload))
if err != nil {
return ProductionForwardResult{}, err
}
req.Header.Set("Content-Type", "application/json")
httpClient := c.HTTPClient
if httpClient == nil {
httpClient = http.DefaultClient
}
resp, err := httpClient.Do(req)
if err != nil {
return ProductionForwardResult{}, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return ProductionForwardResult{}, fmt.Errorf("mesh production forward rejected with status %d", resp.StatusCode)
}
var result ProductionForwardResult
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return ProductionForwardResult{}, err
}
return result, nil
}
func (c Client) DialFabricSession(ctx context.Context, opts FabricSessionDialOptions) (*websocket.Conn, *http.Response, error) {
target, err := c.fabricSessionWebSocketURL()
if err != nil {
return nil, nil, err
}
header := cloneHeader(opts.Header)
if strings.TrimSpace(opts.Token) != "" {
header.Set("X-RAP-Fabric-Session-Token", strings.TrimSpace(opts.Token))
}
dialer := opts.Dialer
if dialer == nil {
base := *websocket.DefaultDialer
if opts.Timeout > 0 {
base.HandshakeTimeout = opts.Timeout
}
dialer = &base
}
return dialer.DialContext(ctx, target, header)
}
func (c Client) OpenFabricSession(ctx context.Context, opts FabricSessionDialOptions) (*FabricSessionClient, *http.Response, error) {
conn, resp, err := c.DialFabricSession(ctx, opts)
if err != nil {
if resp != nil {
return nil, resp, fmt.Errorf("fabric session websocket rejected with status %d: %w", resp.StatusCode, err)
}
return nil, resp, err
}
maxPayload := opts.MaxPayload
if maxPayload <= 0 {
maxPayload = fabricproto.DefaultMaxPayload
}
return &FabricSessionClient{
conn: conn,
timeout: opts.Timeout,
maxPayload: maxPayload,
}, resp, nil
}
func (c Client) SendFabricSessionFrame(ctx context.Context, opts FabricSessionDialOptions, frame fabricproto.Frame) (fabricproto.Frame, error) {
session, _, err := c.OpenFabricSession(ctx, opts)
if err != nil {
return fabricproto.Frame{}, err
}
defer session.Close()
return session.RoundTrip(ctx, frame)
}
func (c *FabricSessionClient) Close() error {
if c == nil || c.conn == nil {
return nil
}
return c.conn.Close()
}
func (c *FabricSessionClient) WriteFrame(ctx context.Context, frame fabricproto.Frame) error {
if c == nil || c.conn == nil {
return fmt.Errorf("fabric session client is closed")
}
payload, err := fabricproto.MarshalFrame(frame)
if err != nil {
return err
}
c.writeMu.Lock()
defer c.writeMu.Unlock()
c.applyWriteDeadline(ctx)
return c.conn.WriteMessage(websocket.BinaryMessage, payload)
}
func (c *FabricSessionClient) ReadFrame(ctx context.Context) (fabricproto.Frame, error) {
if c == nil || c.conn == nil {
return fabricproto.Frame{}, fmt.Errorf("fabric session client is closed")
}
c.readMu.Lock()
defer c.readMu.Unlock()
c.applyReadDeadline(ctx)
messageType, responsePayload, err := c.conn.ReadMessage()
if err != nil {
return fabricproto.Frame{}, err
}
if messageType != websocket.BinaryMessage {
return fabricproto.Frame{}, fmt.Errorf("fabric session websocket returned non-binary message type %d", messageType)
}
return fabricproto.UnmarshalFrame(responsePayload, c.maxPayload)
}
func (c *FabricSessionClient) RoundTrip(ctx context.Context, frame fabricproto.Frame) (fabricproto.Frame, error) {
if err := c.WriteFrame(ctx, frame); err != nil {
return fabricproto.Frame{}, err
}
return c.ReadFrame(ctx)
}
func (c *FabricSessionClient) StartPump(ctx context.Context, opts FabricSessionPumpOptions) *FabricSessionPump {
if opts.OutboundBuffer <= 0 {
opts.OutboundBuffer = 64
}
if opts.InboundBuffer <= 0 {
opts.InboundBuffer = 64
}
if opts.ErrorBuffer <= 0 {
opts.ErrorBuffer = 8
}
pumpCtx, cancel := context.WithCancel(ctx)
pump := &FabricSessionPump{
session: c,
outbound: make(chan fabricproto.Frame, opts.OutboundBuffer),
inbound: make(chan fabricproto.Frame, opts.InboundBuffer),
errors: make(chan error, opts.ErrorBuffer),
done: make(chan struct{}),
cancel: cancel,
}
go pump.writeLoop(pumpCtx)
go pump.readLoop(pumpCtx)
return pump
}
func (p *FabricSessionPump) Send(ctx context.Context, frame fabricproto.Frame) error {
if p == nil {
return fmt.Errorf("fabric session pump is nil")
}
select {
case <-ctx.Done():
return ctx.Err()
case <-p.done:
return fmt.Errorf("fabric session pump is closed")
case p.outbound <- frame:
return nil
}
}
func (p *FabricSessionPump) Frames() <-chan fabricproto.Frame {
if p == nil {
return nil
}
return p.inbound
}
func (p *FabricSessionPump) Errors() <-chan error {
if p == nil {
return nil
}
return p.errors
}
func (p *FabricSessionPump) Close() error {
if p == nil {
return nil
}
var err error
p.closeMu.Do(func() {
close(p.done)
p.cancel()
err = p.session.Close()
})
return err
}
func (p *FabricSessionPump) writeLoop(ctx context.Context) {
defer p.Close()
for {
select {
case <-ctx.Done():
p.reportError(ctx.Err())
return
case <-p.done:
return
case frame := <-p.outbound:
if err := p.session.WriteFrame(ctx, frame); err != nil {
p.reportError(err)
return
}
}
}
}
func (p *FabricSessionPump) readLoop(ctx context.Context) {
defer p.Close()
for {
frame, err := p.session.ReadFrame(ctx)
if err != nil {
p.reportError(err)
return
}
select {
case <-ctx.Done():
p.reportError(ctx.Err())
return
case <-p.done:
return
case p.inbound <- frame:
}
}
}
func (p *FabricSessionPump) reportError(err error) {
if err == nil {
return
}
select {
case p.errors <- err:
default:
}
}
func (c *FabricSessionClient) applyReadDeadline(ctx context.Context) {
if deadline, ok := ctx.Deadline(); ok {
_ = c.conn.SetReadDeadline(deadline)
} else if c.timeout > 0 {
_ = c.conn.SetReadDeadline(time.Now().Add(c.timeout))
}
}
func (c *FabricSessionClient) applyWriteDeadline(ctx context.Context) {
if deadline, ok := ctx.Deadline(); ok {
_ = c.conn.SetWriteDeadline(deadline)
} else if c.timeout > 0 {
_ = c.conn.SetWriteDeadline(time.Now().Add(c.timeout))
}
}
func (c Client) fabricSessionWebSocketURL() (string, error) {
base := strings.TrimSpace(c.BaseURL)
if base == "" {
return "", fmt.Errorf("mesh base url is required")
}
parsed, err := url.Parse(base)
if err != nil {
return "", err
}
switch parsed.Scheme {
case "http":
parsed.Scheme = "ws"
case "https":
parsed.Scheme = "wss"
case "ws", "wss":
default:
return "", fmt.Errorf("unsupported mesh base url scheme %q", parsed.Scheme)
}
parsed.Path = strings.TrimRight(parsed.Path, "/") + "/mesh/v1/fabric/session/ws"
parsed.RawQuery = ""
parsed.Fragment = ""
return parsed.String(), nil
}
func cloneHeader(header http.Header) http.Header {
out := http.Header{}
for key, values := range header {
for _, value := range values {
out.Add(key, value)
}
}
return out
}