Initial project snapshot

This commit is contained in:
2026-04-28 22:29:50 +03:00
commit 8ba0561f4f
365 changed files with 91832 additions and 0 deletions
@@ -0,0 +1,329 @@
package authority
import (
"context"
"crypto/ed25519"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/hex"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/example/remote-access-platform/backend/internal/platform/config"
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
)
const (
ModeStrict = "strict"
ModeLegacy = "legacy"
ActivationSchemaVersion = "rap.installation.activation.v1"
PlatformRoleUser = "user"
PlatformRoleAdmin = "platform_admin"
PlatformRoleRecoveryAdmin = "platform_recovery_admin"
)
var (
ErrInvalidAuthorityMode = errors.New("invalid installation authority mode")
ErrProductRootKeyNeeded = errors.New("product root public key is required")
ErrInvalidActivation = errors.New("invalid installation activation")
ErrInvalidGrant = errors.New("invalid platform role grant")
)
type ActivationPayload struct {
SchemaVersion string `json:"schema_version"`
InstallID string `json:"install_id"`
OwnerEmail string `json:"owner_email"`
PlatformRole string `json:"platform_role"`
IssuedAt time.Time `json:"issued_at"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
Nonce string `json:"nonce,omitempty"`
Environment string `json:"environment,omitempty"`
}
type Verifier struct {
mode string
rootPublicKey ed25519.PublicKey
rootFingerprint string
allowInsecureBootstrap bool
now func() time.Time
}
func NewVerifier(cfg config.InstallationConfig) (*Verifier, error) {
mode := strings.ToLower(strings.TrimSpace(cfg.AuthorityMode))
if mode == "" {
mode = ModeLegacy
}
verifier := &Verifier{
mode: mode,
allowInsecureBootstrap: cfg.AllowInsecureBootstrap,
now: time.Now,
}
switch mode {
case ModeLegacy:
return verifier, nil
case ModeStrict:
publicKey, err := decodeEd25519PublicKey(cfg.ProductRootPublicKeyBase64)
if err != nil {
return nil, err
}
verifier.rootPublicKey = publicKey
fingerprint := sha256.Sum256(publicKey)
verifier.rootFingerprint = hex.EncodeToString(fingerprint[:])
return verifier, nil
default:
return nil, fmt.Errorf("%w: %s", ErrInvalidAuthorityMode, mode)
}
}
func (v *Verifier) Mode() string {
if v == nil || v.mode == "" {
return ModeLegacy
}
return v.mode
}
func (v *Verifier) Strict() bool {
return v != nil && v.mode == ModeStrict
}
func (v *Verifier) AllowInsecureBootstrap() bool {
return v != nil && v.allowInsecureBootstrap
}
func (v *Verifier) RootFingerprint() string {
if v == nil {
return ""
}
return v.rootFingerprint
}
func (v *Verifier) VerifyActivation(payload json.RawMessage, signature string) (ActivationPayload, error) {
if v == nil || !v.Strict() {
return ActivationPayload{}, ErrProductRootKeyNeeded
}
activation, canonical, err := parseActivationPayload(payload)
if err != nil {
return ActivationPayload{}, err
}
if err := activation.validate(v.now().UTC()); err != nil {
return ActivationPayload{}, err
}
if err := v.verifySignature(canonical, signature); err != nil {
return ActivationPayload{}, fmt.Errorf("%w: %v", ErrInvalidActivation, err)
}
return activation, nil
}
func (v *Verifier) VerifyPlatformRoleGrant(payload json.RawMessage, signature, expectedInstallID, expectedEmail, expectedRole string) (ActivationPayload, error) {
activation, err := v.VerifyActivation(payload, signature)
if err != nil {
return ActivationPayload{}, fmt.Errorf("%w: %v", ErrInvalidGrant, err)
}
if activation.InstallID != strings.TrimSpace(expectedInstallID) {
return ActivationPayload{}, fmt.Errorf("%w: install_id mismatch", ErrInvalidGrant)
}
if !strings.EqualFold(activation.OwnerEmail, strings.TrimSpace(expectedEmail)) {
return ActivationPayload{}, fmt.Errorf("%w: owner_email mismatch", ErrInvalidGrant)
}
if activation.PlatformRole != strings.TrimSpace(expectedRole) {
return ActivationPayload{}, fmt.Errorf("%w: platform_role mismatch", ErrInvalidGrant)
}
return activation, nil
}
func CanonicalJSON(raw json.RawMessage) ([]byte, error) {
if len(raw) == 0 {
return nil, fmt.Errorf("%w: empty payload", ErrInvalidActivation)
}
var value any
if err := json.Unmarshal(raw, &value); err != nil {
return nil, fmt.Errorf("%w: invalid json: %v", ErrInvalidActivation, err)
}
canonical, err := json.Marshal(value)
if err != nil {
return nil, fmt.Errorf("%w: canonical json: %v", ErrInvalidActivation, err)
}
return canonical, nil
}
func EffectivePlatformRole(ctx context.Context, db postgresplatform.DBTX, verifier *Verifier, userID string) (string, error) {
userID = strings.TrimSpace(userID)
if userID == "" {
return PlatformRoleUser, nil
}
if verifier == nil || !verifier.Strict() {
return legacyPlatformRole(ctx, db, userID)
}
var email string
if err := db.QueryRow(ctx, `SELECT email FROM users WHERE id = $1::uuid`, userID).Scan(&email); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return PlatformRoleUser, nil
}
return "", fmt.Errorf("get user email for platform grant: %w", err)
}
rows, err := db.Query(ctx, `
SELECT prg.role, prg.install_id, prg.grant_payload, prg.grant_signature
FROM platform_role_grants prg
JOIN installation_authority ia
ON ia.id = 1
AND ia.install_id = prg.install_id
AND ia.authority_state = 'active'
WHERE prg.user_id = $1::uuid
AND prg.revoked_at IS NULL
AND (prg.expires_at IS NULL OR prg.expires_at > NOW())
ORDER BY CASE prg.role
WHEN 'platform_recovery_admin' THEN 0
WHEN 'platform_admin' THEN 1
ELSE 2
END, prg.granted_at DESC
`, userID)
if err != nil {
return "", fmt.Errorf("query platform role grants: %w", err)
}
defer rows.Close()
bestRole := PlatformRoleUser
for rows.Next() {
var role, installID, signature string
var payload []byte
if err := rows.Scan(&role, &installID, &payload, &signature); err != nil {
return "", fmt.Errorf("scan platform role grant: %w", err)
}
if _, err := verifier.VerifyPlatformRoleGrant(json.RawMessage(payload), signature, installID, email, role); err != nil {
continue
}
if role == PlatformRoleRecoveryAdmin {
return role, nil
}
if role == PlatformRoleAdmin {
bestRole = role
}
}
if err := rows.Err(); err != nil {
return "", fmt.Errorf("iterate platform role grants: %w", err)
}
return bestRole, nil
}
func legacyPlatformRole(ctx context.Context, db postgresplatform.DBTX, userID string) (string, error) {
var role string
if err := db.QueryRow(ctx, `SELECT platform_role FROM users WHERE id = $1::uuid`, userID).Scan(&role); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return PlatformRoleUser, nil
}
return "", fmt.Errorf("get platform role: %w", err)
}
if role == "" {
return PlatformRoleUser, nil
}
return role, nil
}
func parseActivationPayload(raw json.RawMessage) (ActivationPayload, []byte, error) {
canonical, err := CanonicalJSON(raw)
if err != nil {
return ActivationPayload{}, nil, err
}
var activation ActivationPayload
if err := json.Unmarshal(canonical, &activation); err != nil {
return ActivationPayload{}, nil, fmt.Errorf("%w: decode activation: %v", ErrInvalidActivation, err)
}
activation.SchemaVersion = strings.TrimSpace(activation.SchemaVersion)
activation.InstallID = strings.TrimSpace(activation.InstallID)
activation.OwnerEmail = strings.ToLower(strings.TrimSpace(activation.OwnerEmail))
activation.PlatformRole = strings.TrimSpace(activation.PlatformRole)
activation.Nonce = strings.TrimSpace(activation.Nonce)
activation.Environment = strings.TrimSpace(activation.Environment)
return activation, canonical, nil
}
func (p ActivationPayload) validate(now time.Time) error {
if p.SchemaVersion != ActivationSchemaVersion {
return fmt.Errorf("%w: schema_version must be %s", ErrInvalidActivation, ActivationSchemaVersion)
}
if p.InstallID == "" {
return fmt.Errorf("%w: install_id is required", ErrInvalidActivation)
}
if p.OwnerEmail == "" || !strings.Contains(p.OwnerEmail, "@") {
return fmt.Errorf("%w: owner_email is required", ErrInvalidActivation)
}
switch p.PlatformRole {
case PlatformRoleAdmin, PlatformRoleRecoveryAdmin:
default:
return fmt.Errorf("%w: platform_role must be platform_admin or platform_recovery_admin", ErrInvalidActivation)
}
if p.IssuedAt.IsZero() {
return fmt.Errorf("%w: issued_at is required", ErrInvalidActivation)
}
if p.IssuedAt.After(now.Add(5 * time.Minute)) {
return fmt.Errorf("%w: issued_at is too far in the future", ErrInvalidActivation)
}
if p.ExpiresAt != nil && !p.ExpiresAt.After(now) {
return fmt.Errorf("%w: activation expired", ErrInvalidActivation)
}
return nil
}
func (v *Verifier) verifySignature(payload []byte, signatureText string) error {
signature, err := decodeBase64(strings.TrimSpace(signatureText))
if err != nil {
return fmt.Errorf("signature must be base64 encoded: %w", err)
}
if len(signature) != ed25519.SignatureSize {
return fmt.Errorf("signature must decode to %d bytes", ed25519.SignatureSize)
}
if !ed25519.Verify(v.rootPublicKey, payload, signature) {
return errors.New("signature verification failed")
}
return nil
}
func decodeEd25519PublicKey(value string) (ed25519.PublicKey, error) {
value = strings.TrimSpace(value)
if value == "" {
return nil, ErrProductRootKeyNeeded
}
if block, _ := pem.Decode([]byte(value)); block != nil {
parsed, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parse product root public key PEM: %w", err)
}
publicKey, ok := parsed.(ed25519.PublicKey)
if !ok {
return nil, fmt.Errorf("product root public key PEM must contain an Ed25519 public key")
}
return publicKey, nil
}
decoded, err := decodeBase64(value)
if err != nil {
return nil, fmt.Errorf("product root public key must be base64 encoded: %w", err)
}
if len(decoded) != ed25519.PublicKeySize {
return nil, fmt.Errorf("product root public key must decode to %d bytes", ed25519.PublicKeySize)
}
return ed25519.PublicKey(decoded), nil
}
func decodeBase64(value string) ([]byte, error) {
decoded, err := base64.StdEncoding.DecodeString(value)
if err == nil {
return decoded, nil
}
decoded, rawErr := base64.RawStdEncoding.DecodeString(value)
if rawErr == nil {
return decoded, nil
}
return nil, err
}
@@ -0,0 +1,85 @@
package authority
import (
"crypto/ed25519"
"encoding/base64"
"encoding/json"
"strings"
"testing"
"time"
"github.com/example/remote-access-platform/backend/internal/platform/config"
)
func TestVerifierAcceptsSignedActivation(t *testing.T) {
publicKey, privateKey, err := ed25519.GenerateKey(nil)
if err != nil {
t.Fatalf("generate key: %v", err)
}
verifier, err := NewVerifier(config.InstallationConfig{
AuthorityMode: ModeStrict,
ProductRootPublicKeyBase64: base64.StdEncoding.EncodeToString(publicKey),
})
if err != nil {
t.Fatalf("NewVerifier: %v", err)
}
verifier.now = func() time.Time { return time.Date(2026, 4, 28, 12, 0, 0, 0, time.UTC) }
payload := json.RawMessage(`{
"platform_role":"platform_admin",
"owner_email":"Owner@Example.test",
"install_id":"install-1",
"schema_version":"rap.installation.activation.v1",
"issued_at":"2026-04-28T11:00:00Z",
"expires_at":"2026-04-29T11:00:00Z"
}`)
canonical, err := CanonicalJSON(payload)
if err != nil {
t.Fatalf("CanonicalJSON: %v", err)
}
signature := base64.StdEncoding.EncodeToString(ed25519.Sign(privateKey, canonical))
activation, err := verifier.VerifyActivation(payload, signature)
if err != nil {
t.Fatalf("VerifyActivation: %v", err)
}
if activation.OwnerEmail != "owner@example.test" || activation.PlatformRole != PlatformRoleAdmin {
t.Fatalf("unexpected activation: %+v", activation)
}
if verifier.RootFingerprint() == "" {
t.Fatal("expected root fingerprint")
}
}
func TestVerifierRejectsTamperedActivation(t *testing.T) {
publicKey, privateKey, err := ed25519.GenerateKey(nil)
if err != nil {
t.Fatalf("generate key: %v", err)
}
verifier, err := NewVerifier(config.InstallationConfig{
AuthorityMode: ModeStrict,
ProductRootPublicKeyBase64: base64.StdEncoding.EncodeToString(publicKey),
})
if err != nil {
t.Fatalf("NewVerifier: %v", err)
}
verifier.now = func() time.Time { return time.Date(2026, 4, 28, 12, 0, 0, 0, time.UTC) }
payload := json.RawMessage(`{
"schema_version":"rap.installation.activation.v1",
"install_id":"install-1",
"owner_email":"owner@example.test",
"platform_role":"platform_admin",
"issued_at":"2026-04-28T11:00:00Z"
}`)
canonical, err := CanonicalJSON(payload)
if err != nil {
t.Fatalf("CanonicalJSON: %v", err)
}
signature := base64.StdEncoding.EncodeToString(ed25519.Sign(privateKey, canonical))
tampered := json.RawMessage(strings.Replace(string(payload), "platform_admin", "platform_recovery_admin", 1))
if _, err := verifier.VerifyActivation(tampered, signature); err == nil {
t.Fatal("expected tampered activation to fail")
}
}
@@ -0,0 +1,186 @@
package clusterauth
import (
"crypto/ed25519"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
)
const (
AuthoritySchemaVersion = "rap.cluster_authority.v1"
SignatureSchemaVersion = "rap.cluster_authority.signature.v1"
AlgorithmEd25519 = "ed25519"
)
var (
ErrInvalidKey = errors.New("invalid cluster authority key")
ErrInvalidSignature = errors.New("invalid cluster authority signature")
ErrInvalidPayload = errors.New("invalid cluster authority payload")
)
type KeyPair struct {
PublicKeyB64 string
PrivateKeyB64 string
Fingerprint string
}
type Signature struct {
SchemaVersion string `json:"schema_version"`
Algorithm string `json:"algorithm"`
KeyFingerprint string `json:"key_fingerprint"`
Signature string `json:"signature"`
SignedAt time.Time `json:"signed_at"`
}
func GenerateKeyPair() (KeyPair, error) {
publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return KeyPair{}, err
}
fingerprint := Fingerprint(publicKey)
return KeyPair{
PublicKeyB64: base64.StdEncoding.EncodeToString(publicKey),
PrivateKeyB64: base64.StdEncoding.EncodeToString(privateKey),
Fingerprint: fingerprint,
}, nil
}
func Fingerprint(publicKey ed25519.PublicKey) string {
sum := sha256.Sum256(publicKey)
return "rap-ca-ed25519-" + hex.EncodeToString(sum[:16])
}
func FingerprintFromBase64(publicKeyB64 string) (string, error) {
publicKey, err := DecodePublicKey(publicKeyB64)
if err != nil {
return "", err
}
return Fingerprint(publicKey), nil
}
func SignRaw(privateKeyB64 string, payload json.RawMessage, signedAt time.Time) (Signature, error) {
privateKey, err := DecodePrivateKey(privateKeyB64)
if err != nil {
return Signature{}, err
}
canonical, err := CanonicalJSON(payload)
if err != nil {
return Signature{}, err
}
publicKey, ok := privateKey.Public().(ed25519.PublicKey)
if !ok {
return Signature{}, ErrInvalidKey
}
signature := ed25519.Sign(privateKey, canonical)
return Signature{
SchemaVersion: SignatureSchemaVersion,
Algorithm: AlgorithmEd25519,
KeyFingerprint: Fingerprint(publicKey),
Signature: base64.StdEncoding.EncodeToString(signature),
SignedAt: signedAt.UTC(),
}, nil
}
func SignPayload(privateKeyB64 string, payload any, signedAt time.Time) (json.RawMessage, Signature, error) {
raw, err := json.Marshal(payload)
if err != nil {
return nil, Signature{}, fmt.Errorf("%w: marshal: %v", ErrInvalidPayload, err)
}
signature, err := SignRaw(privateKeyB64, raw, signedAt)
if err != nil {
return nil, Signature{}, err
}
return json.RawMessage(raw), signature, nil
}
func VerifyRaw(publicKeyB64 string, payload json.RawMessage, signature Signature) error {
if signature.SchemaVersion != SignatureSchemaVersion {
return fmt.Errorf("%w: schema_version must be %s", ErrInvalidSignature, SignatureSchemaVersion)
}
if signature.Algorithm != AlgorithmEd25519 {
return fmt.Errorf("%w: algorithm must be %s", ErrInvalidSignature, AlgorithmEd25519)
}
publicKey, err := DecodePublicKey(publicKeyB64)
if err != nil {
return err
}
if signature.KeyFingerprint != Fingerprint(publicKey) {
return fmt.Errorf("%w: key fingerprint mismatch", ErrInvalidSignature)
}
canonical, err := CanonicalJSON(payload)
if err != nil {
return err
}
decodedSignature, err := decodeBase64(strings.TrimSpace(signature.Signature))
if err != nil || len(decodedSignature) != ed25519.SignatureSize {
return fmt.Errorf("%w: signature must be base64 ed25519 signature", ErrInvalidSignature)
}
if !ed25519.Verify(publicKey, canonical, decodedSignature) {
return ErrInvalidSignature
}
return nil
}
func CanonicalJSON(raw json.RawMessage) ([]byte, error) {
if len(raw) == 0 {
return nil, fmt.Errorf("%w: empty payload", ErrInvalidPayload)
}
var value any
if err := json.Unmarshal(raw, &value); err != nil {
return nil, fmt.Errorf("%w: invalid json: %v", ErrInvalidPayload, err)
}
canonical, err := json.Marshal(value)
if err != nil {
return nil, fmt.Errorf("%w: canonical json: %v", ErrInvalidPayload, err)
}
return canonical, nil
}
func HashRaw(raw json.RawMessage) (string, error) {
canonical, err := CanonicalJSON(raw)
if err != nil {
return "", err
}
sum := sha256.Sum256(canonical)
return hex.EncodeToString(sum[:]), nil
}
func DecodePublicKey(value string) (ed25519.PublicKey, error) {
decoded, err := decodeBase64(strings.TrimSpace(value))
if err != nil {
return nil, fmt.Errorf("%w: public key must be base64 encoded", ErrInvalidKey)
}
if len(decoded) != ed25519.PublicKeySize {
return nil, fmt.Errorf("%w: public key must decode to %d bytes", ErrInvalidKey, ed25519.PublicKeySize)
}
return ed25519.PublicKey(decoded), nil
}
func DecodePrivateKey(value string) (ed25519.PrivateKey, error) {
decoded, err := decodeBase64(strings.TrimSpace(value))
if err != nil {
return nil, fmt.Errorf("%w: private key must be base64 encoded", ErrInvalidKey)
}
if len(decoded) != ed25519.PrivateKeySize {
return nil, fmt.Errorf("%w: private key must decode to %d bytes", ErrInvalidKey, ed25519.PrivateKeySize)
}
return ed25519.PrivateKey(decoded), nil
}
func decodeBase64(value string) ([]byte, error) {
if value == "" {
return nil, errors.New("empty base64 value")
}
decoded, err := base64.StdEncoding.DecodeString(value)
if err == nil {
return decoded, nil
}
return base64.RawStdEncoding.DecodeString(value)
}
@@ -0,0 +1,44 @@
package clusterauth
import (
"encoding/json"
"errors"
"testing"
"time"
)
func TestSignAndVerifyRawPayload(t *testing.T) {
keys, err := GenerateKeyPair()
if err != nil {
t.Fatalf("GenerateKeyPair: %v", err)
}
payload := json.RawMessage(`{"cluster_id":"cluster-1","schema_version":"test.v1","value":1}`)
signature, err := SignRaw(keys.PrivateKeyB64, payload, time.Date(2026, 4, 28, 12, 0, 0, 0, time.UTC))
if err != nil {
t.Fatalf("SignRaw: %v", err)
}
if signature.KeyFingerprint != keys.Fingerprint {
t.Fatalf("fingerprint = %q, want %q", signature.KeyFingerprint, keys.Fingerprint)
}
if err := VerifyRaw(keys.PublicKeyB64, payload, signature); err != nil {
t.Fatalf("VerifyRaw: %v", err)
}
}
func TestVerifyRawRejectsTamperedPayload(t *testing.T) {
keys, err := GenerateKeyPair()
if err != nil {
t.Fatalf("GenerateKeyPair: %v", err)
}
payload := json.RawMessage(`{"cluster_id":"cluster-1","schema_version":"test.v1","value":1}`)
signature, err := SignRaw(keys.PrivateKeyB64, payload, time.Date(2026, 4, 28, 12, 0, 0, 0, time.UTC))
if err != nil {
t.Fatalf("SignRaw: %v", err)
}
tampered := json.RawMessage(`{"cluster_id":"cluster-1","schema_version":"test.v1","value":2}`)
if err := VerifyRaw(keys.PublicKeyB64, tampered, signature); !errors.Is(err, ErrInvalidSignature) {
t.Fatalf("err = %v, want ErrInvalidSignature", err)
}
}
+307
View File
@@ -0,0 +1,307 @@
package config
import (
"encoding/base64"
"fmt"
"os"
"strconv"
"strings"
"time"
)
type Config struct {
App AppConfig
HTTP HTTPConfig
Postgres PostgresConfig
Redis RedisConfig
Auth AuthConfig
Installation InstallationConfig
DataPlane DataPlaneConfig
Secret SecretConfig
Session SessionConfig
Worker WorkerConfig
WebSocket WebSocketConfig
}
type AppConfig struct {
Name string
Env string
}
type HTTPConfig struct {
Host string
Port int
ReadTimeout time.Duration
WriteTimeout time.Duration
IdleTimeout time.Duration
ShutdownTimeout time.Duration
}
type PostgresConfig struct {
DSN string
MaxConns int32
MinConns int32
ConnectTimeout time.Duration
}
type RedisConfig struct {
Addr string
Password string
DB int
DialTimeout time.Duration
}
type AuthConfig struct {
AccessTokenTTL time.Duration
RefreshTokenTTL time.Duration
Issuer string
AccessTokenSecret string
RefreshHashSecret string
}
type InstallationConfig struct {
AuthorityMode string
ProductRootPublicKeyBase64 string
ProductRootPublicKeyFile string
AllowInsecureBootstrap bool
}
type DataPlaneConfig struct {
TokenTTL time.Duration
TokenPrivateKeyPEM string
TokenPrivateKeyFile string
BackendGatewayURL string
DirectWorkerWSSURLTemplate string
DirectWorkerJSONRuntime bool
DirectWorkerBinaryRender bool
DirectWorkerTLSTrustMode string
DirectWorkerTLSCARef string
}
type SecretConfig struct {
EncryptionKeyBase64 string
EncryptionKeyFile string
EncryptionKeyID string
}
type SessionConfig struct {
HeartbeatTTL time.Duration
DetachGracePeriod time.Duration
AttachTokenTTL time.Duration
LiveStateTTL time.Duration
RecoveryBatchSize int
}
type WorkerConfig struct {
LeaseTTL time.Duration
HeartbeatTTL time.Duration
StaleLeaseGracePeriod time.Duration
}
type WebSocketConfig struct {
WriteTimeout time.Duration
PingInterval time.Duration
PongWait time.Duration
}
func Load() (Config, error) {
cfg := Config{
App: AppConfig{
Name: getEnv("APP_NAME", "rap-api"),
Env: getEnv("APP_ENV", "development"),
},
HTTP: HTTPConfig{
Host: getEnv("HTTP_HOST", "0.0.0.0"),
Port: getInt("HTTP_PORT", 8080),
ReadTimeout: getDuration("HTTP_READ_TIMEOUT", 15*time.Second),
WriteTimeout: getDuration("HTTP_WRITE_TIMEOUT", 15*time.Second),
IdleTimeout: getDuration("HTTP_IDLE_TIMEOUT", 60*time.Second),
ShutdownTimeout: getDuration("HTTP_SHUTDOWN_TIMEOUT", 10*time.Second),
},
Postgres: PostgresConfig{
DSN: getEnv("POSTGRES_DSN", ""),
MaxConns: int32(getInt("POSTGRES_MAX_CONNS", 20)),
MinConns: int32(getInt("POSTGRES_MIN_CONNS", 2)),
ConnectTimeout: getDuration("POSTGRES_CONNECT_TIMEOUT", 5*time.Second),
},
Redis: RedisConfig{
Addr: getEnv("REDIS_ADDR", "localhost:6379"),
Password: getEnv("REDIS_PASSWORD", ""),
DB: getInt("REDIS_DB", 0),
DialTimeout: getDuration("REDIS_DIAL_TIMEOUT", 5*time.Second),
},
Auth: AuthConfig{
AccessTokenTTL: getDuration("AUTH_ACCESS_TOKEN_TTL", 15*time.Minute),
RefreshTokenTTL: getDuration("AUTH_REFRESH_TOKEN_TTL", 30*24*time.Hour),
Issuer: getEnv("AUTH_ISSUER", "rap-api"),
AccessTokenSecret: getEnv("AUTH_ACCESS_TOKEN_SECRET", ""),
RefreshHashSecret: getEnv("AUTH_REFRESH_HASH_SECRET", ""),
},
Installation: InstallationConfig{
AuthorityMode: getEnv("INSTALLATION_AUTHORITY_MODE", ""),
ProductRootPublicKeyBase64: getEnv("INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_B64", ""),
ProductRootPublicKeyFile: getEnv("INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_FILE", ""),
AllowInsecureBootstrap: getBool("INSTALLATION_INSECURE_BOOTSTRAP_ENABLED", false),
},
DataPlane: DataPlaneConfig{
TokenTTL: getDuration("DATA_PLANE_TOKEN_TTL", 1*time.Minute),
TokenPrivateKeyPEM: getEnv("DATA_PLANE_TOKEN_PRIVATE_KEY_PEM", ""),
TokenPrivateKeyFile: getEnv("DATA_PLANE_TOKEN_PRIVATE_KEY_FILE", ""),
BackendGatewayURL: getEnv("DATA_PLANE_BACKEND_GATEWAY_URL", "/api/v1/gateway/ws"),
DirectWorkerWSSURLTemplate: getEnv("DATA_PLANE_DIRECT_WORKER_WSS_URL_TEMPLATE", ""),
DirectWorkerJSONRuntime: getBool("DATA_PLANE_DIRECT_WORKER_JSON_RUNTIME", false),
DirectWorkerBinaryRender: getBool("DATA_PLANE_DIRECT_WORKER_BINARY_RENDER", false),
DirectWorkerTLSTrustMode: getEnv("DATA_PLANE_DIRECT_WORKER_TLS_TRUST_MODE", "smoke_insecure"),
DirectWorkerTLSCARef: getEnv("DATA_PLANE_DIRECT_WORKER_TLS_CA_REF", ""),
},
Secret: SecretConfig{
EncryptionKeyBase64: getEnv("SECRET_ENCRYPTION_KEY_B64", ""),
EncryptionKeyFile: getEnv("SECRET_ENCRYPTION_KEY_FILE", ""),
EncryptionKeyID: getEnv("SECRET_ENCRYPTION_KEY_ID", "local-v1"),
},
Session: SessionConfig{
HeartbeatTTL: getDuration("SESSION_HEARTBEAT_TTL", 90*time.Second),
DetachGracePeriod: getDuration("SESSION_DETACH_GRACE_PERIOD", 30*time.Minute),
AttachTokenTTL: getDuration("SESSION_ATTACH_TOKEN_TTL", 2*time.Minute),
LiveStateTTL: getDuration("SESSION_LIVE_STATE_TTL", 2*time.Minute),
RecoveryBatchSize: getInt("SESSION_RECOVERY_BATCH_SIZE", 100),
},
Worker: WorkerConfig{
LeaseTTL: getDuration("WORKER_LEASE_TTL", 45*time.Second),
HeartbeatTTL: getDuration("WORKER_HEARTBEAT_TTL", 15*time.Second),
StaleLeaseGracePeriod: getDuration("WORKER_STALE_LEASE_GRACE_PERIOD", 30*time.Second),
},
WebSocket: WebSocketConfig{
WriteTimeout: getDuration("WEBSOCKET_WRITE_TIMEOUT", 10*time.Second),
PingInterval: getDuration("WEBSOCKET_PING_INTERVAL", 20*time.Second),
PongWait: getDuration("WEBSOCKET_PONG_WAIT", 40*time.Second),
},
}
if cfg.Postgres.DSN == "" {
return Config{}, fmt.Errorf("POSTGRES_DSN is required")
}
if cfg.Auth.AccessTokenSecret == "" {
return Config{}, fmt.Errorf("AUTH_ACCESS_TOKEN_SECRET is required")
}
if cfg.Auth.RefreshHashSecret == "" {
return Config{}, fmt.Errorf("AUTH_REFRESH_HASH_SECRET is required")
}
if cfg.Installation.ProductRootPublicKeyBase64 == "" && cfg.Installation.ProductRootPublicKeyFile != "" {
publicKey, err := os.ReadFile(cfg.Installation.ProductRootPublicKeyFile)
if err != nil {
return Config{}, fmt.Errorf("read INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_FILE: %w", err)
}
cfg.Installation.ProductRootPublicKeyBase64 = strings.TrimSpace(string(publicKey))
}
cfg.Installation.AuthorityMode = normalizeInstallationAuthorityMode(cfg.Installation.AuthorityMode, cfg.Installation.ProductRootPublicKeyBase64)
if isProductionEnv(cfg.App.Env) && cfg.Installation.AuthorityMode != "strict" {
return Config{}, fmt.Errorf("INSTALLATION_AUTHORITY_MODE=strict with INSTALLATION_PRODUCT_ROOT_PUBLIC_KEY_B64 or file is required in production")
}
if cfg.DataPlane.TokenPrivateKeyPEM == "" && cfg.DataPlane.TokenPrivateKeyFile != "" {
privateKey, err := os.ReadFile(cfg.DataPlane.TokenPrivateKeyFile)
if err != nil {
return Config{}, fmt.Errorf("read DATA_PLANE_TOKEN_PRIVATE_KEY_FILE: %w", err)
}
cfg.DataPlane.TokenPrivateKeyPEM = string(privateKey)
}
if cfg.Secret.EncryptionKeyBase64 == "" && cfg.Secret.EncryptionKeyFile != "" {
secretKey, err := os.ReadFile(cfg.Secret.EncryptionKeyFile)
if err != nil {
return Config{}, fmt.Errorf("read SECRET_ENCRYPTION_KEY_FILE: %w", err)
}
cfg.Secret.EncryptionKeyBase64 = strings.TrimSpace(string(secretKey))
}
if cfg.Secret.EncryptionKeyBase64 != "" {
decoded, err := base64.StdEncoding.DecodeString(cfg.Secret.EncryptionKeyBase64)
if err != nil {
if decodedRaw, rawErr := base64.RawStdEncoding.DecodeString(cfg.Secret.EncryptionKeyBase64); rawErr == nil {
decoded = decodedRaw
} else {
return Config{}, fmt.Errorf("SECRET_ENCRYPTION_KEY_B64 must be base64 encoded: %w", err)
}
}
if len(decoded) != 32 {
return Config{}, fmt.Errorf("SECRET_ENCRYPTION_KEY_B64 must decode to 32 bytes for AES-256-GCM")
}
}
if isProductionEnv(cfg.App.Env) && cfg.Secret.EncryptionKeyBase64 == "" {
return Config{}, fmt.Errorf("SECRET_ENCRYPTION_KEY_B64 or SECRET_ENCRYPTION_KEY_FILE is required in production")
}
return cfg, nil
}
func normalizeInstallationAuthorityMode(mode string, rootPublicKey string) string {
mode = strings.ToLower(strings.TrimSpace(mode))
switch mode {
case "strict", "legacy":
return mode
case "":
if strings.TrimSpace(rootPublicKey) != "" {
return "strict"
}
return "legacy"
default:
return mode
}
}
func isProductionEnv(appEnv string) bool {
switch strings.ToLower(strings.TrimSpace(appEnv)) {
case "production", "prod":
return true
default:
return false
}
}
func getEnv(key, fallback string) string {
if value := os.Getenv(key); value != "" {
return value
}
return fallback
}
func getInt(key string, fallback int) int {
value := os.Getenv(key)
if value == "" {
return fallback
}
parsed, err := strconv.Atoi(value)
if err != nil {
return fallback
}
return parsed
}
func getBool(key string, fallback bool) bool {
value := os.Getenv(key)
if value == "" {
return fallback
}
switch value {
case "1", "true", "TRUE", "yes", "on":
return true
case "0", "false", "FALSE", "no", "off":
return false
default:
return fallback
}
}
func getDuration(key string, fallback time.Duration) time.Duration {
value := os.Getenv(key)
if value == "" {
return fallback
}
parsed, err := time.ParseDuration(value)
if err != nil {
return fallback
}
return parsed
}
@@ -0,0 +1,20 @@
package httpserver
import (
"fmt"
"net/http"
"time"
"github.com/example/remote-access-platform/backend/internal/platform/config"
)
func New(cfg config.HTTPConfig, handler http.Handler) *http.Server {
return &http.Server{
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
Handler: handler,
ReadHeaderTimeout: 5 * time.Second,
ReadTimeout: cfg.ReadTimeout,
WriteTimeout: cfg.WriteTimeout,
IdleTimeout: cfg.IdleTimeout,
}
}
+45
View File
@@ -0,0 +1,45 @@
package httpx
import (
"encoding/json"
"net/http"
)
func WriteJSON(w http.ResponseWriter, status int, payload any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(payload)
}
func WriteError(w http.ResponseWriter, status int, message string) {
traceID := ensureTraceID(w)
WriteJSON(w, status, ErrorResponse{
Error: NewErrorMessage(status, message, nil, traceID),
})
}
func WriteErrorMessage(w http.ResponseWriter, status int, message any) {
traceID := ensureTraceID(w)
switch payload := message.(type) {
case string:
WriteJSON(w, status, ErrorResponse{
Error: NewErrorMessage(status, payload, nil, traceID),
})
case ErrorResponse:
payload.Error.TraceID = traceID
WriteJSON(w, status, payload)
case *ErrorResponse:
if payload == nil {
WriteJSON(w, status, ErrorResponse{
Error: NewErrorMessage(status, "", nil, traceID),
})
return
}
payload.Error.TraceID = traceID
WriteJSON(w, status, payload)
default:
WriteJSON(w, status, ErrorResponse{
Error: NewErrorMessage(status, "Request failed.", nil, traceID),
})
}
}
+131
View File
@@ -0,0 +1,131 @@
package httpx
import (
"net/http"
"strings"
"unicode"
"github.com/google/uuid"
messagecontracts "github.com/example/remote-access-platform/backend/pkg/contracts/message"
)
type ErrorResponse struct {
Error messagecontracts.Message `json:"error"`
}
func NewMessage(code, messageKey, fallbackMessage string, details map[string]any, traceID string) messagecontracts.Message {
if traceID == "" {
traceID = uuid.NewString()
}
if details == nil {
details = map[string]any{}
}
return messagecontracts.Message{
Code: code,
MessageKey: messageKey,
FallbackMessage: fallbackMessage,
Details: details,
TraceID: traceID,
}
}
func NewErrorMessage(status int, fallbackMessage string, details map[string]any, traceID string) messagecontracts.Message {
normalizedFallback, normalizedDetails := normalizeErrorFallback(status, fallbackMessage, details)
code := deriveErrorCode(status, normalizedFallback)
return NewMessage(code, "errors."+code, normalizedFallback, normalizedDetails, traceID)
}
func ensureTraceID(w http.ResponseWriter) string {
traceID := w.Header().Get("X-Trace-Id")
if traceID == "" {
traceID = uuid.NewString()
w.Header().Set("X-Trace-Id", traceID)
}
return traceID
}
func normalizeErrorFallback(status int, fallbackMessage string, details map[string]any) (string, map[string]any) {
if details == nil {
details = map[string]any{}
}
details["http_status"] = status
if status >= http.StatusInternalServerError {
return "An internal server error occurred.", details
}
trimmed := strings.TrimSpace(fallbackMessage)
switch strings.ToLower(trimmed) {
case "forbidden", "access denied":
return "Access denied.", details
}
if field, ok := extractRequiredField(trimmed); ok {
details["field"] = field
}
return trimmed, details
}
func deriveErrorCode(status int, fallbackMessage string) string {
switch strings.ToLower(strings.TrimSpace(fallbackMessage)) {
case "invalid credentials":
return "auth.invalid_credentials"
case "session expired. please sign in again.":
return "auth.session_expired"
case "access denied.":
return "common.access_denied"
}
statusPrefix := map[int]string{
http.StatusBadRequest: "bad_request",
http.StatusUnauthorized: "unauthorized",
http.StatusForbidden: "forbidden",
http.StatusNotFound: "not_found",
http.StatusConflict: "conflict",
http.StatusUnprocessableEntity: "unprocessable_entity",
http.StatusInternalServerError: "internal_server_error",
}[status]
if statusPrefix == "" {
statusPrefix = "http_" + strings.ReplaceAll(http.StatusText(status), " ", "_")
statusPrefix = strings.ToLower(statusPrefix)
}
slug := slugifyMessage(fallbackMessage)
if slug == "" {
slug = "message"
}
if status >= http.StatusInternalServerError {
return "common." + statusPrefix
}
return statusPrefix + "." + slug
}
func slugifyMessage(input string) string {
var builder strings.Builder
lastUnderscore := false
for _, r := range strings.ToLower(strings.TrimSpace(input)) {
if unicode.IsLetter(r) || unicode.IsDigit(r) {
builder.WriteRune(r)
lastUnderscore = false
continue
}
if !lastUnderscore {
builder.WriteRune('_')
lastUnderscore = true
}
}
return strings.Trim(builder.String(), "_")
}
func extractRequiredField(message string) (string, bool) {
const suffix = " is required"
if !strings.HasSuffix(strings.ToLower(message), suffix) {
return "", false
}
field := strings.TrimSpace(message[:len(message)-len(suffix)])
field = strings.ReplaceAll(field, " ", "_")
field = strings.ToLower(field)
return field, field != ""
}
@@ -0,0 +1,17 @@
package logging
import (
"log/slog"
"os"
)
func New(env string) *slog.Logger {
level := slog.LevelInfo
if env == "development" {
level = slog.LevelDebug
}
return slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
Level: level,
}))
}
@@ -0,0 +1,38 @@
package module
import (
"log/slog"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/redis/go-redis/v9"
"github.com/example/remote-access-platform/backend/internal/platform/config"
)
type Dependencies struct {
Config Config
Infra Infra
}
type Config struct {
App config.AppConfig
Auth config.AuthConfig
Installation config.InstallationConfig
DataPlane config.DataPlaneConfig
Secret config.SecretConfig
Session config.SessionConfig
Worker config.WorkerConfig
WebSocket config.WebSocketConfig
}
type Infra struct {
Logger *slog.Logger
DB *pgxpool.Pool
Redis *redis.Client
}
type Module interface {
Name() string
RegisterRoutes(router chi.Router)
}
@@ -0,0 +1,33 @@
package postgres
import (
"context"
"fmt"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/example/remote-access-platform/backend/internal/platform/config"
)
func Open(ctx context.Context, cfg config.PostgresConfig) (*pgxpool.Pool, error) {
poolConfig, err := pgxpool.ParseConfig(cfg.DSN)
if err != nil {
return nil, fmt.Errorf("parse postgres dsn: %w", err)
}
poolConfig.MaxConns = cfg.MaxConns
poolConfig.MinConns = cfg.MinConns
poolConfig.ConnConfig.ConnectTimeout = cfg.ConnectTimeout
pool, err := pgxpool.NewWithConfig(ctx, poolConfig)
if err != nil {
return nil, fmt.Errorf("create postgres pool: %w", err)
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
return nil, fmt.Errorf("ping postgres: %w", err)
}
return pool, nil
}
+34
View File
@@ -0,0 +1,34 @@
package postgres
import (
"context"
"fmt"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
)
type DBTX interface {
Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
}
func WithTransaction(ctx context.Context, pool *pgxpool.Pool, fn func(tx pgx.Tx) error) error {
tx, err := pool.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
return fmt.Errorf("begin transaction: %w", err)
}
if err := fn(tx); err != nil {
_ = tx.Rollback(ctx)
return err
}
if err := tx.Commit(ctx); err != nil {
return fmt.Errorf("commit transaction: %w", err)
}
return nil
}
+26
View File
@@ -0,0 +1,26 @@
package redis
import (
"context"
"fmt"
goredis "github.com/redis/go-redis/v9"
"github.com/example/remote-access-platform/backend/internal/platform/config"
)
func Open(ctx context.Context, cfg config.RedisConfig) (*goredis.Client, error) {
client := goredis.NewClient(&goredis.Options{
Addr: cfg.Addr,
Password: cfg.Password,
DB: cfg.DB,
DialTimeout: cfg.DialTimeout,
})
if err := client.Ping(ctx).Err(); err != nil {
_ = client.Close()
return nil, fmt.Errorf("ping redis: %w", err)
}
return client, nil
}
+220
View File
@@ -0,0 +1,220 @@
package runtime
import (
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"time"
"github.com/go-chi/chi/v5"
chimiddleware "github.com/go-chi/chi/v5/middleware"
"github.com/example/remote-access-platform/backend/internal/modules/auth"
"github.com/example/remote-access-platform/backend/internal/modules/cluster"
"github.com/example/remote-access-platform/backend/internal/modules/identitysource"
"github.com/example/remote-access-platform/backend/internal/modules/node"
"github.com/example/remote-access-platform/backend/internal/modules/nodeagent"
"github.com/example/remote-access-platform/backend/internal/modules/organization"
"github.com/example/remote-access-platform/backend/internal/modules/resource"
"github.com/example/remote-access-platform/backend/internal/modules/sessionbroker"
"github.com/example/remote-access-platform/backend/internal/modules/sessiongateway"
"github.com/example/remote-access-platform/backend/internal/modules/worker"
"github.com/example/remote-access-platform/backend/internal/platform/authority"
"github.com/example/remote-access-platform/backend/internal/platform/config"
"github.com/example/remote-access-platform/backend/internal/platform/httpserver"
"github.com/example/remote-access-platform/backend/internal/platform/logging"
"github.com/example/remote-access-platform/backend/internal/platform/module"
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
redisplatform "github.com/example/remote-access-platform/backend/internal/platform/redis"
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
)
type App struct {
cfg config.Config
logger *slog.Logger
httpServer *http.Server
workers []backgroundRunner
db closeFunc
redis closeFunc
}
type closeFunc func() error
type backgroundRunner func(context.Context) error
func NewApp(ctx context.Context) (*App, error) {
cfg, err := config.Load()
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
logger := logging.New(cfg.App.Env)
db, err := postgresplatform.Open(ctx, cfg.Postgres)
if err != nil {
return nil, err
}
redisClient, err := redisplatform.Open(ctx, cfg.Redis)
if err != nil {
db.Close()
return nil, err
}
authorityVerifier, err := authority.NewVerifier(cfg.Installation)
if err != nil {
redisClient.Close()
db.Close()
return nil, fmt.Errorf("create installation authority verifier: %w", err)
}
deps := module.Dependencies{
Config: module.Config{
App: cfg.App,
Auth: cfg.Auth,
Installation: cfg.Installation,
DataPlane: cfg.DataPlane,
Secret: cfg.Secret,
Session: cfg.Session,
Worker: cfg.Worker,
WebSocket: cfg.WebSocket,
},
Infra: module.Infra{
Logger: logger,
DB: db,
Redis: redisClient,
},
}
workerStore := worker.NewRedisStore(redisClient)
workerService := worker.NewService(deps, workerStore)
authStore := auth.NewPostgresStore(db)
authTx := auth.NewPostgresTransactor(db)
authService := auth.NewService(deps, authStore, authTx, authorityVerifier)
var resourceSecretStore *secrets.ResourceSecretStore
if cfg.Secret.EncryptionKeyBase64 != "" {
secretEncryptor, err := secrets.NewEncryptor(cfg.Secret.EncryptionKeyBase64, cfg.Secret.EncryptionKeyID)
if err != nil {
redisClient.Close()
db.Close()
return nil, fmt.Errorf("create resource secret encryptor: %w", err)
}
resourceSecretStore = secrets.NewResourceSecretStore(db, secretEncryptor)
}
brokerStore := sessionbroker.NewPostgresStore(db, authorityVerifier)
brokerTx := sessionbroker.NewPostgresTransactor(db, authorityVerifier)
liveStateStore := sessionbroker.NewRedisLiveStateStore(redisClient)
brokerService := sessionbroker.NewService(deps, brokerStore, brokerTx, liveStateStore, workerService, resourceSecretStore)
workerEvents := worker.NewEventProcessor(redisClient, brokerService)
leaseMonitor := worker.NewLeaseMonitor(workerService, brokerService, cfg.Worker.StaleLeaseGracePeriod)
brokerModule := sessionbroker.NewModule(brokerService)
authModule := auth.NewModule(deps, authService)
clusterModule := cluster.NewModule(deps, authorityVerifier)
organizationModule := organization.NewModule(deps)
identitySourceModule := identitysource.NewModule(deps)
resourceModule := resource.NewModule(deps, resourceSecretStore)
nodeModule := node.NewModule(deps)
nodeAgentModule := nodeagent.NewModule(deps)
sessionGatewayModule := sessiongateway.NewModule(deps, brokerModule.Service(), workerService)
router := buildRouter(
logger,
authModule,
clusterModule,
organizationModule,
identitySourceModule,
resourceModule,
brokerModule,
nodeModule,
nodeAgentModule,
sessionGatewayModule,
)
return &App{
cfg: cfg,
logger: logger,
httpServer: httpserver.New(cfg.HTTP, router),
workers: []backgroundRunner{workerEvents.Run, leaseMonitor.Run},
db: func() error {
db.Close()
return nil
},
redis: redisClient.Close,
}, nil
}
func (a *App) Run(ctx context.Context) error {
errCh := make(chan error, 1)
go func() {
a.logger.Info("http server starting", "addr", a.httpServer.Addr, "service", a.cfg.App.Name)
if err := a.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
errCh <- err
return
}
errCh <- nil
}()
for _, runner := range a.workers {
runner := runner
go func() {
if err := runner(ctx); err != nil {
errCh <- err
}
}()
}
select {
case <-ctx.Done():
a.logger.Info("shutdown signal received")
case err := <-errCh:
if err != nil {
return err
}
return nil
}
shutdownCtx, cancel := context.WithTimeout(context.Background(), a.cfg.HTTP.ShutdownTimeout)
defer cancel()
if err := a.httpServer.Shutdown(shutdownCtx); err != nil {
return fmt.Errorf("shutdown http server: %w", err)
}
if err := a.redis(); err != nil {
return fmt.Errorf("close redis: %w", err)
}
if err := a.db(); err != nil {
return fmt.Errorf("close postgres: %w", err)
}
a.logger.Info("app stopped", "at", time.Now().UTC())
return nil
}
func buildRouter(logger *slog.Logger, modules ...module.Module) http.Handler {
router := chi.NewRouter()
router.Use(chimiddleware.RequestID)
router.Use(chimiddleware.RealIP)
router.Use(chimiddleware.Recoverer)
router.Use(chimiddleware.Timeout(60 * time.Second))
router.Use(chimiddleware.Heartbeat("/healthz"))
router.Get("/readyz", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ready"))
})
router.Route("/api/v1", func(r chi.Router) {
for _, mod := range modules {
logger.Info("register module routes", "module", mod.Name())
mod.RegisterRoutes(r)
}
})
return router
}
@@ -0,0 +1,37 @@
package secrets
import (
"encoding/json"
"fmt"
)
type AssignmentSecretMergeResult struct {
Metadata map[string]any
Keys []string
}
func MergeResourceSecretIntoAssignmentMetadata(metadata map[string]any, payload json.RawMessage) (AssignmentSecretMergeResult, error) {
if metadata == nil {
metadata = map[string]any{}
}
var secretPayload map[string]any
if err := json.Unmarshal(payload, &secretPayload); err != nil {
return AssignmentSecretMergeResult{}, fmt.Errorf("decode resolved resource secret: %w", err)
}
resource, _ := metadata["resource"].(map[string]any)
if resource == nil {
resource = map[string]any{}
metadata["resource"] = resource
}
resourceMetadata, _ := resource["metadata"].(map[string]any)
if resourceMetadata == nil {
resourceMetadata = map[string]any{}
resource["metadata"] = resourceMetadata
}
keys := make([]string, 0, len(secretPayload))
for key, value := range secretPayload {
resourceMetadata[key] = value
keys = append(keys, key)
}
return AssignmentSecretMergeResult{Metadata: metadata, Keys: keys}, nil
}
@@ -0,0 +1,113 @@
package secrets
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"io"
"strings"
)
const AlgorithmAES256GCM = "AES-256-GCM"
var (
ErrSecretEncryptionKeyMissing = errors.New("secret encryption key is not configured")
ErrSecretPayloadInvalid = errors.New("secret payload must be a json object")
)
type Encryptor struct {
aead cipher.AEAD
keyID string
}
type EncryptedPayload struct {
Algorithm string
KeyID string
Nonce []byte
Ciphertext []byte
PayloadSHA256 string
}
func NewEncryptor(masterKeyBase64, keyID string) (*Encryptor, error) {
masterKeyBase64 = strings.TrimSpace(masterKeyBase64)
if masterKeyBase64 == "" {
return nil, ErrSecretEncryptionKeyMissing
}
key, err := base64.StdEncoding.DecodeString(masterKeyBase64)
if err != nil {
if rawKey, rawErr := base64.RawStdEncoding.DecodeString(masterKeyBase64); rawErr == nil {
key = rawKey
} else {
return nil, fmt.Errorf("decode secret encryption key: %w", err)
}
}
if len(key) != 32 {
return nil, fmt.Errorf("secret encryption key must decode to 32 bytes")
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("create secret cipher: %w", err)
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("create secret gcm: %w", err)
}
if strings.TrimSpace(keyID) == "" {
keyID = "local-v1"
}
return &Encryptor{aead: aead, keyID: keyID}, nil
}
func (e *Encryptor) KeyID() string {
if e == nil {
return ""
}
return e.keyID
}
func (e *Encryptor) Encrypt(plaintext, aad []byte) (EncryptedPayload, error) {
if e == nil {
return EncryptedPayload{}, ErrSecretEncryptionKeyMissing
}
nonce := make([]byte, e.aead.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return EncryptedPayload{}, fmt.Errorf("generate secret nonce: %w", err)
}
hash := sha256.Sum256(plaintext)
return EncryptedPayload{
Algorithm: AlgorithmAES256GCM,
KeyID: e.keyID,
Nonce: nonce,
Ciphertext: e.aead.Seal(nil, nonce, plaintext, aad),
PayloadSHA256: hex.EncodeToString(hash[:]),
}, nil
}
func (e *Encryptor) Decrypt(payload EncryptedPayload, aad []byte) ([]byte, error) {
if e == nil {
return nil, ErrSecretEncryptionKeyMissing
}
if payload.Algorithm != "" && payload.Algorithm != AlgorithmAES256GCM {
return nil, fmt.Errorf("unsupported secret algorithm %q", payload.Algorithm)
}
plaintext, err := e.aead.Open(nil, payload.Nonce, payload.Ciphertext, aad)
if err != nil {
return nil, fmt.Errorf("decrypt secret payload: %w", err)
}
return plaintext, nil
}
func ResourceSecretAAD(organizationID, resourceID, secretRef, protocol string) []byte {
return []byte(strings.Join([]string{
"rap-resource-secret-v1",
strings.TrimSpace(organizationID),
strings.TrimSpace(resourceID),
strings.TrimSpace(secretRef),
strings.ToLower(strings.TrimSpace(protocol)),
}, "|"))
}
@@ -0,0 +1,65 @@
package secrets
import (
"encoding/base64"
"encoding/json"
"testing"
)
func TestEncryptorRoundTrip(t *testing.T) {
key := base64.StdEncoding.EncodeToString([]byte("0123456789abcdef0123456789abcdef"))
encryptor, err := NewEncryptor(key, "test-key")
if err != nil {
t.Fatalf("NewEncryptor returned error: %v", err)
}
aad := ResourceSecretAAD("org-1", "resource-1", "rap-secret://test", "rdp")
encrypted, err := encryptor.Encrypt([]byte(`{"username":"user","password":"secret"}`), aad)
if err != nil {
t.Fatalf("Encrypt returned error: %v", err)
}
plaintext, err := encryptor.Decrypt(encrypted, aad)
if err != nil {
t.Fatalf("Decrypt returned error: %v", err)
}
if string(plaintext) != `{"username":"user","password":"secret"}` {
t.Fatalf("unexpected plaintext: %s", plaintext)
}
}
func TestEncryptorRejectsWrongAAD(t *testing.T) {
key := base64.StdEncoding.EncodeToString([]byte("0123456789abcdef0123456789abcdef"))
encryptor, err := NewEncryptor(key, "test-key")
if err != nil {
t.Fatalf("NewEncryptor returned error: %v", err)
}
encrypted, err := encryptor.Encrypt([]byte(`{"password":"secret"}`), ResourceSecretAAD("org-1", "resource-1", "ref", "rdp"))
if err != nil {
t.Fatalf("Encrypt returned error: %v", err)
}
if _, err := encryptor.Decrypt(encrypted, ResourceSecretAAD("org-2", "resource-1", "ref", "rdp")); err == nil {
t.Fatalf("expected decrypt with wrong aad to fail")
}
}
func TestMergeResourceSecretIntoAssignmentMetadata(t *testing.T) {
metadata := map[string]any{
"resource": map[string]any{
"id": "resource-1",
"metadata": map[string]any{
"rdp_host": "host",
},
},
}
merged, err := MergeResourceSecretIntoAssignmentMetadata(metadata, json.RawMessage(`{"username":"user","password":"secret","domain":"corp"}`))
if err != nil {
t.Fatalf("MergeResourceSecretIntoAssignmentMetadata returned error: %v", err)
}
resource := merged.Metadata["resource"].(map[string]any)
resourceMetadata := resource["metadata"].(map[string]any)
if resourceMetadata["rdp_host"] != "host" {
t.Fatalf("existing metadata was not preserved")
}
if resourceMetadata["username"] != "user" || resourceMetadata["password"] != "secret" || resourceMetadata["domain"] != "corp" {
t.Fatalf("secret payload was not merged: %#v", resourceMetadata)
}
}
@@ -0,0 +1,132 @@
package secrets
import (
"encoding/json"
"errors"
"fmt"
"slices"
"sort"
"strings"
)
var (
ErrPlaintextResourceCredentials = errors.New("plaintext resource credentials are not allowed in metadata in production")
ErrMissingResourceSecretRef = errors.New("secret_ref is required for this resource protocol in production")
)
var credentialKeyFragments = []string{
"accesstoken",
"clientsecret",
"credential",
"credentials",
"domain",
"password",
"privatekey",
"refreshtoken",
"secret",
"secrets",
"token",
"user",
"username",
}
var safeReferenceKeys = []string{
"certificateverificationmode",
"renderqualityprofile",
"secretref",
"secretreference",
"vaultref",
}
func ValidateResourceSecretReadiness(protocol string, secretRef *string, metadata json.RawMessage, appEnv string) error {
if !IsProductionEnv(appEnv) {
return nil
}
paths, err := PlaintextCredentialMetadataPaths(metadata)
if err != nil {
return err
}
if len(paths) > 0 {
return fmt.Errorf("%w: %s", ErrPlaintextResourceCredentials, strings.Join(paths, ", "))
}
if ResourceProtocolRequiresSecretRef(protocol) && (secretRef == nil || strings.TrimSpace(*secretRef) == "") {
return ErrMissingResourceSecretRef
}
return nil
}
func IsProductionEnv(appEnv string) bool {
switch strings.ToLower(strings.TrimSpace(appEnv)) {
case "prod", "production":
return true
default:
return false
}
}
func ResourceProtocolRequiresSecretRef(protocol string) bool {
switch strings.ToLower(strings.TrimSpace(protocol)) {
case "rdp", "vnc", "ssh":
return true
default:
return false
}
}
func PlaintextCredentialMetadataPaths(raw json.RawMessage) ([]string, error) {
if len(raw) == 0 {
return nil, nil
}
var value any
if err := json.Unmarshal(raw, &value); err != nil {
return nil, errors.New("metadata must be valid json")
}
metadata, ok := value.(map[string]any)
if !ok {
return nil, errors.New("metadata must be a json object")
}
var paths []string
collectCredentialPaths(metadata, "", &paths)
sort.Strings(paths)
return slices.Compact(paths), nil
}
func collectCredentialPaths(value any, prefix string, paths *[]string) {
switch typed := value.(type) {
case map[string]any:
for key, child := range typed {
path := key
if prefix != "" {
path = prefix + "." + key
}
if isCredentialMetadataKey(key) {
*paths = append(*paths, path)
}
collectCredentialPaths(child, path, paths)
}
case []any:
for index, child := range typed {
collectCredentialPaths(child, fmt.Sprintf("%s[%d]", prefix, index), paths)
}
}
}
func isCredentialMetadataKey(key string) bool {
normalized := normalizeMetadataKey(key)
if slices.Contains(safeReferenceKeys, normalized) {
return false
}
for _, fragment := range credentialKeyFragments {
if normalized == fragment || strings.HasSuffix(normalized, fragment) {
return true
}
}
return false
}
func normalizeMetadataKey(key string) string {
key = strings.ToLower(strings.TrimSpace(key))
replacer := strings.NewReplacer("_", "", "-", "", " ", "", ".", "")
return replacer.Replace(key)
}
@@ -0,0 +1,52 @@
package secrets
import (
"encoding/json"
"errors"
"slices"
"testing"
)
func TestValidateResourceSecretReadinessAllowsPlaintextInDevelopment(t *testing.T) {
metadata := json.RawMessage(`{"username":"m","password":"secret"}`)
if err := ValidateResourceSecretReadiness("rdp", nil, metadata, "development"); err != nil {
t.Fatalf("development metadata should remain allowed for smoke/dev: %v", err)
}
}
func TestValidateResourceSecretReadinessRejectsPlaintextCredentialsInProduction(t *testing.T) {
metadata := json.RawMessage(`{"rdp_host":"host","credentials":{"username":"m","password":"secret"}}`)
err := ValidateResourceSecretReadiness("rdp", stringPtr("vault://org/resource"), metadata, "production")
if !errors.Is(err, ErrPlaintextResourceCredentials) {
t.Fatalf("expected plaintext credential rejection, got %v", err)
}
paths, err := PlaintextCredentialMetadataPaths(metadata)
if err != nil {
t.Fatalf("metadata paths: %v", err)
}
for _, expected := range []string{"credentials", "credentials.password", "credentials.username"} {
if !slices.Contains(paths, expected) {
t.Fatalf("expected sensitive path %q in %v", expected, paths)
}
}
}
func TestValidateResourceSecretReadinessRequiresSecretRefForProductionRDP(t *testing.T) {
metadata := json.RawMessage(`{"rdp_host":"host","rdp_port":3389}`)
err := ValidateResourceSecretReadiness("rdp", nil, metadata, "production")
if !errors.Is(err, ErrMissingResourceSecretRef) {
t.Fatalf("expected missing secret_ref rejection, got %v", err)
}
}
func TestValidateResourceSecretReadinessAllowsProductionSecretRef(t *testing.T) {
metadata := json.RawMessage(`{"rdp_host":"host","rdp_port":3389,"secret_ref":"vault://org/resource"}`)
if err := ValidateResourceSecretReadiness("rdp", stringPtr("vault://org/resource"), metadata, "production"); err != nil {
t.Fatalf("production secret_ref metadata should be accepted: %v", err)
}
}
func stringPtr(value string) *string {
return &value
}
@@ -0,0 +1,259 @@
package secrets
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/jackc/pgx/v5"
postgresplatform "github.com/example/remote-access-platform/backend/internal/platform/postgres"
)
var (
ErrResourceSecretNotFound = errors.New("resource secret not found")
ErrSecretAccessDenied = errors.New("resource secret access denied")
ErrSecretLeaseRequired = errors.New("resource secret resolution requires lease proof")
)
type ResourceSecretStore struct {
db postgresplatform.DBTX
encryptor *Encryptor
now func() time.Time
}
type ResourceSecretResolver interface {
ResolveForSession(ctx context.Context, req ResolveResourceSecretRequest) (*ResolvedResourceSecret, error)
}
type ResourceSecretDescriptor struct {
ID string `json:"id"`
OrganizationID string `json:"organization_id"`
ResourceID string `json:"resource_id"`
SecretRef string `json:"secret_ref"`
Protocol string `json:"protocol"`
Version int `json:"version"`
KeyID string `json:"key_id"`
Algorithm string `json:"algorithm"`
Metadata json.RawMessage `json:"metadata"`
CreatedAt time.Time `json:"created_at"`
RotatedAt *time.Time `json:"rotated_at,omitempty"`
}
type UpsertResourceSecretCommand struct {
OrganizationID string
ResourceID string
Protocol string
SecretRef string
Payload json.RawMessage
Metadata json.RawMessage
ActorUserID string
}
type ResolveResourceSecretRequest struct {
SecretRef string
OrganizationID string
ResourceID string
SessionID string
WorkerID string
LeaseID string
}
type ResolvedResourceSecret struct {
Descriptor ResourceSecretDescriptor
Payload json.RawMessage
}
func NewResourceSecretStore(db postgresplatform.DBTX, encryptor *Encryptor) *ResourceSecretStore {
return &ResourceSecretStore{db: db, encryptor: encryptor, now: time.Now}
}
func (s *ResourceSecretStore) WithDB(db postgresplatform.DBTX) *ResourceSecretStore {
if s == nil {
return nil
}
return &ResourceSecretStore{db: db, encryptor: s.encryptor, now: s.now}
}
func DefaultResourceSecretRef(organizationID, resourceID string) string {
return "rap-secret://org/" + strings.TrimSpace(organizationID) + "/resources/" + strings.TrimSpace(resourceID) + "/primary"
}
func (s *ResourceSecretStore) Upsert(ctx context.Context, cmd UpsertResourceSecretCommand) (*ResourceSecretDescriptor, error) {
if s == nil || s.encryptor == nil {
return nil, ErrSecretEncryptionKeyMissing
}
payload, err := normalizeJSONObject(cmd.Payload)
if err != nil {
return nil, err
}
metadata, err := normalizeJSONObjectAllowEmpty(cmd.Metadata)
if err != nil {
return nil, err
}
secretRef := strings.TrimSpace(cmd.SecretRef)
if secretRef == "" {
secretRef = DefaultResourceSecretRef(cmd.OrganizationID, cmd.ResourceID)
}
protocol := strings.ToLower(strings.TrimSpace(cmd.Protocol))
encrypted, err := s.encryptor.Encrypt(payload, ResourceSecretAAD(cmd.OrganizationID, cmd.ResourceID, secretRef, protocol))
if err != nil {
return nil, err
}
now := s.now().UTC()
const query = `
INSERT INTO resource_secrets (
organization_id, resource_id, secret_ref, protocol, version, key_id,
algorithm, nonce, ciphertext, payload_sha256, metadata, created_by_user_id,
created_at, rotated_at
) VALUES (
$1::uuid, $2::uuid, $3, $4, 1, $5,
$6, $7, $8, $9, $10::jsonb, NULLIF($11, '')::uuid,
$12, NULL
)
ON CONFLICT (resource_id) DO UPDATE SET
secret_ref = EXCLUDED.secret_ref,
protocol = EXCLUDED.protocol,
version = resource_secrets.version + 1,
key_id = EXCLUDED.key_id,
algorithm = EXCLUDED.algorithm,
nonce = EXCLUDED.nonce,
ciphertext = EXCLUDED.ciphertext,
payload_sha256 = EXCLUDED.payload_sha256,
metadata = EXCLUDED.metadata,
created_by_user_id = EXCLUDED.created_by_user_id,
rotated_at = EXCLUDED.created_at
RETURNING id::text, organization_id::text, resource_id::text, secret_ref,
protocol, version, key_id, algorithm, metadata, created_at, rotated_at
`
var descriptor ResourceSecretDescriptor
if err := s.db.QueryRow(ctx, query,
cmd.OrganizationID,
cmd.ResourceID,
secretRef,
protocol,
encrypted.KeyID,
encrypted.Algorithm,
encrypted.Nonce,
encrypted.Ciphertext,
encrypted.PayloadSHA256,
metadata,
cmd.ActorUserID,
now,
).Scan(
&descriptor.ID,
&descriptor.OrganizationID,
&descriptor.ResourceID,
&descriptor.SecretRef,
&descriptor.Protocol,
&descriptor.Version,
&descriptor.KeyID,
&descriptor.Algorithm,
&descriptor.Metadata,
&descriptor.CreatedAt,
&descriptor.RotatedAt,
); err != nil {
return nil, fmt.Errorf("upsert resource secret: %w", err)
}
return &descriptor, nil
}
func (s *ResourceSecretStore) ResolveForSession(ctx context.Context, req ResolveResourceSecretRequest) (*ResolvedResourceSecret, error) {
if s == nil || s.encryptor == nil {
return nil, ErrSecretEncryptionKeyMissing
}
if strings.TrimSpace(req.LeaseID) == "" {
return nil, ErrSecretLeaseRequired
}
const query = `
SELECT sec.id::text, sec.organization_id::text, sec.resource_id::text, sec.secret_ref,
sec.protocol, sec.version, sec.key_id, sec.algorithm, sec.metadata,
sec.created_at, sec.rotated_at, sec.nonce, sec.ciphertext,
rs.organization_id::text, rs.resource_id::text, COALESCE(rs.worker_id, ''), rs.state
FROM resource_secrets sec
JOIN remote_sessions rs ON rs.resource_id = sec.resource_id
WHERE sec.secret_ref = $1 AND rs.id = $2::uuid
`
var descriptor ResourceSecretDescriptor
var nonce, ciphertext []byte
var sessionOrganizationID, sessionResourceID, sessionWorkerID, sessionState string
if err := s.db.QueryRow(ctx, query, req.SecretRef, req.SessionID).Scan(
&descriptor.ID,
&descriptor.OrganizationID,
&descriptor.ResourceID,
&descriptor.SecretRef,
&descriptor.Protocol,
&descriptor.Version,
&descriptor.KeyID,
&descriptor.Algorithm,
&descriptor.Metadata,
&descriptor.CreatedAt,
&descriptor.RotatedAt,
&nonce,
&ciphertext,
&sessionOrganizationID,
&sessionResourceID,
&sessionWorkerID,
&sessionState,
); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, ErrResourceSecretNotFound
}
return nil, fmt.Errorf("resolve resource secret: %w", err)
}
if descriptor.OrganizationID != req.OrganizationID ||
descriptor.ResourceID != req.ResourceID ||
sessionOrganizationID != req.OrganizationID ||
sessionResourceID != req.ResourceID ||
sessionWorkerID != req.WorkerID ||
!secretResolvableSessionState(sessionState) {
return nil, ErrSecretAccessDenied
}
plaintext, err := s.encryptor.Decrypt(EncryptedPayload{
Algorithm: descriptor.Algorithm,
KeyID: descriptor.KeyID,
Nonce: nonce,
Ciphertext: ciphertext,
}, ResourceSecretAAD(descriptor.OrganizationID, descriptor.ResourceID, descriptor.SecretRef, descriptor.Protocol))
if err != nil {
return nil, err
}
return &ResolvedResourceSecret{
Descriptor: descriptor,
Payload: json.RawMessage(plaintext),
}, nil
}
func normalizeJSONObject(raw json.RawMessage) (json.RawMessage, error) {
if len(raw) == 0 || !json.Valid(raw) {
return nil, ErrSecretPayloadInvalid
}
var decoded map[string]any
if err := json.Unmarshal(raw, &decoded); err != nil {
return nil, ErrSecretPayloadInvalid
}
encoded, err := json.Marshal(decoded)
if err != nil {
return nil, err
}
return json.RawMessage(encoded), nil
}
func normalizeJSONObjectAllowEmpty(raw json.RawMessage) (json.RawMessage, error) {
if len(raw) == 0 {
return json.RawMessage(`{}`), nil
}
return normalizeJSONObject(raw)
}
func secretResolvableSessionState(state string) bool {
switch state {
case "starting", "active", "reconnecting":
return true
default:
return false
}
}