#include "rdp_worker/dataplane/direct_wss_server.hpp" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "rdp_worker/common/json.hpp" namespace rdp_worker::dataplane { namespace asio = boost::asio; namespace beast = boost::beast; namespace http = beast::http; namespace websocket = beast::websocket; using tcp = asio::ip::tcp; namespace { constexpr std::size_t kMaxReliableOutboundEvents = 256; constexpr std::size_t kMaxOrderedRenderFrames = 256; constexpr auto kMinBinaryRenderFrameInterval = std::chrono::milliseconds(100); constexpr auto kMinBinaryRegionRenderFrameInterval = std::chrono::milliseconds(33); constexpr std::string_view kBinaryRenderTransportV1 = "binary_v1"; constexpr std::string_view kColorModeFullColor = "full_color"; constexpr std::string_view kColorModeGrayscale = "grayscale"; constexpr std::string_view kRenderFrameFullMessageType = "render.frame.full"; constexpr std::string_view kRenderFrameRegionMessageType = "render.frame.region"; struct BinaryRenderFrameMessage { std::string message; std::string message_type; std::string color_mode; std::string update_kind; bool grayscale_conversion_applied{false}; bool dirty_region_applied{false}; std::size_t raw_frame_bytes_before{0}; std::size_t raw_frame_bytes_after{0}; std::size_t binary_direct_bytes{0}; std::size_t full_frame_bytes{0}; std::size_t region_bytes{0}; double region_savings_percent{0.0}; std::int64_t conversion_time_ms{0}; std::int64_t diff_time_ms{0}; std::string render_update_reason; std::string fallback_to_full_frame_reason; }; std::int64_t UnixMillisecondsNow() { return std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()) .count(); } void AppendUint16Le(std::string& output, std::uint16_t value) { output.push_back(static_cast(value & 0xFF)); output.push_back(static_cast((value >> 8) & 0xFF)); } void AppendUint32Le(std::string& output, std::uint32_t value) { output.push_back(static_cast(value & 0xFF)); output.push_back(static_cast((value >> 8) & 0xFF)); output.push_back(static_cast((value >> 16) & 0xFF)); output.push_back(static_cast((value >> 24) & 0xFF)); } std::string Base64Encode(const std::uint8_t* data, std::size_t size) { static constexpr std::array alphabet{ 'A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P', 'Q','R','S','T','U','V','W','X','Y','Z','a','b','c','d','e','f', 'g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v', 'w','x','y','z','0','1','2','3','4','5','6','7','8','9','+','/' }; std::string output; output.reserve(((size + 2) / 3) * 4); for (std::size_t index = 0; index < size; ) { const std::size_t remaining = size - index; const std::uint32_t octet_a = data[index++]; const std::uint32_t octet_b = remaining > 1 ? data[index++] : 0; const std::uint32_t octet_c = remaining > 2 ? data[index++] : 0; const std::uint32_t triple = (octet_a << 16) | (octet_b << 8) | octet_c; output.push_back(alphabet[(triple >> 18) & 0x3F]); output.push_back(alphabet[(triple >> 12) & 0x3F]); output.push_back(remaining > 1 ? alphabet[(triple >> 6) & 0x3F] : '='); output.push_back(remaining > 2 ? alphabet[triple & 0x3F] : '='); } return output; } std::optional DecodeBase64ToBytes(const std::string& input) { static constexpr unsigned char kInvalid = 255; static unsigned char table[256]; static bool initialized = false; if (!initialized) { std::fill(std::begin(table), std::end(table), kInvalid); const std::string alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; for (std::size_t i = 0; i < alphabet.size(); ++i) { table[static_cast(alphabet[i])] = static_cast(i); } initialized = true; } std::string output; output.reserve((input.size() * 3) / 4); int value = 0; int bits = -8; for (unsigned char ch : input) { if (ch == '=') { break; } if (table[ch] == kInvalid) { return std::nullopt; } value = (value << 6) + table[ch]; bits += 6; if (bits >= 0) { output.push_back(static_cast((value >> bits) & 0xFF)); bits -= 8; } } return output; } std::string UrlDecode(std::string_view value) { std::string output; output.reserve(value.size()); for (std::size_t i = 0; i < value.size(); ++i) { if (value[i] == '%' && i + 2 < value.size()) { const std::string hex{value[i + 1], value[i + 2]}; char* end = nullptr; const long decoded = std::strtol(hex.c_str(), &end, 16); if (end != nullptr && *end == '\0') { output.push_back(static_cast(decoded)); i += 2; continue; } } output.push_back(value[i] == '+' ? ' ' : value[i]); } return output; } std::string QueryParam(std::string_view target, std::string_view name) { const auto question = target.find('?'); if (question == std::string_view::npos) { return {}; } std::string_view query = target.substr(question + 1); while (!query.empty()) { const auto amp = query.find('&'); const auto pair = query.substr(0, amp); const auto equals = pair.find('='); if (equals != std::string_view::npos && pair.substr(0, equals) == name) { return UrlDecode(pair.substr(equals + 1)); } if (amp == std::string_view::npos) { break; } query = query.substr(amp + 1); } return {}; } std::string PathOnly(std::string_view target) { const auto question = target.find('?'); return std::string(target.substr(0, question == std::string_view::npos ? target.size() : question)); } std::string NormalizeColorMode(const std::string& requested) { return requested == kColorModeGrayscale ? std::string(kColorModeGrayscale) : std::string(kColorModeFullColor); } void ApplyGrayscaleBgra(std::string& frame_bytes) { for (std::size_t index = 0; index + 3 < frame_bytes.size(); index += 4) { const auto blue = static_cast(frame_bytes[index]); const auto green = static_cast(frame_bytes[index + 1]); const auto red = static_cast(frame_bytes[index + 2]); const auto gray = static_cast(((static_cast(red) * 77U) + (static_cast(green) * 150U) + (static_cast(blue) * 29U)) >> 8U); frame_bytes[index] = gray; frame_bytes[index + 1] = gray; frame_bytes[index + 2] = gray; } } int JsonIntOr(const common::JsonObject& payload, const std::string& key, int fallback) { return static_cast(common::GetNumber(payload, key).value_or(fallback)); } bool TryCropBgraRegion(const std::string& source, int full_width, int full_height, int full_stride, int region_x, int region_y, int region_width, int region_height, std::string& output) { if (full_width <= 0 || full_height <= 0 || full_stride <= 0 || region_x < 0 || region_y < 0 || region_width <= 0 || region_height <= 0 || region_x + region_width > full_width || region_y + region_height > full_height) { return false; } const int region_stride = region_width * 4; const auto required_source_bytes = static_cast(full_stride) * static_cast(full_height); if (source.size() < required_source_bytes || full_stride < region_x * 4 + region_stride) { return false; } output.assign(static_cast(region_stride) * static_cast(region_height), '\0'); for (int row = 0; row < region_height; ++row) { const auto source_offset = static_cast(region_y + row) * static_cast(full_stride) + static_cast(region_x * 4); const auto target_offset = static_cast(row) * static_cast(region_stride); std::copy_n(source.data() + source_offset, region_stride, output.data() + target_offset); } return true; } template void WriteReject(Stream& stream, http::status status, const std::string& reason) { http::response response{status, 11}; response.set(http::field::server, "rap-rdp-worker"); response.set(http::field::content_type, "application/json"); response.keep_alive(false); response.body() = rdp_worker::common::SerializeJson(rdp_worker::common::JsonObject{ {"error", rdp_worker::common::JsonObject{ {"code", reason}, {"message_key", "errors.data_plane." + reason}, {"fallback_message", "Data plane connection rejected."}, }}, }); response.prepare_payload(); beast::error_code ignored; http::write(stream, response, ignored); } std::string ChannelsToJsonArray(const std::vector& channels) { rdp_worker::common::JsonArray array; for (const auto& channel : channels) { array.emplace_back(channel); } return rdp_worker::common::SerializeJson(array); } std::string WorkerEventTypeToEnvelopeType(const runtime::WorkerEvent& event) { if (event.type == "session_frame") { return "session.frame"; } if (event.type == "session_cursor_updated") { return "cursor.update"; } if (event.type == "session_clipboard_text") { return "clipboard.text"; } if (event.type == "session_taken_over") { return "session.taken_over"; } if (event.type == "session_file_upload_completed") { return "file_upload.progress"; } if (event.type == "session_file_download_available") { return "file_download.available"; } if (event.type == "session_file_download_chunk") { return "file_download.chunk"; } if (event.type == "session_file_download_completed") { return "file_download.completed"; } if (event.type == "session_file_download_failed") { return "file_download.failed"; } if (event.type == "session_file_download_blocked") { return "file_download.blocked"; } if (event.type == "session_file_download_progress") { return "file_download.progress"; } if (event.type == "session_failed" || event.type == "session_terminated" || event.type == "session_connected" || event.type == "session_display_ready" || event.type == "session_heartbeat") { return "session.state"; } return event.type; } common::JsonObject RenderPayloadFromWorkerEvent(const runtime::WorkerEvent& event) { common::JsonObject render{ {"width", event.width}, {"height", event.height}, }; if (auto value = common::GetString(event.payload, "render_quality_profile"); value.has_value()) { render["quality_profile"] = *value; render["render_quality_profile"] = *value; } if (auto value = common::GetString(event.payload, "render_state"); value.has_value()) { render["state"] = *value; render["render_state"] = *value; } if (auto value = common::GetNumber(event.payload, "frame_sequence"); value.has_value()) { render["frame_sequence"] = *value; } if (auto value = common::GetString(event.payload, "frame_format"); value.has_value()) { render["frame_format"] = *value; } if (auto value = common::GetNumber(event.payload, "cursor_x"); value.has_value()) { render["cursor_x"] = *value; } if (auto value = common::GetNumber(event.payload, "cursor_y"); value.has_value()) { render["cursor_y"] = *value; } if (auto value = common::GetBool(event.payload, "cursor_visible"); value.has_value()) { render["cursor_visible"] = *value; } return render; } std::optional WorkerEventToBinaryRenderFrame(const runtime::WorkerEvent& event, const std::string& requested_color_mode, bool require_full_frame) { if (event.type != "session_frame") { return std::nullopt; } std::string frame_bytes; if (!event.raw_frame_bytes.empty()) { frame_bytes.assign(reinterpret_cast(event.raw_frame_bytes.data()), event.raw_frame_bytes.size()); } else { const auto encoded_frame = common::GetString(event.payload, "frame_data").value_or(""); if (encoded_frame.empty()) { return std::nullopt; } auto decoded = DecodeBase64ToBytes(encoded_frame); if (!decoded.has_value()) { return std::nullopt; } frame_bytes = std::move(*decoded); } const std::size_t raw_frame_bytes_before = frame_bytes.size(); const auto frame_sequence = static_cast(common::GetNumber(event.payload, "frame_sequence").value_or(0)); const auto source_frame_width = JsonIntOr(event.payload, "frame_width", event.width); const auto source_frame_height = JsonIntOr(event.payload, "frame_height", event.height); const auto source_frame_stride = JsonIntOr(event.payload, "frame_stride", source_frame_width * 4); const auto desktop_width = JsonIntOr(event.payload, "desktop_width", source_frame_width); const auto desktop_height = JsonIntOr(event.payload, "desktop_height", source_frame_height); const auto frame_format = common::GetString(event.payload, "frame_format").value_or("bgra32"); const auto quality_profile = common::GetString(event.payload, "render_quality_profile").value_or("balanced"); const auto input_correlation_id = common::GetString(event.payload, "input_correlation_id").value_or(""); const auto worker_frame_captured_at = common::GetString(event.payload, "worker_frame_captured_at").value_or(""); const auto requested_update_kind = common::GetString(event.payload, "frame_update_kind").value_or("full"); const auto render_update_reason = common::GetString(event.payload, "capture_source").value_or(""); const int region_x = JsonIntOr(event.payload, "region_x", 0); const int region_y = JsonIntOr(event.payload, "region_y", 0); const int region_width = JsonIntOr(event.payload, "region_width", 0); const int region_height = JsonIntOr(event.payload, "region_height", 0); const bool frame_data_is_region = common::GetBool(event.payload, "frame_data_is_region").value_or(false); const auto full_frame_bytes = static_cast(std::max(0, desktop_width)) * static_cast(std::max(0, desktop_height)) * 4U; const auto diff_time_ms = static_cast( common::GetNumber(event.payload, "periodic_change_poll_ms") .value_or(common::GetNumber(event.payload, "capture_total_ms").value_or(0))); std::string update_kind = "full"; std::string fallback_to_full_frame_reason; int output_frame_width = source_frame_width; int output_frame_height = source_frame_height; int output_frame_stride = source_frame_stride; bool dirty_region_applied = false; if (require_full_frame && requested_update_kind == "region") { fallback_to_full_frame_reason = "baseline_required"; } else if (!require_full_frame && requested_update_kind == "region") { if (frame_data_is_region && region_width > 0 && region_height > 0 && frame_bytes.size() == static_cast(region_width) * static_cast(region_height) * 4U) { output_frame_width = region_width; output_frame_height = region_height; output_frame_stride = region_width * 4; update_kind = "region"; dirty_region_applied = true; } else { std::string region_bytes; if (TryCropBgraRegion(frame_bytes, source_frame_width, source_frame_height, source_frame_stride, region_x, region_y, region_width, region_height, region_bytes)) { frame_bytes = std::move(region_bytes); output_frame_width = region_width; output_frame_height = region_height; output_frame_stride = region_width * 4; update_kind = "region"; dirty_region_applied = true; } } if (!dirty_region_applied) { fallback_to_full_frame_reason = "invalid_region_payload"; } } const auto conversion_started = std::chrono::steady_clock::now(); const std::string applied_color_mode = NormalizeColorMode(requested_color_mode); const bool grayscale = applied_color_mode == kColorModeGrayscale; if (grayscale) { ApplyGrayscaleBgra(frame_bytes); } const auto conversion_finished = std::chrono::steady_clock::now(); const auto conversion_time_ms = std::chrono::duration_cast( conversion_finished - conversion_started) .count(); const auto message_type = update_kind == "region" ? std::string(kRenderFrameRegionMessageType) : std::string(kRenderFrameFullMessageType); const std::size_t region_bytes = update_kind == "region" ? frame_bytes.size() : 0; const double region_savings_percent = update_kind == "region" && full_frame_bytes > 0 ? 100.0 - ((static_cast(region_bytes) * 100.0) / static_cast(full_frame_bytes)) : 0.0; common::JsonObject header{ {"protocol_version", 1}, {"session_id", event.session_id}, {"channel", "render"}, {"message_type", message_type}, {"legacy_message_type", "session.frame"}, {"sequence", static_cast(frame_sequence)}, {"timestamp", static_cast(UnixMillisecondsNow())}, {"flags", 0}, {"payload_length", static_cast(frame_bytes.size())}, {"frame_width", output_frame_width}, {"frame_height", output_frame_height}, {"frame_stride", output_frame_stride}, {"frame_format", frame_format}, {"frame_update_kind", update_kind}, {"desktop_width", desktop_width}, {"desktop_height", desktop_height}, {"region_x", dirty_region_applied ? region_x : 0}, {"region_y", dirty_region_applied ? region_y : 0}, {"region_width", dirty_region_applied ? region_width : output_frame_width}, {"region_height", dirty_region_applied ? region_height : output_frame_height}, {"region_stride", dirty_region_applied ? output_frame_stride : output_frame_stride}, {"region_format", "BGRA32"}, {"color_mode", applied_color_mode}, {"quality_profile", quality_profile}, {"original_frame_format", frame_format}, {"output_frame_format", frame_format}, {"raw_frame_bytes", static_cast(raw_frame_bytes_before)}, {"binary_direct_bytes", static_cast(frame_bytes.size())}, {"full_frame_bytes", static_cast(full_frame_bytes)}, {"region_bytes", static_cast(region_bytes)}, {"region_savings_percent", region_savings_percent}, {"diff_time_ms", static_cast(diff_time_ms)}, {"render_update_reason", render_update_reason}, {"full_frame_sent", update_kind == "full"}, {"region_frame_sent", update_kind == "region"}, }; if (!fallback_to_full_frame_reason.empty()) { header["fallback_to_full_frame_reason"] = fallback_to_full_frame_reason; } if (!input_correlation_id.empty()) { header["input_correlation_id"] = input_correlation_id; } if (!worker_frame_captured_at.empty()) { header["worker_frame_captured_at"] = worker_frame_captured_at; } const std::string header_json = common::SerializeJson(header); if (header_json.size() > std::numeric_limits::max() || frame_bytes.size() > std::numeric_limits::max()) { return std::nullopt; } std::string message; message.reserve(16 + header_json.size() + frame_bytes.size()); message.append("RAP2", 4); AppendUint16Le(message, 1); AppendUint16Le(message, 0); AppendUint32Le(message, static_cast(header_json.size())); AppendUint32Le(message, static_cast(frame_bytes.size())); message.append(header_json); message.append(frame_bytes); return BinaryRenderFrameMessage{ std::move(message), message_type, applied_color_mode, update_kind, grayscale, dirty_region_applied, raw_frame_bytes_before, frame_bytes.size(), frame_bytes.size(), full_frame_bytes, region_bytes, region_savings_percent, conversion_time_ms, diff_time_ms, render_update_reason, fallback_to_full_frame_reason, }; } common::JsonObject WorkerEventToEnvelope(const runtime::WorkerEvent& event) { common::JsonObject payload = event.payload; if (!event.reason.empty()) { payload["reason"] = event.reason; } const std::string envelope_type = WorkerEventTypeToEnvelopeType(event); if (envelope_type == "session.taken_over") { payload["state"] = "taken_over"; } else if (envelope_type == "session.state") { std::string state = "active"; if (event.type == "session_failed") { state = "failed"; } else if (event.type == "session_terminated") { state = "terminated"; } else if (auto value = common::GetString(payload, "state"); value.has_value() && !value->empty()) { state = *value; } payload["state"] = state; payload["render"] = RenderPayloadFromWorkerEvent(event); } else if (envelope_type == "session.frame") { payload["render"] = RenderPayloadFromWorkerEvent(event); if (!event.raw_frame_bytes.empty() && common::GetString(payload, "frame_data").value_or("").empty()) { payload["frame_data"] = Base64Encode(event.raw_frame_bytes.data(), event.raw_frame_bytes.size()); } } else if (envelope_type == "file_upload.progress") { const double file_size = common::GetNumber(payload, "file_size").value_or(0); payload["received"] = file_size; payload["total"] = file_size; payload["status"] = "completed"; } else if (envelope_type.rfind("file_download.", 0) == 0) { payload["direction"] = "server_to_client"; } return common::JsonObject{ {"type", envelope_type}, {"session_id", event.session_id}, {"payload", payload}, }; } bool IsFrameEvent(const runtime::WorkerEvent& event) { return event.type == "session_frame"; } bool IsDirectAttachBaselineFrameEvent(const runtime::WorkerEvent& event) { return IsFrameEvent(event) && common::GetBool(event.payload, "direct_attach_baseline").value_or(false); } bool IsRegionFrameEvent(const runtime::WorkerEvent& event) { return IsFrameEvent(event) && common::GetString(event.payload, "frame_update_kind").value_or("full") == "region"; } bool IsFullFrameEvent(const runtime::WorkerEvent& event) { return IsFrameEvent(event) && common::GetString(event.payload, "frame_update_kind").value_or("full") == "full"; } bool IsCursorEvent(const runtime::WorkerEvent& event) { return event.type == "session_cursor_updated"; } std::chrono::milliseconds BinaryRenderFrameInterval(const runtime::WorkerEvent& event) { if (common::GetBool(event.payload, "interactive_frame").value_or(false)) { return kMinBinaryRegionRenderFrameInterval; } const auto update_kind = common::GetString(event.payload, "frame_update_kind").value_or("full"); if (update_kind == "region") { return kMinBinaryRegionRenderFrameInterval; } return kMinBinaryRenderFrameInterval; } bool IsMouseMoveEnvelope(const common::JsonObject& envelope) { if (common::GetString(envelope, "type").value_or("") != "input") { return false; } const auto* payload = common::GetObject(envelope, "payload"); return payload != nullptr && common::GetString(*payload, "kind").value_or("") == "mouse" && common::GetString(*payload, "action").value_or("") == "move"; } bool ChannelAllowed(const DataPlaneTokenClaims& claims, const std::string& envelope_type) { std::string channel = envelope_type; if (envelope_type == "file_upload") { channel = "file_upload"; } else if (envelope_type == "file_download") { channel = "file_download"; } else if (envelope_type == "heartbeat") { channel = "control"; } return std::find(claims.allowed_channels.begin(), claims.allowed_channels.end(), channel) != claims.allowed_channels.end(); } std::optional BuildBlockedEnvelopeForDisallowedChannel(const common::JsonObject& envelope, const DataPlaneTokenClaims& claims) { const auto type = common::GetString(envelope, "type").value_or(""); if (type != "file_download") { return std::nullopt; } const auto* payload = common::GetObject(envelope, "payload"); const std::string transfer_id = payload != nullptr ? common::GetString(*payload, "transfer_id").value_or("") : ""; const std::string file_id = payload != nullptr ? common::GetString(*payload, "file_id").value_or("") : ""; return common::SerializeJson(common::JsonObject{ {"type", "file_download.blocked"}, {"session_id", claims.session_id}, {"payload", common::JsonObject{ {"direction", "server_to_client"}, {"transfer_id", transfer_id}, {"file_id", file_id}, {"reason", "access denied"}, }}, }); } std::string BuildTakenOverEnvelopeForStaleAttachment(const DataPlaneTokenClaims& claims, const std::string& current_attachment_id) { return common::SerializeJson(common::JsonObject{ {"type", "session.taken_over"}, {"session_id", claims.session_id}, {"payload", common::JsonObject{ {"message", "controller binding changed"}, {"reason", "attachment_mismatch"}, {"attachment_id", claims.attachment_id}, {"current_attachment_id", current_attachment_id}, }}, {"event", common::JsonObject{ {"code", "session.taken_over"}, {"message_key", "events.session.taken_over"}, {"fallback_message", "This session was taken over from another device."}, {"details", common::JsonObject{ {"reason", "attachment_mismatch"}, }}, }}, }); } template bool WriteTextEnvelope(WebSocket& ws, std::mutex& websocket_io_mutex, const std::string& envelope, std::shared_ptr logger, const std::string& session_id, const std::string& reason) { beast::error_code write_ec; std::lock_guard write_lock(websocket_io_mutex); ws.text(true); ws.write(asio::buffer(envelope), write_ec); if (write_ec) { logger->Warn("direct data-plane envelope write failed session=" + session_id + " reason=" + reason + " error=" + write_ec.message()); return false; } return true; } std::optional NormalizeClientEnvelope(const std::string& message) { auto parsed = common::ParseJson(message); if (!parsed.IsObject()) { return std::nullopt; } common::JsonObject envelope = parsed.AsObject(); const auto type = common::GetString(envelope, "type").value_or(""); const auto* payload = common::GetObject(envelope, "payload"); if (type.empty()) { return std::nullopt; } if (type == "heartbeat" || type == "control.ping") { return common::JsonObject{{"type", "heartbeat"}}; } if (payload == nullptr) { return std::nullopt; } if (type == "input" || type == "clipboard" || type == "control") { return envelope; } if (type == "file_upload.start" || type == "file_upload.chunk" || type == "file_upload.cancel") { common::JsonObject normalized_payload = *payload; if (type == "file_upload.start") { normalized_payload["action"] = "start"; } else if (type == "file_upload.chunk") { normalized_payload["action"] = "chunk"; } else { normalized_payload["action"] = "cancel"; } return common::JsonObject{{"type", "file_upload"}, {"payload", normalized_payload}}; } if (type == "file_download.start" || type == "file_download.ack" || type == "file_download.cancel") { common::JsonObject normalized_payload = *payload; if (type == "file_download.start") { normalized_payload["action"] = "start"; } else if (type == "file_download.ack") { normalized_payload["action"] = "ack"; } else { normalized_payload["action"] = "cancel"; } return common::JsonObject{{"type", "file_download"}, {"payload", normalized_payload}}; } return envelope; } void AddBindingClaimsToEnvelope(common::JsonObject& envelope, const DataPlaneTokenClaims& claims) { auto* payload = common::GetObject(envelope, "payload"); if (payload == nullptr) { return; } common::JsonObject enriched = *payload; enriched["session_id"] = claims.session_id; enriched["attachment_id"] = claims.attachment_id; enriched["user_id"] = claims.user_id; enriched["organization_id"] = claims.organization_id; enriched["worker_id"] = claims.worker_id; enriched["resource_id"] = claims.resource_id; envelope["payload"] = std::move(enriched); } template class DirectWssEventSink final : public runtime::DirectEventSink { public: DirectWssEventSink(WebSocket& ws, std::mutex& websocket_io_mutex, std::shared_ptr logger, std::string session_id, std::string attachment_id, bool binary_render_enabled, std::string requested_color_mode, std::function repair_request) : ws_(ws), websocket_io_mutex_(websocket_io_mutex), logger_(std::move(logger)), session_id_(std::move(session_id)), attachment_id_(std::move(attachment_id)), binary_render_enabled_(binary_render_enabled), requested_color_mode_(NormalizeColorMode(requested_color_mode)), repair_request_(std::move(repair_request)) {} ~DirectWssEventSink() override { Stop(); } void Start() { writer_thread_ = std::thread(&DirectWssEventSink::WriterLoop, this); } std::string AttachmentId() const override { return attachment_id_; } void Stop() { { std::lock_guard lock(mutex_); stopped_ = true; reliable_events_.clear(); latest_cursor_.reset(); ordered_frame_events_.clear(); } condition_.notify_all(); if (writer_thread_.joinable()) { writer_thread_.join(); } } void EnqueueEvent(const runtime::WorkerEvent& event) override { { std::lock_guard lock(mutex_); if (stopped_) { return; } if (IsFrameEvent(event)) { if (IsDirectAttachBaselineFrameEvent(event)) { ordered_frame_events_.clear(); if (reliable_events_.size() >= kMaxReliableOutboundEvents) { reliable_events_.pop_front(); ++reliable_dropped_; } reliable_events_.push_back(event); ++frames_queued_; logger_->Info("direct data-plane baseline frame queued as non-droppable session=" + session_id_); condition_.notify_one(); return; } if (IsFullFrameEvent(event)) { if (!ordered_frame_events_.empty()) { frames_dropped_due_to_backpressure_ += ordered_frame_events_.size(); logger_->Info("render.region_queue direct writer cleared pending frames for full frame session=" + session_id_ + " cleared=" + std::to_string(ordered_frame_events_.size())); } ordered_frame_events_.clear(); ordered_frame_events_.push_back(event); } else if (IsRegionFrameEvent(event)) { if (ordered_frame_events_.size() >= kMaxOrderedRenderFrames) { const auto dropped = ordered_frame_events_.size() + 1; frames_dropped_due_to_backpressure_ += dropped; ordered_frame_events_.clear(); logger_->Warn("render.region_queue direct writer overflow session=" + session_id_ + " dropped=" + std::to_string(dropped) + " repair_requested=true"); if (repair_request_) { repair_request_("direct_writer_region_queue_overflow"); } } else { ordered_frame_events_.push_back(event); } } else { ordered_frame_events_.clear(); ordered_frame_events_.push_back(event); } ++frames_queued_; } else if (IsCursorEvent(event)) { if (latest_cursor_.has_value()) { ++cursor_updates_dropped_due_to_backpressure_; } latest_cursor_ = event; ++cursor_updates_queued_; } else { if (reliable_events_.size() >= kMaxReliableOutboundEvents) { reliable_events_.pop_front(); ++reliable_dropped_; } reliable_events_.push_back(event); } } condition_.notify_one(); } private: void WriterLoop() { while (true) { std::optional event; { std::unique_lock lock(mutex_); condition_.wait(lock, [&]() { return stopped_ || !reliable_events_.empty() || latest_cursor_.has_value() || !ordered_frame_events_.empty(); }); if (stopped_) { return; } if (!reliable_events_.empty()) { event = std::move(reliable_events_.front()); reliable_events_.pop_front(); } else if (latest_cursor_.has_value()) { event = std::move(*latest_cursor_); latest_cursor_.reset(); ++cursor_updates_sent_; } else if (!ordered_frame_events_.empty()) { if (binary_render_enabled_ && last_binary_frame_sent_at_.time_since_epoch().count() != 0) { const auto next_allowed_frame_at = last_binary_frame_sent_at_ + BinaryRenderFrameInterval(ordered_frame_events_.front()); const auto now = std::chrono::steady_clock::now(); if (now < next_allowed_frame_at) { condition_.wait_until(lock, next_allowed_frame_at, [&]() { return stopped_ || !reliable_events_.empty() || latest_cursor_.has_value(); }); continue; } } event = std::move(ordered_frame_events_.front()); ordered_frame_events_.pop_front(); ++frames_sent_; } } if (!event.has_value()) { continue; } const bool binary_frame = binary_render_enabled_ && IsFrameEvent(*event); const auto binary_payload = binary_frame ? WorkerEventToBinaryRenderFrame(*event, requested_color_mode_, !baseline_binary_frame_sent_) : std::nullopt; const std::string encoded = binary_payload.has_value() ? binary_payload->message : common::SerializeJson(WorkerEventToEnvelope(*event)); beast::error_code ec; { std::lock_guard write_lock(websocket_io_mutex_); if (binary_payload.has_value()) { ws_.binary(true); } else { ws_.text(true); } ws_.write(asio::buffer(encoded), ec); } if (ec) { logger_->Warn("direct data-plane write failed session=" + session_id_ + " frame=" + (binary_payload.has_value() ? std::string("true") : std::string("false")) + " reason=" + ec.message()); { std::lock_guard lock(mutex_); stopped_ = true; reliable_events_.clear(); latest_cursor_.reset(); ordered_frame_events_.clear(); } condition_.notify_all(); return; } if (binary_payload.has_value()) { last_binary_frame_sent_at_ = std::chrono::steady_clock::now(); if (binary_payload->update_kind == "full") { baseline_binary_frame_sent_ = true; } binary_render_bytes_sent_ += encoded.size(); logger_->Info("direct data-plane render frame prepared session=" + session_id_ + " requested_color_mode=" + requested_color_mode_ + " applied_color_mode=" + binary_payload->color_mode + " message_type=" + binary_payload->message_type + " frame_update_kind=" + binary_payload->update_kind + " full_frame_sent=" + (binary_payload->update_kind == "full" ? std::string("true") : std::string("false")) + " region_frame_sent=" + (binary_payload->update_kind == "region" ? std::string("true") : std::string("false")) + " dirty_region_applied=" + (binary_payload->dirty_region_applied ? std::string("true") : std::string("false")) + " grayscale_conversion_applied=" + (binary_payload->grayscale_conversion_applied ? std::string("true") : std::string("false")) + " raw_frame_bytes_before=" + std::to_string(binary_payload->raw_frame_bytes_before) + " raw_frame_bytes_after=" + std::to_string(binary_payload->raw_frame_bytes_after) + " full_frame_bytes=" + std::to_string(binary_payload->full_frame_bytes) + " region_bytes=" + std::to_string(binary_payload->region_bytes) + " region_savings_percent=" + std::to_string(binary_payload->region_savings_percent) + " binary_direct_bytes=" + std::to_string(binary_payload->binary_direct_bytes) + " diff_time_ms=" + std::to_string(binary_payload->diff_time_ms) + " render_update_reason=" + binary_payload->render_update_reason + " fallback_to_full_frame_reason=" + binary_payload->fallback_to_full_frame_reason + " conversion_time_ms=" + std::to_string(binary_payload->conversion_time_ms)); } else if (IsFrameEvent(*event)) { json_render_bytes_sent_ += encoded.size(); } LogRateIfNeeded(); } } void LogRateIfNeeded() { const auto now = std::chrono::steady_clock::now(); if (last_rate_log_at_.time_since_epoch().count() == 0) { last_rate_log_at_ = now; return; } if (now - last_rate_log_at_ < std::chrono::seconds(2)) { return; } const double seconds = std::max(0.001, std::chrono::duration(now - last_rate_log_at_).count()); logger_->Info("direct data-plane outbound rate session=" + session_id_ + " frames_queued_per_second=" + std::to_string(static_cast(frames_queued_) / seconds) + " frames_sent_per_second=" + std::to_string(static_cast(frames_sent_) / seconds) + " cursor_queued_per_second=" + std::to_string(static_cast(cursor_updates_queued_) / seconds) + " cursor_sent_per_second=" + std::to_string(static_cast(cursor_updates_sent_) / seconds) + " binary_render_bytes_per_second=" + std::to_string(static_cast(binary_render_bytes_sent_) / seconds) + " json_render_bytes_per_second=" + std::to_string(static_cast(json_render_bytes_sent_) / seconds) + " frames_dropped_due_to_backpressure=" + std::to_string(frames_dropped_due_to_backpressure_) + " cursor_dropped_due_to_backpressure=" + std::to_string(cursor_updates_dropped_due_to_backpressure_) + " reliable_dropped=" + std::to_string(reliable_dropped_)); frames_queued_ = 0; frames_sent_ = 0; cursor_updates_queued_ = 0; cursor_updates_sent_ = 0; binary_render_bytes_sent_ = 0; json_render_bytes_sent_ = 0; frames_dropped_due_to_backpressure_ = 0; cursor_updates_dropped_due_to_backpressure_ = 0; reliable_dropped_ = 0; last_rate_log_at_ = now; } WebSocket& ws_; std::mutex& websocket_io_mutex_; std::shared_ptr logger_; std::string session_id_; std::string attachment_id_; std::string requested_color_mode_; std::function repair_request_; std::mutex mutex_; std::condition_variable condition_; std::deque reliable_events_; std::optional latest_cursor_; std::deque ordered_frame_events_; std::thread writer_thread_; bool stopped_{false}; bool binary_render_enabled_{false}; bool baseline_binary_frame_sent_{false}; std::chrono::steady_clock::time_point last_rate_log_at_{}; std::chrono::steady_clock::time_point last_binary_frame_sent_at_{}; std::size_t frames_queued_{0}; std::size_t frames_sent_{0}; std::size_t cursor_updates_queued_{0}; std::size_t cursor_updates_sent_{0}; std::size_t binary_render_bytes_sent_{0}; std::size_t json_render_bytes_sent_{0}; std::size_t frames_dropped_due_to_backpressure_{0}; std::size_t cursor_updates_dropped_due_to_backpressure_{0}; std::size_t reliable_dropped_{0}; }; } // namespace DirectWssServer::DirectWssServer(config::Config config, std::shared_ptr session_manager, std::shared_ptr logger) : config_(std::move(config)), session_manager_(std::move(session_manager)), logger_(std::move(logger)), token_validator_(config_.data_plane_public_key_pem, config_.worker_id), ssl_context_(asio::ssl::context::tlsv12_server) {} DirectWssServer::~DirectWssServer() { Stop(); } void DirectWssServer::Start() { if (!config_.data_plane_enabled) { logger_->Info("direct data-plane WSS endpoint disabled"); return; } if (config_.data_plane_public_key_pem.empty() || config_.data_plane_tls_cert_file.empty() || config_.data_plane_tls_key_file.empty()) { logger_->Warn("direct data-plane WSS endpoint not started because token public key or TLS certificate/key is missing"); return; } ssl_context_.set_options(asio::ssl::context::default_workarounds | asio::ssl::context::no_sslv2 | asio::ssl::context::no_sslv3 | asio::ssl::context::single_dh_use); ssl_context_.use_certificate_chain_file(config_.data_plane_tls_cert_file); ssl_context_.use_private_key_file(config_.data_plane_tls_key_file, asio::ssl::context::pem); thread_ = std::thread(&DirectWssServer::Run, this); } void DirectWssServer::Stop() { stop_requested_.store(true); if (acceptor_.has_value()) { beast::error_code ignored; acceptor_->close(ignored); } io_context_.stop(); if (thread_.joinable()) { thread_.join(); } } void DirectWssServer::Run() { try { const auto address = asio::ip::make_address(config_.data_plane_listen_host); const tcp::endpoint endpoint{address, static_cast(config_.data_plane_listen_port)}; acceptor_.emplace(io_context_); acceptor_->open(endpoint.protocol()); acceptor_->set_option(asio::socket_base::reuse_address(true)); acceptor_->bind(endpoint); acceptor_->listen(asio::socket_base::max_listen_connections); logger_->Info("direct data-plane WSS endpoint listening host=" + config_.data_plane_listen_host + " port=" + std::to_string(config_.data_plane_listen_port) + " path=/rap/v1/data-plane"); while (!stop_requested_.load()) { beast::error_code ec; tcp::socket socket{io_context_}; acceptor_->accept(socket, ec); if (ec) { if (!stop_requested_.load()) { logger_->Warn("direct data-plane accept failed reason=" + ec.message()); } continue; } std::thread(&DirectWssServer::HandleConnection, this, std::move(socket)).detach(); } } catch (const std::exception& error) { logger_->Error(std::string("direct data-plane WSS endpoint stopped reason=") + error.what()); } } void DirectWssServer::HandleConnection(tcp::socket socket) { beast::error_code ec; asio::ssl::stream tls_stream{std::move(socket), ssl_context_}; tls_stream.handshake(asio::ssl::stream_base::server, ec); if (ec) { logger_->Warn("direct data-plane TLS handshake failed reason=" + ec.message()); return; } beast::flat_buffer buffer; http::request request; http::read(tls_stream, buffer, request, ec); if (ec) { logger_->Warn("direct data-plane HTTP upgrade read failed reason=" + ec.message()); return; } const std::string target = std::string(request.target()); if (PathOnly(target) != "/rap/v1/data-plane") { WriteReject(tls_stream, http::status::not_found, "unknown_path"); return; } if (!websocket::is_upgrade(request)) { WriteReject(tls_stream, http::status::upgrade_required, "websocket_required"); return; } const auto token = QueryParam(target, "data_plane_token"); const bool binary_render_requested = QueryParam(target, "render_transport") == kBinaryRenderTransportV1; const std::string requested_color_mode = binary_render_requested ? NormalizeColorMode(QueryParam(target, "color_mode")) : std::string(kColorModeFullColor); const auto validation = token_validator_.Validate(token); if (!validation.ok) { logger_->Warn("event=token_validation_failed reason=" + validation.reason); WriteReject(tls_stream, http::status::unauthorized, validation.reason); return; } logger_->Info("event=token_validation_success session=" + validation.claims.session_id + " attachment=" + validation.claims.attachment_id + " worker=" + validation.claims.worker_id + " jti=" + validation.claims.jti); if (!ConsumeJti(validation.claims)) { logger_->Warn("event=jti_replay_rejected session=" + validation.claims.session_id + " attachment=" + validation.claims.attachment_id + " jti=" + validation.claims.jti); WriteReject(tls_stream, http::status::unauthorized, "jti_replay_rejected"); return; } std::string bind_reason; auto runtime = session_manager_->BindDirectDataPlaneRuntime(validation.claims, bind_reason); if (!runtime) { logger_->Warn("event=data_plane_bind_failed session=" + validation.claims.session_id + " attachment=" + validation.claims.attachment_id + " reason=" + bind_reason); WriteReject(tls_stream, http::status::not_found, bind_reason); return; } websocket::stream> ws{std::move(tls_stream)}; ws.accept(request, ec); if (ec) { logger_->Warn("direct data-plane WebSocket accept failed session=" + validation.claims.session_id + " reason=" + ec.message()); return; } logger_->Info("event=data_plane_bind_success session=" + validation.claims.session_id + " attachment=" + validation.claims.attachment_id + " channels=" + ChannelsToJsonArray(validation.claims.allowed_channels) + " render_transport=" + (binary_render_requested ? std::string("binary_v1") : std::string("json_base64")) + " requested_color_mode=" + requested_color_mode + " applied_color_mode=" + requested_color_mode); ws.text(true); ws.write(asio::buffer(rdp_worker::common::SerializeJson(rdp_worker::common::JsonObject{ {"type", "data_plane.attached"}, {"session_id", validation.claims.session_id}, {"attachment_id", validation.claims.attachment_id}, {"worker_id", validation.claims.worker_id}, {"render_transport", binary_render_requested ? "binary_v1" : "json_base64"}, {"color_mode", requested_color_mode}, })), ec); if (ec) { return; } std::mutex websocket_io_mutex; auto sink = std::make_shared>>>( ws, websocket_io_mutex, logger_, validation.claims.session_id, validation.claims.attachment_id, binary_render_requested, requested_color_mode, [runtime](std::string reason) { runtime->RequestDirectFullFrameRepair(std::move(reason)); }); sink->Start(); runtime->AddDirectEventSink(sink); while (!stop_requested_.load()) { beast::error_code available_ec; const auto available = beast::get_lowest_layer(ws).available(available_ec); if (available_ec) { logger_->Warn("direct data-plane socket availability check failed session=" + validation.claims.session_id + " reason=" + available_ec.message()); break; } if (available == 0) { std::this_thread::sleep_for(std::chrono::milliseconds(5)); continue; } beast::flat_buffer message; { std::lock_guard read_lock(websocket_io_mutex); ws.read(message, ec); } if (ec == websocket::error::closed) { break; } if (ec) { logger_->Warn("direct data-plane WebSocket read failed session=" + validation.claims.session_id + " reason=" + ec.message()); break; } const auto payload = beast::buffers_to_string(message.data()); auto envelope = NormalizeClientEnvelope(payload); if (!envelope.has_value()) { logger_->Warn("direct data-plane ignored malformed client envelope session=" + validation.claims.session_id); continue; } const auto type = common::GetString(*envelope, "type").value_or(""); if (type == "heartbeat") { continue; } if (!ChannelAllowed(validation.claims, type)) { logger_->Warn("direct data-plane envelope rejected by token channel scope session=" + validation.claims.session_id + " type=" + type); if (auto blocked = BuildBlockedEnvelopeForDisallowedChannel(*envelope, validation.claims); blocked.has_value()) { if (!WriteTextEnvelope(ws, websocket_io_mutex, *blocked, logger_, validation.claims.session_id, "channel_scope_blocked")) { break; } } continue; } const auto snapshot = runtime->Snapshot(); if (snapshot.attachment_id != validation.claims.attachment_id) { logger_->Warn("direct data-plane envelope rejected because attachment is no longer current session=" + validation.claims.session_id + " envelope_attachment=" + validation.claims.attachment_id + " current_attachment=" + snapshot.attachment_id + " type=" + type); const auto taken_over = BuildTakenOverEnvelopeForStaleAttachment(validation.claims, snapshot.attachment_id); if (!WriteTextEnvelope(ws, websocket_io_mutex, taken_over, logger_, validation.claims.session_id, "attachment_mismatch")) { break; } continue; } AddBindingClaimsToEnvelope(*envelope, validation.claims); const bool mouse_move = IsMouseMoveEnvelope(*envelope); if (!runtime->EnqueueDirectEnvelope(std::move(*envelope))) { logger_->Warn("direct data-plane inbound queue full session=" + validation.claims.session_id + " type=" + type + " mouse_move=" + (mouse_move ? std::string("true") : std::string("false"))); } } sink->Stop(); logger_->Info("direct data-plane WebSocket detached session=" + validation.claims.session_id + " attachment=" + validation.claims.attachment_id); } bool DirectWssServer::ConsumeJti(const DataPlaneTokenClaims& claims) { std::lock_guard lock(jti_mutex_); const auto now = std::chrono::system_clock::now(); for (auto iterator = used_jti_.begin(); iterator != used_jti_.end();) { if (iterator->second <= now) { iterator = used_jti_.erase(iterator); } else { ++iterator; } } if (used_jti_.find(claims.jti) != used_jti_.end()) { return false; } if (used_jti_.size() >= 4096) { used_jti_.erase(used_jti_.begin()); } const auto expires_at = std::chrono::system_clock::time_point(std::chrono::seconds(claims.expires_at_unix)); used_jti_.emplace(claims.jti, expires_at + std::chrono::seconds(5)); return true; } } // namespace rdp_worker::dataplane