Files
rdp-proxy/agents/rap-node-agent/internal/fabricproto/frame.go
T
2026-05-16 00:10:04 +03:00

217 lines
5.3 KiB
Go

package fabricproto
import (
"encoding/binary"
"errors"
"fmt"
"io"
)
const (
Magic uint32 = 0x52415046 // RAPF
Version uint8 = 1
HeaderSize = 32
DefaultMaxPayload = 1024 * 1024
)
type FrameType uint8
const (
FrameHello FrameType = iota + 1
FrameAuth
FrameSessionReady
FrameOpenStream
FrameData
FrameAck
FramePing
FramePong
FrameRouteUpdate
FrameStreamCredit
FrameNodePressure
FrameCloseStream
FrameResetStream
FrameGoAway
)
type TrafficClass uint8
const (
TrafficClassControl TrafficClass = iota + 1
TrafficClassDNS
TrafficClassInteractive
TrafficClassReliable
TrafficClassBulk
TrafficClassDroppable
)
type Frame struct {
Type FrameType
Flags uint16
TrafficClass TrafficClass
StreamID uint64
Sequence uint64
Payload []byte
}
var (
ErrInvalidMagic = errors.New("invalid fabric frame magic")
ErrUnsupportedVer = errors.New("unsupported fabric frame version")
ErrUnknownFrameType = errors.New("unknown fabric frame type")
ErrInvalidStreamID = errors.New("invalid fabric frame stream id")
ErrInvalidPayloadLen = errors.New("invalid fabric frame payload length")
ErrUnknownTraffic = errors.New("unknown fabric traffic class")
)
func MarshalFrame(frame Frame) ([]byte, error) {
if err := ValidateFrame(frame, DefaultMaxPayload); err != nil {
return nil, err
}
out := make([]byte, HeaderSize+len(frame.Payload))
writeHeader(out[:HeaderSize], frame, uint32(len(frame.Payload)))
copy(out[HeaderSize:], frame.Payload)
return out, nil
}
func WriteFrame(w io.Writer, frame Frame) error {
if err := ValidateFrame(frame, DefaultMaxPayload); err != nil {
return err
}
header := make([]byte, HeaderSize)
writeHeader(header, frame, uint32(len(frame.Payload)))
if err := writeFull(w, header); err != nil {
return err
}
if len(frame.Payload) == 0 {
return nil
}
return writeFull(w, frame.Payload)
}
func ReadFrame(r io.Reader, maxPayload int) (Frame, error) {
if maxPayload <= 0 {
maxPayload = DefaultMaxPayload
}
header := make([]byte, HeaderSize)
if _, err := io.ReadFull(r, header); err != nil {
return Frame{}, err
}
frame, payloadLength, err := parseHeader(header, maxPayload)
if err != nil {
return Frame{}, err
}
if payloadLength == 0 {
return frame, nil
}
frame.Payload = make([]byte, payloadLength)
if _, err := io.ReadFull(r, frame.Payload); err != nil {
return Frame{}, err
}
return frame, nil
}
func UnmarshalFrame(data []byte, maxPayload int) (Frame, error) {
if len(data) < HeaderSize {
return Frame{}, io.ErrUnexpectedEOF
}
if maxPayload <= 0 {
maxPayload = DefaultMaxPayload
}
frame, payloadLength, err := parseHeader(data[:HeaderSize], maxPayload)
if err != nil {
return Frame{}, err
}
if len(data)-HeaderSize != payloadLength {
return Frame{}, fmt.Errorf("%w: header=%d actual=%d", ErrInvalidPayloadLen, payloadLength, len(data)-HeaderSize)
}
if payloadLength > 0 {
frame.Payload = append([]byte(nil), data[HeaderSize:]...)
}
return frame, nil
}
func ValidateFrame(frame Frame, maxPayload int) error {
if maxPayload <= 0 {
maxPayload = DefaultMaxPayload
}
if !KnownFrameType(frame.Type) {
return ErrUnknownFrameType
}
if len(frame.Payload) > maxPayload {
return fmt.Errorf("%w: %d > %d", ErrInvalidPayloadLen, len(frame.Payload), maxPayload)
}
if requiresStream(frame.Type) && frame.StreamID == 0 {
return ErrInvalidStreamID
}
if frame.TrafficClass != 0 && !KnownTrafficClass(frame.TrafficClass) {
return ErrUnknownTraffic
}
return nil
}
func KnownFrameType(frameType FrameType) bool {
return frameType >= FrameHello && frameType <= FrameGoAway
}
func KnownTrafficClass(trafficClass TrafficClass) bool {
return trafficClass >= TrafficClassControl && trafficClass <= TrafficClassDroppable
}
func requiresStream(frameType FrameType) bool {
switch frameType {
case FrameOpenStream, FrameData, FrameAck, FrameStreamCredit, FrameCloseStream, FrameResetStream:
return true
default:
return false
}
}
func writeHeader(header []byte, frame Frame, payloadLength uint32) {
binary.BigEndian.PutUint32(header[0:4], Magic)
header[4] = Version
header[5] = byte(frame.Type)
binary.BigEndian.PutUint16(header[6:8], frame.Flags)
header[8] = byte(frame.TrafficClass)
binary.BigEndian.PutUint64(header[12:20], frame.StreamID)
binary.BigEndian.PutUint64(header[20:28], frame.Sequence)
binary.BigEndian.PutUint32(header[28:32], payloadLength)
}
func parseHeader(header []byte, maxPayload int) (Frame, int, error) {
if binary.BigEndian.Uint32(header[0:4]) != Magic {
return Frame{}, 0, ErrInvalidMagic
}
if header[4] != Version {
return Frame{}, 0, ErrUnsupportedVer
}
frame := Frame{
Type: FrameType(header[5]),
Flags: binary.BigEndian.Uint16(header[6:8]),
TrafficClass: TrafficClass(header[8]),
StreamID: binary.BigEndian.Uint64(header[12:20]),
Sequence: binary.BigEndian.Uint64(header[20:28]),
}
payloadLength := int(binary.BigEndian.Uint32(header[28:32]))
if payloadLength > maxPayload {
return Frame{}, 0, fmt.Errorf("%w: %d > %d", ErrInvalidPayloadLen, payloadLength, maxPayload)
}
if err := ValidateFrame(frame, maxPayload); err != nil {
return Frame{}, 0, err
}
return frame, payloadLength, nil
}
func writeFull(w io.Writer, data []byte) error {
for len(data) > 0 {
n, err := w.Write(data)
if err != nil {
return err
}
if n <= 0 {
return io.ErrShortWrite
}
data = data[n:]
}
return nil
}