Files
rdp-proxy/backend/internal/modules/resource/module.go
T
2026-05-12 21:02:29 +03:00

644 lines
22 KiB
Go

package resource
import (
"context"
"encoding/json"
"errors"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/example/remote-access-platform/backend/internal/platform/authority"
"github.com/example/remote-access-platform/backend/internal/platform/httpx"
"github.com/example/remote-access-platform/backend/internal/platform/module"
"github.com/example/remote-access-platform/backend/internal/platform/secrets"
)
const (
CertificateVerificationModeStrict = "strict"
CertificateVerificationModeIgnore = "ignore"
RenderQualityProfileLowBandwidth = "low_bandwidth"
RenderQualityProfileBalanced = "balanced"
RenderQualityProfileHighQuality = "high_quality"
RenderQualityProfileTextPriority = "text_priority"
ClipboardModeDisabled = "disabled"
ClipboardModeClientToServer = "client_to_server"
ClipboardModeServerToClient = "server_to_client"
ClipboardModeBidirectional = "bidirectional"
FileTransferModeDisabled = "disabled"
FileTransferModeClientToServer = "client_to_server"
FileTransferModeServerToClient = "server_to_client"
FileTransferModeBidirectional = "bidirectional"
)
type Module struct {
db *pgxpool.Pool
appEnv string
secretStore *secrets.ResourceSecretStore
authority *authority.Verifier
}
type Resource struct {
ID string `json:"id"`
OrganizationID string `json:"organization_id"`
Name string `json:"name"`
Address string `json:"address"`
Protocol string `json:"protocol"`
SecretRef *string `json:"secret_ref,omitempty"`
HasSecret bool `json:"has_secret"`
CertificateVerificationMode string `json:"certificate_verification_mode"`
RenderQualityProfile string `json:"render_quality_profile"`
ClipboardMode string `json:"clipboard_mode"`
FileTransferMode string `json:"file_transfer_mode"`
Metadata json.RawMessage `json:"metadata"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type upsertResourceRequest struct {
ActorUserID string `json:"actor_user_id"`
OrganizationID string `json:"organization_id"`
Name string `json:"name"`
Address string `json:"address"`
Protocol string `json:"protocol"`
SecretRef *string `json:"secret_ref"`
CertificateVerificationMode string `json:"certificate_verification_mode"`
RenderQualityProfile string `json:"render_quality_profile"`
ClipboardMode string `json:"clipboard_mode"`
FileTransferMode string `json:"file_transfer_mode"`
Metadata json.RawMessage `json:"metadata"`
}
type upsertResourceSecretRequest struct {
ActorUserID string `json:"actor_user_id"`
Payload json.RawMessage `json:"payload"`
Metadata json.RawMessage `json:"metadata"`
}
func NewModule(deps module.Dependencies, secretStores ...*secrets.ResourceSecretStore) *Module {
var secretStore *secrets.ResourceSecretStore
if len(secretStores) > 0 {
secretStore = secretStores[0]
}
authorityVerifier, _ := authority.NewVerifier(deps.Config.Installation)
return &Module{db: deps.Infra.DB, appEnv: deps.Config.App.Env, secretStore: secretStore, authority: authorityVerifier}
}
func (m *Module) Name() string {
return "resource"
}
func (m *Module) RegisterRoutes(router chi.Router) {
router.Route("/resources", func(r chi.Router) {
r.Get("/", m.listResources)
r.Post("/", m.createResource)
r.Get("/{resourceID}", m.getResource)
r.Put("/{resourceID}", m.updateResource)
r.Put("/{resourceID}/secret", m.upsertResourceSecret)
})
}
func (m *Module) listResources(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id")
orgID := r.URL.Query().Get("organization_id")
if userID == "" {
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
return
}
platformRole, err := m.getPlatformRole(r.Context(), userID)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
query := `
SELECT r.id, r.organization_id, r.name, r.address, r.protocol, r.secret_ref,
r.certificate_verification_mode, r.metadata, r.created_at, r.updated_at,
EXISTS (SELECT 1 FROM resource_secrets sec WHERE sec.resource_id = r.id) AS has_secret,
COALESCE(rp.clipboard_mode, 'disabled') AS clipboard_mode,
COALESCE(rp.file_transfer_mode, 'disabled') AS file_transfer_mode
FROM resources r
LEFT JOIN resource_policies rp ON rp.resource_id = r.id
`
args := make([]any, 0, 2)
if platformRole == "platform_admin" || platformRole == "platform_recovery_admin" {
if orgID != "" {
query += ` WHERE r.organization_id = $1`
args = append(args, orgID)
}
query += ` ORDER BY r.created_at DESC`
} else {
query += `
INNER JOIN organization_memberships om ON om.organization_id = r.organization_id
WHERE om.user_id = $1 AND om.status = 'active'
`
args = append(args, userID)
if orgID != "" {
query += ` AND r.organization_id = $2`
args = append(args, orgID)
}
query += ` ORDER BY r.created_at DESC`
}
rows, err := m.db.Query(r.Context(), query, args...)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer rows.Close()
resources := make([]Resource, 0)
for rows.Next() {
resource, err := scanResource(rows)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
resources = append(resources, resource)
}
if err := rows.Err(); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"resources": resources})
}
func (m *Module) getResource(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id")
if userID == "" {
httpx.WriteError(w, http.StatusBadRequest, "user_id is required")
return
}
resource, err := m.getByID(r.Context(), chi.URLParam(r, "resourceID"))
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
httpx.WriteError(w, http.StatusNotFound, "resource not found")
return
}
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := m.ensureResourceAccess(r.Context(), resource.OrganizationID, userID, false); err != nil {
httpx.WriteError(w, http.StatusForbidden, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"resource": resource})
}
func (m *Module) createResource(w http.ResponseWriter, r *http.Request) {
req, err := decodeUpsertRequest(r)
if err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
if err := secrets.ValidateResourceSecretReadiness(req.Protocol, req.SecretRef, req.Metadata, m.appEnv); err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
now := time.Now().UTC()
resource := Resource{
ID: uuid.NewString(),
OrganizationID: req.OrganizationID,
Name: req.Name,
Address: req.Address,
Protocol: req.Protocol,
SecretRef: req.SecretRef,
CertificateVerificationMode: req.CertificateVerificationMode,
RenderQualityProfile: req.RenderQualityProfile,
ClipboardMode: req.ClipboardMode,
FileTransferMode: req.FileTransferMode,
Metadata: req.Metadata,
CreatedAt: now,
UpdatedAt: now,
}
if err := m.ensureResourceAccess(r.Context(), req.OrganizationID, req.ActorUserID, true); err != nil {
httpx.WriteError(w, http.StatusForbidden, err.Error())
return
}
tx, err := m.db.Begin(r.Context())
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer tx.Rollback(r.Context())
if _, err := tx.Exec(r.Context(), `
INSERT INTO resources (
id, organization_id, name, address, protocol, secret_ref, certificate_verification_mode, metadata, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9, $10)
`, resource.ID, resource.OrganizationID, resource.Name, resource.Address, resource.Protocol, resource.SecretRef, resource.CertificateVerificationMode, []byte(resource.Metadata), resource.CreatedAt, resource.UpdatedAt); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := upsertResourcePolicy(r.Context(), tx, resource.ID, resource.ClipboardMode, resource.FileTransferMode, now); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := tx.Commit(r.Context()); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusCreated, map[string]any{"resource": resource})
}
func (m *Module) updateResource(w http.ResponseWriter, r *http.Request) {
req, err := decodeUpsertRequest(r)
if err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
if err := secrets.ValidateResourceSecretReadiness(req.Protocol, req.SecretRef, req.Metadata, m.appEnv); err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
resourceID := chi.URLParam(r, "resourceID")
existing, err := m.getByID(r.Context(), resourceID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
httpx.WriteError(w, http.StatusNotFound, "resource not found")
return
}
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := m.ensureResourceAccess(r.Context(), existing.OrganizationID, req.ActorUserID, true); err != nil {
httpx.WriteError(w, http.StatusForbidden, err.Error())
return
}
if req.OrganizationID != existing.OrganizationID {
if err := m.ensureResourceAccess(r.Context(), req.OrganizationID, req.ActorUserID, true); err != nil {
httpx.WriteError(w, http.StatusForbidden, err.Error())
return
}
}
now := time.Now().UTC()
tx, err := m.db.Begin(r.Context())
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer tx.Rollback(r.Context())
tag, err := tx.Exec(r.Context(), `
UPDATE resources
SET
organization_id = $2,
name = $3,
address = $4,
protocol = $5,
secret_ref = $6,
certificate_verification_mode = $7,
metadata = $8::jsonb,
updated_at = $9
WHERE id = $1
`, resourceID, req.OrganizationID, req.Name, req.Address, req.Protocol, req.SecretRef, req.CertificateVerificationMode, []byte(req.Metadata), now)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if tag.RowsAffected() == 0 {
httpx.WriteError(w, http.StatusNotFound, "resource not found")
return
}
if err := upsertResourcePolicy(r.Context(), tx, resourceID, req.ClipboardMode, req.FileTransferMode, now); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := tx.Commit(r.Context()); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
resource, err := m.getByID(r.Context(), resourceID)
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"resource": resource})
}
func (m *Module) upsertResourceSecret(w http.ResponseWriter, r *http.Request) {
if m.secretStore == nil {
httpx.WriteError(w, http.StatusServiceUnavailable, "resource secret encryption is not configured")
return
}
var req upsertResourceSecretRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
httpx.WriteError(w, http.StatusBadRequest, "invalid resource secret payload")
return
}
if req.ActorUserID == "" {
httpx.WriteError(w, http.StatusBadRequest, "actor_user_id is required")
return
}
resourceID := chi.URLParam(r, "resourceID")
resource, err := m.getByID(r.Context(), resourceID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
httpx.WriteError(w, http.StatusNotFound, "resource not found")
return
}
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := m.ensureResourceAccess(r.Context(), resource.OrganizationID, req.ActorUserID, true); err != nil {
httpx.WriteError(w, http.StatusForbidden, err.Error())
return
}
tx, err := m.db.Begin(r.Context())
if err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
defer tx.Rollback(r.Context())
secretStore := m.secretStore.WithDB(tx)
secretRef := secrets.DefaultResourceSecretRef(resource.OrganizationID, resource.ID)
descriptor, err := secretStore.Upsert(r.Context(), secrets.UpsertResourceSecretCommand{
OrganizationID: resource.OrganizationID,
ResourceID: resource.ID,
Protocol: resource.Protocol,
SecretRef: secretRef,
Payload: req.Payload,
Metadata: req.Metadata,
ActorUserID: req.ActorUserID,
})
if err != nil {
httpx.WriteError(w, http.StatusBadRequest, err.Error())
return
}
if _, err := tx.Exec(r.Context(), `
UPDATE resources
SET secret_ref = $2, updated_at = $3
WHERE id = $1::uuid
`, resource.ID, descriptor.SecretRef, time.Now().UTC()); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := writeAuditEvent(r.Context(), tx, "resource_secret_rotated", req.ActorUserID, "resource_secret", descriptor.SecretRef, map[string]any{
"resource_id": resource.ID,
"organization_id": resource.OrganizationID,
"protocol": resource.Protocol,
"version": descriptor.Version,
"secret_ref": descriptor.SecretRef,
}); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
if err := tx.Commit(r.Context()); err != nil {
httpx.WriteError(w, http.StatusInternalServerError, err.Error())
return
}
httpx.WriteJSON(w, http.StatusOK, map[string]any{"secret": descriptor})
}
func decodeUpsertRequest(r *http.Request) (*upsertResourceRequest, error) {
var req upsertResourceRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.New("invalid resource payload")
}
if req.Name == "" {
return nil, errors.New("name is required")
}
if req.ActorUserID == "" {
return nil, errors.New("actor_user_id is required")
}
if req.OrganizationID == "" {
return nil, errors.New("organization_id is required")
}
if req.Address == "" {
return nil, errors.New("address is required")
}
if req.Protocol == "" {
req.Protocol = "rdp"
}
mode, err := normalizeCertificateVerificationMode(req.CertificateVerificationMode)
if err != nil {
return nil, err
}
req.CertificateVerificationMode = mode
renderQualityProfile, err := normalizeRenderQualityProfile(req.RenderQualityProfile)
if err != nil {
return nil, err
}
req.RenderQualityProfile = renderQualityProfile
clipboardMode, err := normalizeClipboardMode(req.ClipboardMode)
if err != nil {
return nil, err
}
req.ClipboardMode = clipboardMode
fileTransferMode, err := normalizeFileTransferMode(req.FileTransferMode)
if err != nil {
return nil, err
}
req.FileTransferMode = fileTransferMode
metadata, err := normalizeMetadata(req.Metadata, req.CertificateVerificationMode, req.RenderQualityProfile)
if err != nil {
return nil, err
}
req.Metadata = metadata
return &req, nil
}
func normalizeCertificateVerificationMode(mode string) (string, error) {
switch mode {
case "", CertificateVerificationModeStrict:
return CertificateVerificationModeStrict, nil
case CertificateVerificationModeIgnore:
return CertificateVerificationModeIgnore, nil
default:
return "", errors.New("certificate_verification_mode must be one of: strict, ignore")
}
}
func normalizeClipboardMode(mode string) (string, error) {
switch mode {
case "", ClipboardModeDisabled:
return ClipboardModeDisabled, nil
case ClipboardModeClientToServer, ClipboardModeServerToClient, ClipboardModeBidirectional:
return mode, nil
default:
return "", errors.New("clipboard_mode must be one of: disabled, client_to_server, server_to_client, bidirectional")
}
}
func normalizeFileTransferMode(mode string) (string, error) {
switch mode {
case "", FileTransferModeDisabled:
return FileTransferModeDisabled, nil
case FileTransferModeClientToServer, FileTransferModeServerToClient, FileTransferModeBidirectional:
return mode, nil
default:
return "", errors.New("file_transfer_mode must be one of: disabled, client_to_server, server_to_client, bidirectional")
}
}
func normalizeMetadata(raw json.RawMessage, certificateVerificationMode, renderQualityProfile string) (json.RawMessage, error) {
if len(raw) == 0 {
raw = json.RawMessage(`{}`)
}
if !json.Valid(raw) {
return nil, errors.New("metadata must be valid json")
}
var metadata map[string]any
if err := json.Unmarshal(raw, &metadata); err != nil {
return nil, errors.New("metadata must be a json object")
}
metadata["certificate_verification_mode"] = certificateVerificationMode
metadata["render_quality_profile"] = renderQualityProfile
encoded, err := json.Marshal(metadata)
if err != nil {
return nil, err
}
return json.RawMessage(encoded), nil
}
func (m *Module) getByID(ctx context.Context, resourceID string) (Resource, error) {
row := m.db.QueryRow(ctx, `
SELECT r.id, r.organization_id, r.name, r.address, r.protocol, r.secret_ref,
r.certificate_verification_mode, r.metadata, r.created_at, r.updated_at,
EXISTS (SELECT 1 FROM resource_secrets sec WHERE sec.resource_id = r.id) AS has_secret,
COALESCE(rp.clipboard_mode, 'disabled') AS clipboard_mode,
COALESCE(rp.file_transfer_mode, 'disabled') AS file_transfer_mode
FROM resources r
LEFT JOIN resource_policies rp ON rp.resource_id = r.id
WHERE r.id = $1
`, resourceID)
return scanResource(row)
}
func (m *Module) ensureResourceAccess(ctx context.Context, orgID, userID string, adminRequired bool) error {
role, err := m.getPlatformRole(ctx, userID)
if err != nil {
return err
}
if role == "platform_admin" || role == "platform_recovery_admin" {
return nil
}
var membershipRole string
if err := m.db.QueryRow(ctx, `
SELECT role_id
FROM organization_memberships
WHERE organization_id = $1 AND user_id = $2 AND status = 'active'
`, orgID, userID).Scan(&membershipRole); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return errors.New("forbidden")
}
return err
}
if adminRequired && membershipRole != "org_owner" && membershipRole != "org_admin" {
return errors.New("forbidden")
}
return nil
}
func (m *Module) getPlatformRole(ctx context.Context, userID string) (string, error) {
return authority.EffectivePlatformRole(ctx, m.db, m.authority, userID)
}
type rowScanner interface {
Scan(dest ...any) error
}
func scanResource(row rowScanner) (Resource, error) {
var resource Resource
if err := row.Scan(
&resource.ID,
&resource.OrganizationID,
&resource.Name,
&resource.Address,
&resource.Protocol,
&resource.SecretRef,
&resource.CertificateVerificationMode,
&resource.Metadata,
&resource.CreatedAt,
&resource.UpdatedAt,
&resource.HasSecret,
&resource.ClipboardMode,
&resource.FileTransferMode,
); err != nil {
return Resource{}, err
}
if len(resource.Metadata) == 0 {
resource.Metadata = json.RawMessage(`{}`)
}
if resource.CertificateVerificationMode == "" {
resource.CertificateVerificationMode = CertificateVerificationModeStrict
}
if resource.RenderQualityProfile == "" {
resource.RenderQualityProfile = renderQualityProfileFromMetadata(resource.Metadata)
}
if resource.ClipboardMode == "" {
resource.ClipboardMode = ClipboardModeDisabled
}
if resource.FileTransferMode == "" {
resource.FileTransferMode = FileTransferModeDisabled
}
return resource, nil
}
func upsertResourcePolicy(ctx context.Context, tx pgx.Tx, resourceID, clipboardMode, fileTransferMode string, now time.Time) error {
clipboardEnabled := clipboardMode != ClipboardModeDisabled
fileTransferEnabled := fileTransferMode == FileTransferModeClientToServer || fileTransferMode == FileTransferModeBidirectional
_, err := tx.Exec(ctx, `
INSERT INTO resource_policies (
resource_id, clipboard_enabled, clipboard_mode, file_transfer_enabled, file_transfer_mode, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $6)
ON CONFLICT (resource_id) DO UPDATE SET
clipboard_enabled = EXCLUDED.clipboard_enabled,
clipboard_mode = EXCLUDED.clipboard_mode,
file_transfer_enabled = EXCLUDED.file_transfer_enabled,
file_transfer_mode = EXCLUDED.file_transfer_mode,
updated_at = EXCLUDED.updated_at
`, resourceID, clipboardEnabled, clipboardMode, fileTransferEnabled, fileTransferMode, now)
return err
}
func writeAuditEvent(ctx context.Context, tx pgx.Tx, eventType, actorUserID, targetType, targetID string, payload map[string]any) error {
encoded, err := json.Marshal(payload)
if err != nil {
return err
}
_, err = tx.Exec(ctx, `
INSERT INTO audit_events (
id, actor_user_id, event_type, target_type, target_id, payload, created_at
) VALUES (
$1::uuid, NULLIF($2, '')::uuid, $3, $4, $5, $6::jsonb, $7
)
`, uuid.NewString(), actorUserID, eventType, targetType, targetID, encoded, time.Now().UTC())
return err
}
func normalizeRenderQualityProfile(profile string) (string, error) {
switch profile {
case "", RenderQualityProfileBalanced:
return RenderQualityProfileBalanced, nil
case RenderQualityProfileLowBandwidth, RenderQualityProfileHighQuality, RenderQualityProfileTextPriority:
return profile, nil
default:
return "", errors.New("render_quality_profile must be one of: low_bandwidth, balanced, high_quality, text_priority")
}
}
func renderQualityProfileFromMetadata(raw json.RawMessage) string {
if len(raw) == 0 {
return RenderQualityProfileBalanced
}
var metadata map[string]any
if err := json.Unmarshal(raw, &metadata); err != nil {
return RenderQualityProfileBalanced
}
if profile, ok := metadata["render_quality_profile"].(string); ok {
switch profile {
case RenderQualityProfileLowBandwidth, RenderQualityProfileBalanced, RenderQualityProfileHighQuality, RenderQualityProfileTextPriority:
return profile
}
}
return RenderQualityProfileBalanced
}