Files
rdp-proxy/agents/rap-node-agent/internal/webingress/receiver.go
T

220 lines
7.0 KiB
Go

package webingress
import (
"context"
"crypto/ed25519"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
)
const (
SignedFabricServiceChannelEnvelopeSchema = "rap.web_ingress.signed_fabric_service_channel_envelope.v1"
FabricRuntimeResponseSchema = "rap.web_ingress.fabric_runtime_response.v1"
)
var (
ErrFabricEnvelopeSignatureInvalid = errors.New("web ingress fabric envelope signature invalid")
ErrFabricEnvelopeUnauthorized = errors.New("web ingress fabric envelope unauthorized")
ErrFabricEnvelopeRuntimeRequired = errors.New("web ingress fabric runtime handler required")
)
type EnvelopeKeyResolver interface {
PublicKey(ctx context.Context, keyID string) (ed25519.PublicKey, bool, error)
}
type EnvelopeRuntimeHandler interface {
HandleFabricRequest(ctx context.Context, request FabricRequest) (FabricResponse, error)
}
type RuntimeHandlerFunc func(ctx context.Context, request FabricRequest) (FabricResponse, error)
func (f RuntimeHandlerFunc) HandleFabricRequest(ctx context.Context, request FabricRequest) (FabricResponse, error) {
return f(ctx, request)
}
type ReceiverConfig struct {
ServiceType string
Scope string
ServiceClasses []string
MaxClockSkew time.Duration
}
type FabricRuntimeReceiver struct {
Config ReceiverConfig
Keys EnvelopeKeyResolver
Handler EnvelopeRuntimeHandler
Now func() time.Time
}
type StaticEnvelopeKeyResolver map[string]ed25519.PublicKey
func (r StaticEnvelopeKeyResolver) PublicKey(_ context.Context, keyID string) (ed25519.PublicKey, bool, error) {
key, ok := r[strings.TrimSpace(keyID)]
if !ok {
return nil, false, nil
}
return append(ed25519.PublicKey(nil), key...), true, nil
}
func (r FabricRuntimeReceiver) Receive(ctx context.Context, payload []byte) ([]byte, error) {
response, err := r.ReceiveResponse(ctx, payload)
if err != nil {
return nil, err
}
return encodeFabricRuntimeResponse(response)
}
func (r FabricRuntimeReceiver) ReceiveResponse(ctx context.Context, payload []byte) (FabricResponse, error) {
if r.Handler == nil {
return FabricResponse{}, ErrFabricEnvelopeRuntimeRequired
}
var signed SignedFabricServiceChannelEnvelope
if err := json.Unmarshal(payload, &signed); err != nil {
return FabricResponse{}, fmt.Errorf("%w: invalid signed envelope json", ErrFabricEnvelopeSignatureInvalid)
}
if err := r.verify(ctx, signed); err != nil {
return FabricResponse{}, err
}
request, err := requestFromEnvelope(signed.Envelope)
if err != nil {
return FabricResponse{}, err
}
return r.Handler.HandleFabricRequest(ctx, request)
}
func (r FabricRuntimeReceiver) verify(ctx context.Context, signed SignedFabricServiceChannelEnvelope) error {
if signed.SchemaVersion != SignedFabricServiceChannelEnvelopeSchema {
return fmt.Errorf("%w: signed schema mismatch", ErrFabricEnvelopeSignatureInvalid)
}
if signed.Envelope.SchemaVersion != FabricServiceChannelEnvelopeSchema ||
strings.TrimSpace(signed.Envelope.Scope) == "" ||
strings.TrimSpace(signed.Envelope.ServiceClass) == "" {
return fmt.Errorf("%w: envelope contract invalid", ErrFabricEnvelopeSignatureInvalid)
}
if scope := strings.TrimSpace(r.Config.Scope); scope != "" && signed.Envelope.Scope != scope {
return fmt.Errorf("%w: scope mismatch", ErrFabricEnvelopeUnauthorized)
}
if len(r.Config.ServiceClasses) > 0 && !contains(r.Config.ServiceClasses, signed.Envelope.ServiceClass) {
return fmt.Errorf("%w: service class not allowed", ErrFabricEnvelopeUnauthorized)
}
if err := r.verifyClock(signed.Envelope); err != nil {
return err
}
if r.Keys == nil {
return fmt.Errorf("%w: key resolver required", ErrFabricEnvelopeSignatureInvalid)
}
keyID := strings.TrimSpace(signed.Signature.KeyID)
publicKey, ok, err := r.Keys.PublicKey(ctx, keyID)
if err != nil {
return err
}
if !ok || len(publicKey) != ed25519.PublicKeySize {
return fmt.Errorf("%w: signing key not trusted", ErrFabricEnvelopeUnauthorized)
}
if signed.Signature.Alg != "ed25519" {
return fmt.Errorf("%w: algorithm mismatch", ErrFabricEnvelopeSignatureInvalid)
}
signature, err := decodeEnvelopeBase64(signed.Signature.Signature)
if err != nil || len(signature) != ed25519.SignatureSize {
return fmt.Errorf("%w: signature must be base64 ed25519", ErrFabricEnvelopeSignatureInvalid)
}
canonical, err := json.Marshal(signed.Envelope)
if err != nil {
return err
}
if !ed25519.Verify(publicKey, canonical, signature) {
return ErrFabricEnvelopeSignatureInvalid
}
return nil
}
func (r FabricRuntimeReceiver) verifyClock(envelope FabricServiceChannelEnvelope) error {
maxSkew := r.Config.MaxClockSkew
if maxSkew <= 0 {
maxSkew = 5 * time.Minute
}
now := time.Now().UTC()
if r.Now != nil {
now = r.Now().UTC()
}
for _, value := range []string{envelope.ObservedAt, envelope.EnvelopedAt} {
if strings.TrimSpace(value) == "" {
continue
}
parsed, err := time.Parse(time.RFC3339Nano, value)
if err != nil {
return fmt.Errorf("%w: invalid envelope timestamp", ErrFabricEnvelopeSignatureInvalid)
}
if parsed.After(now.Add(maxSkew)) || parsed.Before(now.Add(-maxSkew)) {
return fmt.Errorf("%w: envelope timestamp outside skew", ErrFabricEnvelopeUnauthorized)
}
}
return nil
}
func requestFromEnvelope(envelope FabricServiceChannelEnvelope) (FabricRequest, error) {
body, err := base64.StdEncoding.DecodeString(envelope.BodyBase64)
if err != nil && envelope.BodyBase64 != "" {
return FabricRequest{}, fmt.Errorf("%w: invalid body_b64", ErrFabricEnvelopeSignatureInvalid)
}
observedAt, _ := time.Parse(time.RFC3339Nano, envelope.ObservedAt)
headers := http.Header{}
for key, values := range envelope.Headers {
if !safeRequestHeader(key) {
continue
}
for _, value := range values {
headers.Add(key, value)
}
}
return FabricRequest{
SchemaVersion: envelope.RequestSchema,
Method: envelope.Method,
Path: envelope.Path,
Query: envelope.Query,
Host: envelope.Host,
ServiceType: envelope.ServiceType,
Scope: envelope.Scope,
ServiceClass: envelope.ServiceClass,
Headers: headers,
Body: body,
ObservedAt: observedAt,
}, nil
}
func encodeFabricRuntimeResponse(response FabricResponse) ([]byte, error) {
headers := map[string][]string{}
for key, values := range response.Headers {
if !safeResponseHeader(key) {
continue
}
copied := append([]string(nil), values...)
if len(copied) > 0 {
headers[http.CanonicalHeaderKey(key)] = copied
}
}
payload := struct {
SchemaVersion string `json:"schema_version"`
StatusCode int `json:"status_code"`
Headers map[string][]string `json:"headers,omitempty"`
BodyBase64 string `json:"body_b64,omitempty"`
}{
SchemaVersion: FabricRuntimeResponseSchema,
StatusCode: response.StatusCode,
Headers: headers,
BodyBase64: base64.StdEncoding.EncodeToString(response.Body),
}
if payload.StatusCode < 100 || payload.StatusCode > 599 {
payload.StatusCode = http.StatusOK
}
if len(payload.Headers) == 0 {
payload.Headers = nil
}
return json.Marshal(payload)
}