1303 lines
56 KiB
C++
1303 lines
56 KiB
C++
#include "rdp_worker/dataplane/direct_wss_server.hpp"
|
|
|
|
#include <algorithm>
|
|
#include <array>
|
|
#include <chrono>
|
|
#include <condition_variable>
|
|
#include <cstdint>
|
|
#include <deque>
|
|
#include <functional>
|
|
#include <limits>
|
|
#include <sstream>
|
|
#include <stdexcept>
|
|
#include <string_view>
|
|
#include <thread>
|
|
#include <cstdlib>
|
|
#include <utility>
|
|
|
|
#include <boost/asio/buffer.hpp>
|
|
#include <boost/asio/ip/address.hpp>
|
|
#include <boost/beast/core.hpp>
|
|
#include <boost/beast/http.hpp>
|
|
#include <boost/beast/websocket.hpp>
|
|
#include <boost/beast/websocket/ssl.hpp>
|
|
|
|
#include "rdp_worker/common/json.hpp"
|
|
|
|
namespace rdp_worker::dataplane {
|
|
|
|
namespace asio = boost::asio;
|
|
namespace beast = boost::beast;
|
|
namespace http = beast::http;
|
|
namespace websocket = beast::websocket;
|
|
using tcp = asio::ip::tcp;
|
|
|
|
namespace {
|
|
|
|
constexpr std::size_t kMaxReliableOutboundEvents = 256;
|
|
constexpr std::size_t kMaxOrderedRenderFrames = 256;
|
|
constexpr auto kMinBinaryRenderFrameInterval = std::chrono::milliseconds(100);
|
|
constexpr auto kMinBinaryRegionRenderFrameInterval = std::chrono::milliseconds(33);
|
|
constexpr std::string_view kBinaryRenderTransportV1 = "binary_v1";
|
|
constexpr std::string_view kColorModeFullColor = "full_color";
|
|
constexpr std::string_view kColorModeGrayscale = "grayscale";
|
|
constexpr std::string_view kRenderFrameFullMessageType = "render.frame.full";
|
|
constexpr std::string_view kRenderFrameRegionMessageType = "render.frame.region";
|
|
|
|
struct BinaryRenderFrameMessage {
|
|
std::string message;
|
|
std::string message_type;
|
|
std::string color_mode;
|
|
std::string update_kind;
|
|
bool grayscale_conversion_applied{false};
|
|
bool dirty_region_applied{false};
|
|
std::size_t raw_frame_bytes_before{0};
|
|
std::size_t raw_frame_bytes_after{0};
|
|
std::size_t binary_direct_bytes{0};
|
|
std::size_t full_frame_bytes{0};
|
|
std::size_t region_bytes{0};
|
|
double region_savings_percent{0.0};
|
|
std::int64_t conversion_time_ms{0};
|
|
std::int64_t diff_time_ms{0};
|
|
std::string render_update_reason;
|
|
std::string fallback_to_full_frame_reason;
|
|
};
|
|
|
|
std::int64_t UnixMillisecondsNow() {
|
|
return std::chrono::duration_cast<std::chrono::milliseconds>(
|
|
std::chrono::system_clock::now().time_since_epoch())
|
|
.count();
|
|
}
|
|
|
|
void AppendUint16Le(std::string& output, std::uint16_t value) {
|
|
output.push_back(static_cast<char>(value & 0xFF));
|
|
output.push_back(static_cast<char>((value >> 8) & 0xFF));
|
|
}
|
|
|
|
void AppendUint32Le(std::string& output, std::uint32_t value) {
|
|
output.push_back(static_cast<char>(value & 0xFF));
|
|
output.push_back(static_cast<char>((value >> 8) & 0xFF));
|
|
output.push_back(static_cast<char>((value >> 16) & 0xFF));
|
|
output.push_back(static_cast<char>((value >> 24) & 0xFF));
|
|
}
|
|
|
|
std::string Base64Encode(const std::uint8_t* data, std::size_t size) {
|
|
static constexpr std::array<char, 64> alphabet{
|
|
'A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P',
|
|
'Q','R','S','T','U','V','W','X','Y','Z','a','b','c','d','e','f',
|
|
'g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v',
|
|
'w','x','y','z','0','1','2','3','4','5','6','7','8','9','+','/'
|
|
};
|
|
std::string output;
|
|
output.reserve(((size + 2) / 3) * 4);
|
|
for (std::size_t index = 0; index < size; ) {
|
|
const std::size_t remaining = size - index;
|
|
const std::uint32_t octet_a = data[index++];
|
|
const std::uint32_t octet_b = remaining > 1 ? data[index++] : 0;
|
|
const std::uint32_t octet_c = remaining > 2 ? data[index++] : 0;
|
|
const std::uint32_t triple = (octet_a << 16) | (octet_b << 8) | octet_c;
|
|
|
|
output.push_back(alphabet[(triple >> 18) & 0x3F]);
|
|
output.push_back(alphabet[(triple >> 12) & 0x3F]);
|
|
output.push_back(remaining > 1 ? alphabet[(triple >> 6) & 0x3F] : '=');
|
|
output.push_back(remaining > 2 ? alphabet[triple & 0x3F] : '=');
|
|
}
|
|
return output;
|
|
}
|
|
|
|
std::optional<std::string> DecodeBase64ToBytes(const std::string& input) {
|
|
static constexpr unsigned char kInvalid = 255;
|
|
static unsigned char table[256];
|
|
static bool initialized = false;
|
|
if (!initialized) {
|
|
std::fill(std::begin(table), std::end(table), kInvalid);
|
|
const std::string alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
|
for (std::size_t i = 0; i < alphabet.size(); ++i) {
|
|
table[static_cast<unsigned char>(alphabet[i])] = static_cast<unsigned char>(i);
|
|
}
|
|
initialized = true;
|
|
}
|
|
|
|
std::string output;
|
|
output.reserve((input.size() * 3) / 4);
|
|
int value = 0;
|
|
int bits = -8;
|
|
for (unsigned char ch : input) {
|
|
if (ch == '=') {
|
|
break;
|
|
}
|
|
if (table[ch] == kInvalid) {
|
|
return std::nullopt;
|
|
}
|
|
value = (value << 6) + table[ch];
|
|
bits += 6;
|
|
if (bits >= 0) {
|
|
output.push_back(static_cast<char>((value >> bits) & 0xFF));
|
|
bits -= 8;
|
|
}
|
|
}
|
|
return output;
|
|
}
|
|
|
|
std::string UrlDecode(std::string_view value) {
|
|
std::string output;
|
|
output.reserve(value.size());
|
|
for (std::size_t i = 0; i < value.size(); ++i) {
|
|
if (value[i] == '%' && i + 2 < value.size()) {
|
|
const std::string hex{value[i + 1], value[i + 2]};
|
|
char* end = nullptr;
|
|
const long decoded = std::strtol(hex.c_str(), &end, 16);
|
|
if (end != nullptr && *end == '\0') {
|
|
output.push_back(static_cast<char>(decoded));
|
|
i += 2;
|
|
continue;
|
|
}
|
|
}
|
|
output.push_back(value[i] == '+' ? ' ' : value[i]);
|
|
}
|
|
return output;
|
|
}
|
|
|
|
std::string QueryParam(std::string_view target, std::string_view name) {
|
|
const auto question = target.find('?');
|
|
if (question == std::string_view::npos) {
|
|
return {};
|
|
}
|
|
std::string_view query = target.substr(question + 1);
|
|
while (!query.empty()) {
|
|
const auto amp = query.find('&');
|
|
const auto pair = query.substr(0, amp);
|
|
const auto equals = pair.find('=');
|
|
if (equals != std::string_view::npos && pair.substr(0, equals) == name) {
|
|
return UrlDecode(pair.substr(equals + 1));
|
|
}
|
|
if (amp == std::string_view::npos) {
|
|
break;
|
|
}
|
|
query = query.substr(amp + 1);
|
|
}
|
|
return {};
|
|
}
|
|
|
|
std::string PathOnly(std::string_view target) {
|
|
const auto question = target.find('?');
|
|
return std::string(target.substr(0, question == std::string_view::npos ? target.size() : question));
|
|
}
|
|
|
|
std::string NormalizeColorMode(const std::string& requested) {
|
|
return requested == kColorModeGrayscale ? std::string(kColorModeGrayscale) : std::string(kColorModeFullColor);
|
|
}
|
|
|
|
void ApplyGrayscaleBgra(std::string& frame_bytes) {
|
|
for (std::size_t index = 0; index + 3 < frame_bytes.size(); index += 4) {
|
|
const auto blue = static_cast<unsigned char>(frame_bytes[index]);
|
|
const auto green = static_cast<unsigned char>(frame_bytes[index + 1]);
|
|
const auto red = static_cast<unsigned char>(frame_bytes[index + 2]);
|
|
const auto gray = static_cast<char>(((static_cast<unsigned int>(red) * 77U) +
|
|
(static_cast<unsigned int>(green) * 150U) +
|
|
(static_cast<unsigned int>(blue) * 29U)) >> 8U);
|
|
frame_bytes[index] = gray;
|
|
frame_bytes[index + 1] = gray;
|
|
frame_bytes[index + 2] = gray;
|
|
}
|
|
}
|
|
|
|
int JsonIntOr(const common::JsonObject& payload, const std::string& key, int fallback) {
|
|
return static_cast<int>(common::GetNumber(payload, key).value_or(fallback));
|
|
}
|
|
|
|
bool TryCropBgraRegion(const std::string& source,
|
|
int full_width,
|
|
int full_height,
|
|
int full_stride,
|
|
int region_x,
|
|
int region_y,
|
|
int region_width,
|
|
int region_height,
|
|
std::string& output) {
|
|
if (full_width <= 0 || full_height <= 0 || full_stride <= 0 ||
|
|
region_x < 0 || region_y < 0 || region_width <= 0 || region_height <= 0 ||
|
|
region_x + region_width > full_width || region_y + region_height > full_height) {
|
|
return false;
|
|
}
|
|
const int region_stride = region_width * 4;
|
|
const auto required_source_bytes = static_cast<std::size_t>(full_stride) * static_cast<std::size_t>(full_height);
|
|
if (source.size() < required_source_bytes || full_stride < region_x * 4 + region_stride) {
|
|
return false;
|
|
}
|
|
output.assign(static_cast<std::size_t>(region_stride) * static_cast<std::size_t>(region_height), '\0');
|
|
for (int row = 0; row < region_height; ++row) {
|
|
const auto source_offset = static_cast<std::size_t>(region_y + row) * static_cast<std::size_t>(full_stride) +
|
|
static_cast<std::size_t>(region_x * 4);
|
|
const auto target_offset = static_cast<std::size_t>(row) * static_cast<std::size_t>(region_stride);
|
|
std::copy_n(source.data() + source_offset, region_stride, output.data() + target_offset);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
template <typename Stream>
|
|
void WriteReject(Stream& stream, http::status status, const std::string& reason) {
|
|
http::response<http::string_body> response{status, 11};
|
|
response.set(http::field::server, "rap-rdp-worker");
|
|
response.set(http::field::content_type, "application/json");
|
|
response.keep_alive(false);
|
|
response.body() = rdp_worker::common::SerializeJson(rdp_worker::common::JsonObject{
|
|
{"error", rdp_worker::common::JsonObject{
|
|
{"code", reason},
|
|
{"message_key", "errors.data_plane." + reason},
|
|
{"fallback_message", "Data plane connection rejected."},
|
|
}},
|
|
});
|
|
response.prepare_payload();
|
|
beast::error_code ignored;
|
|
http::write(stream, response, ignored);
|
|
}
|
|
|
|
std::string ChannelsToJsonArray(const std::vector<std::string>& channels) {
|
|
rdp_worker::common::JsonArray array;
|
|
for (const auto& channel : channels) {
|
|
array.emplace_back(channel);
|
|
}
|
|
return rdp_worker::common::SerializeJson(array);
|
|
}
|
|
|
|
std::string WorkerEventTypeToEnvelopeType(const runtime::WorkerEvent& event) {
|
|
if (event.type == "session_frame") {
|
|
return "session.frame";
|
|
}
|
|
if (event.type == "session_cursor_updated") {
|
|
return "cursor.update";
|
|
}
|
|
if (event.type == "session_clipboard_text") {
|
|
return "clipboard.text";
|
|
}
|
|
if (event.type == "session_taken_over") {
|
|
return "session.taken_over";
|
|
}
|
|
if (event.type == "session_file_upload_completed") {
|
|
return "file_upload.progress";
|
|
}
|
|
if (event.type == "session_file_download_available") {
|
|
return "file_download.available";
|
|
}
|
|
if (event.type == "session_file_download_chunk") {
|
|
return "file_download.chunk";
|
|
}
|
|
if (event.type == "session_file_download_completed") {
|
|
return "file_download.completed";
|
|
}
|
|
if (event.type == "session_file_download_failed") {
|
|
return "file_download.failed";
|
|
}
|
|
if (event.type == "session_file_download_blocked") {
|
|
return "file_download.blocked";
|
|
}
|
|
if (event.type == "session_file_download_progress") {
|
|
return "file_download.progress";
|
|
}
|
|
if (event.type == "session_failed" || event.type == "session_terminated" ||
|
|
event.type == "session_connected" || event.type == "session_display_ready" ||
|
|
event.type == "session_heartbeat") {
|
|
return "session.state";
|
|
}
|
|
return event.type;
|
|
}
|
|
|
|
common::JsonObject RenderPayloadFromWorkerEvent(const runtime::WorkerEvent& event) {
|
|
common::JsonObject render{
|
|
{"width", event.width},
|
|
{"height", event.height},
|
|
};
|
|
if (auto value = common::GetString(event.payload, "render_quality_profile"); value.has_value()) {
|
|
render["quality_profile"] = *value;
|
|
render["render_quality_profile"] = *value;
|
|
}
|
|
if (auto value = common::GetString(event.payload, "render_state"); value.has_value()) {
|
|
render["state"] = *value;
|
|
render["render_state"] = *value;
|
|
}
|
|
if (auto value = common::GetNumber(event.payload, "frame_sequence"); value.has_value()) {
|
|
render["frame_sequence"] = *value;
|
|
}
|
|
if (auto value = common::GetString(event.payload, "frame_format"); value.has_value()) {
|
|
render["frame_format"] = *value;
|
|
}
|
|
if (auto value = common::GetNumber(event.payload, "cursor_x"); value.has_value()) {
|
|
render["cursor_x"] = *value;
|
|
}
|
|
if (auto value = common::GetNumber(event.payload, "cursor_y"); value.has_value()) {
|
|
render["cursor_y"] = *value;
|
|
}
|
|
if (auto value = common::GetBool(event.payload, "cursor_visible"); value.has_value()) {
|
|
render["cursor_visible"] = *value;
|
|
}
|
|
return render;
|
|
}
|
|
|
|
std::optional<BinaryRenderFrameMessage> WorkerEventToBinaryRenderFrame(const runtime::WorkerEvent& event,
|
|
const std::string& requested_color_mode,
|
|
bool require_full_frame) {
|
|
if (event.type != "session_frame") {
|
|
return std::nullopt;
|
|
}
|
|
std::string frame_bytes;
|
|
if (!event.raw_frame_bytes.empty()) {
|
|
frame_bytes.assign(reinterpret_cast<const char*>(event.raw_frame_bytes.data()), event.raw_frame_bytes.size());
|
|
} else {
|
|
const auto encoded_frame = common::GetString(event.payload, "frame_data").value_or("");
|
|
if (encoded_frame.empty()) {
|
|
return std::nullopt;
|
|
}
|
|
auto decoded = DecodeBase64ToBytes(encoded_frame);
|
|
if (!decoded.has_value()) {
|
|
return std::nullopt;
|
|
}
|
|
frame_bytes = std::move(*decoded);
|
|
}
|
|
|
|
const std::size_t raw_frame_bytes_before = frame_bytes.size();
|
|
const auto frame_sequence = static_cast<std::int64_t>(common::GetNumber(event.payload, "frame_sequence").value_or(0));
|
|
const auto source_frame_width = JsonIntOr(event.payload, "frame_width", event.width);
|
|
const auto source_frame_height = JsonIntOr(event.payload, "frame_height", event.height);
|
|
const auto source_frame_stride = JsonIntOr(event.payload, "frame_stride", source_frame_width * 4);
|
|
const auto desktop_width = JsonIntOr(event.payload, "desktop_width", source_frame_width);
|
|
const auto desktop_height = JsonIntOr(event.payload, "desktop_height", source_frame_height);
|
|
const auto frame_format = common::GetString(event.payload, "frame_format").value_or("bgra32");
|
|
const auto quality_profile = common::GetString(event.payload, "render_quality_profile").value_or("balanced");
|
|
const auto input_correlation_id = common::GetString(event.payload, "input_correlation_id").value_or("");
|
|
const auto worker_frame_captured_at = common::GetString(event.payload, "worker_frame_captured_at").value_or("");
|
|
const auto requested_update_kind = common::GetString(event.payload, "frame_update_kind").value_or("full");
|
|
const auto render_update_reason = common::GetString(event.payload, "capture_source").value_or("");
|
|
const int region_x = JsonIntOr(event.payload, "region_x", 0);
|
|
const int region_y = JsonIntOr(event.payload, "region_y", 0);
|
|
const int region_width = JsonIntOr(event.payload, "region_width", 0);
|
|
const int region_height = JsonIntOr(event.payload, "region_height", 0);
|
|
const bool frame_data_is_region = common::GetBool(event.payload, "frame_data_is_region").value_or(false);
|
|
const auto full_frame_bytes = static_cast<std::size_t>(std::max(0, desktop_width)) *
|
|
static_cast<std::size_t>(std::max(0, desktop_height)) * 4U;
|
|
const auto diff_time_ms = static_cast<std::int64_t>(
|
|
common::GetNumber(event.payload, "periodic_change_poll_ms")
|
|
.value_or(common::GetNumber(event.payload, "capture_total_ms").value_or(0)));
|
|
|
|
std::string update_kind = "full";
|
|
std::string fallback_to_full_frame_reason;
|
|
int output_frame_width = source_frame_width;
|
|
int output_frame_height = source_frame_height;
|
|
int output_frame_stride = source_frame_stride;
|
|
bool dirty_region_applied = false;
|
|
if (require_full_frame && requested_update_kind == "region") {
|
|
fallback_to_full_frame_reason = "baseline_required";
|
|
} else if (!require_full_frame && requested_update_kind == "region") {
|
|
if (frame_data_is_region &&
|
|
region_width > 0 && region_height > 0 &&
|
|
frame_bytes.size() == static_cast<std::size_t>(region_width) * static_cast<std::size_t>(region_height) * 4U) {
|
|
output_frame_width = region_width;
|
|
output_frame_height = region_height;
|
|
output_frame_stride = region_width * 4;
|
|
update_kind = "region";
|
|
dirty_region_applied = true;
|
|
} else {
|
|
std::string region_bytes;
|
|
if (TryCropBgraRegion(frame_bytes,
|
|
source_frame_width,
|
|
source_frame_height,
|
|
source_frame_stride,
|
|
region_x,
|
|
region_y,
|
|
region_width,
|
|
region_height,
|
|
region_bytes)) {
|
|
frame_bytes = std::move(region_bytes);
|
|
output_frame_width = region_width;
|
|
output_frame_height = region_height;
|
|
output_frame_stride = region_width * 4;
|
|
update_kind = "region";
|
|
dirty_region_applied = true;
|
|
}
|
|
}
|
|
if (!dirty_region_applied) {
|
|
fallback_to_full_frame_reason = "invalid_region_payload";
|
|
}
|
|
}
|
|
|
|
const auto conversion_started = std::chrono::steady_clock::now();
|
|
const std::string applied_color_mode = NormalizeColorMode(requested_color_mode);
|
|
const bool grayscale = applied_color_mode == kColorModeGrayscale;
|
|
if (grayscale) {
|
|
ApplyGrayscaleBgra(frame_bytes);
|
|
}
|
|
const auto conversion_finished = std::chrono::steady_clock::now();
|
|
const auto conversion_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
|
|
conversion_finished - conversion_started)
|
|
.count();
|
|
const auto message_type = update_kind == "region"
|
|
? std::string(kRenderFrameRegionMessageType)
|
|
: std::string(kRenderFrameFullMessageType);
|
|
const std::size_t region_bytes = update_kind == "region" ? frame_bytes.size() : 0;
|
|
const double region_savings_percent = update_kind == "region" && full_frame_bytes > 0
|
|
? 100.0 - ((static_cast<double>(region_bytes) * 100.0) / static_cast<double>(full_frame_bytes))
|
|
: 0.0;
|
|
|
|
common::JsonObject header{
|
|
{"protocol_version", 1},
|
|
{"session_id", event.session_id},
|
|
{"channel", "render"},
|
|
{"message_type", message_type},
|
|
{"legacy_message_type", "session.frame"},
|
|
{"sequence", static_cast<double>(frame_sequence)},
|
|
{"timestamp", static_cast<double>(UnixMillisecondsNow())},
|
|
{"flags", 0},
|
|
{"payload_length", static_cast<double>(frame_bytes.size())},
|
|
{"frame_width", output_frame_width},
|
|
{"frame_height", output_frame_height},
|
|
{"frame_stride", output_frame_stride},
|
|
{"frame_format", frame_format},
|
|
{"frame_update_kind", update_kind},
|
|
{"desktop_width", desktop_width},
|
|
{"desktop_height", desktop_height},
|
|
{"region_x", dirty_region_applied ? region_x : 0},
|
|
{"region_y", dirty_region_applied ? region_y : 0},
|
|
{"region_width", dirty_region_applied ? region_width : output_frame_width},
|
|
{"region_height", dirty_region_applied ? region_height : output_frame_height},
|
|
{"region_stride", dirty_region_applied ? output_frame_stride : output_frame_stride},
|
|
{"region_format", "BGRA32"},
|
|
{"color_mode", applied_color_mode},
|
|
{"quality_profile", quality_profile},
|
|
{"original_frame_format", frame_format},
|
|
{"output_frame_format", frame_format},
|
|
{"raw_frame_bytes", static_cast<double>(raw_frame_bytes_before)},
|
|
{"binary_direct_bytes", static_cast<double>(frame_bytes.size())},
|
|
{"full_frame_bytes", static_cast<double>(full_frame_bytes)},
|
|
{"region_bytes", static_cast<double>(region_bytes)},
|
|
{"region_savings_percent", region_savings_percent},
|
|
{"diff_time_ms", static_cast<double>(diff_time_ms)},
|
|
{"render_update_reason", render_update_reason},
|
|
{"full_frame_sent", update_kind == "full"},
|
|
{"region_frame_sent", update_kind == "region"},
|
|
};
|
|
if (!fallback_to_full_frame_reason.empty()) {
|
|
header["fallback_to_full_frame_reason"] = fallback_to_full_frame_reason;
|
|
}
|
|
if (!input_correlation_id.empty()) {
|
|
header["input_correlation_id"] = input_correlation_id;
|
|
}
|
|
if (!worker_frame_captured_at.empty()) {
|
|
header["worker_frame_captured_at"] = worker_frame_captured_at;
|
|
}
|
|
|
|
const std::string header_json = common::SerializeJson(header);
|
|
if (header_json.size() > std::numeric_limits<std::uint32_t>::max() ||
|
|
frame_bytes.size() > std::numeric_limits<std::uint32_t>::max()) {
|
|
return std::nullopt;
|
|
}
|
|
|
|
std::string message;
|
|
message.reserve(16 + header_json.size() + frame_bytes.size());
|
|
message.append("RAP2", 4);
|
|
AppendUint16Le(message, 1);
|
|
AppendUint16Le(message, 0);
|
|
AppendUint32Le(message, static_cast<std::uint32_t>(header_json.size()));
|
|
AppendUint32Le(message, static_cast<std::uint32_t>(frame_bytes.size()));
|
|
message.append(header_json);
|
|
message.append(frame_bytes);
|
|
return BinaryRenderFrameMessage{
|
|
std::move(message),
|
|
message_type,
|
|
applied_color_mode,
|
|
update_kind,
|
|
grayscale,
|
|
dirty_region_applied,
|
|
raw_frame_bytes_before,
|
|
frame_bytes.size(),
|
|
frame_bytes.size(),
|
|
full_frame_bytes,
|
|
region_bytes,
|
|
region_savings_percent,
|
|
conversion_time_ms,
|
|
diff_time_ms,
|
|
render_update_reason,
|
|
fallback_to_full_frame_reason,
|
|
};
|
|
}
|
|
|
|
common::JsonObject WorkerEventToEnvelope(const runtime::WorkerEvent& event) {
|
|
common::JsonObject payload = event.payload;
|
|
if (!event.reason.empty()) {
|
|
payload["reason"] = event.reason;
|
|
}
|
|
|
|
const std::string envelope_type = WorkerEventTypeToEnvelopeType(event);
|
|
if (envelope_type == "session.taken_over") {
|
|
payload["state"] = "taken_over";
|
|
} else if (envelope_type == "session.state") {
|
|
std::string state = "active";
|
|
if (event.type == "session_failed") {
|
|
state = "failed";
|
|
} else if (event.type == "session_terminated") {
|
|
state = "terminated";
|
|
} else if (auto value = common::GetString(payload, "state"); value.has_value() && !value->empty()) {
|
|
state = *value;
|
|
}
|
|
payload["state"] = state;
|
|
payload["render"] = RenderPayloadFromWorkerEvent(event);
|
|
} else if (envelope_type == "session.frame") {
|
|
payload["render"] = RenderPayloadFromWorkerEvent(event);
|
|
if (!event.raw_frame_bytes.empty() && common::GetString(payload, "frame_data").value_or("").empty()) {
|
|
payload["frame_data"] = Base64Encode(event.raw_frame_bytes.data(), event.raw_frame_bytes.size());
|
|
}
|
|
} else if (envelope_type == "file_upload.progress") {
|
|
const double file_size = common::GetNumber(payload, "file_size").value_or(0);
|
|
payload["received"] = file_size;
|
|
payload["total"] = file_size;
|
|
payload["status"] = "completed";
|
|
} else if (envelope_type.rfind("file_download.", 0) == 0) {
|
|
payload["direction"] = "server_to_client";
|
|
}
|
|
|
|
return common::JsonObject{
|
|
{"type", envelope_type},
|
|
{"session_id", event.session_id},
|
|
{"payload", payload},
|
|
};
|
|
}
|
|
|
|
bool IsFrameEvent(const runtime::WorkerEvent& event) {
|
|
return event.type == "session_frame";
|
|
}
|
|
|
|
bool IsDirectAttachBaselineFrameEvent(const runtime::WorkerEvent& event) {
|
|
return IsFrameEvent(event) &&
|
|
common::GetBool(event.payload, "direct_attach_baseline").value_or(false);
|
|
}
|
|
|
|
bool IsRegionFrameEvent(const runtime::WorkerEvent& event) {
|
|
return IsFrameEvent(event) &&
|
|
common::GetString(event.payload, "frame_update_kind").value_or("full") == "region";
|
|
}
|
|
|
|
bool IsFullFrameEvent(const runtime::WorkerEvent& event) {
|
|
return IsFrameEvent(event) &&
|
|
common::GetString(event.payload, "frame_update_kind").value_or("full") == "full";
|
|
}
|
|
|
|
bool IsCursorEvent(const runtime::WorkerEvent& event) {
|
|
return event.type == "session_cursor_updated";
|
|
}
|
|
|
|
std::chrono::milliseconds BinaryRenderFrameInterval(const runtime::WorkerEvent& event) {
|
|
if (common::GetBool(event.payload, "interactive_frame").value_or(false)) {
|
|
return kMinBinaryRegionRenderFrameInterval;
|
|
}
|
|
const auto update_kind = common::GetString(event.payload, "frame_update_kind").value_or("full");
|
|
if (update_kind == "region") {
|
|
return kMinBinaryRegionRenderFrameInterval;
|
|
}
|
|
return kMinBinaryRenderFrameInterval;
|
|
}
|
|
|
|
bool IsMouseMoveEnvelope(const common::JsonObject& envelope) {
|
|
if (common::GetString(envelope, "type").value_or("") != "input") {
|
|
return false;
|
|
}
|
|
const auto* payload = common::GetObject(envelope, "payload");
|
|
return payload != nullptr &&
|
|
common::GetString(*payload, "kind").value_or("") == "mouse" &&
|
|
common::GetString(*payload, "action").value_or("") == "move";
|
|
}
|
|
|
|
bool ChannelAllowed(const DataPlaneTokenClaims& claims, const std::string& envelope_type) {
|
|
std::string channel = envelope_type;
|
|
if (envelope_type == "file_upload") {
|
|
channel = "file_upload";
|
|
} else if (envelope_type == "file_download") {
|
|
channel = "file_download";
|
|
} else if (envelope_type == "heartbeat") {
|
|
channel = "control";
|
|
}
|
|
return std::find(claims.allowed_channels.begin(), claims.allowed_channels.end(), channel) != claims.allowed_channels.end();
|
|
}
|
|
|
|
std::optional<std::string> BuildBlockedEnvelopeForDisallowedChannel(const common::JsonObject& envelope,
|
|
const DataPlaneTokenClaims& claims) {
|
|
const auto type = common::GetString(envelope, "type").value_or("");
|
|
if (type != "file_download") {
|
|
return std::nullopt;
|
|
}
|
|
|
|
const auto* payload = common::GetObject(envelope, "payload");
|
|
const std::string transfer_id = payload != nullptr ? common::GetString(*payload, "transfer_id").value_or("") : "";
|
|
const std::string file_id = payload != nullptr ? common::GetString(*payload, "file_id").value_or("") : "";
|
|
return common::SerializeJson(common::JsonObject{
|
|
{"type", "file_download.blocked"},
|
|
{"session_id", claims.session_id},
|
|
{"payload", common::JsonObject{
|
|
{"direction", "server_to_client"},
|
|
{"transfer_id", transfer_id},
|
|
{"file_id", file_id},
|
|
{"reason", "access denied"},
|
|
}},
|
|
});
|
|
}
|
|
|
|
std::string BuildTakenOverEnvelopeForStaleAttachment(const DataPlaneTokenClaims& claims,
|
|
const std::string& current_attachment_id) {
|
|
return common::SerializeJson(common::JsonObject{
|
|
{"type", "session.taken_over"},
|
|
{"session_id", claims.session_id},
|
|
{"payload", common::JsonObject{
|
|
{"message", "controller binding changed"},
|
|
{"reason", "attachment_mismatch"},
|
|
{"attachment_id", claims.attachment_id},
|
|
{"current_attachment_id", current_attachment_id},
|
|
}},
|
|
{"event", common::JsonObject{
|
|
{"code", "session.taken_over"},
|
|
{"message_key", "events.session.taken_over"},
|
|
{"fallback_message", "This session was taken over from another device."},
|
|
{"details", common::JsonObject{
|
|
{"reason", "attachment_mismatch"},
|
|
}},
|
|
}},
|
|
});
|
|
}
|
|
|
|
template <typename WebSocket>
|
|
bool WriteTextEnvelope(WebSocket& ws,
|
|
std::mutex& websocket_io_mutex,
|
|
const std::string& envelope,
|
|
std::shared_ptr<common::Logger> logger,
|
|
const std::string& session_id,
|
|
const std::string& reason) {
|
|
beast::error_code write_ec;
|
|
std::lock_guard<std::mutex> write_lock(websocket_io_mutex);
|
|
ws.text(true);
|
|
ws.write(asio::buffer(envelope), write_ec);
|
|
if (write_ec) {
|
|
logger->Warn("direct data-plane envelope write failed session=" + session_id +
|
|
" reason=" + reason +
|
|
" error=" + write_ec.message());
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
std::optional<common::JsonObject> NormalizeClientEnvelope(const std::string& message) {
|
|
auto parsed = common::ParseJson(message);
|
|
if (!parsed.IsObject()) {
|
|
return std::nullopt;
|
|
}
|
|
common::JsonObject envelope = parsed.AsObject();
|
|
const auto type = common::GetString(envelope, "type").value_or("");
|
|
const auto* payload = common::GetObject(envelope, "payload");
|
|
if (type.empty()) {
|
|
return std::nullopt;
|
|
}
|
|
if (type == "heartbeat" || type == "control.ping") {
|
|
return common::JsonObject{{"type", "heartbeat"}};
|
|
}
|
|
if (payload == nullptr) {
|
|
return std::nullopt;
|
|
}
|
|
if (type == "input" || type == "clipboard" || type == "control") {
|
|
return envelope;
|
|
}
|
|
if (type == "file_upload.start" || type == "file_upload.chunk" || type == "file_upload.cancel") {
|
|
common::JsonObject normalized_payload = *payload;
|
|
if (type == "file_upload.start") {
|
|
normalized_payload["action"] = "start";
|
|
} else if (type == "file_upload.chunk") {
|
|
normalized_payload["action"] = "chunk";
|
|
} else {
|
|
normalized_payload["action"] = "cancel";
|
|
}
|
|
return common::JsonObject{{"type", "file_upload"}, {"payload", normalized_payload}};
|
|
}
|
|
if (type == "file_download.start" || type == "file_download.ack" || type == "file_download.cancel") {
|
|
common::JsonObject normalized_payload = *payload;
|
|
if (type == "file_download.start") {
|
|
normalized_payload["action"] = "start";
|
|
} else if (type == "file_download.ack") {
|
|
normalized_payload["action"] = "ack";
|
|
} else {
|
|
normalized_payload["action"] = "cancel";
|
|
}
|
|
return common::JsonObject{{"type", "file_download"}, {"payload", normalized_payload}};
|
|
}
|
|
return envelope;
|
|
}
|
|
|
|
void AddBindingClaimsToEnvelope(common::JsonObject& envelope, const DataPlaneTokenClaims& claims) {
|
|
auto* payload = common::GetObject(envelope, "payload");
|
|
if (payload == nullptr) {
|
|
return;
|
|
}
|
|
common::JsonObject enriched = *payload;
|
|
enriched["session_id"] = claims.session_id;
|
|
enriched["attachment_id"] = claims.attachment_id;
|
|
enriched["user_id"] = claims.user_id;
|
|
enriched["organization_id"] = claims.organization_id;
|
|
enriched["worker_id"] = claims.worker_id;
|
|
enriched["resource_id"] = claims.resource_id;
|
|
envelope["payload"] = std::move(enriched);
|
|
}
|
|
|
|
template <typename WebSocket>
|
|
class DirectWssEventSink final : public runtime::DirectEventSink {
|
|
public:
|
|
DirectWssEventSink(WebSocket& ws,
|
|
std::mutex& websocket_io_mutex,
|
|
std::shared_ptr<common::Logger> logger,
|
|
std::string session_id,
|
|
std::string attachment_id,
|
|
bool binary_render_enabled,
|
|
std::string requested_color_mode,
|
|
std::function<void(std::string)> repair_request)
|
|
: ws_(ws),
|
|
websocket_io_mutex_(websocket_io_mutex),
|
|
logger_(std::move(logger)),
|
|
session_id_(std::move(session_id)),
|
|
attachment_id_(std::move(attachment_id)),
|
|
binary_render_enabled_(binary_render_enabled),
|
|
requested_color_mode_(NormalizeColorMode(requested_color_mode)),
|
|
repair_request_(std::move(repair_request)) {}
|
|
|
|
~DirectWssEventSink() override {
|
|
Stop();
|
|
}
|
|
|
|
void Start() {
|
|
writer_thread_ = std::thread(&DirectWssEventSink::WriterLoop, this);
|
|
}
|
|
|
|
std::string AttachmentId() const override {
|
|
return attachment_id_;
|
|
}
|
|
|
|
void Stop() {
|
|
{
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
stopped_ = true;
|
|
reliable_events_.clear();
|
|
latest_cursor_.reset();
|
|
ordered_frame_events_.clear();
|
|
}
|
|
condition_.notify_all();
|
|
if (writer_thread_.joinable()) {
|
|
writer_thread_.join();
|
|
}
|
|
}
|
|
|
|
void EnqueueEvent(const runtime::WorkerEvent& event) override {
|
|
{
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
if (stopped_) {
|
|
return;
|
|
}
|
|
if (IsFrameEvent(event)) {
|
|
if (IsDirectAttachBaselineFrameEvent(event)) {
|
|
ordered_frame_events_.clear();
|
|
if (reliable_events_.size() >= kMaxReliableOutboundEvents) {
|
|
reliable_events_.pop_front();
|
|
++reliable_dropped_;
|
|
}
|
|
reliable_events_.push_back(event);
|
|
++frames_queued_;
|
|
logger_->Info("direct data-plane baseline frame queued as non-droppable session=" + session_id_);
|
|
condition_.notify_one();
|
|
return;
|
|
}
|
|
if (IsFullFrameEvent(event)) {
|
|
if (!ordered_frame_events_.empty()) {
|
|
frames_dropped_due_to_backpressure_ += ordered_frame_events_.size();
|
|
logger_->Info("render.region_queue direct writer cleared pending frames for full frame session=" +
|
|
session_id_ +
|
|
" cleared=" + std::to_string(ordered_frame_events_.size()));
|
|
}
|
|
ordered_frame_events_.clear();
|
|
ordered_frame_events_.push_back(event);
|
|
} else if (IsRegionFrameEvent(event)) {
|
|
if (ordered_frame_events_.size() >= kMaxOrderedRenderFrames) {
|
|
const auto dropped = ordered_frame_events_.size() + 1;
|
|
frames_dropped_due_to_backpressure_ += dropped;
|
|
ordered_frame_events_.clear();
|
|
logger_->Warn("render.region_queue direct writer overflow session=" + session_id_ +
|
|
" dropped=" + std::to_string(dropped) +
|
|
" repair_requested=true");
|
|
if (repair_request_) {
|
|
repair_request_("direct_writer_region_queue_overflow");
|
|
}
|
|
} else {
|
|
ordered_frame_events_.push_back(event);
|
|
}
|
|
} else {
|
|
ordered_frame_events_.clear();
|
|
ordered_frame_events_.push_back(event);
|
|
}
|
|
++frames_queued_;
|
|
} else if (IsCursorEvent(event)) {
|
|
if (latest_cursor_.has_value()) {
|
|
++cursor_updates_dropped_due_to_backpressure_;
|
|
}
|
|
latest_cursor_ = event;
|
|
++cursor_updates_queued_;
|
|
} else {
|
|
if (reliable_events_.size() >= kMaxReliableOutboundEvents) {
|
|
reliable_events_.pop_front();
|
|
++reliable_dropped_;
|
|
}
|
|
reliable_events_.push_back(event);
|
|
}
|
|
}
|
|
condition_.notify_one();
|
|
}
|
|
|
|
private:
|
|
void WriterLoop() {
|
|
while (true) {
|
|
std::optional<runtime::WorkerEvent> event;
|
|
{
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
condition_.wait(lock, [&]() {
|
|
return stopped_ || !reliable_events_.empty() || latest_cursor_.has_value() || !ordered_frame_events_.empty();
|
|
});
|
|
if (stopped_) {
|
|
return;
|
|
}
|
|
if (!reliable_events_.empty()) {
|
|
event = std::move(reliable_events_.front());
|
|
reliable_events_.pop_front();
|
|
} else if (latest_cursor_.has_value()) {
|
|
event = std::move(*latest_cursor_);
|
|
latest_cursor_.reset();
|
|
++cursor_updates_sent_;
|
|
} else if (!ordered_frame_events_.empty()) {
|
|
if (binary_render_enabled_ && last_binary_frame_sent_at_.time_since_epoch().count() != 0) {
|
|
const auto next_allowed_frame_at = last_binary_frame_sent_at_ + BinaryRenderFrameInterval(ordered_frame_events_.front());
|
|
const auto now = std::chrono::steady_clock::now();
|
|
if (now < next_allowed_frame_at) {
|
|
condition_.wait_until(lock, next_allowed_frame_at, [&]() {
|
|
return stopped_ || !reliable_events_.empty() || latest_cursor_.has_value();
|
|
});
|
|
continue;
|
|
}
|
|
}
|
|
event = std::move(ordered_frame_events_.front());
|
|
ordered_frame_events_.pop_front();
|
|
++frames_sent_;
|
|
}
|
|
}
|
|
if (!event.has_value()) {
|
|
continue;
|
|
}
|
|
const bool binary_frame = binary_render_enabled_ && IsFrameEvent(*event);
|
|
const auto binary_payload = binary_frame
|
|
? WorkerEventToBinaryRenderFrame(*event, requested_color_mode_, !baseline_binary_frame_sent_)
|
|
: std::nullopt;
|
|
const std::string encoded = binary_payload.has_value()
|
|
? binary_payload->message
|
|
: common::SerializeJson(WorkerEventToEnvelope(*event));
|
|
beast::error_code ec;
|
|
{
|
|
std::lock_guard<std::mutex> write_lock(websocket_io_mutex_);
|
|
if (binary_payload.has_value()) {
|
|
ws_.binary(true);
|
|
} else {
|
|
ws_.text(true);
|
|
}
|
|
ws_.write(asio::buffer(encoded), ec);
|
|
}
|
|
if (ec) {
|
|
logger_->Warn("direct data-plane write failed session=" + session_id_ +
|
|
" frame=" + (binary_payload.has_value() ? std::string("true") : std::string("false")) +
|
|
" reason=" + ec.message());
|
|
{
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
stopped_ = true;
|
|
reliable_events_.clear();
|
|
latest_cursor_.reset();
|
|
ordered_frame_events_.clear();
|
|
}
|
|
condition_.notify_all();
|
|
return;
|
|
}
|
|
if (binary_payload.has_value()) {
|
|
last_binary_frame_sent_at_ = std::chrono::steady_clock::now();
|
|
if (binary_payload->update_kind == "full") {
|
|
baseline_binary_frame_sent_ = true;
|
|
}
|
|
binary_render_bytes_sent_ += encoded.size();
|
|
logger_->Info("direct data-plane render frame prepared session=" + session_id_ +
|
|
" requested_color_mode=" + requested_color_mode_ +
|
|
" applied_color_mode=" + binary_payload->color_mode +
|
|
" message_type=" + binary_payload->message_type +
|
|
" frame_update_kind=" + binary_payload->update_kind +
|
|
" full_frame_sent=" + (binary_payload->update_kind == "full" ? std::string("true") : std::string("false")) +
|
|
" region_frame_sent=" + (binary_payload->update_kind == "region" ? std::string("true") : std::string("false")) +
|
|
" dirty_region_applied=" + (binary_payload->dirty_region_applied ? std::string("true") : std::string("false")) +
|
|
" grayscale_conversion_applied=" + (binary_payload->grayscale_conversion_applied ? std::string("true") : std::string("false")) +
|
|
" raw_frame_bytes_before=" + std::to_string(binary_payload->raw_frame_bytes_before) +
|
|
" raw_frame_bytes_after=" + std::to_string(binary_payload->raw_frame_bytes_after) +
|
|
" full_frame_bytes=" + std::to_string(binary_payload->full_frame_bytes) +
|
|
" region_bytes=" + std::to_string(binary_payload->region_bytes) +
|
|
" region_savings_percent=" + std::to_string(binary_payload->region_savings_percent) +
|
|
" binary_direct_bytes=" + std::to_string(binary_payload->binary_direct_bytes) +
|
|
" diff_time_ms=" + std::to_string(binary_payload->diff_time_ms) +
|
|
" render_update_reason=" + binary_payload->render_update_reason +
|
|
" fallback_to_full_frame_reason=" + binary_payload->fallback_to_full_frame_reason +
|
|
" conversion_time_ms=" + std::to_string(binary_payload->conversion_time_ms));
|
|
} else if (IsFrameEvent(*event)) {
|
|
json_render_bytes_sent_ += encoded.size();
|
|
}
|
|
LogRateIfNeeded();
|
|
}
|
|
}
|
|
|
|
void LogRateIfNeeded() {
|
|
const auto now = std::chrono::steady_clock::now();
|
|
if (last_rate_log_at_.time_since_epoch().count() == 0) {
|
|
last_rate_log_at_ = now;
|
|
return;
|
|
}
|
|
if (now - last_rate_log_at_ < std::chrono::seconds(2)) {
|
|
return;
|
|
}
|
|
const double seconds = std::max(0.001, std::chrono::duration<double>(now - last_rate_log_at_).count());
|
|
logger_->Info("direct data-plane outbound rate session=" + session_id_ +
|
|
" frames_queued_per_second=" + std::to_string(static_cast<double>(frames_queued_) / seconds) +
|
|
" frames_sent_per_second=" + std::to_string(static_cast<double>(frames_sent_) / seconds) +
|
|
" cursor_queued_per_second=" + std::to_string(static_cast<double>(cursor_updates_queued_) / seconds) +
|
|
" cursor_sent_per_second=" + std::to_string(static_cast<double>(cursor_updates_sent_) / seconds) +
|
|
" binary_render_bytes_per_second=" + std::to_string(static_cast<double>(binary_render_bytes_sent_) / seconds) +
|
|
" json_render_bytes_per_second=" + std::to_string(static_cast<double>(json_render_bytes_sent_) / seconds) +
|
|
" frames_dropped_due_to_backpressure=" + std::to_string(frames_dropped_due_to_backpressure_) +
|
|
" cursor_dropped_due_to_backpressure=" + std::to_string(cursor_updates_dropped_due_to_backpressure_) +
|
|
" reliable_dropped=" + std::to_string(reliable_dropped_));
|
|
frames_queued_ = 0;
|
|
frames_sent_ = 0;
|
|
cursor_updates_queued_ = 0;
|
|
cursor_updates_sent_ = 0;
|
|
binary_render_bytes_sent_ = 0;
|
|
json_render_bytes_sent_ = 0;
|
|
frames_dropped_due_to_backpressure_ = 0;
|
|
cursor_updates_dropped_due_to_backpressure_ = 0;
|
|
reliable_dropped_ = 0;
|
|
last_rate_log_at_ = now;
|
|
}
|
|
|
|
WebSocket& ws_;
|
|
std::mutex& websocket_io_mutex_;
|
|
std::shared_ptr<common::Logger> logger_;
|
|
std::string session_id_;
|
|
std::string attachment_id_;
|
|
std::string requested_color_mode_;
|
|
std::function<void(std::string)> repair_request_;
|
|
std::mutex mutex_;
|
|
std::condition_variable condition_;
|
|
std::deque<runtime::WorkerEvent> reliable_events_;
|
|
std::optional<runtime::WorkerEvent> latest_cursor_;
|
|
std::deque<runtime::WorkerEvent> ordered_frame_events_;
|
|
std::thread writer_thread_;
|
|
bool stopped_{false};
|
|
bool binary_render_enabled_{false};
|
|
bool baseline_binary_frame_sent_{false};
|
|
std::chrono::steady_clock::time_point last_rate_log_at_{};
|
|
std::chrono::steady_clock::time_point last_binary_frame_sent_at_{};
|
|
std::size_t frames_queued_{0};
|
|
std::size_t frames_sent_{0};
|
|
std::size_t cursor_updates_queued_{0};
|
|
std::size_t cursor_updates_sent_{0};
|
|
std::size_t binary_render_bytes_sent_{0};
|
|
std::size_t json_render_bytes_sent_{0};
|
|
std::size_t frames_dropped_due_to_backpressure_{0};
|
|
std::size_t cursor_updates_dropped_due_to_backpressure_{0};
|
|
std::size_t reliable_dropped_{0};
|
|
};
|
|
|
|
} // namespace
|
|
|
|
DirectWssServer::DirectWssServer(config::Config config,
|
|
std::shared_ptr<runtime::SessionManager> session_manager,
|
|
std::shared_ptr<common::Logger> logger)
|
|
: config_(std::move(config)),
|
|
session_manager_(std::move(session_manager)),
|
|
logger_(std::move(logger)),
|
|
token_validator_(config_.data_plane_public_key_pem, config_.worker_id),
|
|
ssl_context_(asio::ssl::context::tlsv12_server) {}
|
|
|
|
DirectWssServer::~DirectWssServer() {
|
|
Stop();
|
|
}
|
|
|
|
void DirectWssServer::Start() {
|
|
if (!config_.data_plane_enabled) {
|
|
logger_->Info("direct data-plane WSS endpoint disabled");
|
|
return;
|
|
}
|
|
if (config_.data_plane_public_key_pem.empty() ||
|
|
config_.data_plane_tls_cert_file.empty() ||
|
|
config_.data_plane_tls_key_file.empty()) {
|
|
logger_->Warn("direct data-plane WSS endpoint not started because token public key or TLS certificate/key is missing");
|
|
return;
|
|
}
|
|
|
|
ssl_context_.set_options(asio::ssl::context::default_workarounds |
|
|
asio::ssl::context::no_sslv2 |
|
|
asio::ssl::context::no_sslv3 |
|
|
asio::ssl::context::single_dh_use);
|
|
ssl_context_.use_certificate_chain_file(config_.data_plane_tls_cert_file);
|
|
ssl_context_.use_private_key_file(config_.data_plane_tls_key_file, asio::ssl::context::pem);
|
|
|
|
thread_ = std::thread(&DirectWssServer::Run, this);
|
|
}
|
|
|
|
void DirectWssServer::Stop() {
|
|
stop_requested_.store(true);
|
|
if (acceptor_.has_value()) {
|
|
beast::error_code ignored;
|
|
acceptor_->close(ignored);
|
|
}
|
|
io_context_.stop();
|
|
if (thread_.joinable()) {
|
|
thread_.join();
|
|
}
|
|
}
|
|
|
|
void DirectWssServer::Run() {
|
|
try {
|
|
const auto address = asio::ip::make_address(config_.data_plane_listen_host);
|
|
const tcp::endpoint endpoint{address, static_cast<unsigned short>(config_.data_plane_listen_port)};
|
|
acceptor_.emplace(io_context_);
|
|
acceptor_->open(endpoint.protocol());
|
|
acceptor_->set_option(asio::socket_base::reuse_address(true));
|
|
acceptor_->bind(endpoint);
|
|
acceptor_->listen(asio::socket_base::max_listen_connections);
|
|
logger_->Info("direct data-plane WSS endpoint listening host=" + config_.data_plane_listen_host +
|
|
" port=" + std::to_string(config_.data_plane_listen_port) +
|
|
" path=/rap/v1/data-plane");
|
|
|
|
while (!stop_requested_.load()) {
|
|
beast::error_code ec;
|
|
tcp::socket socket{io_context_};
|
|
acceptor_->accept(socket, ec);
|
|
if (ec) {
|
|
if (!stop_requested_.load()) {
|
|
logger_->Warn("direct data-plane accept failed reason=" + ec.message());
|
|
}
|
|
continue;
|
|
}
|
|
std::thread(&DirectWssServer::HandleConnection, this, std::move(socket)).detach();
|
|
}
|
|
} catch (const std::exception& error) {
|
|
logger_->Error(std::string("direct data-plane WSS endpoint stopped reason=") + error.what());
|
|
}
|
|
}
|
|
|
|
void DirectWssServer::HandleConnection(tcp::socket socket) {
|
|
beast::error_code ec;
|
|
asio::ssl::stream<tcp::socket> tls_stream{std::move(socket), ssl_context_};
|
|
tls_stream.handshake(asio::ssl::stream_base::server, ec);
|
|
if (ec) {
|
|
logger_->Warn("direct data-plane TLS handshake failed reason=" + ec.message());
|
|
return;
|
|
}
|
|
|
|
beast::flat_buffer buffer;
|
|
http::request<http::string_body> request;
|
|
http::read(tls_stream, buffer, request, ec);
|
|
if (ec) {
|
|
logger_->Warn("direct data-plane HTTP upgrade read failed reason=" + ec.message());
|
|
return;
|
|
}
|
|
|
|
const std::string target = std::string(request.target());
|
|
if (PathOnly(target) != "/rap/v1/data-plane") {
|
|
WriteReject(tls_stream, http::status::not_found, "unknown_path");
|
|
return;
|
|
}
|
|
if (!websocket::is_upgrade(request)) {
|
|
WriteReject(tls_stream, http::status::upgrade_required, "websocket_required");
|
|
return;
|
|
}
|
|
|
|
const auto token = QueryParam(target, "data_plane_token");
|
|
const bool binary_render_requested = QueryParam(target, "render_transport") == kBinaryRenderTransportV1;
|
|
const std::string requested_color_mode = binary_render_requested
|
|
? NormalizeColorMode(QueryParam(target, "color_mode"))
|
|
: std::string(kColorModeFullColor);
|
|
const auto validation = token_validator_.Validate(token);
|
|
if (!validation.ok) {
|
|
logger_->Warn("event=token_validation_failed reason=" + validation.reason);
|
|
WriteReject(tls_stream, http::status::unauthorized, validation.reason);
|
|
return;
|
|
}
|
|
logger_->Info("event=token_validation_success session=" + validation.claims.session_id +
|
|
" attachment=" + validation.claims.attachment_id +
|
|
" worker=" + validation.claims.worker_id +
|
|
" jti=" + validation.claims.jti);
|
|
|
|
if (!ConsumeJti(validation.claims)) {
|
|
logger_->Warn("event=jti_replay_rejected session=" + validation.claims.session_id +
|
|
" attachment=" + validation.claims.attachment_id +
|
|
" jti=" + validation.claims.jti);
|
|
WriteReject(tls_stream, http::status::unauthorized, "jti_replay_rejected");
|
|
return;
|
|
}
|
|
|
|
std::string bind_reason;
|
|
auto runtime = session_manager_->BindDirectDataPlaneRuntime(validation.claims, bind_reason);
|
|
if (!runtime) {
|
|
logger_->Warn("event=data_plane_bind_failed session=" + validation.claims.session_id +
|
|
" attachment=" + validation.claims.attachment_id +
|
|
" reason=" + bind_reason);
|
|
WriteReject(tls_stream, http::status::not_found, bind_reason);
|
|
return;
|
|
}
|
|
|
|
websocket::stream<asio::ssl::stream<tcp::socket>> ws{std::move(tls_stream)};
|
|
ws.accept(request, ec);
|
|
if (ec) {
|
|
logger_->Warn("direct data-plane WebSocket accept failed session=" + validation.claims.session_id +
|
|
" reason=" + ec.message());
|
|
return;
|
|
}
|
|
|
|
logger_->Info("event=data_plane_bind_success session=" + validation.claims.session_id +
|
|
" attachment=" + validation.claims.attachment_id +
|
|
" channels=" + ChannelsToJsonArray(validation.claims.allowed_channels) +
|
|
" render_transport=" + (binary_render_requested ? std::string("binary_v1") : std::string("json_base64")) +
|
|
" requested_color_mode=" + requested_color_mode +
|
|
" applied_color_mode=" + requested_color_mode);
|
|
ws.text(true);
|
|
ws.write(asio::buffer(rdp_worker::common::SerializeJson(rdp_worker::common::JsonObject{
|
|
{"type", "data_plane.attached"},
|
|
{"session_id", validation.claims.session_id},
|
|
{"attachment_id", validation.claims.attachment_id},
|
|
{"worker_id", validation.claims.worker_id},
|
|
{"render_transport", binary_render_requested ? "binary_v1" : "json_base64"},
|
|
{"color_mode", requested_color_mode},
|
|
})),
|
|
ec);
|
|
if (ec) {
|
|
return;
|
|
}
|
|
|
|
std::mutex websocket_io_mutex;
|
|
auto sink = std::make_shared<DirectWssEventSink<websocket::stream<asio::ssl::stream<tcp::socket>>>>(
|
|
ws,
|
|
websocket_io_mutex,
|
|
logger_,
|
|
validation.claims.session_id,
|
|
validation.claims.attachment_id,
|
|
binary_render_requested,
|
|
requested_color_mode,
|
|
[runtime](std::string reason) {
|
|
runtime->RequestDirectFullFrameRepair(std::move(reason));
|
|
});
|
|
sink->Start();
|
|
runtime->AddDirectEventSink(sink);
|
|
|
|
while (!stop_requested_.load()) {
|
|
beast::error_code available_ec;
|
|
const auto available = beast::get_lowest_layer(ws).available(available_ec);
|
|
if (available_ec) {
|
|
logger_->Warn("direct data-plane socket availability check failed session=" + validation.claims.session_id +
|
|
" reason=" + available_ec.message());
|
|
break;
|
|
}
|
|
if (available == 0) {
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
|
continue;
|
|
}
|
|
|
|
beast::flat_buffer message;
|
|
{
|
|
std::lock_guard<std::mutex> read_lock(websocket_io_mutex);
|
|
ws.read(message, ec);
|
|
}
|
|
if (ec == websocket::error::closed) {
|
|
break;
|
|
}
|
|
if (ec) {
|
|
logger_->Warn("direct data-plane WebSocket read failed session=" + validation.claims.session_id +
|
|
" reason=" + ec.message());
|
|
break;
|
|
}
|
|
const auto payload = beast::buffers_to_string(message.data());
|
|
auto envelope = NormalizeClientEnvelope(payload);
|
|
if (!envelope.has_value()) {
|
|
logger_->Warn("direct data-plane ignored malformed client envelope session=" + validation.claims.session_id);
|
|
continue;
|
|
}
|
|
const auto type = common::GetString(*envelope, "type").value_or("");
|
|
if (type == "heartbeat") {
|
|
continue;
|
|
}
|
|
if (!ChannelAllowed(validation.claims, type)) {
|
|
logger_->Warn("direct data-plane envelope rejected by token channel scope session=" + validation.claims.session_id +
|
|
" type=" + type);
|
|
if (auto blocked = BuildBlockedEnvelopeForDisallowedChannel(*envelope, validation.claims); blocked.has_value()) {
|
|
if (!WriteTextEnvelope(ws,
|
|
websocket_io_mutex,
|
|
*blocked,
|
|
logger_,
|
|
validation.claims.session_id,
|
|
"channel_scope_blocked")) {
|
|
break;
|
|
}
|
|
}
|
|
continue;
|
|
}
|
|
const auto snapshot = runtime->Snapshot();
|
|
if (snapshot.attachment_id != validation.claims.attachment_id) {
|
|
logger_->Warn("direct data-plane envelope rejected because attachment is no longer current session=" +
|
|
validation.claims.session_id +
|
|
" envelope_attachment=" + validation.claims.attachment_id +
|
|
" current_attachment=" + snapshot.attachment_id +
|
|
" type=" + type);
|
|
const auto taken_over = BuildTakenOverEnvelopeForStaleAttachment(validation.claims, snapshot.attachment_id);
|
|
if (!WriteTextEnvelope(ws,
|
|
websocket_io_mutex,
|
|
taken_over,
|
|
logger_,
|
|
validation.claims.session_id,
|
|
"attachment_mismatch")) {
|
|
break;
|
|
}
|
|
continue;
|
|
}
|
|
AddBindingClaimsToEnvelope(*envelope, validation.claims);
|
|
const bool mouse_move = IsMouseMoveEnvelope(*envelope);
|
|
if (!runtime->EnqueueDirectEnvelope(std::move(*envelope))) {
|
|
logger_->Warn("direct data-plane inbound queue full session=" + validation.claims.session_id +
|
|
" type=" + type +
|
|
" mouse_move=" + (mouse_move ? std::string("true") : std::string("false")));
|
|
}
|
|
}
|
|
sink->Stop();
|
|
logger_->Info("direct data-plane WebSocket detached session=" + validation.claims.session_id +
|
|
" attachment=" + validation.claims.attachment_id);
|
|
}
|
|
|
|
bool DirectWssServer::ConsumeJti(const DataPlaneTokenClaims& claims) {
|
|
std::lock_guard<std::mutex> lock(jti_mutex_);
|
|
const auto now = std::chrono::system_clock::now();
|
|
for (auto iterator = used_jti_.begin(); iterator != used_jti_.end();) {
|
|
if (iterator->second <= now) {
|
|
iterator = used_jti_.erase(iterator);
|
|
} else {
|
|
++iterator;
|
|
}
|
|
}
|
|
if (used_jti_.find(claims.jti) != used_jti_.end()) {
|
|
return false;
|
|
}
|
|
if (used_jti_.size() >= 4096) {
|
|
used_jti_.erase(used_jti_.begin());
|
|
}
|
|
const auto expires_at = std::chrono::system_clock::time_point(std::chrono::seconds(claims.expires_at_unix));
|
|
used_jti_.emplace(claims.jti, expires_at + std::chrono::seconds(5));
|
|
return true;
|
|
}
|
|
|
|
} // namespace rdp_worker::dataplane
|