Files
rdp-proxy/workers/rdp-worker/src/dataplane/direct_wss_server.cpp
T
2026-04-28 22:29:50 +03:00

1303 lines
56 KiB
C++

#include "rdp_worker/dataplane/direct_wss_server.hpp"
#include <algorithm>
#include <array>
#include <chrono>
#include <condition_variable>
#include <cstdint>
#include <deque>
#include <functional>
#include <limits>
#include <sstream>
#include <stdexcept>
#include <string_view>
#include <thread>
#include <cstdlib>
#include <utility>
#include <boost/asio/buffer.hpp>
#include <boost/asio/ip/address.hpp>
#include <boost/beast/core.hpp>
#include <boost/beast/http.hpp>
#include <boost/beast/websocket.hpp>
#include <boost/beast/websocket/ssl.hpp>
#include "rdp_worker/common/json.hpp"
namespace rdp_worker::dataplane {
namespace asio = boost::asio;
namespace beast = boost::beast;
namespace http = beast::http;
namespace websocket = beast::websocket;
using tcp = asio::ip::tcp;
namespace {
constexpr std::size_t kMaxReliableOutboundEvents = 256;
constexpr std::size_t kMaxOrderedRenderFrames = 256;
constexpr auto kMinBinaryRenderFrameInterval = std::chrono::milliseconds(100);
constexpr auto kMinBinaryRegionRenderFrameInterval = std::chrono::milliseconds(33);
constexpr std::string_view kBinaryRenderTransportV1 = "binary_v1";
constexpr std::string_view kColorModeFullColor = "full_color";
constexpr std::string_view kColorModeGrayscale = "grayscale";
constexpr std::string_view kRenderFrameFullMessageType = "render.frame.full";
constexpr std::string_view kRenderFrameRegionMessageType = "render.frame.region";
struct BinaryRenderFrameMessage {
std::string message;
std::string message_type;
std::string color_mode;
std::string update_kind;
bool grayscale_conversion_applied{false};
bool dirty_region_applied{false};
std::size_t raw_frame_bytes_before{0};
std::size_t raw_frame_bytes_after{0};
std::size_t binary_direct_bytes{0};
std::size_t full_frame_bytes{0};
std::size_t region_bytes{0};
double region_savings_percent{0.0};
std::int64_t conversion_time_ms{0};
std::int64_t diff_time_ms{0};
std::string render_update_reason;
std::string fallback_to_full_frame_reason;
};
std::int64_t UnixMillisecondsNow() {
return std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
}
void AppendUint16Le(std::string& output, std::uint16_t value) {
output.push_back(static_cast<char>(value & 0xFF));
output.push_back(static_cast<char>((value >> 8) & 0xFF));
}
void AppendUint32Le(std::string& output, std::uint32_t value) {
output.push_back(static_cast<char>(value & 0xFF));
output.push_back(static_cast<char>((value >> 8) & 0xFF));
output.push_back(static_cast<char>((value >> 16) & 0xFF));
output.push_back(static_cast<char>((value >> 24) & 0xFF));
}
std::string Base64Encode(const std::uint8_t* data, std::size_t size) {
static constexpr std::array<char, 64> alphabet{
'A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P',
'Q','R','S','T','U','V','W','X','Y','Z','a','b','c','d','e','f',
'g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v',
'w','x','y','z','0','1','2','3','4','5','6','7','8','9','+','/'
};
std::string output;
output.reserve(((size + 2) / 3) * 4);
for (std::size_t index = 0; index < size; ) {
const std::size_t remaining = size - index;
const std::uint32_t octet_a = data[index++];
const std::uint32_t octet_b = remaining > 1 ? data[index++] : 0;
const std::uint32_t octet_c = remaining > 2 ? data[index++] : 0;
const std::uint32_t triple = (octet_a << 16) | (octet_b << 8) | octet_c;
output.push_back(alphabet[(triple >> 18) & 0x3F]);
output.push_back(alphabet[(triple >> 12) & 0x3F]);
output.push_back(remaining > 1 ? alphabet[(triple >> 6) & 0x3F] : '=');
output.push_back(remaining > 2 ? alphabet[triple & 0x3F] : '=');
}
return output;
}
std::optional<std::string> DecodeBase64ToBytes(const std::string& input) {
static constexpr unsigned char kInvalid = 255;
static unsigned char table[256];
static bool initialized = false;
if (!initialized) {
std::fill(std::begin(table), std::end(table), kInvalid);
const std::string alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
for (std::size_t i = 0; i < alphabet.size(); ++i) {
table[static_cast<unsigned char>(alphabet[i])] = static_cast<unsigned char>(i);
}
initialized = true;
}
std::string output;
output.reserve((input.size() * 3) / 4);
int value = 0;
int bits = -8;
for (unsigned char ch : input) {
if (ch == '=') {
break;
}
if (table[ch] == kInvalid) {
return std::nullopt;
}
value = (value << 6) + table[ch];
bits += 6;
if (bits >= 0) {
output.push_back(static_cast<char>((value >> bits) & 0xFF));
bits -= 8;
}
}
return output;
}
std::string UrlDecode(std::string_view value) {
std::string output;
output.reserve(value.size());
for (std::size_t i = 0; i < value.size(); ++i) {
if (value[i] == '%' && i + 2 < value.size()) {
const std::string hex{value[i + 1], value[i + 2]};
char* end = nullptr;
const long decoded = std::strtol(hex.c_str(), &end, 16);
if (end != nullptr && *end == '\0') {
output.push_back(static_cast<char>(decoded));
i += 2;
continue;
}
}
output.push_back(value[i] == '+' ? ' ' : value[i]);
}
return output;
}
std::string QueryParam(std::string_view target, std::string_view name) {
const auto question = target.find('?');
if (question == std::string_view::npos) {
return {};
}
std::string_view query = target.substr(question + 1);
while (!query.empty()) {
const auto amp = query.find('&');
const auto pair = query.substr(0, amp);
const auto equals = pair.find('=');
if (equals != std::string_view::npos && pair.substr(0, equals) == name) {
return UrlDecode(pair.substr(equals + 1));
}
if (amp == std::string_view::npos) {
break;
}
query = query.substr(amp + 1);
}
return {};
}
std::string PathOnly(std::string_view target) {
const auto question = target.find('?');
return std::string(target.substr(0, question == std::string_view::npos ? target.size() : question));
}
std::string NormalizeColorMode(const std::string& requested) {
return requested == kColorModeGrayscale ? std::string(kColorModeGrayscale) : std::string(kColorModeFullColor);
}
void ApplyGrayscaleBgra(std::string& frame_bytes) {
for (std::size_t index = 0; index + 3 < frame_bytes.size(); index += 4) {
const auto blue = static_cast<unsigned char>(frame_bytes[index]);
const auto green = static_cast<unsigned char>(frame_bytes[index + 1]);
const auto red = static_cast<unsigned char>(frame_bytes[index + 2]);
const auto gray = static_cast<char>(((static_cast<unsigned int>(red) * 77U) +
(static_cast<unsigned int>(green) * 150U) +
(static_cast<unsigned int>(blue) * 29U)) >> 8U);
frame_bytes[index] = gray;
frame_bytes[index + 1] = gray;
frame_bytes[index + 2] = gray;
}
}
int JsonIntOr(const common::JsonObject& payload, const std::string& key, int fallback) {
return static_cast<int>(common::GetNumber(payload, key).value_or(fallback));
}
bool TryCropBgraRegion(const std::string& source,
int full_width,
int full_height,
int full_stride,
int region_x,
int region_y,
int region_width,
int region_height,
std::string& output) {
if (full_width <= 0 || full_height <= 0 || full_stride <= 0 ||
region_x < 0 || region_y < 0 || region_width <= 0 || region_height <= 0 ||
region_x + region_width > full_width || region_y + region_height > full_height) {
return false;
}
const int region_stride = region_width * 4;
const auto required_source_bytes = static_cast<std::size_t>(full_stride) * static_cast<std::size_t>(full_height);
if (source.size() < required_source_bytes || full_stride < region_x * 4 + region_stride) {
return false;
}
output.assign(static_cast<std::size_t>(region_stride) * static_cast<std::size_t>(region_height), '\0');
for (int row = 0; row < region_height; ++row) {
const auto source_offset = static_cast<std::size_t>(region_y + row) * static_cast<std::size_t>(full_stride) +
static_cast<std::size_t>(region_x * 4);
const auto target_offset = static_cast<std::size_t>(row) * static_cast<std::size_t>(region_stride);
std::copy_n(source.data() + source_offset, region_stride, output.data() + target_offset);
}
return true;
}
template <typename Stream>
void WriteReject(Stream& stream, http::status status, const std::string& reason) {
http::response<http::string_body> response{status, 11};
response.set(http::field::server, "rap-rdp-worker");
response.set(http::field::content_type, "application/json");
response.keep_alive(false);
response.body() = rdp_worker::common::SerializeJson(rdp_worker::common::JsonObject{
{"error", rdp_worker::common::JsonObject{
{"code", reason},
{"message_key", "errors.data_plane." + reason},
{"fallback_message", "Data plane connection rejected."},
}},
});
response.prepare_payload();
beast::error_code ignored;
http::write(stream, response, ignored);
}
std::string ChannelsToJsonArray(const std::vector<std::string>& channels) {
rdp_worker::common::JsonArray array;
for (const auto& channel : channels) {
array.emplace_back(channel);
}
return rdp_worker::common::SerializeJson(array);
}
std::string WorkerEventTypeToEnvelopeType(const runtime::WorkerEvent& event) {
if (event.type == "session_frame") {
return "session.frame";
}
if (event.type == "session_cursor_updated") {
return "cursor.update";
}
if (event.type == "session_clipboard_text") {
return "clipboard.text";
}
if (event.type == "session_taken_over") {
return "session.taken_over";
}
if (event.type == "session_file_upload_completed") {
return "file_upload.progress";
}
if (event.type == "session_file_download_available") {
return "file_download.available";
}
if (event.type == "session_file_download_chunk") {
return "file_download.chunk";
}
if (event.type == "session_file_download_completed") {
return "file_download.completed";
}
if (event.type == "session_file_download_failed") {
return "file_download.failed";
}
if (event.type == "session_file_download_blocked") {
return "file_download.blocked";
}
if (event.type == "session_file_download_progress") {
return "file_download.progress";
}
if (event.type == "session_failed" || event.type == "session_terminated" ||
event.type == "session_connected" || event.type == "session_display_ready" ||
event.type == "session_heartbeat") {
return "session.state";
}
return event.type;
}
common::JsonObject RenderPayloadFromWorkerEvent(const runtime::WorkerEvent& event) {
common::JsonObject render{
{"width", event.width},
{"height", event.height},
};
if (auto value = common::GetString(event.payload, "render_quality_profile"); value.has_value()) {
render["quality_profile"] = *value;
render["render_quality_profile"] = *value;
}
if (auto value = common::GetString(event.payload, "render_state"); value.has_value()) {
render["state"] = *value;
render["render_state"] = *value;
}
if (auto value = common::GetNumber(event.payload, "frame_sequence"); value.has_value()) {
render["frame_sequence"] = *value;
}
if (auto value = common::GetString(event.payload, "frame_format"); value.has_value()) {
render["frame_format"] = *value;
}
if (auto value = common::GetNumber(event.payload, "cursor_x"); value.has_value()) {
render["cursor_x"] = *value;
}
if (auto value = common::GetNumber(event.payload, "cursor_y"); value.has_value()) {
render["cursor_y"] = *value;
}
if (auto value = common::GetBool(event.payload, "cursor_visible"); value.has_value()) {
render["cursor_visible"] = *value;
}
return render;
}
std::optional<BinaryRenderFrameMessage> WorkerEventToBinaryRenderFrame(const runtime::WorkerEvent& event,
const std::string& requested_color_mode,
bool require_full_frame) {
if (event.type != "session_frame") {
return std::nullopt;
}
std::string frame_bytes;
if (!event.raw_frame_bytes.empty()) {
frame_bytes.assign(reinterpret_cast<const char*>(event.raw_frame_bytes.data()), event.raw_frame_bytes.size());
} else {
const auto encoded_frame = common::GetString(event.payload, "frame_data").value_or("");
if (encoded_frame.empty()) {
return std::nullopt;
}
auto decoded = DecodeBase64ToBytes(encoded_frame);
if (!decoded.has_value()) {
return std::nullopt;
}
frame_bytes = std::move(*decoded);
}
const std::size_t raw_frame_bytes_before = frame_bytes.size();
const auto frame_sequence = static_cast<std::int64_t>(common::GetNumber(event.payload, "frame_sequence").value_or(0));
const auto source_frame_width = JsonIntOr(event.payload, "frame_width", event.width);
const auto source_frame_height = JsonIntOr(event.payload, "frame_height", event.height);
const auto source_frame_stride = JsonIntOr(event.payload, "frame_stride", source_frame_width * 4);
const auto desktop_width = JsonIntOr(event.payload, "desktop_width", source_frame_width);
const auto desktop_height = JsonIntOr(event.payload, "desktop_height", source_frame_height);
const auto frame_format = common::GetString(event.payload, "frame_format").value_or("bgra32");
const auto quality_profile = common::GetString(event.payload, "render_quality_profile").value_or("balanced");
const auto input_correlation_id = common::GetString(event.payload, "input_correlation_id").value_or("");
const auto worker_frame_captured_at = common::GetString(event.payload, "worker_frame_captured_at").value_or("");
const auto requested_update_kind = common::GetString(event.payload, "frame_update_kind").value_or("full");
const auto render_update_reason = common::GetString(event.payload, "capture_source").value_or("");
const int region_x = JsonIntOr(event.payload, "region_x", 0);
const int region_y = JsonIntOr(event.payload, "region_y", 0);
const int region_width = JsonIntOr(event.payload, "region_width", 0);
const int region_height = JsonIntOr(event.payload, "region_height", 0);
const bool frame_data_is_region = common::GetBool(event.payload, "frame_data_is_region").value_or(false);
const auto full_frame_bytes = static_cast<std::size_t>(std::max(0, desktop_width)) *
static_cast<std::size_t>(std::max(0, desktop_height)) * 4U;
const auto diff_time_ms = static_cast<std::int64_t>(
common::GetNumber(event.payload, "periodic_change_poll_ms")
.value_or(common::GetNumber(event.payload, "capture_total_ms").value_or(0)));
std::string update_kind = "full";
std::string fallback_to_full_frame_reason;
int output_frame_width = source_frame_width;
int output_frame_height = source_frame_height;
int output_frame_stride = source_frame_stride;
bool dirty_region_applied = false;
if (require_full_frame && requested_update_kind == "region") {
fallback_to_full_frame_reason = "baseline_required";
} else if (!require_full_frame && requested_update_kind == "region") {
if (frame_data_is_region &&
region_width > 0 && region_height > 0 &&
frame_bytes.size() == static_cast<std::size_t>(region_width) * static_cast<std::size_t>(region_height) * 4U) {
output_frame_width = region_width;
output_frame_height = region_height;
output_frame_stride = region_width * 4;
update_kind = "region";
dirty_region_applied = true;
} else {
std::string region_bytes;
if (TryCropBgraRegion(frame_bytes,
source_frame_width,
source_frame_height,
source_frame_stride,
region_x,
region_y,
region_width,
region_height,
region_bytes)) {
frame_bytes = std::move(region_bytes);
output_frame_width = region_width;
output_frame_height = region_height;
output_frame_stride = region_width * 4;
update_kind = "region";
dirty_region_applied = true;
}
}
if (!dirty_region_applied) {
fallback_to_full_frame_reason = "invalid_region_payload";
}
}
const auto conversion_started = std::chrono::steady_clock::now();
const std::string applied_color_mode = NormalizeColorMode(requested_color_mode);
const bool grayscale = applied_color_mode == kColorModeGrayscale;
if (grayscale) {
ApplyGrayscaleBgra(frame_bytes);
}
const auto conversion_finished = std::chrono::steady_clock::now();
const auto conversion_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
conversion_finished - conversion_started)
.count();
const auto message_type = update_kind == "region"
? std::string(kRenderFrameRegionMessageType)
: std::string(kRenderFrameFullMessageType);
const std::size_t region_bytes = update_kind == "region" ? frame_bytes.size() : 0;
const double region_savings_percent = update_kind == "region" && full_frame_bytes > 0
? 100.0 - ((static_cast<double>(region_bytes) * 100.0) / static_cast<double>(full_frame_bytes))
: 0.0;
common::JsonObject header{
{"protocol_version", 1},
{"session_id", event.session_id},
{"channel", "render"},
{"message_type", message_type},
{"legacy_message_type", "session.frame"},
{"sequence", static_cast<double>(frame_sequence)},
{"timestamp", static_cast<double>(UnixMillisecondsNow())},
{"flags", 0},
{"payload_length", static_cast<double>(frame_bytes.size())},
{"frame_width", output_frame_width},
{"frame_height", output_frame_height},
{"frame_stride", output_frame_stride},
{"frame_format", frame_format},
{"frame_update_kind", update_kind},
{"desktop_width", desktop_width},
{"desktop_height", desktop_height},
{"region_x", dirty_region_applied ? region_x : 0},
{"region_y", dirty_region_applied ? region_y : 0},
{"region_width", dirty_region_applied ? region_width : output_frame_width},
{"region_height", dirty_region_applied ? region_height : output_frame_height},
{"region_stride", dirty_region_applied ? output_frame_stride : output_frame_stride},
{"region_format", "BGRA32"},
{"color_mode", applied_color_mode},
{"quality_profile", quality_profile},
{"original_frame_format", frame_format},
{"output_frame_format", frame_format},
{"raw_frame_bytes", static_cast<double>(raw_frame_bytes_before)},
{"binary_direct_bytes", static_cast<double>(frame_bytes.size())},
{"full_frame_bytes", static_cast<double>(full_frame_bytes)},
{"region_bytes", static_cast<double>(region_bytes)},
{"region_savings_percent", region_savings_percent},
{"diff_time_ms", static_cast<double>(diff_time_ms)},
{"render_update_reason", render_update_reason},
{"full_frame_sent", update_kind == "full"},
{"region_frame_sent", update_kind == "region"},
};
if (!fallback_to_full_frame_reason.empty()) {
header["fallback_to_full_frame_reason"] = fallback_to_full_frame_reason;
}
if (!input_correlation_id.empty()) {
header["input_correlation_id"] = input_correlation_id;
}
if (!worker_frame_captured_at.empty()) {
header["worker_frame_captured_at"] = worker_frame_captured_at;
}
const std::string header_json = common::SerializeJson(header);
if (header_json.size() > std::numeric_limits<std::uint32_t>::max() ||
frame_bytes.size() > std::numeric_limits<std::uint32_t>::max()) {
return std::nullopt;
}
std::string message;
message.reserve(16 + header_json.size() + frame_bytes.size());
message.append("RAP2", 4);
AppendUint16Le(message, 1);
AppendUint16Le(message, 0);
AppendUint32Le(message, static_cast<std::uint32_t>(header_json.size()));
AppendUint32Le(message, static_cast<std::uint32_t>(frame_bytes.size()));
message.append(header_json);
message.append(frame_bytes);
return BinaryRenderFrameMessage{
std::move(message),
message_type,
applied_color_mode,
update_kind,
grayscale,
dirty_region_applied,
raw_frame_bytes_before,
frame_bytes.size(),
frame_bytes.size(),
full_frame_bytes,
region_bytes,
region_savings_percent,
conversion_time_ms,
diff_time_ms,
render_update_reason,
fallback_to_full_frame_reason,
};
}
common::JsonObject WorkerEventToEnvelope(const runtime::WorkerEvent& event) {
common::JsonObject payload = event.payload;
if (!event.reason.empty()) {
payload["reason"] = event.reason;
}
const std::string envelope_type = WorkerEventTypeToEnvelopeType(event);
if (envelope_type == "session.taken_over") {
payload["state"] = "taken_over";
} else if (envelope_type == "session.state") {
std::string state = "active";
if (event.type == "session_failed") {
state = "failed";
} else if (event.type == "session_terminated") {
state = "terminated";
} else if (auto value = common::GetString(payload, "state"); value.has_value() && !value->empty()) {
state = *value;
}
payload["state"] = state;
payload["render"] = RenderPayloadFromWorkerEvent(event);
} else if (envelope_type == "session.frame") {
payload["render"] = RenderPayloadFromWorkerEvent(event);
if (!event.raw_frame_bytes.empty() && common::GetString(payload, "frame_data").value_or("").empty()) {
payload["frame_data"] = Base64Encode(event.raw_frame_bytes.data(), event.raw_frame_bytes.size());
}
} else if (envelope_type == "file_upload.progress") {
const double file_size = common::GetNumber(payload, "file_size").value_or(0);
payload["received"] = file_size;
payload["total"] = file_size;
payload["status"] = "completed";
} else if (envelope_type.rfind("file_download.", 0) == 0) {
payload["direction"] = "server_to_client";
}
return common::JsonObject{
{"type", envelope_type},
{"session_id", event.session_id},
{"payload", payload},
};
}
bool IsFrameEvent(const runtime::WorkerEvent& event) {
return event.type == "session_frame";
}
bool IsDirectAttachBaselineFrameEvent(const runtime::WorkerEvent& event) {
return IsFrameEvent(event) &&
common::GetBool(event.payload, "direct_attach_baseline").value_or(false);
}
bool IsRegionFrameEvent(const runtime::WorkerEvent& event) {
return IsFrameEvent(event) &&
common::GetString(event.payload, "frame_update_kind").value_or("full") == "region";
}
bool IsFullFrameEvent(const runtime::WorkerEvent& event) {
return IsFrameEvent(event) &&
common::GetString(event.payload, "frame_update_kind").value_or("full") == "full";
}
bool IsCursorEvent(const runtime::WorkerEvent& event) {
return event.type == "session_cursor_updated";
}
std::chrono::milliseconds BinaryRenderFrameInterval(const runtime::WorkerEvent& event) {
if (common::GetBool(event.payload, "interactive_frame").value_or(false)) {
return kMinBinaryRegionRenderFrameInterval;
}
const auto update_kind = common::GetString(event.payload, "frame_update_kind").value_or("full");
if (update_kind == "region") {
return kMinBinaryRegionRenderFrameInterval;
}
return kMinBinaryRenderFrameInterval;
}
bool IsMouseMoveEnvelope(const common::JsonObject& envelope) {
if (common::GetString(envelope, "type").value_or("") != "input") {
return false;
}
const auto* payload = common::GetObject(envelope, "payload");
return payload != nullptr &&
common::GetString(*payload, "kind").value_or("") == "mouse" &&
common::GetString(*payload, "action").value_or("") == "move";
}
bool ChannelAllowed(const DataPlaneTokenClaims& claims, const std::string& envelope_type) {
std::string channel = envelope_type;
if (envelope_type == "file_upload") {
channel = "file_upload";
} else if (envelope_type == "file_download") {
channel = "file_download";
} else if (envelope_type == "heartbeat") {
channel = "control";
}
return std::find(claims.allowed_channels.begin(), claims.allowed_channels.end(), channel) != claims.allowed_channels.end();
}
std::optional<std::string> BuildBlockedEnvelopeForDisallowedChannel(const common::JsonObject& envelope,
const DataPlaneTokenClaims& claims) {
const auto type = common::GetString(envelope, "type").value_or("");
if (type != "file_download") {
return std::nullopt;
}
const auto* payload = common::GetObject(envelope, "payload");
const std::string transfer_id = payload != nullptr ? common::GetString(*payload, "transfer_id").value_or("") : "";
const std::string file_id = payload != nullptr ? common::GetString(*payload, "file_id").value_or("") : "";
return common::SerializeJson(common::JsonObject{
{"type", "file_download.blocked"},
{"session_id", claims.session_id},
{"payload", common::JsonObject{
{"direction", "server_to_client"},
{"transfer_id", transfer_id},
{"file_id", file_id},
{"reason", "access denied"},
}},
});
}
std::string BuildTakenOverEnvelopeForStaleAttachment(const DataPlaneTokenClaims& claims,
const std::string& current_attachment_id) {
return common::SerializeJson(common::JsonObject{
{"type", "session.taken_over"},
{"session_id", claims.session_id},
{"payload", common::JsonObject{
{"message", "controller binding changed"},
{"reason", "attachment_mismatch"},
{"attachment_id", claims.attachment_id},
{"current_attachment_id", current_attachment_id},
}},
{"event", common::JsonObject{
{"code", "session.taken_over"},
{"message_key", "events.session.taken_over"},
{"fallback_message", "This session was taken over from another device."},
{"details", common::JsonObject{
{"reason", "attachment_mismatch"},
}},
}},
});
}
template <typename WebSocket>
bool WriteTextEnvelope(WebSocket& ws,
std::mutex& websocket_io_mutex,
const std::string& envelope,
std::shared_ptr<common::Logger> logger,
const std::string& session_id,
const std::string& reason) {
beast::error_code write_ec;
std::lock_guard<std::mutex> write_lock(websocket_io_mutex);
ws.text(true);
ws.write(asio::buffer(envelope), write_ec);
if (write_ec) {
logger->Warn("direct data-plane envelope write failed session=" + session_id +
" reason=" + reason +
" error=" + write_ec.message());
return false;
}
return true;
}
std::optional<common::JsonObject> NormalizeClientEnvelope(const std::string& message) {
auto parsed = common::ParseJson(message);
if (!parsed.IsObject()) {
return std::nullopt;
}
common::JsonObject envelope = parsed.AsObject();
const auto type = common::GetString(envelope, "type").value_or("");
const auto* payload = common::GetObject(envelope, "payload");
if (type.empty()) {
return std::nullopt;
}
if (type == "heartbeat" || type == "control.ping") {
return common::JsonObject{{"type", "heartbeat"}};
}
if (payload == nullptr) {
return std::nullopt;
}
if (type == "input" || type == "clipboard" || type == "control") {
return envelope;
}
if (type == "file_upload.start" || type == "file_upload.chunk" || type == "file_upload.cancel") {
common::JsonObject normalized_payload = *payload;
if (type == "file_upload.start") {
normalized_payload["action"] = "start";
} else if (type == "file_upload.chunk") {
normalized_payload["action"] = "chunk";
} else {
normalized_payload["action"] = "cancel";
}
return common::JsonObject{{"type", "file_upload"}, {"payload", normalized_payload}};
}
if (type == "file_download.start" || type == "file_download.ack" || type == "file_download.cancel") {
common::JsonObject normalized_payload = *payload;
if (type == "file_download.start") {
normalized_payload["action"] = "start";
} else if (type == "file_download.ack") {
normalized_payload["action"] = "ack";
} else {
normalized_payload["action"] = "cancel";
}
return common::JsonObject{{"type", "file_download"}, {"payload", normalized_payload}};
}
return envelope;
}
void AddBindingClaimsToEnvelope(common::JsonObject& envelope, const DataPlaneTokenClaims& claims) {
auto* payload = common::GetObject(envelope, "payload");
if (payload == nullptr) {
return;
}
common::JsonObject enriched = *payload;
enriched["session_id"] = claims.session_id;
enriched["attachment_id"] = claims.attachment_id;
enriched["user_id"] = claims.user_id;
enriched["organization_id"] = claims.organization_id;
enriched["worker_id"] = claims.worker_id;
enriched["resource_id"] = claims.resource_id;
envelope["payload"] = std::move(enriched);
}
template <typename WebSocket>
class DirectWssEventSink final : public runtime::DirectEventSink {
public:
DirectWssEventSink(WebSocket& ws,
std::mutex& websocket_io_mutex,
std::shared_ptr<common::Logger> logger,
std::string session_id,
std::string attachment_id,
bool binary_render_enabled,
std::string requested_color_mode,
std::function<void(std::string)> repair_request)
: ws_(ws),
websocket_io_mutex_(websocket_io_mutex),
logger_(std::move(logger)),
session_id_(std::move(session_id)),
attachment_id_(std::move(attachment_id)),
binary_render_enabled_(binary_render_enabled),
requested_color_mode_(NormalizeColorMode(requested_color_mode)),
repair_request_(std::move(repair_request)) {}
~DirectWssEventSink() override {
Stop();
}
void Start() {
writer_thread_ = std::thread(&DirectWssEventSink::WriterLoop, this);
}
std::string AttachmentId() const override {
return attachment_id_;
}
void Stop() {
{
std::lock_guard<std::mutex> lock(mutex_);
stopped_ = true;
reliable_events_.clear();
latest_cursor_.reset();
ordered_frame_events_.clear();
}
condition_.notify_all();
if (writer_thread_.joinable()) {
writer_thread_.join();
}
}
void EnqueueEvent(const runtime::WorkerEvent& event) override {
{
std::lock_guard<std::mutex> lock(mutex_);
if (stopped_) {
return;
}
if (IsFrameEvent(event)) {
if (IsDirectAttachBaselineFrameEvent(event)) {
ordered_frame_events_.clear();
if (reliable_events_.size() >= kMaxReliableOutboundEvents) {
reliable_events_.pop_front();
++reliable_dropped_;
}
reliable_events_.push_back(event);
++frames_queued_;
logger_->Info("direct data-plane baseline frame queued as non-droppable session=" + session_id_);
condition_.notify_one();
return;
}
if (IsFullFrameEvent(event)) {
if (!ordered_frame_events_.empty()) {
frames_dropped_due_to_backpressure_ += ordered_frame_events_.size();
logger_->Info("render.region_queue direct writer cleared pending frames for full frame session=" +
session_id_ +
" cleared=" + std::to_string(ordered_frame_events_.size()));
}
ordered_frame_events_.clear();
ordered_frame_events_.push_back(event);
} else if (IsRegionFrameEvent(event)) {
if (ordered_frame_events_.size() >= kMaxOrderedRenderFrames) {
const auto dropped = ordered_frame_events_.size() + 1;
frames_dropped_due_to_backpressure_ += dropped;
ordered_frame_events_.clear();
logger_->Warn("render.region_queue direct writer overflow session=" + session_id_ +
" dropped=" + std::to_string(dropped) +
" repair_requested=true");
if (repair_request_) {
repair_request_("direct_writer_region_queue_overflow");
}
} else {
ordered_frame_events_.push_back(event);
}
} else {
ordered_frame_events_.clear();
ordered_frame_events_.push_back(event);
}
++frames_queued_;
} else if (IsCursorEvent(event)) {
if (latest_cursor_.has_value()) {
++cursor_updates_dropped_due_to_backpressure_;
}
latest_cursor_ = event;
++cursor_updates_queued_;
} else {
if (reliable_events_.size() >= kMaxReliableOutboundEvents) {
reliable_events_.pop_front();
++reliable_dropped_;
}
reliable_events_.push_back(event);
}
}
condition_.notify_one();
}
private:
void WriterLoop() {
while (true) {
std::optional<runtime::WorkerEvent> event;
{
std::unique_lock<std::mutex> lock(mutex_);
condition_.wait(lock, [&]() {
return stopped_ || !reliable_events_.empty() || latest_cursor_.has_value() || !ordered_frame_events_.empty();
});
if (stopped_) {
return;
}
if (!reliable_events_.empty()) {
event = std::move(reliable_events_.front());
reliable_events_.pop_front();
} else if (latest_cursor_.has_value()) {
event = std::move(*latest_cursor_);
latest_cursor_.reset();
++cursor_updates_sent_;
} else if (!ordered_frame_events_.empty()) {
if (binary_render_enabled_ && last_binary_frame_sent_at_.time_since_epoch().count() != 0) {
const auto next_allowed_frame_at = last_binary_frame_sent_at_ + BinaryRenderFrameInterval(ordered_frame_events_.front());
const auto now = std::chrono::steady_clock::now();
if (now < next_allowed_frame_at) {
condition_.wait_until(lock, next_allowed_frame_at, [&]() {
return stopped_ || !reliable_events_.empty() || latest_cursor_.has_value();
});
continue;
}
}
event = std::move(ordered_frame_events_.front());
ordered_frame_events_.pop_front();
++frames_sent_;
}
}
if (!event.has_value()) {
continue;
}
const bool binary_frame = binary_render_enabled_ && IsFrameEvent(*event);
const auto binary_payload = binary_frame
? WorkerEventToBinaryRenderFrame(*event, requested_color_mode_, !baseline_binary_frame_sent_)
: std::nullopt;
const std::string encoded = binary_payload.has_value()
? binary_payload->message
: common::SerializeJson(WorkerEventToEnvelope(*event));
beast::error_code ec;
{
std::lock_guard<std::mutex> write_lock(websocket_io_mutex_);
if (binary_payload.has_value()) {
ws_.binary(true);
} else {
ws_.text(true);
}
ws_.write(asio::buffer(encoded), ec);
}
if (ec) {
logger_->Warn("direct data-plane write failed session=" + session_id_ +
" frame=" + (binary_payload.has_value() ? std::string("true") : std::string("false")) +
" reason=" + ec.message());
{
std::lock_guard<std::mutex> lock(mutex_);
stopped_ = true;
reliable_events_.clear();
latest_cursor_.reset();
ordered_frame_events_.clear();
}
condition_.notify_all();
return;
}
if (binary_payload.has_value()) {
last_binary_frame_sent_at_ = std::chrono::steady_clock::now();
if (binary_payload->update_kind == "full") {
baseline_binary_frame_sent_ = true;
}
binary_render_bytes_sent_ += encoded.size();
logger_->Info("direct data-plane render frame prepared session=" + session_id_ +
" requested_color_mode=" + requested_color_mode_ +
" applied_color_mode=" + binary_payload->color_mode +
" message_type=" + binary_payload->message_type +
" frame_update_kind=" + binary_payload->update_kind +
" full_frame_sent=" + (binary_payload->update_kind == "full" ? std::string("true") : std::string("false")) +
" region_frame_sent=" + (binary_payload->update_kind == "region" ? std::string("true") : std::string("false")) +
" dirty_region_applied=" + (binary_payload->dirty_region_applied ? std::string("true") : std::string("false")) +
" grayscale_conversion_applied=" + (binary_payload->grayscale_conversion_applied ? std::string("true") : std::string("false")) +
" raw_frame_bytes_before=" + std::to_string(binary_payload->raw_frame_bytes_before) +
" raw_frame_bytes_after=" + std::to_string(binary_payload->raw_frame_bytes_after) +
" full_frame_bytes=" + std::to_string(binary_payload->full_frame_bytes) +
" region_bytes=" + std::to_string(binary_payload->region_bytes) +
" region_savings_percent=" + std::to_string(binary_payload->region_savings_percent) +
" binary_direct_bytes=" + std::to_string(binary_payload->binary_direct_bytes) +
" diff_time_ms=" + std::to_string(binary_payload->diff_time_ms) +
" render_update_reason=" + binary_payload->render_update_reason +
" fallback_to_full_frame_reason=" + binary_payload->fallback_to_full_frame_reason +
" conversion_time_ms=" + std::to_string(binary_payload->conversion_time_ms));
} else if (IsFrameEvent(*event)) {
json_render_bytes_sent_ += encoded.size();
}
LogRateIfNeeded();
}
}
void LogRateIfNeeded() {
const auto now = std::chrono::steady_clock::now();
if (last_rate_log_at_.time_since_epoch().count() == 0) {
last_rate_log_at_ = now;
return;
}
if (now - last_rate_log_at_ < std::chrono::seconds(2)) {
return;
}
const double seconds = std::max(0.001, std::chrono::duration<double>(now - last_rate_log_at_).count());
logger_->Info("direct data-plane outbound rate session=" + session_id_ +
" frames_queued_per_second=" + std::to_string(static_cast<double>(frames_queued_) / seconds) +
" frames_sent_per_second=" + std::to_string(static_cast<double>(frames_sent_) / seconds) +
" cursor_queued_per_second=" + std::to_string(static_cast<double>(cursor_updates_queued_) / seconds) +
" cursor_sent_per_second=" + std::to_string(static_cast<double>(cursor_updates_sent_) / seconds) +
" binary_render_bytes_per_second=" + std::to_string(static_cast<double>(binary_render_bytes_sent_) / seconds) +
" json_render_bytes_per_second=" + std::to_string(static_cast<double>(json_render_bytes_sent_) / seconds) +
" frames_dropped_due_to_backpressure=" + std::to_string(frames_dropped_due_to_backpressure_) +
" cursor_dropped_due_to_backpressure=" + std::to_string(cursor_updates_dropped_due_to_backpressure_) +
" reliable_dropped=" + std::to_string(reliable_dropped_));
frames_queued_ = 0;
frames_sent_ = 0;
cursor_updates_queued_ = 0;
cursor_updates_sent_ = 0;
binary_render_bytes_sent_ = 0;
json_render_bytes_sent_ = 0;
frames_dropped_due_to_backpressure_ = 0;
cursor_updates_dropped_due_to_backpressure_ = 0;
reliable_dropped_ = 0;
last_rate_log_at_ = now;
}
WebSocket& ws_;
std::mutex& websocket_io_mutex_;
std::shared_ptr<common::Logger> logger_;
std::string session_id_;
std::string attachment_id_;
std::string requested_color_mode_;
std::function<void(std::string)> repair_request_;
std::mutex mutex_;
std::condition_variable condition_;
std::deque<runtime::WorkerEvent> reliable_events_;
std::optional<runtime::WorkerEvent> latest_cursor_;
std::deque<runtime::WorkerEvent> ordered_frame_events_;
std::thread writer_thread_;
bool stopped_{false};
bool binary_render_enabled_{false};
bool baseline_binary_frame_sent_{false};
std::chrono::steady_clock::time_point last_rate_log_at_{};
std::chrono::steady_clock::time_point last_binary_frame_sent_at_{};
std::size_t frames_queued_{0};
std::size_t frames_sent_{0};
std::size_t cursor_updates_queued_{0};
std::size_t cursor_updates_sent_{0};
std::size_t binary_render_bytes_sent_{0};
std::size_t json_render_bytes_sent_{0};
std::size_t frames_dropped_due_to_backpressure_{0};
std::size_t cursor_updates_dropped_due_to_backpressure_{0};
std::size_t reliable_dropped_{0};
};
} // namespace
DirectWssServer::DirectWssServer(config::Config config,
std::shared_ptr<runtime::SessionManager> session_manager,
std::shared_ptr<common::Logger> logger)
: config_(std::move(config)),
session_manager_(std::move(session_manager)),
logger_(std::move(logger)),
token_validator_(config_.data_plane_public_key_pem, config_.worker_id),
ssl_context_(asio::ssl::context::tlsv12_server) {}
DirectWssServer::~DirectWssServer() {
Stop();
}
void DirectWssServer::Start() {
if (!config_.data_plane_enabled) {
logger_->Info("direct data-plane WSS endpoint disabled");
return;
}
if (config_.data_plane_public_key_pem.empty() ||
config_.data_plane_tls_cert_file.empty() ||
config_.data_plane_tls_key_file.empty()) {
logger_->Warn("direct data-plane WSS endpoint not started because token public key or TLS certificate/key is missing");
return;
}
ssl_context_.set_options(asio::ssl::context::default_workarounds |
asio::ssl::context::no_sslv2 |
asio::ssl::context::no_sslv3 |
asio::ssl::context::single_dh_use);
ssl_context_.use_certificate_chain_file(config_.data_plane_tls_cert_file);
ssl_context_.use_private_key_file(config_.data_plane_tls_key_file, asio::ssl::context::pem);
thread_ = std::thread(&DirectWssServer::Run, this);
}
void DirectWssServer::Stop() {
stop_requested_.store(true);
if (acceptor_.has_value()) {
beast::error_code ignored;
acceptor_->close(ignored);
}
io_context_.stop();
if (thread_.joinable()) {
thread_.join();
}
}
void DirectWssServer::Run() {
try {
const auto address = asio::ip::make_address(config_.data_plane_listen_host);
const tcp::endpoint endpoint{address, static_cast<unsigned short>(config_.data_plane_listen_port)};
acceptor_.emplace(io_context_);
acceptor_->open(endpoint.protocol());
acceptor_->set_option(asio::socket_base::reuse_address(true));
acceptor_->bind(endpoint);
acceptor_->listen(asio::socket_base::max_listen_connections);
logger_->Info("direct data-plane WSS endpoint listening host=" + config_.data_plane_listen_host +
" port=" + std::to_string(config_.data_plane_listen_port) +
" path=/rap/v1/data-plane");
while (!stop_requested_.load()) {
beast::error_code ec;
tcp::socket socket{io_context_};
acceptor_->accept(socket, ec);
if (ec) {
if (!stop_requested_.load()) {
logger_->Warn("direct data-plane accept failed reason=" + ec.message());
}
continue;
}
std::thread(&DirectWssServer::HandleConnection, this, std::move(socket)).detach();
}
} catch (const std::exception& error) {
logger_->Error(std::string("direct data-plane WSS endpoint stopped reason=") + error.what());
}
}
void DirectWssServer::HandleConnection(tcp::socket socket) {
beast::error_code ec;
asio::ssl::stream<tcp::socket> tls_stream{std::move(socket), ssl_context_};
tls_stream.handshake(asio::ssl::stream_base::server, ec);
if (ec) {
logger_->Warn("direct data-plane TLS handshake failed reason=" + ec.message());
return;
}
beast::flat_buffer buffer;
http::request<http::string_body> request;
http::read(tls_stream, buffer, request, ec);
if (ec) {
logger_->Warn("direct data-plane HTTP upgrade read failed reason=" + ec.message());
return;
}
const std::string target = std::string(request.target());
if (PathOnly(target) != "/rap/v1/data-plane") {
WriteReject(tls_stream, http::status::not_found, "unknown_path");
return;
}
if (!websocket::is_upgrade(request)) {
WriteReject(tls_stream, http::status::upgrade_required, "websocket_required");
return;
}
const auto token = QueryParam(target, "data_plane_token");
const bool binary_render_requested = QueryParam(target, "render_transport") == kBinaryRenderTransportV1;
const std::string requested_color_mode = binary_render_requested
? NormalizeColorMode(QueryParam(target, "color_mode"))
: std::string(kColorModeFullColor);
const auto validation = token_validator_.Validate(token);
if (!validation.ok) {
logger_->Warn("event=token_validation_failed reason=" + validation.reason);
WriteReject(tls_stream, http::status::unauthorized, validation.reason);
return;
}
logger_->Info("event=token_validation_success session=" + validation.claims.session_id +
" attachment=" + validation.claims.attachment_id +
" worker=" + validation.claims.worker_id +
" jti=" + validation.claims.jti);
if (!ConsumeJti(validation.claims)) {
logger_->Warn("event=jti_replay_rejected session=" + validation.claims.session_id +
" attachment=" + validation.claims.attachment_id +
" jti=" + validation.claims.jti);
WriteReject(tls_stream, http::status::unauthorized, "jti_replay_rejected");
return;
}
std::string bind_reason;
auto runtime = session_manager_->BindDirectDataPlaneRuntime(validation.claims, bind_reason);
if (!runtime) {
logger_->Warn("event=data_plane_bind_failed session=" + validation.claims.session_id +
" attachment=" + validation.claims.attachment_id +
" reason=" + bind_reason);
WriteReject(tls_stream, http::status::not_found, bind_reason);
return;
}
websocket::stream<asio::ssl::stream<tcp::socket>> ws{std::move(tls_stream)};
ws.accept(request, ec);
if (ec) {
logger_->Warn("direct data-plane WebSocket accept failed session=" + validation.claims.session_id +
" reason=" + ec.message());
return;
}
logger_->Info("event=data_plane_bind_success session=" + validation.claims.session_id +
" attachment=" + validation.claims.attachment_id +
" channels=" + ChannelsToJsonArray(validation.claims.allowed_channels) +
" render_transport=" + (binary_render_requested ? std::string("binary_v1") : std::string("json_base64")) +
" requested_color_mode=" + requested_color_mode +
" applied_color_mode=" + requested_color_mode);
ws.text(true);
ws.write(asio::buffer(rdp_worker::common::SerializeJson(rdp_worker::common::JsonObject{
{"type", "data_plane.attached"},
{"session_id", validation.claims.session_id},
{"attachment_id", validation.claims.attachment_id},
{"worker_id", validation.claims.worker_id},
{"render_transport", binary_render_requested ? "binary_v1" : "json_base64"},
{"color_mode", requested_color_mode},
})),
ec);
if (ec) {
return;
}
std::mutex websocket_io_mutex;
auto sink = std::make_shared<DirectWssEventSink<websocket::stream<asio::ssl::stream<tcp::socket>>>>(
ws,
websocket_io_mutex,
logger_,
validation.claims.session_id,
validation.claims.attachment_id,
binary_render_requested,
requested_color_mode,
[runtime](std::string reason) {
runtime->RequestDirectFullFrameRepair(std::move(reason));
});
sink->Start();
runtime->AddDirectEventSink(sink);
while (!stop_requested_.load()) {
beast::error_code available_ec;
const auto available = beast::get_lowest_layer(ws).available(available_ec);
if (available_ec) {
logger_->Warn("direct data-plane socket availability check failed session=" + validation.claims.session_id +
" reason=" + available_ec.message());
break;
}
if (available == 0) {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
continue;
}
beast::flat_buffer message;
{
std::lock_guard<std::mutex> read_lock(websocket_io_mutex);
ws.read(message, ec);
}
if (ec == websocket::error::closed) {
break;
}
if (ec) {
logger_->Warn("direct data-plane WebSocket read failed session=" + validation.claims.session_id +
" reason=" + ec.message());
break;
}
const auto payload = beast::buffers_to_string(message.data());
auto envelope = NormalizeClientEnvelope(payload);
if (!envelope.has_value()) {
logger_->Warn("direct data-plane ignored malformed client envelope session=" + validation.claims.session_id);
continue;
}
const auto type = common::GetString(*envelope, "type").value_or("");
if (type == "heartbeat") {
continue;
}
if (!ChannelAllowed(validation.claims, type)) {
logger_->Warn("direct data-plane envelope rejected by token channel scope session=" + validation.claims.session_id +
" type=" + type);
if (auto blocked = BuildBlockedEnvelopeForDisallowedChannel(*envelope, validation.claims); blocked.has_value()) {
if (!WriteTextEnvelope(ws,
websocket_io_mutex,
*blocked,
logger_,
validation.claims.session_id,
"channel_scope_blocked")) {
break;
}
}
continue;
}
const auto snapshot = runtime->Snapshot();
if (snapshot.attachment_id != validation.claims.attachment_id) {
logger_->Warn("direct data-plane envelope rejected because attachment is no longer current session=" +
validation.claims.session_id +
" envelope_attachment=" + validation.claims.attachment_id +
" current_attachment=" + snapshot.attachment_id +
" type=" + type);
const auto taken_over = BuildTakenOverEnvelopeForStaleAttachment(validation.claims, snapshot.attachment_id);
if (!WriteTextEnvelope(ws,
websocket_io_mutex,
taken_over,
logger_,
validation.claims.session_id,
"attachment_mismatch")) {
break;
}
continue;
}
AddBindingClaimsToEnvelope(*envelope, validation.claims);
const bool mouse_move = IsMouseMoveEnvelope(*envelope);
if (!runtime->EnqueueDirectEnvelope(std::move(*envelope))) {
logger_->Warn("direct data-plane inbound queue full session=" + validation.claims.session_id +
" type=" + type +
" mouse_move=" + (mouse_move ? std::string("true") : std::string("false")));
}
}
sink->Stop();
logger_->Info("direct data-plane WebSocket detached session=" + validation.claims.session_id +
" attachment=" + validation.claims.attachment_id);
}
bool DirectWssServer::ConsumeJti(const DataPlaneTokenClaims& claims) {
std::lock_guard<std::mutex> lock(jti_mutex_);
const auto now = std::chrono::system_clock::now();
for (auto iterator = used_jti_.begin(); iterator != used_jti_.end();) {
if (iterator->second <= now) {
iterator = used_jti_.erase(iterator);
} else {
++iterator;
}
}
if (used_jti_.find(claims.jti) != used_jti_.end()) {
return false;
}
if (used_jti_.size() >= 4096) {
used_jti_.erase(used_jti_.begin());
}
const auto expires_at = std::chrono::system_clock::time_point(std::chrono::seconds(claims.expires_at_unix));
used_jti_.emplace(claims.jti, expires_at + std::chrono::seconds(5));
return true;
}
} // namespace rdp_worker::dataplane