Initial project snapshot
This commit is contained in:
@@ -0,0 +1,329 @@
|
||||
package authority
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/config"
|
||||
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
|
||||
)
|
||||
|
||||
const (
|
||||
ModeStrict = "strict"
|
||||
ModeLegacy = "legacy"
|
||||
|
||||
ActivationSchemaVersion = "rap.installation.activation.v1"
|
||||
|
||||
PlatformRoleUser = "user"
|
||||
PlatformRoleAdmin = "platform_admin"
|
||||
PlatformRoleRecoveryAdmin = "platform_recovery_admin"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidAuthorityMode = errors.New("invalid installation authority mode")
|
||||
ErrProductRootKeyNeeded = errors.New("product root public key is required")
|
||||
ErrInvalidActivation = errors.New("invalid installation activation")
|
||||
ErrInvalidGrant = errors.New("invalid platform role grant")
|
||||
)
|
||||
|
||||
type ActivationPayload struct {
|
||||
SchemaVersion string `json:"schema_version"`
|
||||
InstallID string `json:"install_id"`
|
||||
OwnerEmail string `json:"owner_email"`
|
||||
PlatformRole string `json:"platform_role"`
|
||||
IssuedAt time.Time `json:"issued_at"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
Environment string `json:"environment,omitempty"`
|
||||
}
|
||||
|
||||
type Verifier struct {
|
||||
mode string
|
||||
rootPublicKey ed25519.PublicKey
|
||||
rootFingerprint string
|
||||
allowInsecureBootstrap bool
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
func NewVerifier(cfg config.InstallationConfig) (*Verifier, error) {
|
||||
mode := strings.ToLower(strings.TrimSpace(cfg.AuthorityMode))
|
||||
if mode == "" {
|
||||
mode = ModeLegacy
|
||||
}
|
||||
verifier := &Verifier{
|
||||
mode: mode,
|
||||
allowInsecureBootstrap: cfg.AllowInsecureBootstrap,
|
||||
now: time.Now,
|
||||
}
|
||||
|
||||
switch mode {
|
||||
case ModeLegacy:
|
||||
return verifier, nil
|
||||
case ModeStrict:
|
||||
publicKey, err := decodeEd25519PublicKey(cfg.ProductRootPublicKeyBase64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
verifier.rootPublicKey = publicKey
|
||||
fingerprint := sha256.Sum256(publicKey)
|
||||
verifier.rootFingerprint = hex.EncodeToString(fingerprint[:])
|
||||
return verifier, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: %s", ErrInvalidAuthorityMode, mode)
|
||||
}
|
||||
}
|
||||
|
||||
func (v *Verifier) Mode() string {
|
||||
if v == nil || v.mode == "" {
|
||||
return ModeLegacy
|
||||
}
|
||||
return v.mode
|
||||
}
|
||||
|
||||
func (v *Verifier) Strict() bool {
|
||||
return v != nil && v.mode == ModeStrict
|
||||
}
|
||||
|
||||
func (v *Verifier) AllowInsecureBootstrap() bool {
|
||||
return v != nil && v.allowInsecureBootstrap
|
||||
}
|
||||
|
||||
func (v *Verifier) RootFingerprint() string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return v.rootFingerprint
|
||||
}
|
||||
|
||||
func (v *Verifier) VerifyActivation(payload json.RawMessage, signature string) (ActivationPayload, error) {
|
||||
if v == nil || !v.Strict() {
|
||||
return ActivationPayload{}, ErrProductRootKeyNeeded
|
||||
}
|
||||
activation, canonical, err := parseActivationPayload(payload)
|
||||
if err != nil {
|
||||
return ActivationPayload{}, err
|
||||
}
|
||||
if err := activation.validate(v.now().UTC()); err != nil {
|
||||
return ActivationPayload{}, err
|
||||
}
|
||||
if err := v.verifySignature(canonical, signature); err != nil {
|
||||
return ActivationPayload{}, fmt.Errorf("%w: %v", ErrInvalidActivation, err)
|
||||
}
|
||||
return activation, nil
|
||||
}
|
||||
|
||||
func (v *Verifier) VerifyPlatformRoleGrant(payload json.RawMessage, signature, expectedInstallID, expectedEmail, expectedRole string) (ActivationPayload, error) {
|
||||
activation, err := v.VerifyActivation(payload, signature)
|
||||
if err != nil {
|
||||
return ActivationPayload{}, fmt.Errorf("%w: %v", ErrInvalidGrant, err)
|
||||
}
|
||||
if activation.InstallID != strings.TrimSpace(expectedInstallID) {
|
||||
return ActivationPayload{}, fmt.Errorf("%w: install_id mismatch", ErrInvalidGrant)
|
||||
}
|
||||
if !strings.EqualFold(activation.OwnerEmail, strings.TrimSpace(expectedEmail)) {
|
||||
return ActivationPayload{}, fmt.Errorf("%w: owner_email mismatch", ErrInvalidGrant)
|
||||
}
|
||||
if activation.PlatformRole != strings.TrimSpace(expectedRole) {
|
||||
return ActivationPayload{}, fmt.Errorf("%w: platform_role mismatch", ErrInvalidGrant)
|
||||
}
|
||||
return activation, nil
|
||||
}
|
||||
|
||||
func CanonicalJSON(raw json.RawMessage) ([]byte, error) {
|
||||
if len(raw) == 0 {
|
||||
return nil, fmt.Errorf("%w: empty payload", ErrInvalidActivation)
|
||||
}
|
||||
var value any
|
||||
if err := json.Unmarshal(raw, &value); err != nil {
|
||||
return nil, fmt.Errorf("%w: invalid json: %v", ErrInvalidActivation, err)
|
||||
}
|
||||
canonical, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: canonical json: %v", ErrInvalidActivation, err)
|
||||
}
|
||||
return canonical, nil
|
||||
}
|
||||
|
||||
func EffectivePlatformRole(ctx context.Context, db postgresplatform.DBTX, verifier *Verifier, userID string) (string, error) {
|
||||
userID = strings.TrimSpace(userID)
|
||||
if userID == "" {
|
||||
return PlatformRoleUser, nil
|
||||
}
|
||||
if verifier == nil || !verifier.Strict() {
|
||||
return legacyPlatformRole(ctx, db, userID)
|
||||
}
|
||||
|
||||
var email string
|
||||
if err := db.QueryRow(ctx, `SELECT email FROM users WHERE id = $1::uuid`, userID).Scan(&email); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return PlatformRoleUser, nil
|
||||
}
|
||||
return "", fmt.Errorf("get user email for platform grant: %w", err)
|
||||
}
|
||||
|
||||
rows, err := db.Query(ctx, `
|
||||
SELECT prg.role, prg.install_id, prg.grant_payload, prg.grant_signature
|
||||
FROM platform_role_grants prg
|
||||
JOIN installation_authority ia
|
||||
ON ia.id = 1
|
||||
AND ia.install_id = prg.install_id
|
||||
AND ia.authority_state = 'active'
|
||||
WHERE prg.user_id = $1::uuid
|
||||
AND prg.revoked_at IS NULL
|
||||
AND (prg.expires_at IS NULL OR prg.expires_at > NOW())
|
||||
ORDER BY CASE prg.role
|
||||
WHEN 'platform_recovery_admin' THEN 0
|
||||
WHEN 'platform_admin' THEN 1
|
||||
ELSE 2
|
||||
END, prg.granted_at DESC
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("query platform role grants: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
bestRole := PlatformRoleUser
|
||||
for rows.Next() {
|
||||
var role, installID, signature string
|
||||
var payload []byte
|
||||
if err := rows.Scan(&role, &installID, &payload, &signature); err != nil {
|
||||
return "", fmt.Errorf("scan platform role grant: %w", err)
|
||||
}
|
||||
if _, err := verifier.VerifyPlatformRoleGrant(json.RawMessage(payload), signature, installID, email, role); err != nil {
|
||||
continue
|
||||
}
|
||||
if role == PlatformRoleRecoveryAdmin {
|
||||
return role, nil
|
||||
}
|
||||
if role == PlatformRoleAdmin {
|
||||
bestRole = role
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return "", fmt.Errorf("iterate platform role grants: %w", err)
|
||||
}
|
||||
return bestRole, nil
|
||||
}
|
||||
|
||||
func legacyPlatformRole(ctx context.Context, db postgresplatform.DBTX, userID string) (string, error) {
|
||||
var role string
|
||||
if err := db.QueryRow(ctx, `SELECT platform_role FROM users WHERE id = $1::uuid`, userID).Scan(&role); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return PlatformRoleUser, nil
|
||||
}
|
||||
return "", fmt.Errorf("get platform role: %w", err)
|
||||
}
|
||||
if role == "" {
|
||||
return PlatformRoleUser, nil
|
||||
}
|
||||
return role, nil
|
||||
}
|
||||
|
||||
func parseActivationPayload(raw json.RawMessage) (ActivationPayload, []byte, error) {
|
||||
canonical, err := CanonicalJSON(raw)
|
||||
if err != nil {
|
||||
return ActivationPayload{}, nil, err
|
||||
}
|
||||
var activation ActivationPayload
|
||||
if err := json.Unmarshal(canonical, &activation); err != nil {
|
||||
return ActivationPayload{}, nil, fmt.Errorf("%w: decode activation: %v", ErrInvalidActivation, err)
|
||||
}
|
||||
activation.SchemaVersion = strings.TrimSpace(activation.SchemaVersion)
|
||||
activation.InstallID = strings.TrimSpace(activation.InstallID)
|
||||
activation.OwnerEmail = strings.ToLower(strings.TrimSpace(activation.OwnerEmail))
|
||||
activation.PlatformRole = strings.TrimSpace(activation.PlatformRole)
|
||||
activation.Nonce = strings.TrimSpace(activation.Nonce)
|
||||
activation.Environment = strings.TrimSpace(activation.Environment)
|
||||
return activation, canonical, nil
|
||||
}
|
||||
|
||||
func (p ActivationPayload) validate(now time.Time) error {
|
||||
if p.SchemaVersion != ActivationSchemaVersion {
|
||||
return fmt.Errorf("%w: schema_version must be %s", ErrInvalidActivation, ActivationSchemaVersion)
|
||||
}
|
||||
if p.InstallID == "" {
|
||||
return fmt.Errorf("%w: install_id is required", ErrInvalidActivation)
|
||||
}
|
||||
if p.OwnerEmail == "" || !strings.Contains(p.OwnerEmail, "@") {
|
||||
return fmt.Errorf("%w: owner_email is required", ErrInvalidActivation)
|
||||
}
|
||||
switch p.PlatformRole {
|
||||
case PlatformRoleAdmin, PlatformRoleRecoveryAdmin:
|
||||
default:
|
||||
return fmt.Errorf("%w: platform_role must be platform_admin or platform_recovery_admin", ErrInvalidActivation)
|
||||
}
|
||||
if p.IssuedAt.IsZero() {
|
||||
return fmt.Errorf("%w: issued_at is required", ErrInvalidActivation)
|
||||
}
|
||||
if p.IssuedAt.After(now.Add(5 * time.Minute)) {
|
||||
return fmt.Errorf("%w: issued_at is too far in the future", ErrInvalidActivation)
|
||||
}
|
||||
if p.ExpiresAt != nil && !p.ExpiresAt.After(now) {
|
||||
return fmt.Errorf("%w: activation expired", ErrInvalidActivation)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *Verifier) verifySignature(payload []byte, signatureText string) error {
|
||||
signature, err := decodeBase64(strings.TrimSpace(signatureText))
|
||||
if err != nil {
|
||||
return fmt.Errorf("signature must be base64 encoded: %w", err)
|
||||
}
|
||||
if len(signature) != ed25519.SignatureSize {
|
||||
return fmt.Errorf("signature must decode to %d bytes", ed25519.SignatureSize)
|
||||
}
|
||||
if !ed25519.Verify(v.rootPublicKey, payload, signature) {
|
||||
return errors.New("signature verification failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeEd25519PublicKey(value string) (ed25519.PublicKey, error) {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return nil, ErrProductRootKeyNeeded
|
||||
}
|
||||
if block, _ := pem.Decode([]byte(value)); block != nil {
|
||||
parsed, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse product root public key PEM: %w", err)
|
||||
}
|
||||
publicKey, ok := parsed.(ed25519.PublicKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("product root public key PEM must contain an Ed25519 public key")
|
||||
}
|
||||
return publicKey, nil
|
||||
}
|
||||
decoded, err := decodeBase64(value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("product root public key must be base64 encoded: %w", err)
|
||||
}
|
||||
if len(decoded) != ed25519.PublicKeySize {
|
||||
return nil, fmt.Errorf("product root public key must decode to %d bytes", ed25519.PublicKeySize)
|
||||
}
|
||||
return ed25519.PublicKey(decoded), nil
|
||||
}
|
||||
|
||||
func decodeBase64(value string) ([]byte, error) {
|
||||
decoded, err := base64.StdEncoding.DecodeString(value)
|
||||
if err == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
decoded, rawErr := base64.RawStdEncoding.DecodeString(value)
|
||||
if rawErr == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
package authority
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/config"
|
||||
)
|
||||
|
||||
func TestVerifierAcceptsSignedActivation(t *testing.T) {
|
||||
publicKey, privateKey, err := ed25519.GenerateKey(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("generate key: %v", err)
|
||||
}
|
||||
verifier, err := NewVerifier(config.InstallationConfig{
|
||||
AuthorityMode: ModeStrict,
|
||||
ProductRootPublicKeyBase64: base64.StdEncoding.EncodeToString(publicKey),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewVerifier: %v", err)
|
||||
}
|
||||
verifier.now = func() time.Time { return time.Date(2026, 4, 28, 12, 0, 0, 0, time.UTC) }
|
||||
|
||||
payload := json.RawMessage(`{
|
||||
"platform_role":"platform_admin",
|
||||
"owner_email":"Owner@Example.test",
|
||||
"install_id":"install-1",
|
||||
"schema_version":"rap.installation.activation.v1",
|
||||
"issued_at":"2026-04-28T11:00:00Z",
|
||||
"expires_at":"2026-04-29T11:00:00Z"
|
||||
}`)
|
||||
canonical, err := CanonicalJSON(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("CanonicalJSON: %v", err)
|
||||
}
|
||||
signature := base64.StdEncoding.EncodeToString(ed25519.Sign(privateKey, canonical))
|
||||
|
||||
activation, err := verifier.VerifyActivation(payload, signature)
|
||||
if err != nil {
|
||||
t.Fatalf("VerifyActivation: %v", err)
|
||||
}
|
||||
if activation.OwnerEmail != "owner@example.test" || activation.PlatformRole != PlatformRoleAdmin {
|
||||
t.Fatalf("unexpected activation: %+v", activation)
|
||||
}
|
||||
if verifier.RootFingerprint() == "" {
|
||||
t.Fatal("expected root fingerprint")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifierRejectsTamperedActivation(t *testing.T) {
|
||||
publicKey, privateKey, err := ed25519.GenerateKey(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("generate key: %v", err)
|
||||
}
|
||||
verifier, err := NewVerifier(config.InstallationConfig{
|
||||
AuthorityMode: ModeStrict,
|
||||
ProductRootPublicKeyBase64: base64.StdEncoding.EncodeToString(publicKey),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewVerifier: %v", err)
|
||||
}
|
||||
verifier.now = func() time.Time { return time.Date(2026, 4, 28, 12, 0, 0, 0, time.UTC) }
|
||||
|
||||
payload := json.RawMessage(`{
|
||||
"schema_version":"rap.installation.activation.v1",
|
||||
"install_id":"install-1",
|
||||
"owner_email":"owner@example.test",
|
||||
"platform_role":"platform_admin",
|
||||
"issued_at":"2026-04-28T11:00:00Z"
|
||||
}`)
|
||||
canonical, err := CanonicalJSON(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("CanonicalJSON: %v", err)
|
||||
}
|
||||
signature := base64.StdEncoding.EncodeToString(ed25519.Sign(privateKey, canonical))
|
||||
tampered := json.RawMessage(strings.Replace(string(payload), "platform_admin", "platform_recovery_admin", 1))
|
||||
|
||||
if _, err := verifier.VerifyActivation(tampered, signature); err == nil {
|
||||
t.Fatal("expected tampered activation to fail")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
package clusterauth
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
AuthoritySchemaVersion = "rap.cluster_authority.v1"
|
||||
SignatureSchemaVersion = "rap.cluster_authority.signature.v1"
|
||||
AlgorithmEd25519 = "ed25519"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidKey = errors.New("invalid cluster authority key")
|
||||
ErrInvalidSignature = errors.New("invalid cluster authority signature")
|
||||
ErrInvalidPayload = errors.New("invalid cluster authority payload")
|
||||
)
|
||||
|
||||
type KeyPair struct {
|
||||
PublicKeyB64 string
|
||||
PrivateKeyB64 string
|
||||
Fingerprint string
|
||||
}
|
||||
|
||||
type Signature struct {
|
||||
SchemaVersion string `json:"schema_version"`
|
||||
Algorithm string `json:"algorithm"`
|
||||
KeyFingerprint string `json:"key_fingerprint"`
|
||||
Signature string `json:"signature"`
|
||||
SignedAt time.Time `json:"signed_at"`
|
||||
}
|
||||
|
||||
func GenerateKeyPair() (KeyPair, error) {
|
||||
publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return KeyPair{}, err
|
||||
}
|
||||
fingerprint := Fingerprint(publicKey)
|
||||
return KeyPair{
|
||||
PublicKeyB64: base64.StdEncoding.EncodeToString(publicKey),
|
||||
PrivateKeyB64: base64.StdEncoding.EncodeToString(privateKey),
|
||||
Fingerprint: fingerprint,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func Fingerprint(publicKey ed25519.PublicKey) string {
|
||||
sum := sha256.Sum256(publicKey)
|
||||
return "rap-ca-ed25519-" + hex.EncodeToString(sum[:16])
|
||||
}
|
||||
|
||||
func FingerprintFromBase64(publicKeyB64 string) (string, error) {
|
||||
publicKey, err := DecodePublicKey(publicKeyB64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return Fingerprint(publicKey), nil
|
||||
}
|
||||
|
||||
func SignRaw(privateKeyB64 string, payload json.RawMessage, signedAt time.Time) (Signature, error) {
|
||||
privateKey, err := DecodePrivateKey(privateKeyB64)
|
||||
if err != nil {
|
||||
return Signature{}, err
|
||||
}
|
||||
canonical, err := CanonicalJSON(payload)
|
||||
if err != nil {
|
||||
return Signature{}, err
|
||||
}
|
||||
publicKey, ok := privateKey.Public().(ed25519.PublicKey)
|
||||
if !ok {
|
||||
return Signature{}, ErrInvalidKey
|
||||
}
|
||||
signature := ed25519.Sign(privateKey, canonical)
|
||||
return Signature{
|
||||
SchemaVersion: SignatureSchemaVersion,
|
||||
Algorithm: AlgorithmEd25519,
|
||||
KeyFingerprint: Fingerprint(publicKey),
|
||||
Signature: base64.StdEncoding.EncodeToString(signature),
|
||||
SignedAt: signedAt.UTC(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func SignPayload(privateKeyB64 string, payload any, signedAt time.Time) (json.RawMessage, Signature, error) {
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, Signature{}, fmt.Errorf("%w: marshal: %v", ErrInvalidPayload, err)
|
||||
}
|
||||
signature, err := SignRaw(privateKeyB64, raw, signedAt)
|
||||
if err != nil {
|
||||
return nil, Signature{}, err
|
||||
}
|
||||
return json.RawMessage(raw), signature, nil
|
||||
}
|
||||
|
||||
func VerifyRaw(publicKeyB64 string, payload json.RawMessage, signature Signature) error {
|
||||
if signature.SchemaVersion != SignatureSchemaVersion {
|
||||
return fmt.Errorf("%w: schema_version must be %s", ErrInvalidSignature, SignatureSchemaVersion)
|
||||
}
|
||||
if signature.Algorithm != AlgorithmEd25519 {
|
||||
return fmt.Errorf("%w: algorithm must be %s", ErrInvalidSignature, AlgorithmEd25519)
|
||||
}
|
||||
publicKey, err := DecodePublicKey(publicKeyB64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if signature.KeyFingerprint != Fingerprint(publicKey) {
|
||||
return fmt.Errorf("%w: key fingerprint mismatch", ErrInvalidSignature)
|
||||
}
|
||||
canonical, err := CanonicalJSON(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decodedSignature, err := decodeBase64(strings.TrimSpace(signature.Signature))
|
||||
if err != nil || len(decodedSignature) != ed25519.SignatureSize {
|
||||
return fmt.Errorf("%w: signature must be base64 ed25519 signature", ErrInvalidSignature)
|
||||
}
|
||||
if !ed25519.Verify(publicKey, canonical, decodedSignature) {
|
||||
return ErrInvalidSignature
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func CanonicalJSON(raw json.RawMessage) ([]byte, error) {
|
||||
if len(raw) == 0 {
|
||||
return nil, fmt.Errorf("%w: empty payload", ErrInvalidPayload)
|
||||
}
|
||||
var value any
|
||||
if err := json.Unmarshal(raw, &value); err != nil {
|
||||
return nil, fmt.Errorf("%w: invalid json: %v", ErrInvalidPayload, err)
|
||||
}
|
||||
canonical, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: canonical json: %v", ErrInvalidPayload, err)
|
||||
}
|
||||
return canonical, nil
|
||||
}
|
||||
|
||||
func HashRaw(raw json.RawMessage) (string, error) {
|
||||
canonical, err := CanonicalJSON(raw)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sum := sha256.Sum256(canonical)
|
||||
return hex.EncodeToString(sum[:]), nil
|
||||
}
|
||||
|
||||
func DecodePublicKey(value string) (ed25519.PublicKey, error) {
|
||||
decoded, err := decodeBase64(strings.TrimSpace(value))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: public key must be base64 encoded", ErrInvalidKey)
|
||||
}
|
||||
if len(decoded) != ed25519.PublicKeySize {
|
||||
return nil, fmt.Errorf("%w: public key must decode to %d bytes", ErrInvalidKey, ed25519.PublicKeySize)
|
||||
}
|
||||
return ed25519.PublicKey(decoded), nil
|
||||
}
|
||||
|
||||
func DecodePrivateKey(value string) (ed25519.PrivateKey, error) {
|
||||
decoded, err := decodeBase64(strings.TrimSpace(value))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: private key must be base64 encoded", ErrInvalidKey)
|
||||
}
|
||||
if len(decoded) != ed25519.PrivateKeySize {
|
||||
return nil, fmt.Errorf("%w: private key must decode to %d bytes", ErrInvalidKey, ed25519.PrivateKeySize)
|
||||
}
|
||||
return ed25519.PrivateKey(decoded), nil
|
||||
}
|
||||
|
||||
func decodeBase64(value string) ([]byte, error) {
|
||||
if value == "" {
|
||||
return nil, errors.New("empty base64 value")
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(value)
|
||||
if err == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
return base64.RawStdEncoding.DecodeString(value)
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package clusterauth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSignAndVerifyRawPayload(t *testing.T) {
|
||||
keys, err := GenerateKeyPair()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKeyPair: %v", err)
|
||||
}
|
||||
payload := json.RawMessage(`{"cluster_id":"cluster-1","schema_version":"test.v1","value":1}`)
|
||||
|
||||
signature, err := SignRaw(keys.PrivateKeyB64, payload, time.Date(2026, 4, 28, 12, 0, 0, 0, time.UTC))
|
||||
if err != nil {
|
||||
t.Fatalf("SignRaw: %v", err)
|
||||
}
|
||||
if signature.KeyFingerprint != keys.Fingerprint {
|
||||
t.Fatalf("fingerprint = %q, want %q", signature.KeyFingerprint, keys.Fingerprint)
|
||||
}
|
||||
if err := VerifyRaw(keys.PublicKeyB64, payload, signature); err != nil {
|
||||
t.Fatalf("VerifyRaw: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyRawRejectsTamperedPayload(t *testing.T) {
|
||||
keys, err := GenerateKeyPair()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKeyPair: %v", err)
|
||||
}
|
||||
payload := json.RawMessage(`{"cluster_id":"cluster-1","schema_version":"test.v1","value":1}`)
|
||||
signature, err := SignRaw(keys.PrivateKeyB64, payload, time.Date(2026, 4, 28, 12, 0, 0, 0, time.UTC))
|
||||
if err != nil {
|
||||
t.Fatalf("SignRaw: %v", err)
|
||||
}
|
||||
|
||||
tampered := json.RawMessage(`{"cluster_id":"cluster-1","schema_version":"test.v1","value":2}`)
|
||||
if err := VerifyRaw(keys.PublicKeyB64, tampered, signature); !errors.Is(err, ErrInvalidSignature) {
|
||||
t.Fatalf("err = %v, want ErrInvalidSignature", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,307 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
App AppConfig
|
||||
HTTP HTTPConfig
|
||||
Postgres PostgresConfig
|
||||
Redis RedisConfig
|
||||
Auth AuthConfig
|
||||
Installation InstallationConfig
|
||||
DataPlane DataPlaneConfig
|
||||
Secret SecretConfig
|
||||
Session SessionConfig
|
||||
Worker WorkerConfig
|
||||
WebSocket WebSocketConfig
|
||||
}
|
||||
|
||||
type AppConfig struct {
|
||||
Name string
|
||||
Env string
|
||||
}
|
||||
|
||||
type HTTPConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
IdleTimeout time.Duration
|
||||
ShutdownTimeout time.Duration
|
||||
}
|
||||
|
||||
type PostgresConfig struct {
|
||||
DSN string
|
||||
MaxConns int32
|
||||
MinConns int32
|
||||
ConnectTimeout time.Duration
|
||||
}
|
||||
|
||||
type RedisConfig struct {
|
||||
Addr string
|
||||
Password string
|
||||
DB int
|
||||
DialTimeout time.Duration
|
||||
}
|
||||
|
||||
type AuthConfig struct {
|
||||
AccessTokenTTL time.Duration
|
||||
RefreshTokenTTL time.Duration
|
||||
Issuer string
|
||||
AccessTokenSecret string
|
||||
RefreshHashSecret string
|
||||
}
|
||||
|
||||
type InstallationConfig struct {
|
||||
AuthorityMode string
|
||||
ProductRootPublicKeyBase64 string
|
||||
ProductRootPublicKeyFile string
|
||||
AllowInsecureBootstrap bool
|
||||
}
|
||||
|
||||
type DataPlaneConfig struct {
|
||||
TokenTTL time.Duration
|
||||
TokenPrivateKeyPEM string
|
||||
TokenPrivateKeyFile string
|
||||
BackendGatewayURL string
|
||||
DirectWorkerWSSURLTemplate string
|
||||
DirectWorkerJSONRuntime bool
|
||||
DirectWorkerBinaryRender bool
|
||||
DirectWorkerTLSTrustMode string
|
||||
DirectWorkerTLSCARef string
|
||||
}
|
||||
|
||||
type SecretConfig struct {
|
||||
EncryptionKeyBase64 string
|
||||
EncryptionKeyFile string
|
||||
EncryptionKeyID string
|
||||
}
|
||||
|
||||
type SessionConfig struct {
|
||||
HeartbeatTTL time.Duration
|
||||
DetachGracePeriod time.Duration
|
||||
AttachTokenTTL time.Duration
|
||||
LiveStateTTL time.Duration
|
||||
RecoveryBatchSize int
|
||||
}
|
||||
|
||||
type WorkerConfig struct {
|
||||
LeaseTTL time.Duration
|
||||
HeartbeatTTL time.Duration
|
||||
StaleLeaseGracePeriod time.Duration
|
||||
}
|
||||
|
||||
type WebSocketConfig struct {
|
||||
WriteTimeout time.Duration
|
||||
PingInterval time.Duration
|
||||
PongWait time.Duration
|
||||
}
|
||||
|
||||
func Load() (Config, error) {
|
||||
cfg := Config{
|
||||
App: AppConfig{
|
||||
Name: getEnv("APP_NAME", "rap-api"),
|
||||
Env: getEnv("APP_ENV", "development"),
|
||||
},
|
||||
HTTP: HTTPConfig{
|
||||
Host: getEnv("HTTP_HOST", "0.0.0.0"),
|
||||
Port: getInt("HTTP_PORT", 8080),
|
||||
ReadTimeout: getDuration("HTTP_READ_TIMEOUT", 15*time.Second),
|
||||
WriteTimeout: getDuration("HTTP_WRITE_TIMEOUT", 15*time.Second),
|
||||
IdleTimeout: getDuration("HTTP_IDLE_TIMEOUT", 60*time.Second),
|
||||
ShutdownTimeout: getDuration("HTTP_SHUTDOWN_TIMEOUT", 10*time.Second),
|
||||
},
|
||||
Postgres: PostgresConfig{
|
||||
DSN: getEnv("POSTGRES_DSN", ""),
|
||||
MaxConns: int32(getInt("POSTGRES_MAX_CONNS", 20)),
|
||||
MinConns: int32(getInt("POSTGRES_MIN_CONNS", 2)),
|
||||
ConnectTimeout: getDuration("POSTGRES_CONNECT_TIMEOUT", 5*time.Second),
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Addr: getEnv("REDIS_ADDR", "localhost:6379"),
|
||||
Password: getEnv("REDIS_PASSWORD", ""),
|
||||
DB: getInt("REDIS_DB", 0),
|
||||
DialTimeout: getDuration("REDIS_DIAL_TIMEOUT", 5*time.Second),
|
||||
},
|
||||
Auth: AuthConfig{
|
||||
AccessTokenTTL: getDuration("AUTH_ACCESS_TOKEN_TTL", 15*time.Minute),
|
||||
RefreshTokenTTL: getDuration("AUTH_REFRESH_TOKEN_TTL", 30*24*time.Hour),
|
||||
Issuer: getEnv("AUTH_ISSUER", "rap-api"),
|
||||
AccessTokenSecret: getEnv("AUTH_ACCESS_TOKEN_SECRET", ""),
|
||||
RefreshHashSecret: getEnv("AUTH_REFRESH_HASH_SECRET", ""),
|
||||
},
|
||||
Installation: InstallationConfig{
|
||||
AuthorityMode: getEnv("INSTALLATION_AUTHORITY_MODE", ""),
|
||||
ProductRootPublicKeyBase64: getEnv("INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_B64", ""),
|
||||
ProductRootPublicKeyFile: getEnv("INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_FILE", ""),
|
||||
AllowInsecureBootstrap: getBool("INSTALLATION_INSECURE_BOOTSTRAP_ENABLED", false),
|
||||
},
|
||||
DataPlane: DataPlaneConfig{
|
||||
TokenTTL: getDuration("DATA_PLANE_TOKEN_TTL", 1*time.Minute),
|
||||
TokenPrivateKeyPEM: getEnv("DATA_PLANE_TOKEN_PRIVATE_KEY_PEM", ""),
|
||||
TokenPrivateKeyFile: getEnv("DATA_PLANE_TOKEN_PRIVATE_KEY_FILE", ""),
|
||||
BackendGatewayURL: getEnv("DATA_PLANE_BACKEND_GATEWAY_URL", "/api/v1/gateway/ws"),
|
||||
DirectWorkerWSSURLTemplate: getEnv("DATA_PLANE_DIRECT_WORKER_WSS_URL_TEMPLATE", ""),
|
||||
DirectWorkerJSONRuntime: getBool("DATA_PLANE_DIRECT_WORKER_JSON_RUNTIME", false),
|
||||
DirectWorkerBinaryRender: getBool("DATA_PLANE_DIRECT_WORKER_BINARY_RENDER", false),
|
||||
DirectWorkerTLSTrustMode: getEnv("DATA_PLANE_DIRECT_WORKER_TLS_TRUST_MODE", "smoke_insecure"),
|
||||
DirectWorkerTLSCARef: getEnv("DATA_PLANE_DIRECT_WORKER_TLS_CA_REF", ""),
|
||||
},
|
||||
Secret: SecretConfig{
|
||||
EncryptionKeyBase64: getEnv("SECRET_ENCRYPTION_KEY_B64", ""),
|
||||
EncryptionKeyFile: getEnv("SECRET_ENCRYPTION_KEY_FILE", ""),
|
||||
EncryptionKeyID: getEnv("SECRET_ENCRYPTION_KEY_ID", "local-v1"),
|
||||
},
|
||||
Session: SessionConfig{
|
||||
HeartbeatTTL: getDuration("SESSION_HEARTBEAT_TTL", 90*time.Second),
|
||||
DetachGracePeriod: getDuration("SESSION_DETACH_GRACE_PERIOD", 30*time.Minute),
|
||||
AttachTokenTTL: getDuration("SESSION_ATTACH_TOKEN_TTL", 2*time.Minute),
|
||||
LiveStateTTL: getDuration("SESSION_LIVE_STATE_TTL", 2*time.Minute),
|
||||
RecoveryBatchSize: getInt("SESSION_RECOVERY_BATCH_SIZE", 100),
|
||||
},
|
||||
Worker: WorkerConfig{
|
||||
LeaseTTL: getDuration("WORKER_LEASE_TTL", 45*time.Second),
|
||||
HeartbeatTTL: getDuration("WORKER_HEARTBEAT_TTL", 15*time.Second),
|
||||
StaleLeaseGracePeriod: getDuration("WORKER_STALE_LEASE_GRACE_PERIOD", 30*time.Second),
|
||||
},
|
||||
WebSocket: WebSocketConfig{
|
||||
WriteTimeout: getDuration("WEBSOCKET_WRITE_TIMEOUT", 10*time.Second),
|
||||
PingInterval: getDuration("WEBSOCKET_PING_INTERVAL", 20*time.Second),
|
||||
PongWait: getDuration("WEBSOCKET_PONG_WAIT", 40*time.Second),
|
||||
},
|
||||
}
|
||||
|
||||
if cfg.Postgres.DSN == "" {
|
||||
return Config{}, fmt.Errorf("POSTGRES_DSN is required")
|
||||
}
|
||||
if cfg.Auth.AccessTokenSecret == "" {
|
||||
return Config{}, fmt.Errorf("AUTH_ACCESS_TOKEN_SECRET is required")
|
||||
}
|
||||
if cfg.Auth.RefreshHashSecret == "" {
|
||||
return Config{}, fmt.Errorf("AUTH_REFRESH_HASH_SECRET is required")
|
||||
}
|
||||
if cfg.Installation.ProductRootPublicKeyBase64 == "" && cfg.Installation.ProductRootPublicKeyFile != "" {
|
||||
publicKey, err := os.ReadFile(cfg.Installation.ProductRootPublicKeyFile)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("read INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_FILE: %w", err)
|
||||
}
|
||||
cfg.Installation.ProductRootPublicKeyBase64 = strings.TrimSpace(string(publicKey))
|
||||
}
|
||||
cfg.Installation.AuthorityMode = normalizeInstallationAuthorityMode(cfg.Installation.AuthorityMode, cfg.Installation.ProductRootPublicKeyBase64)
|
||||
if isProductionEnv(cfg.App.Env) && cfg.Installation.AuthorityMode != "strict" {
|
||||
return Config{}, fmt.Errorf("INSTALLATION_AUTHORITY_MODE=strict with INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_B64 or file is required in production")
|
||||
}
|
||||
if cfg.DataPlane.TokenPrivateKeyPEM == "" && cfg.DataPlane.TokenPrivateKeyFile != "" {
|
||||
privateKey, err := os.ReadFile(cfg.DataPlane.TokenPrivateKeyFile)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("read DATA_PLANE_TOKEN_PRIVATE_KEY_FILE: %w", err)
|
||||
}
|
||||
cfg.DataPlane.TokenPrivateKeyPEM = string(privateKey)
|
||||
}
|
||||
if cfg.Secret.EncryptionKeyBase64 == "" && cfg.Secret.EncryptionKeyFile != "" {
|
||||
secretKey, err := os.ReadFile(cfg.Secret.EncryptionKeyFile)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("read SECRET_ENCRYPTION_KEY_FILE: %w", err)
|
||||
}
|
||||
cfg.Secret.EncryptionKeyBase64 = strings.TrimSpace(string(secretKey))
|
||||
}
|
||||
if cfg.Secret.EncryptionKeyBase64 != "" {
|
||||
decoded, err := base64.StdEncoding.DecodeString(cfg.Secret.EncryptionKeyBase64)
|
||||
if err != nil {
|
||||
if decodedRaw, rawErr := base64.RawStdEncoding.DecodeString(cfg.Secret.EncryptionKeyBase64); rawErr == nil {
|
||||
decoded = decodedRaw
|
||||
} else {
|
||||
return Config{}, fmt.Errorf("SECRET_ENCRYPTION_KEY_B64 must be base64 encoded: %w", err)
|
||||
}
|
||||
}
|
||||
if len(decoded) != 32 {
|
||||
return Config{}, fmt.Errorf("SECRET_ENCRYPTION_KEY_B64 must decode to 32 bytes for AES-256-GCM")
|
||||
}
|
||||
}
|
||||
if isProductionEnv(cfg.App.Env) && cfg.Secret.EncryptionKeyBase64 == "" {
|
||||
return Config{}, fmt.Errorf("SECRET_ENCRYPTION_KEY_B64 or SECRET_ENCRYPTION_KEY_FILE is required in production")
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func normalizeInstallationAuthorityMode(mode string, rootPublicKey string) string {
|
||||
mode = strings.ToLower(strings.TrimSpace(mode))
|
||||
switch mode {
|
||||
case "strict", "legacy":
|
||||
return mode
|
||||
case "":
|
||||
if strings.TrimSpace(rootPublicKey) != "" {
|
||||
return "strict"
|
||||
}
|
||||
return "legacy"
|
||||
default:
|
||||
return mode
|
||||
}
|
||||
}
|
||||
|
||||
func isProductionEnv(appEnv string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(appEnv)) {
|
||||
case "production", "prod":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func getEnv(key, fallback string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func getInt(key string, fallback int) int {
|
||||
value := os.Getenv(key)
|
||||
if value == "" {
|
||||
return fallback
|
||||
}
|
||||
|
||||
parsed, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
|
||||
return parsed
|
||||
}
|
||||
|
||||
func getBool(key string, fallback bool) bool {
|
||||
value := os.Getenv(key)
|
||||
if value == "" {
|
||||
return fallback
|
||||
}
|
||||
switch value {
|
||||
case "1", "true", "TRUE", "yes", "on":
|
||||
return true
|
||||
case "0", "false", "FALSE", "no", "off":
|
||||
return false
|
||||
default:
|
||||
return fallback
|
||||
}
|
||||
}
|
||||
|
||||
func getDuration(key string, fallback time.Duration) time.Duration {
|
||||
value := os.Getenv(key)
|
||||
if value == "" {
|
||||
return fallback
|
||||
}
|
||||
|
||||
parsed, err := time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
|
||||
return parsed
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package httpserver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/config"
|
||||
)
|
||||
|
||||
func New(cfg config.HTTPConfig, handler http.Handler) *http.Server {
|
||||
return &http.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||
Handler: handler,
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
ReadTimeout: cfg.ReadTimeout,
|
||||
WriteTimeout: cfg.WriteTimeout,
|
||||
IdleTimeout: cfg.IdleTimeout,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func WriteJSON(w http.ResponseWriter, status int, payload any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(payload)
|
||||
}
|
||||
|
||||
func WriteError(w http.ResponseWriter, status int, message string) {
|
||||
traceID := ensureTraceID(w)
|
||||
WriteJSON(w, status, ErrorResponse{
|
||||
Error: NewErrorMessage(status, message, nil, traceID),
|
||||
})
|
||||
}
|
||||
|
||||
func WriteErrorMessage(w http.ResponseWriter, status int, message any) {
|
||||
traceID := ensureTraceID(w)
|
||||
switch payload := message.(type) {
|
||||
case string:
|
||||
WriteJSON(w, status, ErrorResponse{
|
||||
Error: NewErrorMessage(status, payload, nil, traceID),
|
||||
})
|
||||
case ErrorResponse:
|
||||
payload.Error.TraceID = traceID
|
||||
WriteJSON(w, status, payload)
|
||||
case *ErrorResponse:
|
||||
if payload == nil {
|
||||
WriteJSON(w, status, ErrorResponse{
|
||||
Error: NewErrorMessage(status, "", nil, traceID),
|
||||
})
|
||||
return
|
||||
}
|
||||
payload.Error.TraceID = traceID
|
||||
WriteJSON(w, status, payload)
|
||||
default:
|
||||
WriteJSON(w, status, ErrorResponse{
|
||||
Error: NewErrorMessage(status, "Request failed.", nil, traceID),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
messagecontracts "github.com/example/remote-access-platform/backend/pkg/contracts/message"
|
||||
)
|
||||
|
||||
type ErrorResponse struct {
|
||||
Error messagecontracts.Message `json:"error"`
|
||||
}
|
||||
|
||||
func NewMessage(code, messageKey, fallbackMessage string, details map[string]any, traceID string) messagecontracts.Message {
|
||||
if traceID == "" {
|
||||
traceID = uuid.NewString()
|
||||
}
|
||||
if details == nil {
|
||||
details = map[string]any{}
|
||||
}
|
||||
return messagecontracts.Message{
|
||||
Code: code,
|
||||
MessageKey: messageKey,
|
||||
FallbackMessage: fallbackMessage,
|
||||
Details: details,
|
||||
TraceID: traceID,
|
||||
}
|
||||
}
|
||||
|
||||
func NewErrorMessage(status int, fallbackMessage string, details map[string]any, traceID string) messagecontracts.Message {
|
||||
normalizedFallback, normalizedDetails := normalizeErrorFallback(status, fallbackMessage, details)
|
||||
code := deriveErrorCode(status, normalizedFallback)
|
||||
return NewMessage(code, "errors."+code, normalizedFallback, normalizedDetails, traceID)
|
||||
}
|
||||
|
||||
func ensureTraceID(w http.ResponseWriter) string {
|
||||
traceID := w.Header().Get("X-Trace-Id")
|
||||
if traceID == "" {
|
||||
traceID = uuid.NewString()
|
||||
w.Header().Set("X-Trace-Id", traceID)
|
||||
}
|
||||
return traceID
|
||||
}
|
||||
|
||||
func normalizeErrorFallback(status int, fallbackMessage string, details map[string]any) (string, map[string]any) {
|
||||
if details == nil {
|
||||
details = map[string]any{}
|
||||
}
|
||||
details["http_status"] = status
|
||||
|
||||
if status >= http.StatusInternalServerError {
|
||||
return "An internal server error occurred.", details
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(fallbackMessage)
|
||||
switch strings.ToLower(trimmed) {
|
||||
case "forbidden", "access denied":
|
||||
return "Access denied.", details
|
||||
}
|
||||
|
||||
if field, ok := extractRequiredField(trimmed); ok {
|
||||
details["field"] = field
|
||||
}
|
||||
|
||||
return trimmed, details
|
||||
}
|
||||
|
||||
func deriveErrorCode(status int, fallbackMessage string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(fallbackMessage)) {
|
||||
case "invalid credentials":
|
||||
return "auth.invalid_credentials"
|
||||
case "session expired. please sign in again.":
|
||||
return "auth.session_expired"
|
||||
case "access denied.":
|
||||
return "common.access_denied"
|
||||
}
|
||||
|
||||
statusPrefix := map[int]string{
|
||||
http.StatusBadRequest: "bad_request",
|
||||
http.StatusUnauthorized: "unauthorized",
|
||||
http.StatusForbidden: "forbidden",
|
||||
http.StatusNotFound: "not_found",
|
||||
http.StatusConflict: "conflict",
|
||||
http.StatusUnprocessableEntity: "unprocessable_entity",
|
||||
http.StatusInternalServerError: "internal_server_error",
|
||||
}[status]
|
||||
if statusPrefix == "" {
|
||||
statusPrefix = "http_" + strings.ReplaceAll(http.StatusText(status), " ", "_")
|
||||
statusPrefix = strings.ToLower(statusPrefix)
|
||||
}
|
||||
|
||||
slug := slugifyMessage(fallbackMessage)
|
||||
if slug == "" {
|
||||
slug = "message"
|
||||
}
|
||||
if status >= http.StatusInternalServerError {
|
||||
return "common." + statusPrefix
|
||||
}
|
||||
return statusPrefix + "." + slug
|
||||
}
|
||||
|
||||
func slugifyMessage(input string) string {
|
||||
var builder strings.Builder
|
||||
lastUnderscore := false
|
||||
for _, r := range strings.ToLower(strings.TrimSpace(input)) {
|
||||
if unicode.IsLetter(r) || unicode.IsDigit(r) {
|
||||
builder.WriteRune(r)
|
||||
lastUnderscore = false
|
||||
continue
|
||||
}
|
||||
if !lastUnderscore {
|
||||
builder.WriteRune('_')
|
||||
lastUnderscore = true
|
||||
}
|
||||
}
|
||||
return strings.Trim(builder.String(), "_")
|
||||
}
|
||||
|
||||
func extractRequiredField(message string) (string, bool) {
|
||||
const suffix = " is required"
|
||||
if !strings.HasSuffix(strings.ToLower(message), suffix) {
|
||||
return "", false
|
||||
}
|
||||
field := strings.TrimSpace(message[:len(message)-len(suffix)])
|
||||
field = strings.ReplaceAll(field, " ", "_")
|
||||
field = strings.ToLower(field)
|
||||
return field, field != ""
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
)
|
||||
|
||||
func New(env string) *slog.Logger {
|
||||
level := slog.LevelInfo
|
||||
if env == "development" {
|
||||
level = slog.LevelDebug
|
||||
}
|
||||
|
||||
return slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
}))
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package module
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/config"
|
||||
)
|
||||
|
||||
type Dependencies struct {
|
||||
Config Config
|
||||
Infra Infra
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
App config.AppConfig
|
||||
Auth config.AuthConfig
|
||||
Installation config.InstallationConfig
|
||||
DataPlane config.DataPlaneConfig
|
||||
Secret config.SecretConfig
|
||||
Session config.SessionConfig
|
||||
Worker config.WorkerConfig
|
||||
WebSocket config.WebSocketConfig
|
||||
}
|
||||
|
||||
type Infra struct {
|
||||
Logger *slog.Logger
|
||||
DB *pgxpool.Pool
|
||||
Redis *redis.Client
|
||||
}
|
||||
|
||||
type Module interface {
|
||||
Name() string
|
||||
RegisterRoutes(router chi.Router)
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/config"
|
||||
)
|
||||
|
||||
func Open(ctx context.Context, cfg config.PostgresConfig) (*pgxpool.Pool, error) {
|
||||
poolConfig, err := pgxpool.ParseConfig(cfg.DSN)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse postgres dsn: %w", err)
|
||||
}
|
||||
|
||||
poolConfig.MaxConns = cfg.MaxConns
|
||||
poolConfig.MinConns = cfg.MinConns
|
||||
poolConfig.ConnConfig.ConnectTimeout = cfg.ConnectTimeout
|
||||
|
||||
pool, err := pgxpool.NewWithConfig(ctx, poolConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create postgres pool: %w", err)
|
||||
}
|
||||
|
||||
if err := pool.Ping(ctx); err != nil {
|
||||
pool.Close()
|
||||
return nil, fmt.Errorf("ping postgres: %w", err)
|
||||
}
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type DBTX interface {
|
||||
Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
|
||||
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
|
||||
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
|
||||
}
|
||||
|
||||
func WithTransaction(ctx context.Context, pool *pgxpool.Pool, fn func(tx pgx.Tx) error) error {
|
||||
tx, err := pool.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
_ = tx.Rollback(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return fmt.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
goredis "github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/config"
|
||||
)
|
||||
|
||||
func Open(ctx context.Context, cfg config.RedisConfig) (*goredis.Client, error) {
|
||||
client := goredis.NewClient(&goredis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
DialTimeout: cfg.DialTimeout,
|
||||
})
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, fmt.Errorf("ping redis: %w", err)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
@@ -0,0 +1,220 @@
|
||||
package runtime
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
chimiddleware "github.com/go-chi/chi/v5/middleware"
|
||||
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/auth"
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/cluster"
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/identitysource"
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/node"
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/nodeagent"
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/organization"
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/resource"
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/sessionbroker"
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/sessiongateway"
|
||||
"github.com/example/remote-access-platform/backend/internal/modules/worker"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/authority"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/config"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/httpserver"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/logging"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/module"
|
||||
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
|
||||
redisplatform "github.com/example/remote-access-platform/backend/internal/platform/redis"
|
||||
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
|
||||
)
|
||||
|
||||
type App struct {
|
||||
cfg config.Config
|
||||
logger *slog.Logger
|
||||
httpServer *http.Server
|
||||
workers []backgroundRunner
|
||||
db closeFunc
|
||||
redis closeFunc
|
||||
}
|
||||
|
||||
type closeFunc func() error
|
||||
type backgroundRunner func(context.Context) error
|
||||
|
||||
func NewApp(ctx context.Context) (*App, error) {
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
logger := logging.New(cfg.App.Env)
|
||||
|
||||
db, err := postgresplatform.Open(ctx, cfg.Postgres)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
redisClient, err := redisplatform.Open(ctx, cfg.Redis)
|
||||
if err != nil {
|
||||
db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authorityVerifier, err := authority.NewVerifier(cfg.Installation)
|
||||
if err != nil {
|
||||
redisClient.Close()
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("create installation authority verifier: %w", err)
|
||||
}
|
||||
|
||||
deps := module.Dependencies{
|
||||
Config: module.Config{
|
||||
App: cfg.App,
|
||||
Auth: cfg.Auth,
|
||||
Installation: cfg.Installation,
|
||||
DataPlane: cfg.DataPlane,
|
||||
Secret: cfg.Secret,
|
||||
Session: cfg.Session,
|
||||
Worker: cfg.Worker,
|
||||
WebSocket: cfg.WebSocket,
|
||||
},
|
||||
Infra: module.Infra{
|
||||
Logger: logger,
|
||||
DB: db,
|
||||
Redis: redisClient,
|
||||
},
|
||||
}
|
||||
|
||||
workerStore := worker.NewRedisStore(redisClient)
|
||||
workerService := worker.NewService(deps, workerStore)
|
||||
authStore := auth.NewPostgresStore(db)
|
||||
authTx := auth.NewPostgresTransactor(db)
|
||||
authService := auth.NewService(deps, authStore, authTx, authorityVerifier)
|
||||
var resourceSecretStore *secrets.ResourceSecretStore
|
||||
if cfg.Secret.EncryptionKeyBase64 != "" {
|
||||
secretEncryptor, err := secrets.NewEncryptor(cfg.Secret.EncryptionKeyBase64, cfg.Secret.EncryptionKeyID)
|
||||
if err != nil {
|
||||
redisClient.Close()
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("create resource secret encryptor: %w", err)
|
||||
}
|
||||
resourceSecretStore = secrets.NewResourceSecretStore(db, secretEncryptor)
|
||||
}
|
||||
|
||||
brokerStore := sessionbroker.NewPostgresStore(db, authorityVerifier)
|
||||
brokerTx := sessionbroker.NewPostgresTransactor(db, authorityVerifier)
|
||||
liveStateStore := sessionbroker.NewRedisLiveStateStore(redisClient)
|
||||
brokerService := sessionbroker.NewService(deps, brokerStore, brokerTx, liveStateStore, workerService, resourceSecretStore)
|
||||
workerEvents := worker.NewEventProcessor(redisClient, brokerService)
|
||||
leaseMonitor := worker.NewLeaseMonitor(workerService, brokerService, cfg.Worker.StaleLeaseGracePeriod)
|
||||
|
||||
brokerModule := sessionbroker.NewModule(brokerService)
|
||||
authModule := auth.NewModule(deps, authService)
|
||||
clusterModule := cluster.NewModule(deps, authorityVerifier)
|
||||
organizationModule := organization.NewModule(deps)
|
||||
identitySourceModule := identitysource.NewModule(deps)
|
||||
resourceModule := resource.NewModule(deps, resourceSecretStore)
|
||||
nodeModule := node.NewModule(deps)
|
||||
nodeAgentModule := nodeagent.NewModule(deps)
|
||||
sessionGatewayModule := sessiongateway.NewModule(deps, brokerModule.Service(), workerService)
|
||||
|
||||
router := buildRouter(
|
||||
logger,
|
||||
authModule,
|
||||
clusterModule,
|
||||
organizationModule,
|
||||
identitySourceModule,
|
||||
resourceModule,
|
||||
brokerModule,
|
||||
nodeModule,
|
||||
nodeAgentModule,
|
||||
sessionGatewayModule,
|
||||
)
|
||||
|
||||
return &App{
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
httpServer: httpserver.New(cfg.HTTP, router),
|
||||
workers: []backgroundRunner{workerEvents.Run, leaseMonitor.Run},
|
||||
db: func() error {
|
||||
db.Close()
|
||||
return nil
|
||||
},
|
||||
redis: redisClient.Close,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *App) Run(ctx context.Context) error {
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
a.logger.Info("http server starting", "addr", a.httpServer.Addr, "service", a.cfg.App.Name)
|
||||
if err := a.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
errCh <- nil
|
||||
}()
|
||||
|
||||
for _, runner := range a.workers {
|
||||
runner := runner
|
||||
go func() {
|
||||
if err := runner(ctx); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
a.logger.Info("shutdown signal received")
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), a.cfg.HTTP.ShutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := a.httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
return fmt.Errorf("shutdown http server: %w", err)
|
||||
}
|
||||
|
||||
if err := a.redis(); err != nil {
|
||||
return fmt.Errorf("close redis: %w", err)
|
||||
}
|
||||
|
||||
if err := a.db(); err != nil {
|
||||
return fmt.Errorf("close postgres: %w", err)
|
||||
}
|
||||
|
||||
a.logger.Info("app stopped", "at", time.Now().UTC())
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildRouter(logger *slog.Logger, modules ...module.Module) http.Handler {
|
||||
router := chi.NewRouter()
|
||||
router.Use(chimiddleware.RequestID)
|
||||
router.Use(chimiddleware.RealIP)
|
||||
router.Use(chimiddleware.Recoverer)
|
||||
router.Use(chimiddleware.Timeout(60 * time.Second))
|
||||
router.Use(chimiddleware.Heartbeat("/healthz"))
|
||||
|
||||
router.Get("/readyz", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ready"))
|
||||
})
|
||||
|
||||
router.Route("/api/v1", func(r chi.Router) {
|
||||
for _, mod := range modules {
|
||||
logger.Info("register module routes", "module", mod.Name())
|
||||
mod.RegisterRoutes(r)
|
||||
}
|
||||
})
|
||||
|
||||
return router
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package secrets
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type AssignmentSecretMergeResult struct {
|
||||
Metadata map[string]any
|
||||
Keys []string
|
||||
}
|
||||
|
||||
func MergeResourceSecretIntoAssignmentMetadata(metadata map[string]any, payload json.RawMessage) (AssignmentSecretMergeResult, error) {
|
||||
if metadata == nil {
|
||||
metadata = map[string]any{}
|
||||
}
|
||||
var secretPayload map[string]any
|
||||
if err := json.Unmarshal(payload, &secretPayload); err != nil {
|
||||
return AssignmentSecretMergeResult{}, fmt.Errorf("decode resolved resource secret: %w", err)
|
||||
}
|
||||
resource, _ := metadata["resource"].(map[string]any)
|
||||
if resource == nil {
|
||||
resource = map[string]any{}
|
||||
metadata["resource"] = resource
|
||||
}
|
||||
resourceMetadata, _ := resource["metadata"].(map[string]any)
|
||||
if resourceMetadata == nil {
|
||||
resourceMetadata = map[string]any{}
|
||||
resource["metadata"] = resourceMetadata
|
||||
}
|
||||
keys := make([]string, 0, len(secretPayload))
|
||||
for key, value := range secretPayload {
|
||||
resourceMetadata[key] = value
|
||||
keys = append(keys, key)
|
||||
}
|
||||
return AssignmentSecretMergeResult{Metadata: metadata, Keys: keys}, nil
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
package secrets
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const AlgorithmAES256GCM = "AES-256-GCM"
|
||||
|
||||
var (
|
||||
ErrSecretEncryptionKeyMissing = errors.New("secret encryption key is not configured")
|
||||
ErrSecretPayloadInvalid = errors.New("secret payload must be a json object")
|
||||
)
|
||||
|
||||
type Encryptor struct {
|
||||
aead cipher.AEAD
|
||||
keyID string
|
||||
}
|
||||
|
||||
type EncryptedPayload struct {
|
||||
Algorithm string
|
||||
KeyID string
|
||||
Nonce []byte
|
||||
Ciphertext []byte
|
||||
PayloadSHA256 string
|
||||
}
|
||||
|
||||
func NewEncryptor(masterKeyBase64, keyID string) (*Encryptor, error) {
|
||||
masterKeyBase64 = strings.TrimSpace(masterKeyBase64)
|
||||
if masterKeyBase64 == "" {
|
||||
return nil, ErrSecretEncryptionKeyMissing
|
||||
}
|
||||
key, err := base64.StdEncoding.DecodeString(masterKeyBase64)
|
||||
if err != nil {
|
||||
if rawKey, rawErr := base64.RawStdEncoding.DecodeString(masterKeyBase64); rawErr == nil {
|
||||
key = rawKey
|
||||
} else {
|
||||
return nil, fmt.Errorf("decode secret encryption key: %w", err)
|
||||
}
|
||||
}
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("secret encryption key must decode to 32 bytes")
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create secret cipher: %w", err)
|
||||
}
|
||||
aead, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create secret gcm: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(keyID) == "" {
|
||||
keyID = "local-v1"
|
||||
}
|
||||
return &Encryptor{aead: aead, keyID: keyID}, nil
|
||||
}
|
||||
|
||||
func (e *Encryptor) KeyID() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return e.keyID
|
||||
}
|
||||
|
||||
func (e *Encryptor) Encrypt(plaintext, aad []byte) (EncryptedPayload, error) {
|
||||
if e == nil {
|
||||
return EncryptedPayload{}, ErrSecretEncryptionKeyMissing
|
||||
}
|
||||
nonce := make([]byte, e.aead.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return EncryptedPayload{}, fmt.Errorf("generate secret nonce: %w", err)
|
||||
}
|
||||
hash := sha256.Sum256(plaintext)
|
||||
return EncryptedPayload{
|
||||
Algorithm: AlgorithmAES256GCM,
|
||||
KeyID: e.keyID,
|
||||
Nonce: nonce,
|
||||
Ciphertext: e.aead.Seal(nil, nonce, plaintext, aad),
|
||||
PayloadSHA256: hex.EncodeToString(hash[:]),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *Encryptor) Decrypt(payload EncryptedPayload, aad []byte) ([]byte, error) {
|
||||
if e == nil {
|
||||
return nil, ErrSecretEncryptionKeyMissing
|
||||
}
|
||||
if payload.Algorithm != "" && payload.Algorithm != AlgorithmAES256GCM {
|
||||
return nil, fmt.Errorf("unsupported secret algorithm %q", payload.Algorithm)
|
||||
}
|
||||
plaintext, err := e.aead.Open(nil, payload.Nonce, payload.Ciphertext, aad)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt secret payload: %w", err)
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
func ResourceSecretAAD(organizationID, resourceID, secretRef, protocol string) []byte {
|
||||
return []byte(strings.Join([]string{
|
||||
"rap-resource-secret-v1",
|
||||
strings.TrimSpace(organizationID),
|
||||
strings.TrimSpace(resourceID),
|
||||
strings.TrimSpace(secretRef),
|
||||
strings.ToLower(strings.TrimSpace(protocol)),
|
||||
}, "|"))
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package secrets
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEncryptorRoundTrip(t *testing.T) {
|
||||
key := base64.StdEncoding.EncodeToString([]byte("0123456789abcdef0123456789abcdef"))
|
||||
encryptor, err := NewEncryptor(key, "test-key")
|
||||
if err != nil {
|
||||
t.Fatalf("NewEncryptor returned error: %v", err)
|
||||
}
|
||||
aad := ResourceSecretAAD("org-1", "resource-1", "rap-secret://test", "rdp")
|
||||
encrypted, err := encryptor.Encrypt([]byte(`{"username":"user","password":"secret"}`), aad)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt returned error: %v", err)
|
||||
}
|
||||
plaintext, err := encryptor.Decrypt(encrypted, aad)
|
||||
if err != nil {
|
||||
t.Fatalf("Decrypt returned error: %v", err)
|
||||
}
|
||||
if string(plaintext) != `{"username":"user","password":"secret"}` {
|
||||
t.Fatalf("unexpected plaintext: %s", plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptorRejectsWrongAAD(t *testing.T) {
|
||||
key := base64.StdEncoding.EncodeToString([]byte("0123456789abcdef0123456789abcdef"))
|
||||
encryptor, err := NewEncryptor(key, "test-key")
|
||||
if err != nil {
|
||||
t.Fatalf("NewEncryptor returned error: %v", err)
|
||||
}
|
||||
encrypted, err := encryptor.Encrypt([]byte(`{"password":"secret"}`), ResourceSecretAAD("org-1", "resource-1", "ref", "rdp"))
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt returned error: %v", err)
|
||||
}
|
||||
if _, err := encryptor.Decrypt(encrypted, ResourceSecretAAD("org-2", "resource-1", "ref", "rdp")); err == nil {
|
||||
t.Fatalf("expected decrypt with wrong aad to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeResourceSecretIntoAssignmentMetadata(t *testing.T) {
|
||||
metadata := map[string]any{
|
||||
"resource": map[string]any{
|
||||
"id": "resource-1",
|
||||
"metadata": map[string]any{
|
||||
"rdp_host": "host",
|
||||
},
|
||||
},
|
||||
}
|
||||
merged, err := MergeResourceSecretIntoAssignmentMetadata(metadata, json.RawMessage(`{"username":"user","password":"secret","domain":"corp"}`))
|
||||
if err != nil {
|
||||
t.Fatalf("MergeResourceSecretIntoAssignmentMetadata returned error: %v", err)
|
||||
}
|
||||
resource := merged.Metadata["resource"].(map[string]any)
|
||||
resourceMetadata := resource["metadata"].(map[string]any)
|
||||
if resourceMetadata["rdp_host"] != "host" {
|
||||
t.Fatalf("existing metadata was not preserved")
|
||||
}
|
||||
if resourceMetadata["username"] != "user" || resourceMetadata["password"] != "secret" || resourceMetadata["domain"] != "corp" {
|
||||
t.Fatalf("secret payload was not merged: %#v", resourceMetadata)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
package secrets
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPlaintextResourceCredentials = errors.New("plaintext resource credentials are not allowed in metadata in production")
|
||||
ErrMissingResourceSecretRef = errors.New("secret_ref is required for this resource protocol in production")
|
||||
)
|
||||
|
||||
var credentialKeyFragments = []string{
|
||||
"accesstoken",
|
||||
"clientsecret",
|
||||
"credential",
|
||||
"credentials",
|
||||
"domain",
|
||||
"password",
|
||||
"privatekey",
|
||||
"refreshtoken",
|
||||
"secret",
|
||||
"secrets",
|
||||
"token",
|
||||
"user",
|
||||
"username",
|
||||
}
|
||||
|
||||
var safeReferenceKeys = []string{
|
||||
"certificateverificationmode",
|
||||
"renderqualityprofile",
|
||||
"secretref",
|
||||
"secretreference",
|
||||
"vaultref",
|
||||
}
|
||||
|
||||
func ValidateResourceSecretReadiness(protocol string, secretRef *string, metadata json.RawMessage, appEnv string) error {
|
||||
if !IsProductionEnv(appEnv) {
|
||||
return nil
|
||||
}
|
||||
|
||||
paths, err := PlaintextCredentialMetadataPaths(metadata)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(paths) > 0 {
|
||||
return fmt.Errorf("%w: %s", ErrPlaintextResourceCredentials, strings.Join(paths, ", "))
|
||||
}
|
||||
if ResourceProtocolRequiresSecretRef(protocol) && (secretRef == nil || strings.TrimSpace(*secretRef) == "") {
|
||||
return ErrMissingResourceSecretRef
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func IsProductionEnv(appEnv string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(appEnv)) {
|
||||
case "prod", "production":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func ResourceProtocolRequiresSecretRef(protocol string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(protocol)) {
|
||||
case "rdp", "vnc", "ssh":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func PlaintextCredentialMetadataPaths(raw json.RawMessage) ([]string, error) {
|
||||
if len(raw) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var value any
|
||||
if err := json.Unmarshal(raw, &value); err != nil {
|
||||
return nil, errors.New("metadata must be valid json")
|
||||
}
|
||||
metadata, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("metadata must be a json object")
|
||||
}
|
||||
var paths []string
|
||||
collectCredentialPaths(metadata, "", &paths)
|
||||
sort.Strings(paths)
|
||||
return slices.Compact(paths), nil
|
||||
}
|
||||
|
||||
func collectCredentialPaths(value any, prefix string, paths *[]string) {
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
for key, child := range typed {
|
||||
path := key
|
||||
if prefix != "" {
|
||||
path = prefix + "." + key
|
||||
}
|
||||
if isCredentialMetadataKey(key) {
|
||||
*paths = append(*paths, path)
|
||||
}
|
||||
collectCredentialPaths(child, path, paths)
|
||||
}
|
||||
case []any:
|
||||
for index, child := range typed {
|
||||
collectCredentialPaths(child, fmt.Sprintf("%s[%d]", prefix, index), paths)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isCredentialMetadataKey(key string) bool {
|
||||
normalized := normalizeMetadataKey(key)
|
||||
if slices.Contains(safeReferenceKeys, normalized) {
|
||||
return false
|
||||
}
|
||||
for _, fragment := range credentialKeyFragments {
|
||||
if normalized == fragment || strings.HasSuffix(normalized, fragment) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizeMetadataKey(key string) string {
|
||||
key = strings.ToLower(strings.TrimSpace(key))
|
||||
replacer := strings.NewReplacer("_", "", "-", "", " ", "", ".", "")
|
||||
return replacer.Replace(key)
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package secrets
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidateResourceSecretReadinessAllowsPlaintextInDevelopment(t *testing.T) {
|
||||
metadata := json.RawMessage(`{"username":"m","password":"secret"}`)
|
||||
if err := ValidateResourceSecretReadiness("rdp", nil, metadata, "development"); err != nil {
|
||||
t.Fatalf("development metadata should remain allowed for smoke/dev: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateResourceSecretReadinessRejectsPlaintextCredentialsInProduction(t *testing.T) {
|
||||
metadata := json.RawMessage(`{"rdp_host":"host","credentials":{"username":"m","password":"secret"}}`)
|
||||
err := ValidateResourceSecretReadiness("rdp", stringPtr("vault://org/resource"), metadata, "production")
|
||||
if !errors.Is(err, ErrPlaintextResourceCredentials) {
|
||||
t.Fatalf("expected plaintext credential rejection, got %v", err)
|
||||
}
|
||||
|
||||
paths, err := PlaintextCredentialMetadataPaths(metadata)
|
||||
if err != nil {
|
||||
t.Fatalf("metadata paths: %v", err)
|
||||
}
|
||||
for _, expected := range []string{"credentials", "credentials.password", "credentials.username"} {
|
||||
if !slices.Contains(paths, expected) {
|
||||
t.Fatalf("expected sensitive path %q in %v", expected, paths)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateResourceSecretReadinessRequiresSecretRefForProductionRDP(t *testing.T) {
|
||||
metadata := json.RawMessage(`{"rdp_host":"host","rdp_port":3389}`)
|
||||
err := ValidateResourceSecretReadiness("rdp", nil, metadata, "production")
|
||||
if !errors.Is(err, ErrMissingResourceSecretRef) {
|
||||
t.Fatalf("expected missing secret_ref rejection, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateResourceSecretReadinessAllowsProductionSecretRef(t *testing.T) {
|
||||
metadata := json.RawMessage(`{"rdp_host":"host","rdp_port":3389,"secret_ref":"vault://org/resource"}`)
|
||||
if err := ValidateResourceSecretReadiness("rdp", stringPtr("vault://org/resource"), metadata, "production"); err != nil {
|
||||
t.Fatalf("production secret_ref metadata should be accepted: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func stringPtr(value string) *string {
|
||||
return &value
|
||||
}
|
||||
@@ -0,0 +1,259 @@
|
||||
package secrets
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrResourceSecretNotFound = errors.New("resource secret not found")
|
||||
ErrSecretAccessDenied = errors.New("resource secret access denied")
|
||||
ErrSecretLeaseRequired = errors.New("resource secret resolution requires lease proof")
|
||||
)
|
||||
|
||||
type ResourceSecretStore struct {
|
||||
db postgresplatform.DBTX
|
||||
encryptor *Encryptor
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
type ResourceSecretResolver interface {
|
||||
ResolveForSession(ctx context.Context, req ResolveResourceSecretRequest) (*ResolvedResourceSecret, error)
|
||||
}
|
||||
|
||||
type ResourceSecretDescriptor struct {
|
||||
ID string `json:"id"`
|
||||
OrganizationID string `json:"organization_id"`
|
||||
ResourceID string `json:"resource_id"`
|
||||
SecretRef string `json:"secret_ref"`
|
||||
Protocol string `json:"protocol"`
|
||||
Version int `json:"version"`
|
||||
KeyID string `json:"key_id"`
|
||||
Algorithm string `json:"algorithm"`
|
||||
Metadata json.RawMessage `json:"metadata"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
RotatedAt *time.Time `json:"rotated_at,omitempty"`
|
||||
}
|
||||
|
||||
type UpsertResourceSecretCommand struct {
|
||||
OrganizationID string
|
||||
ResourceID string
|
||||
Protocol string
|
||||
SecretRef string
|
||||
Payload json.RawMessage
|
||||
Metadata json.RawMessage
|
||||
ActorUserID string
|
||||
}
|
||||
|
||||
type ResolveResourceSecretRequest struct {
|
||||
SecretRef string
|
||||
OrganizationID string
|
||||
ResourceID string
|
||||
SessionID string
|
||||
WorkerID string
|
||||
LeaseID string
|
||||
}
|
||||
|
||||
type ResolvedResourceSecret struct {
|
||||
Descriptor ResourceSecretDescriptor
|
||||
Payload json.RawMessage
|
||||
}
|
||||
|
||||
func NewResourceSecretStore(db postgresplatform.DBTX, encryptor *Encryptor) *ResourceSecretStore {
|
||||
return &ResourceSecretStore{db: db, encryptor: encryptor, now: time.Now}
|
||||
}
|
||||
|
||||
func (s *ResourceSecretStore) WithDB(db postgresplatform.DBTX) *ResourceSecretStore {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return &ResourceSecretStore{db: db, encryptor: s.encryptor, now: s.now}
|
||||
}
|
||||
|
||||
func DefaultResourceSecretRef(organizationID, resourceID string) string {
|
||||
return "rap-secret://org/" + strings.TrimSpace(organizationID) + "/resources/" + strings.TrimSpace(resourceID) + "/primary"
|
||||
}
|
||||
|
||||
func (s *ResourceSecretStore) Upsert(ctx context.Context, cmd UpsertResourceSecretCommand) (*ResourceSecretDescriptor, error) {
|
||||
if s == nil || s.encryptor == nil {
|
||||
return nil, ErrSecretEncryptionKeyMissing
|
||||
}
|
||||
payload, err := normalizeJSONObject(cmd.Payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
metadata, err := normalizeJSONObjectAllowEmpty(cmd.Metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
secretRef := strings.TrimSpace(cmd.SecretRef)
|
||||
if secretRef == "" {
|
||||
secretRef = DefaultResourceSecretRef(cmd.OrganizationID, cmd.ResourceID)
|
||||
}
|
||||
protocol := strings.ToLower(strings.TrimSpace(cmd.Protocol))
|
||||
encrypted, err := s.encryptor.Encrypt(payload, ResourceSecretAAD(cmd.OrganizationID, cmd.ResourceID, secretRef, protocol))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
now := s.now().UTC()
|
||||
const query = `
|
||||
INSERT INTO resource_secrets (
|
||||
organization_id, resource_id, secret_ref, protocol, version, key_id,
|
||||
algorithm, nonce, ciphertext, payload_sha256, metadata, created_by_user_id,
|
||||
created_at, rotated_at
|
||||
) VALUES (
|
||||
$1::uuid, $2::uuid, $3, $4, 1, $5,
|
||||
$6, $7, $8, $9, $10::jsonb, NULLIF($11, '')::uuid,
|
||||
$12, NULL
|
||||
)
|
||||
ON CONFLICT (resource_id) DO UPDATE SET
|
||||
secret_ref = EXCLUDED.secret_ref,
|
||||
protocol = EXCLUDED.protocol,
|
||||
version = resource_secrets.version + 1,
|
||||
key_id = EXCLUDED.key_id,
|
||||
algorithm = EXCLUDED.algorithm,
|
||||
nonce = EXCLUDED.nonce,
|
||||
ciphertext = EXCLUDED.ciphertext,
|
||||
payload_sha256 = EXCLUDED.payload_sha256,
|
||||
metadata = EXCLUDED.metadata,
|
||||
created_by_user_id = EXCLUDED.created_by_user_id,
|
||||
rotated_at = EXCLUDED.created_at
|
||||
RETURNING id::text, organization_id::text, resource_id::text, secret_ref,
|
||||
protocol, version, key_id, algorithm, metadata, created_at, rotated_at
|
||||
`
|
||||
var descriptor ResourceSecretDescriptor
|
||||
if err := s.db.QueryRow(ctx, query,
|
||||
cmd.OrganizationID,
|
||||
cmd.ResourceID,
|
||||
secretRef,
|
||||
protocol,
|
||||
encrypted.KeyID,
|
||||
encrypted.Algorithm,
|
||||
encrypted.Nonce,
|
||||
encrypted.Ciphertext,
|
||||
encrypted.PayloadSHA256,
|
||||
metadata,
|
||||
cmd.ActorUserID,
|
||||
now,
|
||||
).Scan(
|
||||
&descriptor.ID,
|
||||
&descriptor.OrganizationID,
|
||||
&descriptor.ResourceID,
|
||||
&descriptor.SecretRef,
|
||||
&descriptor.Protocol,
|
||||
&descriptor.Version,
|
||||
&descriptor.KeyID,
|
||||
&descriptor.Algorithm,
|
||||
&descriptor.Metadata,
|
||||
&descriptor.CreatedAt,
|
||||
&descriptor.RotatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("upsert resource secret: %w", err)
|
||||
}
|
||||
return &descriptor, nil
|
||||
}
|
||||
|
||||
func (s *ResourceSecretStore) ResolveForSession(ctx context.Context, req ResolveResourceSecretRequest) (*ResolvedResourceSecret, error) {
|
||||
if s == nil || s.encryptor == nil {
|
||||
return nil, ErrSecretEncryptionKeyMissing
|
||||
}
|
||||
if strings.TrimSpace(req.LeaseID) == "" {
|
||||
return nil, ErrSecretLeaseRequired
|
||||
}
|
||||
const query = `
|
||||
SELECT sec.id::text, sec.organization_id::text, sec.resource_id::text, sec.secret_ref,
|
||||
sec.protocol, sec.version, sec.key_id, sec.algorithm, sec.metadata,
|
||||
sec.created_at, sec.rotated_at, sec.nonce, sec.ciphertext,
|
||||
rs.organization_id::text, rs.resource_id::text, COALESCE(rs.worker_id, ''), rs.state
|
||||
FROM resource_secrets sec
|
||||
JOIN remote_sessions rs ON rs.resource_id = sec.resource_id
|
||||
WHERE sec.secret_ref = $1 AND rs.id = $2::uuid
|
||||
`
|
||||
var descriptor ResourceSecretDescriptor
|
||||
var nonce, ciphertext []byte
|
||||
var sessionOrganizationID, sessionResourceID, sessionWorkerID, sessionState string
|
||||
if err := s.db.QueryRow(ctx, query, req.SecretRef, req.SessionID).Scan(
|
||||
&descriptor.ID,
|
||||
&descriptor.OrganizationID,
|
||||
&descriptor.ResourceID,
|
||||
&descriptor.SecretRef,
|
||||
&descriptor.Protocol,
|
||||
&descriptor.Version,
|
||||
&descriptor.KeyID,
|
||||
&descriptor.Algorithm,
|
||||
&descriptor.Metadata,
|
||||
&descriptor.CreatedAt,
|
||||
&descriptor.RotatedAt,
|
||||
&nonce,
|
||||
&ciphertext,
|
||||
&sessionOrganizationID,
|
||||
&sessionResourceID,
|
||||
&sessionWorkerID,
|
||||
&sessionState,
|
||||
); err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, ErrResourceSecretNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("resolve resource secret: %w", err)
|
||||
}
|
||||
if descriptor.OrganizationID != req.OrganizationID ||
|
||||
descriptor.ResourceID != req.ResourceID ||
|
||||
sessionOrganizationID != req.OrganizationID ||
|
||||
sessionResourceID != req.ResourceID ||
|
||||
sessionWorkerID != req.WorkerID ||
|
||||
!secretResolvableSessionState(sessionState) {
|
||||
return nil, ErrSecretAccessDenied
|
||||
}
|
||||
plaintext, err := s.encryptor.Decrypt(EncryptedPayload{
|
||||
Algorithm: descriptor.Algorithm,
|
||||
KeyID: descriptor.KeyID,
|
||||
Nonce: nonce,
|
||||
Ciphertext: ciphertext,
|
||||
}, ResourceSecretAAD(descriptor.OrganizationID, descriptor.ResourceID, descriptor.SecretRef, descriptor.Protocol))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ResolvedResourceSecret{
|
||||
Descriptor: descriptor,
|
||||
Payload: json.RawMessage(plaintext),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeJSONObject(raw json.RawMessage) (json.RawMessage, error) {
|
||||
if len(raw) == 0 || !json.Valid(raw) {
|
||||
return nil, ErrSecretPayloadInvalid
|
||||
}
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(raw, &decoded); err != nil {
|
||||
return nil, ErrSecretPayloadInvalid
|
||||
}
|
||||
encoded, err := json.Marshal(decoded)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.RawMessage(encoded), nil
|
||||
}
|
||||
|
||||
func normalizeJSONObjectAllowEmpty(raw json.RawMessage) (json.RawMessage, error) {
|
||||
if len(raw) == 0 {
|
||||
return json.RawMessage(`{}`), nil
|
||||
}
|
||||
return normalizeJSONObject(raw)
|
||||
}
|
||||
|
||||
func secretResolvableSessionState(state string) bool {
|
||||
switch state {
|
||||
case "starting", "active", "reconnecting":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user