Refactor RDP proxy handling and update related tests
This commit is contained in:
@@ -0,0 +1,219 @@
|
||||
package webingress
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
SignedFabricServiceChannelEnvelopeSchema = "rap.web_ingress.signed_fabric_service_channel_envelope.v1"
|
||||
FabricRuntimeResponseSchema = "rap.web_ingress.fabric_runtime_response.v1"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrFabricEnvelopeSignatureInvalid = errors.New("web ingress fabric envelope signature invalid")
|
||||
ErrFabricEnvelopeUnauthorized = errors.New("web ingress fabric envelope unauthorized")
|
||||
ErrFabricEnvelopeRuntimeRequired = errors.New("web ingress fabric runtime handler required")
|
||||
)
|
||||
|
||||
type EnvelopeKeyResolver interface {
|
||||
PublicKey(ctx context.Context, keyID string) (ed25519.PublicKey, bool, error)
|
||||
}
|
||||
|
||||
type EnvelopeRuntimeHandler interface {
|
||||
HandleFabricRequest(ctx context.Context, request FabricRequest) (FabricResponse, error)
|
||||
}
|
||||
|
||||
type RuntimeHandlerFunc func(ctx context.Context, request FabricRequest) (FabricResponse, error)
|
||||
|
||||
func (f RuntimeHandlerFunc) HandleFabricRequest(ctx context.Context, request FabricRequest) (FabricResponse, error) {
|
||||
return f(ctx, request)
|
||||
}
|
||||
|
||||
type ReceiverConfig struct {
|
||||
ServiceType string
|
||||
Scope string
|
||||
ServiceClasses []string
|
||||
MaxClockSkew time.Duration
|
||||
}
|
||||
|
||||
type FabricRuntimeReceiver struct {
|
||||
Config ReceiverConfig
|
||||
Keys EnvelopeKeyResolver
|
||||
Handler EnvelopeRuntimeHandler
|
||||
Now func() time.Time
|
||||
}
|
||||
|
||||
type StaticEnvelopeKeyResolver map[string]ed25519.PublicKey
|
||||
|
||||
func (r StaticEnvelopeKeyResolver) PublicKey(_ context.Context, keyID string) (ed25519.PublicKey, bool, error) {
|
||||
key, ok := r[strings.TrimSpace(keyID)]
|
||||
if !ok {
|
||||
return nil, false, nil
|
||||
}
|
||||
return append(ed25519.PublicKey(nil), key...), true, nil
|
||||
}
|
||||
|
||||
func (r FabricRuntimeReceiver) Receive(ctx context.Context, payload []byte) ([]byte, error) {
|
||||
response, err := r.ReceiveResponse(ctx, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return encodeFabricRuntimeResponse(response)
|
||||
}
|
||||
|
||||
func (r FabricRuntimeReceiver) ReceiveResponse(ctx context.Context, payload []byte) (FabricResponse, error) {
|
||||
if r.Handler == nil {
|
||||
return FabricResponse{}, ErrFabricEnvelopeRuntimeRequired
|
||||
}
|
||||
var signed SignedFabricServiceChannelEnvelope
|
||||
if err := json.Unmarshal(payload, &signed); err != nil {
|
||||
return FabricResponse{}, fmt.Errorf("%w: invalid signed envelope json", ErrFabricEnvelopeSignatureInvalid)
|
||||
}
|
||||
if err := r.verify(ctx, signed); err != nil {
|
||||
return FabricResponse{}, err
|
||||
}
|
||||
request, err := requestFromEnvelope(signed.Envelope)
|
||||
if err != nil {
|
||||
return FabricResponse{}, err
|
||||
}
|
||||
return r.Handler.HandleFabricRequest(ctx, request)
|
||||
}
|
||||
|
||||
func (r FabricRuntimeReceiver) verify(ctx context.Context, signed SignedFabricServiceChannelEnvelope) error {
|
||||
if signed.SchemaVersion != SignedFabricServiceChannelEnvelopeSchema {
|
||||
return fmt.Errorf("%w: signed schema mismatch", ErrFabricEnvelopeSignatureInvalid)
|
||||
}
|
||||
if signed.Envelope.SchemaVersion != FabricServiceChannelEnvelopeSchema ||
|
||||
strings.TrimSpace(signed.Envelope.Scope) == "" ||
|
||||
strings.TrimSpace(signed.Envelope.ServiceClass) == "" {
|
||||
return fmt.Errorf("%w: envelope contract invalid", ErrFabricEnvelopeSignatureInvalid)
|
||||
}
|
||||
if scope := strings.TrimSpace(r.Config.Scope); scope != "" && signed.Envelope.Scope != scope {
|
||||
return fmt.Errorf("%w: scope mismatch", ErrFabricEnvelopeUnauthorized)
|
||||
}
|
||||
if len(r.Config.ServiceClasses) > 0 && !contains(r.Config.ServiceClasses, signed.Envelope.ServiceClass) {
|
||||
return fmt.Errorf("%w: service class not allowed", ErrFabricEnvelopeUnauthorized)
|
||||
}
|
||||
if err := r.verifyClock(signed.Envelope); err != nil {
|
||||
return err
|
||||
}
|
||||
if r.Keys == nil {
|
||||
return fmt.Errorf("%w: key resolver required", ErrFabricEnvelopeSignatureInvalid)
|
||||
}
|
||||
keyID := strings.TrimSpace(signed.Signature.KeyID)
|
||||
publicKey, ok, err := r.Keys.PublicKey(ctx, keyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok || len(publicKey) != ed25519.PublicKeySize {
|
||||
return fmt.Errorf("%w: signing key not trusted", ErrFabricEnvelopeUnauthorized)
|
||||
}
|
||||
if signed.Signature.Alg != "ed25519" {
|
||||
return fmt.Errorf("%w: algorithm mismatch", ErrFabricEnvelopeSignatureInvalid)
|
||||
}
|
||||
signature, err := decodeEnvelopeBase64(signed.Signature.Signature)
|
||||
if err != nil || len(signature) != ed25519.SignatureSize {
|
||||
return fmt.Errorf("%w: signature must be base64 ed25519", ErrFabricEnvelopeSignatureInvalid)
|
||||
}
|
||||
canonical, err := json.Marshal(signed.Envelope)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ed25519.Verify(publicKey, canonical, signature) {
|
||||
return ErrFabricEnvelopeSignatureInvalid
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r FabricRuntimeReceiver) verifyClock(envelope FabricServiceChannelEnvelope) error {
|
||||
maxSkew := r.Config.MaxClockSkew
|
||||
if maxSkew <= 0 {
|
||||
maxSkew = 5 * time.Minute
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if r.Now != nil {
|
||||
now = r.Now().UTC()
|
||||
}
|
||||
for _, value := range []string{envelope.ObservedAt, envelope.EnvelopedAt} {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
continue
|
||||
}
|
||||
parsed, err := time.Parse(time.RFC3339Nano, value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: invalid envelope timestamp", ErrFabricEnvelopeSignatureInvalid)
|
||||
}
|
||||
if parsed.After(now.Add(maxSkew)) || parsed.Before(now.Add(-maxSkew)) {
|
||||
return fmt.Errorf("%w: envelope timestamp outside skew", ErrFabricEnvelopeUnauthorized)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func requestFromEnvelope(envelope FabricServiceChannelEnvelope) (FabricRequest, error) {
|
||||
body, err := base64.StdEncoding.DecodeString(envelope.BodyBase64)
|
||||
if err != nil && envelope.BodyBase64 != "" {
|
||||
return FabricRequest{}, fmt.Errorf("%w: invalid body_b64", ErrFabricEnvelopeSignatureInvalid)
|
||||
}
|
||||
observedAt, _ := time.Parse(time.RFC3339Nano, envelope.ObservedAt)
|
||||
headers := http.Header{}
|
||||
for key, values := range envelope.Headers {
|
||||
if !safeRequestHeader(key) {
|
||||
continue
|
||||
}
|
||||
for _, value := range values {
|
||||
headers.Add(key, value)
|
||||
}
|
||||
}
|
||||
return FabricRequest{
|
||||
SchemaVersion: envelope.RequestSchema,
|
||||
Method: envelope.Method,
|
||||
Path: envelope.Path,
|
||||
Query: envelope.Query,
|
||||
Host: envelope.Host,
|
||||
ServiceType: envelope.ServiceType,
|
||||
Scope: envelope.Scope,
|
||||
ServiceClass: envelope.ServiceClass,
|
||||
Headers: headers,
|
||||
Body: body,
|
||||
ObservedAt: observedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func encodeFabricRuntimeResponse(response FabricResponse) ([]byte, error) {
|
||||
headers := map[string][]string{}
|
||||
for key, values := range response.Headers {
|
||||
if !safeResponseHeader(key) {
|
||||
continue
|
||||
}
|
||||
copied := append([]string(nil), values...)
|
||||
if len(copied) > 0 {
|
||||
headers[http.CanonicalHeaderKey(key)] = copied
|
||||
}
|
||||
}
|
||||
payload := struct {
|
||||
SchemaVersion string `json:"schema_version"`
|
||||
StatusCode int `json:"status_code"`
|
||||
Headers map[string][]string `json:"headers,omitempty"`
|
||||
BodyBase64 string `json:"body_b64,omitempty"`
|
||||
}{
|
||||
SchemaVersion: FabricRuntimeResponseSchema,
|
||||
StatusCode: response.StatusCode,
|
||||
Headers: headers,
|
||||
BodyBase64: base64.StdEncoding.EncodeToString(response.Body),
|
||||
}
|
||||
if payload.StatusCode < 100 || payload.StatusCode > 599 {
|
||||
payload.StatusCode = http.StatusOK
|
||||
}
|
||||
if len(payload.Headers) == 0 {
|
||||
payload.Headers = nil
|
||||
}
|
||||
return json.Marshal(payload)
|
||||
}
|
||||
Reference in New Issue
Block a user