Initial project snapshot

This commit is contained in:
2026-04-28 22:29:50 +03:00
commit 8ba0561f4f
365 changed files with 91832 additions and 0 deletions
@@ -0,0 +1,104 @@
#include "rdp_worker/adapter/adapter_event_router.hpp"
#include "rdp_worker/common/json.hpp"
namespace rdp_worker::adapter {
namespace {
AdapterEventDescriptor MakeDescriptor(AdapterChannel channel,
std::string_view type,
bool adapter_origin) {
return AdapterEventDescriptor{
channel,
type,
adapter_origin,
IsReliable(channel),
IsDroppable(channel),
};
}
} // namespace
AdapterEventDescriptor AdapterEventRouter::DescribeRenderNotification(
const runtime::RenderNotification& notification) const {
if (notification.type == "session_frame") {
const auto update_kind = common::GetString(notification.payload, "frame_update_kind").value_or("full");
return MakeDescriptor(
AdapterChannel::kDisplay,
update_kind == "region" ? "display.region_bgra" : "display.baseline_full_bgra",
true);
}
if (notification.type == "session_cursor_updated") {
return MakeDescriptor(AdapterChannel::kCursor, "cursor.update", true);
}
if (notification.type == "session_render_resized") {
return MakeDescriptor(AdapterChannel::kDisplay, "display.resize", true);
}
if (notification.type == "session_render_dirty") {
return MakeDescriptor(AdapterChannel::kDisplay, "display.dirty", true);
}
return MakeDescriptor(AdapterChannel::kTelemetry, notification.type, true);
}
AdapterEventDescriptor AdapterEventRouter::DescribeClipboardNotification(
const runtime::ClipboardNotification&) const {
return MakeDescriptor(AdapterChannel::kClipboard, "clipboard.server_text", true);
}
AdapterEventDescriptor AdapterEventRouter::DescribeClientEnvelope(std::string_view envelope_type,
std::string_view payload_kind,
std::string_view payload_action) const {
if (envelope_type == "input") {
if (payload_kind == "mouse" && payload_action == "move") {
return MakeDescriptor(AdapterChannel::kInput, "input.pointer_move", false);
}
if (payload_kind == "mouse") {
return MakeDescriptor(AdapterChannel::kInput, "input.pointer", false);
}
if (payload_kind == "keyboard") {
return MakeDescriptor(AdapterChannel::kInput, "input.keyboard", false);
}
if (payload_kind == "focus") {
return MakeDescriptor(AdapterChannel::kInput, "input.focus", false);
}
return MakeDescriptor(AdapterChannel::kInput, "input.unknown", false);
}
if (envelope_type == "clipboard") {
return MakeDescriptor(AdapterChannel::kClipboard, "clipboard.client_text", false);
}
if (envelope_type == "file_upload") {
return MakeDescriptor(AdapterChannel::kFileTransfer, "file_upload.client_to_server", false);
}
if (envelope_type == "control") {
return MakeDescriptor(AdapterChannel::kControl, "control.command", false);
}
return MakeDescriptor(AdapterChannel::kTelemetry, envelope_type, false);
}
std::string AdapterEventDescriptorLogLine(const AdapterEventDescriptor& descriptor) {
std::string line;
line.reserve(160);
line += "adapter_event channel=";
line += ChannelName(descriptor.channel);
line += " type=";
line += descriptor.normalized_type;
line += " origin=";
line += descriptor.adapter_origin ? "adapter" : "client";
line += " reliable=";
line += descriptor.reliable ? "true" : "false";
line += " droppable=";
line += descriptor.droppable ? "true" : "false";
return line;
}
} // namespace rdp_worker::adapter
@@ -0,0 +1,134 @@
#include "rdp_worker/adapter/rdp_adapter_runtime.hpp"
#include <utility>
namespace rdp_worker::adapter {
RdpAdapterRuntime::RdpAdapterRuntime(std::shared_ptr<common::Logger> logger)
: logger_(std::move(logger)),
freerdp_(logger_) {}
bool RdpAdapterRuntime::Start(const runtime::ConnectionSpec& spec) {
logger_->Info("rdp_adapter.runtime_start substrate=freerdp resource_id=" + spec.resource_id +
" host=" + spec.host +
" render_quality_profile=" + spec.render_quality_profile);
lifecycle_logged_ = true;
return freerdp_.Start(spec);
}
void RdpAdapterRuntime::Disconnect(bool terminate) {
logger_->Info("rdp_adapter.runtime_disconnect terminate=" + (terminate ? std::string("true") : std::string("false")));
freerdp_.Disconnect(terminate);
}
bool RdpAdapterRuntime::IsConnected() const {
return freerdp_.IsConnected();
}
bool RdpAdapterRuntime::PumpEvents(std::chrono::milliseconds timeout) {
if (!lifecycle_logged_) {
logger_->Info("rdp_adapter.event_pump_start substrate=freerdp");
lifecycle_logged_ = true;
}
return freerdp_.PumpEvents(timeout);
}
int RdpAdapterRuntime::DesktopWidth() const {
return freerdp_.DesktopWidth();
}
int RdpAdapterRuntime::DesktopHeight() const {
return freerdp_.DesktopHeight();
}
bool RdpAdapterRuntime::SendFocusEvent(bool focused) {
TraceClientEnvelope("input", "focus", focused ? "focus_in" : "focus_out");
return freerdp_.SendFocusEvent(focused);
}
bool RdpAdapterRuntime::SendKeyboardInput(uint16_t scan_code, bool key_down, bool extended) {
TraceClientEnvelope("input", "keyboard", key_down ? "key_down" : "key_up");
return freerdp_.SendKeyboardInput(scan_code, key_down, extended);
}
bool RdpAdapterRuntime::SendMouseMove(double normalized_x, double normalized_y) {
TraceClientEnvelope("input", "mouse", "move");
return freerdp_.SendMouseMove(normalized_x, normalized_y);
}
bool RdpAdapterRuntime::SendMouseButton(const std::string& button,
bool pressed,
double normalized_x,
double normalized_y) {
TraceClientEnvelope("input", "mouse", pressed ? "button_down" : "button_up");
return freerdp_.SendMouseButton(button, pressed, normalized_x, normalized_y);
}
bool RdpAdapterRuntime::SendMouseWheel(int wheel_delta, bool horizontal, double normalized_x, double normalized_y) {
TraceClientEnvelope("input", "mouse", "wheel");
return freerdp_.SendMouseWheel(wheel_delta, horizontal, normalized_x, normalized_y);
}
bool RdpAdapterRuntime::SetClipboardText(const std::string& text) {
TraceClientEnvelope("clipboard", "text", "client_to_server");
return freerdp_.SetClipboardText(text);
}
void RdpAdapterRuntime::MarkInputAppliedForGraphicsTrace(const std::string& correlation_id) {
freerdp_.MarkInputAppliedForGraphicsTrace(correlation_id);
}
std::optional<runtime::RenderNotification> RdpAdapterRuntime::CaptureFullFrameNotification(
const std::string& state,
const std::string& capture_source) {
auto notification = freerdp_.CaptureFullFrameNotification(state, capture_source);
if (notification.has_value()) {
TraceAdapterEvent(event_router_.DescribeRenderNotification(*notification));
}
return notification;
}
std::optional<runtime::RenderNotification> RdpAdapterRuntime::PopRenderNotification() {
auto notification = freerdp_.PopRenderNotification();
if (notification.has_value()) {
TraceAdapterEvent(event_router_.DescribeRenderNotification(*notification));
}
return notification;
}
std::optional<runtime::ClipboardNotification> RdpAdapterRuntime::PopClipboardNotification() {
auto notification = freerdp_.PopClipboardNotification();
if (notification.has_value()) {
TraceAdapterEvent(event_router_.DescribeClipboardNotification(*notification));
}
return notification;
}
const std::string& RdpAdapterRuntime::RenderQualityProfile() const {
return freerdp_.RenderQualityProfile();
}
const AdapterEventRouter& RdpAdapterRuntime::EventRouter() const {
return event_router_;
}
void RdpAdapterRuntime::TraceClientEnvelope(std::string_view envelope_type,
std::string_view payload_kind,
std::string_view payload_action) {
const auto descriptor = event_router_.DescribeClientEnvelope(envelope_type, payload_kind, payload_action);
if (descriptor.channel == AdapterChannel::kInput && descriptor.normalized_type == "input.pointer_move") {
return;
}
TraceAdapterEvent(descriptor);
}
void RdpAdapterRuntime::TraceAdapterEvent(const AdapterEventDescriptor& descriptor) {
if (descriptor.channel == AdapterChannel::kDisplay &&
descriptor.normalized_type != "display.resize" &&
descriptor.normalized_type != "display.baseline_full_bgra") {
return;
}
logger_->Info(AdapterEventDescriptorLogLine(descriptor));
}
} // namespace rdp_worker::adapter
@@ -0,0 +1,97 @@
#include "rdp_worker/adapter/service_adapter_protocol.hpp"
namespace rdp_worker::adapter {
std::optional<ChannelSpec> FindChannelSpec(std::string_view name) {
for (const auto& spec : AllChannelSpecs()) {
if (spec.name == name) {
return spec;
}
}
return std::nullopt;
}
std::string_view ChannelName(AdapterChannel channel) {
for (const auto& spec : AllChannelSpecs()) {
if (spec.channel == channel) {
return spec.name;
}
}
return "unknown";
}
std::string_view DirectionName(ChannelDirection direction) {
switch (direction) {
case ChannelDirection::kClientToAdapter:
return "client_to_adapter";
case ChannelDirection::kAdapterToClient:
return "adapter_to_client";
case ChannelDirection::kBidirectional:
return "bidirectional";
}
return "unknown";
}
std::string_view ReliabilityName(ChannelReliability reliability) {
switch (reliability) {
case ChannelReliability::kReliableOrdered:
return "reliable_ordered";
case ChannelReliability::kReliableChunked:
return "reliable_chunked";
case ChannelReliability::kDroppableLatest:
return "droppable_latest";
case ChannelReliability::kAdaptiveDroppable:
return "adaptive_droppable";
case ChannelReliability::kSampledDroppable:
return "sampled_droppable";
}
return "unknown";
}
int PriorityValue(ChannelPriority priority) {
return static_cast<int>(priority);
}
bool IsDroppable(AdapterChannel channel) {
for (const auto& spec : AllChannelSpecs()) {
if (spec.channel == channel) {
return spec.stale_updates_droppable;
}
}
return false;
}
bool IsReliable(AdapterChannel channel) {
for (const auto& spec : AllChannelSpecs()) {
if (spec.channel == channel) {
return spec.reliability == ChannelReliability::kReliableOrdered ||
spec.reliability == ChannelReliability::kReliableChunked;
}
}
return false;
}
bool ValidateAdapterChannelInvariants() {
const auto input = FindChannelSpec("input");
if (!input.has_value() || input->priority != ChannelPriority::kCritical || input->may_block_input) {
return false;
}
for (const auto& spec : AllChannelSpecs()) {
if (spec.channel != AdapterChannel::kInput &&
PriorityValue(spec.priority) <= PriorityValue(ChannelPriority::kCritical)) {
return false;
}
if (spec.may_block_input) {
return false;
}
if ((spec.channel == AdapterChannel::kDisplay || spec.channel == AdapterChannel::kCursor) &&
!spec.stale_updates_droppable) {
return false;
}
}
return true;
}
} // namespace rdp_worker::adapter
+401
View File
@@ -0,0 +1,401 @@
#include "rdp_worker/common/json.hpp"
#include <cctype>
#include <cstdint>
#include <sstream>
#include <stdexcept>
namespace rdp_worker::common {
JsonValue::JsonValue() : value(nullptr) {}
JsonValue::JsonValue(std::nullptr_t) : value(nullptr) {}
JsonValue::JsonValue(bool input) : value(input) {}
JsonValue::JsonValue(double input) : value(input) {}
JsonValue::JsonValue(int input) : value(static_cast<double>(input)) {}
JsonValue::JsonValue(const char* input) : value(std::string(input)) {}
JsonValue::JsonValue(std::string input) : value(std::move(input)) {}
JsonValue::JsonValue(JsonArray input) : value(std::move(input)) {}
JsonValue::JsonValue(JsonObject input) : value(std::move(input)) {}
bool JsonValue::IsObject() const { return std::holds_alternative<JsonObject>(value); }
bool JsonValue::IsArray() const { return std::holds_alternative<JsonArray>(value); }
bool JsonValue::IsString() const { return std::holds_alternative<std::string>(value); }
bool JsonValue::IsBool() const { return std::holds_alternative<bool>(value); }
bool JsonValue::IsNumber() const { return std::holds_alternative<double>(value); }
const JsonObject& JsonValue::AsObject() const { return std::get<JsonObject>(value); }
const JsonArray& JsonValue::AsArray() const { return std::get<JsonArray>(value); }
const std::string& JsonValue::AsString() const { return std::get<std::string>(value); }
bool JsonValue::AsBool() const { return std::get<bool>(value); }
double JsonValue::AsNumber() const { return std::get<double>(value); }
namespace {
void AppendUtf8(std::string& output, std::uint32_t codepoint) {
if (codepoint <= 0x7F) {
output.push_back(static_cast<char>(codepoint));
} else if (codepoint <= 0x7FF) {
output.push_back(static_cast<char>(0xC0 | (codepoint >> 6)));
output.push_back(static_cast<char>(0x80 | (codepoint & 0x3F)));
} else if (codepoint <= 0xFFFF) {
output.push_back(static_cast<char>(0xE0 | (codepoint >> 12)));
output.push_back(static_cast<char>(0x80 | ((codepoint >> 6) & 0x3F)));
output.push_back(static_cast<char>(0x80 | (codepoint & 0x3F)));
} else if (codepoint <= 0x10FFFF) {
output.push_back(static_cast<char>(0xF0 | (codepoint >> 18)));
output.push_back(static_cast<char>(0x80 | ((codepoint >> 12) & 0x3F)));
output.push_back(static_cast<char>(0x80 | ((codepoint >> 6) & 0x3F)));
output.push_back(static_cast<char>(0x80 | (codepoint & 0x3F)));
} else {
throw std::runtime_error("invalid JSON unicode escape");
}
}
class Parser {
public:
explicit Parser(const std::string& input) : input_(input), index_(0) {}
JsonValue Parse() {
SkipWhitespace();
JsonValue value = ParseValue();
SkipWhitespace();
if (index_ != input_.size()) {
throw std::runtime_error("unexpected trailing JSON data");
}
return value;
}
private:
JsonValue ParseValue() {
SkipWhitespace();
if (Match("null")) {
return JsonValue(nullptr);
}
if (Match("true")) {
return JsonValue(true);
}
if (Match("false")) {
return JsonValue(false);
}
if (Peek() == '"') {
return JsonValue(ParseString());
}
if (Peek() == '{') {
return JsonValue(ParseObject());
}
if (Peek() == '[') {
return JsonValue(ParseArray());
}
if (Peek() == '-' || std::isdigit(static_cast<unsigned char>(Peek()))) {
return JsonValue(ParseNumber());
}
throw std::runtime_error("unexpected JSON token");
}
JsonObject ParseObject() {
Expect('{');
JsonObject object;
SkipWhitespace();
if (Peek() == '}') {
Advance();
return object;
}
while (true) {
const std::string key = ParseString();
SkipWhitespace();
Expect(':');
object.emplace(key, ParseValue());
SkipWhitespace();
if (Peek() == '}') {
Advance();
break;
}
Expect(',');
}
return object;
}
JsonArray ParseArray() {
Expect('[');
JsonArray array;
SkipWhitespace();
if (Peek() == ']') {
Advance();
return array;
}
while (true) {
array.emplace_back(ParseValue());
SkipWhitespace();
if (Peek() == ']') {
Advance();
break;
}
Expect(',');
}
return array;
}
std::string ParseString() {
Expect('"');
std::string output;
while (index_ < input_.size()) {
const char current = input_[index_++];
if (current == '"') {
return output;
}
if (current == '\\') {
if (index_ >= input_.size()) {
throw std::runtime_error("invalid JSON escape");
}
const char escaped = input_[index_++];
switch (escaped) {
case '"':
case '\\':
case '/':
output.push_back(escaped);
break;
case 'b':
output.push_back('\b');
break;
case 'f':
output.push_back('\f');
break;
case 'n':
output.push_back('\n');
break;
case 'r':
output.push_back('\r');
break;
case 't':
output.push_back('\t');
break;
case 'u':
AppendUtf8(output, ParseUnicodeEscape());
break;
default:
throw std::runtime_error("unsupported JSON escape");
}
continue;
}
output.push_back(current);
}
throw std::runtime_error("unterminated JSON string");
}
std::uint32_t ParseUnicodeEscape() {
const std::uint32_t first = ParseHexQuad();
if (first >= 0xD800 && first <= 0xDBFF) {
if (index_ + 1 >= input_.size() || input_[index_] != '\\' || input_[index_ + 1] != 'u') {
throw std::runtime_error("invalid JSON unicode surrogate");
}
index_ += 2;
const std::uint32_t second = ParseHexQuad();
if (second < 0xDC00 || second > 0xDFFF) {
throw std::runtime_error("invalid JSON unicode surrogate");
}
return 0x10000 + (((first - 0xD800) << 10) | (second - 0xDC00));
}
if (first >= 0xDC00 && first <= 0xDFFF) {
throw std::runtime_error("invalid JSON unicode surrogate");
}
return first;
}
std::uint32_t ParseHexQuad() {
if (index_ + 4 > input_.size()) {
throw std::runtime_error("invalid JSON unicode escape");
}
std::uint32_t value = 0;
for (int i = 0; i < 4; ++i) {
const char ch = input_[index_++];
value <<= 4;
if (ch >= '0' && ch <= '9') {
value |= static_cast<std::uint32_t>(ch - '0');
} else if (ch >= 'a' && ch <= 'f') {
value |= static_cast<std::uint32_t>(ch - 'a' + 10);
} else if (ch >= 'A' && ch <= 'F') {
value |= static_cast<std::uint32_t>(ch - 'A' + 10);
} else {
throw std::runtime_error("invalid JSON unicode escape");
}
}
return value;
}
double ParseNumber() {
const std::size_t start = index_;
if (Peek() == '-') {
Advance();
}
while (std::isdigit(static_cast<unsigned char>(Peek()))) {
Advance();
}
if (Peek() == '.') {
Advance();
while (std::isdigit(static_cast<unsigned char>(Peek()))) {
Advance();
}
}
return std::stod(input_.substr(start, index_ - start));
}
void SkipWhitespace() {
while (index_ < input_.size() && std::isspace(static_cast<unsigned char>(input_[index_]))) {
++index_;
}
}
char Peek() const {
if (index_ >= input_.size()) {
return '\0';
}
return input_[index_];
}
void Advance() {
if (index_ < input_.size()) {
++index_;
}
}
void Expect(char expected) {
SkipWhitespace();
if (Peek() != expected) {
throw std::runtime_error("unexpected JSON character");
}
Advance();
}
bool Match(const char* keyword) {
const std::size_t length = std::char_traits<char>::length(keyword);
if (input_.substr(index_, length) == keyword) {
index_ += length;
return true;
}
return false;
}
const std::string& input_;
std::size_t index_;
};
std::string Escape(const std::string& value) {
std::ostringstream output;
for (const char ch : value) {
switch (ch) {
case '"':
output << "\\\"";
break;
case '\\':
output << "\\\\";
break;
case '\n':
output << "\\n";
break;
case '\r':
output << "\\r";
break;
case '\t':
output << "\\t";
break;
default:
output << ch;
break;
}
}
return output.str();
}
std::string SerializeInternal(const JsonValue& value) {
if (std::holds_alternative<std::nullptr_t>(value.value)) {
return "null";
}
if (std::holds_alternative<bool>(value.value)) {
return std::get<bool>(value.value) ? "true" : "false";
}
if (std::holds_alternative<double>(value.value)) {
std::ostringstream output;
output << std::get<double>(value.value);
return output.str();
}
if (std::holds_alternative<std::string>(value.value)) {
return "\"" + Escape(std::get<std::string>(value.value)) + "\"";
}
if (std::holds_alternative<JsonArray>(value.value)) {
std::ostringstream output;
output << "[";
const auto& array = std::get<JsonArray>(value.value);
for (std::size_t i = 0; i < array.size(); ++i) {
if (i > 0) {
output << ",";
}
output << SerializeInternal(array[i]);
}
output << "]";
return output.str();
}
std::ostringstream output;
output << "{";
const auto& object = std::get<JsonObject>(value.value);
bool first = true;
for (const auto& [key, child] : object) {
if (!first) {
output << ",";
}
first = false;
output << "\"" << Escape(key) << "\":" << SerializeInternal(child);
}
output << "}";
return output.str();
}
} // namespace
JsonValue ParseJson(const std::string& input) {
return Parser(input).Parse();
}
std::string SerializeJson(const JsonValue& value) {
return SerializeInternal(value);
}
std::optional<std::string> GetString(const JsonObject& object, const std::string& key) {
auto iterator = object.find(key);
if (iterator == object.end() || !iterator->second.IsString()) {
return std::nullopt;
}
return iterator->second.AsString();
}
std::optional<bool> GetBool(const JsonObject& object, const std::string& key) {
auto iterator = object.find(key);
if (iterator == object.end() || !iterator->second.IsBool()) {
return std::nullopt;
}
return iterator->second.AsBool();
}
std::optional<double> GetNumber(const JsonObject& object, const std::string& key) {
auto iterator = object.find(key);
if (iterator == object.end() || !iterator->second.IsNumber()) {
return std::nullopt;
}
return iterator->second.AsNumber();
}
const JsonObject* GetObject(const JsonObject& object, const std::string& key) {
auto iterator = object.find(key);
if (iterator == object.end() || !iterator->second.IsObject()) {
return nullptr;
}
return &iterator->second.AsObject();
}
const JsonArray* GetArray(const JsonObject& object, const std::string& key) {
auto iterator = object.find(key);
if (iterator == object.end() || !iterator->second.IsArray()) {
return nullptr;
}
return &iterator->second.AsArray();
}
} // namespace rdp_worker::common
+52
View File
@@ -0,0 +1,52 @@
#include "rdp_worker/common/logger.hpp"
#include <iostream>
#include "rdp_worker/common/time.hpp"
namespace rdp_worker::common {
namespace {
std::string LevelToString(LogLevel level) {
switch (level) {
case LogLevel::kDebug:
return "DEBUG";
case LogLevel::kInfo:
return "INFO";
case LogLevel::kWarn:
return "WARN";
case LogLevel::kError:
return "ERROR";
}
return "INFO";
}
} // namespace
Logger::Logger(std::string service_name) : service_name_(std::move(service_name)) {}
void Logger::Debug(const std::string& message) {
Write(LogLevel::kDebug, message);
}
void Logger::Info(const std::string& message) {
Write(LogLevel::kInfo, message);
}
void Logger::Warn(const std::string& message) {
Write(LogLevel::kWarn, message);
}
void Logger::Error(const std::string& message) {
Write(LogLevel::kError, message);
}
void Logger::Write(LogLevel level, const std::string& message) {
std::lock_guard<std::mutex> lock(mutex_);
std::cout << "{\"ts\":\"" << ToRfc3339(NowUtc()) << "\",\"service\":\"" << service_name_
<< "\",\"level\":\"" << LevelToString(level) << "\",\"message\":\"" << message
<< "\"}" << std::endl;
}
} // namespace rdp_worker::common
+42
View File
@@ -0,0 +1,42 @@
#include "rdp_worker/common/time.hpp"
#include <ctime>
#include <iomanip>
#include <sstream>
#include <stdexcept>
namespace rdp_worker::common {
Clock::time_point NowUtc() {
return Clock::now();
}
std::string ToRfc3339(Clock::time_point time_point) {
const std::time_t raw = Clock::to_time_t(time_point);
std::tm utc_tm{};
#if defined(_WIN32)
gmtime_s(&utc_tm, &raw);
#else
gmtime_r(&raw, &utc_tm);
#endif
std::ostringstream output;
output << std::put_time(&utc_tm, "%Y-%m-%dT%H:%M:%SZ");
return output.str();
}
Clock::time_point ParseRfc3339(const std::string& value) {
std::tm utc_tm{};
std::istringstream input(value);
input >> std::get_time(&utc_tm, "%Y-%m-%dT%H:%M:%SZ");
if (input.fail()) {
throw std::runtime_error("failed to parse RFC3339 timestamp: " + value);
}
#if defined(_WIN32)
const std::time_t raw = _mkgmtime(&utc_tm);
#else
const std::time_t raw = timegm(&utc_tm);
#endif
return Clock::from_time_t(raw);
}
} // namespace rdp_worker::common
+89
View File
@@ -0,0 +1,89 @@
#include "rdp_worker/config/config.hpp"
#include <cstdlib>
#include <fstream>
#include <sstream>
#include <stdexcept>
namespace rdp_worker::config {
namespace {
std::string GetEnvOrDefault(const char* key, const char* fallback) {
const char* value = std::getenv(key);
if (value == nullptr || std::string(value).empty()) {
return fallback;
}
return value;
}
int GetInt(const char* key, int fallback) {
const char* value = std::getenv(key);
if (value == nullptr || std::string(value).empty()) {
return fallback;
}
return std::stoi(value);
}
bool GetBool(const char* key, bool fallback) {
const char* value = std::getenv(key);
if (value == nullptr || std::string(value).empty()) {
return fallback;
}
const std::string raw(value);
return raw == "1" || raw == "true" || raw == "TRUE" || raw == "yes" || raw == "on";
}
std::string ReadFileIfConfigured(const std::string& path) {
if (path.empty()) {
return "";
}
std::ifstream input(path);
if (!input.good()) {
throw std::runtime_error("failed to read file " + path);
}
std::stringstream buffer;
buffer << input.rdbuf();
return buffer.str();
}
std::vector<std::string> Split(const std::string& value) {
std::vector<std::string> parts;
std::stringstream stream(value);
std::string part;
while (std::getline(stream, part, ',')) {
if (!part.empty()) {
parts.push_back(part);
}
}
return parts;
}
} // namespace
Config LoadFromEnv() {
Config config{};
config.worker_id = GetEnvOrDefault("RDP_WORKER_ID", "rdp-worker-1");
config.redis_host = GetEnvOrDefault("RDP_WORKER_REDIS_HOST", "127.0.0.1");
config.redis_port = GetInt("RDP_WORKER_REDIS_PORT", 6379);
config.redis_password = GetEnvOrDefault("RDP_WORKER_REDIS_PASSWORD", "");
config.redis_db = GetInt("RDP_WORKER_REDIS_DB", 0);
config.worker_heartbeat_interval = std::chrono::seconds(GetInt("RDP_WORKER_HEARTBEAT_INTERVAL_SECONDS", 5));
config.lease_renew_interval = std::chrono::seconds(GetInt("RDP_WORKER_LEASE_RENEW_INTERVAL_SECONDS", 10));
config.assignment_poll_interval = std::chrono::seconds(GetInt("RDP_WORKER_ASSIGNMENT_POLL_INTERVAL_SECONDS", 2));
config.insecure_skip_verify = GetBool("RDP_WORKER_INSECURE_SKIP_VERIFY", false);
config.capabilities = Split(GetEnvOrDefault("RDP_WORKER_CAPABILITIES", "adaptive-quality,dirty-rects,clipboard,file-transfer"));
config.data_plane_enabled = GetBool("RDP_WORKER_DATA_PLANE_ENABLED", false);
config.data_plane_listen_host = GetEnvOrDefault("RDP_WORKER_DATA_PLANE_LISTEN_HOST", "0.0.0.0");
config.data_plane_listen_port = GetInt("RDP_WORKER_DATA_PLANE_LISTEN_PORT", 8443);
config.data_plane_public_key_pem = GetEnvOrDefault("RDP_WORKER_DATA_PLANE_PUBLIC_KEY_PEM", "");
config.data_plane_public_key_file = GetEnvOrDefault("RDP_WORKER_DATA_PLANE_PUBLIC_KEY_FILE", "");
if (config.data_plane_public_key_pem.empty()) {
config.data_plane_public_key_pem = ReadFileIfConfigured(config.data_plane_public_key_file);
}
config.data_plane_tls_cert_file = GetEnvOrDefault("RDP_WORKER_DATA_PLANE_TLS_CERT_FILE", "");
config.data_plane_tls_key_file = GetEnvOrDefault("RDP_WORKER_DATA_PLANE_TLS_KEY_FILE", "");
return config;
}
} // namespace rdp_worker::config
@@ -0,0 +1,300 @@
#include "rdp_worker/coordination/control_plane.hpp"
#include <sstream>
#include <stdexcept>
#include "rdp_worker/common/logger.hpp"
#include "rdp_worker/common/time.hpp"
namespace rdp_worker::coordination {
using common::GetArray;
using common::GetBool;
using common::GetNumber;
using common::GetObject;
using common::GetString;
using common::JsonArray;
using common::JsonObject;
using common::JsonValue;
ControlPlane::ControlPlane(config::Config config, std::shared_ptr<common::Logger> logger)
: config_(std::move(config)),
logger_(std::move(logger)),
redis_(std::make_unique<RedisClient>(config_.redis_host, config_.redis_port, config_.redis_password, config_.redis_db)) {}
void ControlPlane::Connect() {
std::lock_guard<std::mutex> lock(mutex_);
redis_->Connect();
}
void ControlPlane::RegisterWorker() {
std::lock_guard<std::mutex> lock(mutex_);
redis_->Set("worker:registration:" + config_.worker_id, WorkerRegistrationPayload(), config_.worker_heartbeat_interval * 3);
redis_->SAdd("worker:registrations", config_.worker_id);
}
void ControlPlane::ReleaseOwnedLeasesOnStartup() {
std::lock_guard<std::mutex> lock(mutex_);
int released = 0;
for (const auto& lease_id : redis_->SMembers("worker:leases")) {
auto encoded = redis_->Get("worker:lease:" + lease_id);
if (!encoded.has_value()) {
redis_->SRem("worker:leases", lease_id);
continue;
}
auto lease = ParseLease(common::ParseJson(*encoded).AsObject());
if (lease.worker_id != config_.worker_id) {
continue;
}
redis_->Delete("worker:lease:" + lease_id);
redis_->SRem("worker:leases", lease_id);
if (!lease.session_id.empty()) {
redis_->Delete("worker:session-lease:" + lease.session_id);
redis_->Delete("worker:queue:" + lease.session_id);
}
++released;
}
if (released > 0) {
logger_->Warn("released stale owned worker leases on startup worker=" + config_.worker_id +
" released_count=" + std::to_string(released));
}
}
void ControlPlane::SendHeartbeat() {
RegisterWorker();
}
std::optional<runtime::Assignment> ControlPlane::PollAssignment(std::chrono::seconds timeout) {
RedisClient stream_client(config_.redis_host, config_.redis_port, config_.redis_password, config_.redis_db);
stream_client.Connect();
auto entry = stream_client.BLPop("worker:control:" + config_.worker_id, timeout);
if (!entry.has_value() || entry->size() != 2) {
return std::nullopt;
}
const JsonObject object = common::ParseJson((*entry)[1]).AsObject();
return ParseAssignment(object);
}
std::optional<common::JsonObject> ControlPlane::PollSessionEnvelope(const std::string& session_id, std::chrono::seconds timeout) {
RedisClient stream_client(config_.redis_host, config_.redis_port, config_.redis_password, config_.redis_db);
stream_client.Connect();
auto entry = stream_client.BLPop("worker:queue:" + session_id, timeout);
if (!entry.has_value() || entry->size() != 2) {
return std::nullopt;
}
auto object = common::ParseJson((*entry)[1]).AsObject();
const JsonObject* payload = GetObject(object, "payload");
if (payload != nullptr) {
const std::string type = GetString(object, "type").value_or("");
const std::string correlation_id = GetString(*payload, "correlation_id").value_or("");
if (type == "input" && !correlation_id.empty()) {
logger_->Info("input.trace worker_queue_pop session=" + session_id +
" correlation_id=" + correlation_id +
" trace_stage=worker_queue_pop");
}
}
return object;
}
std::vector<common::JsonObject> ControlPlane::DrainSessionEnvelopes(const std::string& session_id, std::size_t max_count) {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<common::JsonObject> output;
output.reserve(max_count);
const std::string key = "worker:queue:" + session_id;
for (std::size_t i = 0; i < max_count; ++i) {
auto encoded = redis_->LPop(key);
if (!encoded.has_value()) {
break;
}
auto object = common::ParseJson(*encoded).AsObject();
const JsonObject* payload = GetObject(object, "payload");
if (payload != nullptr) {
const std::string type = GetString(object, "type").value_or("");
const std::string correlation_id = GetString(*payload, "correlation_id").value_or("");
if (type == "input" && !correlation_id.empty()) {
logger_->Info("input.trace worker_queue_pop session=" + session_id +
" correlation_id=" + correlation_id +
" trace_stage=worker_queue_pop");
}
}
output.push_back(std::move(object));
}
return output;
}
int64_t ControlPlane::SessionEnvelopeQueueLength(const std::string& session_id) {
std::lock_guard<std::mutex> lock(mutex_);
return redis_->LLen("worker:queue:" + session_id);
}
std::optional<runtime::WorkerLease> ControlPlane::GetLeaseBySession(const std::string& session_id) {
std::lock_guard<std::mutex> lock(mutex_);
auto lease_id = redis_->Get("worker:session-lease:" + session_id);
if (!lease_id.has_value()) {
return std::nullopt;
}
auto encoded = redis_->Get("worker:lease:" + *lease_id);
if (!encoded.has_value()) {
return std::nullopt;
}
return ParseLease(common::ParseJson(*encoded).AsObject());
}
void ControlPlane::RenewLease(const runtime::WorkerLease& lease) {
std::lock_guard<std::mutex> lock(mutex_);
redis_->Set("worker:lease:" + lease.lease_id, LeasePayload(lease), std::chrono::seconds(45));
redis_->Set("worker:session-lease:" + lease.session_id, lease.lease_id, std::chrono::seconds(45));
}
void ControlPlane::ReleaseLease(const runtime::WorkerLease& lease) {
std::lock_guard<std::mutex> lock(mutex_);
redis_->Delete("worker:lease:" + lease.lease_id);
redis_->Delete("worker:session-lease:" + lease.session_id);
}
void ControlPlane::PublishEvent(const runtime::WorkerEvent& event) {
std::lock_guard<std::mutex> lock(mutex_);
const std::string encoded = EventPayload(event);
redis_->RPush("worker:events", encoded);
redis_->Expire("worker:events", std::chrono::minutes(10));
}
runtime::Assignment ControlPlane::ParseAssignment(const JsonObject& object) const {
runtime::Assignment assignment{};
assignment.session_id = GetString(object, "session_id").value_or("");
assignment.worker_id = GetString(object, "worker_id").value_or("");
assignment.attachment_id = GetString(object, "attachment_id").value_or("");
assignment.user_id = GetString(object, "user_id").value_or("");
assignment.device_id = GetString(object, "device_id").value_or("");
assignment.takeover_of = GetString(object, "takeover_of");
const std::string state = GetString(object, "state").value_or("starting");
if (state == "active") {
assignment.state = runtime::SessionState::kActive;
} else if (state == "detached") {
assignment.state = runtime::SessionState::kDetached;
} else if (state == "reconnecting") {
assignment.state = runtime::SessionState::kReconnecting;
} else {
assignment.state = runtime::SessionState::kStarting;
}
const JsonObject* metadata = GetObject(object, "metadata");
if (metadata == nullptr) {
throw std::runtime_error("assignment metadata is required");
}
const JsonObject* resource = GetObject(*metadata, "resource");
if (resource == nullptr) {
throw std::runtime_error("assignment resource metadata is required");
}
assignment.organization_id = GetString(*resource, "organization_id").value_or("");
assignment.connection.resource_id = GetString(*resource, "id").value_or("");
assignment.connection.resource_name = GetString(*resource, "name").value_or("");
assignment.connection.host = GetString(*resource, "address").value_or("");
assignment.connection.port = 3389;
assignment.connection.username = "";
assignment.connection.password = "";
assignment.connection.domain = "";
assignment.connection.certificate_verification_mode = GetString(*resource, "certificate_verification_mode").value_or("strict");
assignment.connection.render_quality_profile = GetString(*resource, "render_quality_profile").value_or("balanced");
assignment.connection.insecure_skip_verify = config_.insecure_skip_verify;
const JsonObject* resource_meta = GetObject(*resource, "metadata");
if (resource_meta != nullptr) {
assignment.connection.host = GetString(*resource_meta, "rdp_host").value_or(assignment.connection.host);
assignment.connection.port = static_cast<uint16_t>(GetNumber(*resource_meta, "rdp_port").value_or(3389));
assignment.connection.username = GetString(*resource_meta, "username").value_or("");
assignment.connection.password = GetString(*resource_meta, "password").value_or("");
assignment.connection.domain = GetString(*resource_meta, "domain").value_or("");
assignment.connection.certificate_verification_mode =
GetString(*resource_meta, "certificate_verification_mode").value_or(assignment.connection.certificate_verification_mode);
assignment.connection.render_quality_profile =
GetString(*resource_meta, "render_quality_profile").value_or(assignment.connection.render_quality_profile);
}
const JsonObject* policy = GetObject(*metadata, "policy");
if (policy != nullptr) {
assignment.policy.detach_grace_period = std::chrono::seconds(static_cast<int>(GetNumber(*policy, "detach_grace_period_seconds").value_or(1800)));
assignment.policy.clipboard_mode = GetString(*policy, "clipboard_mode").value_or("disabled");
if (assignment.policy.clipboard_mode.empty()) {
assignment.policy.clipboard_mode = GetBool(*policy, "clipboard_enabled").value_or(false) ? "bidirectional" : "disabled";
}
assignment.policy.file_transfer_mode = GetString(*policy, "file_transfer_mode").value_or("disabled");
if (assignment.policy.file_transfer_mode.empty()) {
assignment.policy.file_transfer_mode = GetBool(*policy, "file_transfer_enabled").value_or(false) ? "client_to_server" : "disabled";
}
}
return assignment;
}
runtime::WorkerLease ControlPlane::ParseLease(const JsonObject& object) const {
runtime::WorkerLease lease{};
lease.lease_id = GetString(object, "lease_id").value_or("");
lease.worker_id = GetString(object, "worker_id").value_or("");
lease.session_id = GetString(object, "session_id").value_or("");
lease.resource_id = GetString(object, "resource_id").value_or("");
lease.control_stream = GetString(object, "control_stream").value_or("");
lease.expires_at = GetString(object, "expires_at").value_or("");
if (const JsonArray* capabilities = GetArray(object, "capabilities"); capabilities != nullptr) {
for (const auto& item : *capabilities) {
if (item.IsString()) {
lease.capabilities.push_back(item.AsString());
}
}
}
return lease;
}
std::string ControlPlane::WorkerRegistrationPayload() const {
JsonArray capabilities;
for (const auto& item : config_.capabilities) {
capabilities.emplace_back(item);
}
return common::SerializeJson(JsonObject{
{"worker_id", config_.worker_id},
{"protocol", "rdp"},
{"status", "online"},
{"capabilities", capabilities},
{"control_stream", "worker://control/" + config_.worker_id},
{"last_heartbeat_at", common::ToRfc3339(common::NowUtc())},
});
}
std::string ControlPlane::LeasePayload(const runtime::WorkerLease& lease) const {
JsonArray capabilities;
for (const auto& capability : lease.capabilities) {
capabilities.emplace_back(capability);
}
return common::SerializeJson(JsonObject{
{"lease_id", lease.lease_id},
{"worker_id", lease.worker_id},
{"protocol", "rdp"},
{"resource_id", lease.resource_id},
{"session_id", lease.session_id},
{"capabilities", capabilities},
{"control_stream", lease.control_stream},
{"expires_at", common::ToRfc3339(common::NowUtc() + std::chrono::seconds(45))},
});
}
std::string ControlPlane::EventPayload(const runtime::WorkerEvent& event) const {
JsonObject payload{
{"type", event.type},
{"session_id", event.session_id},
{"worker_id", event.worker_id},
};
JsonObject detail;
if (!event.reason.empty()) {
detail.emplace("reason", event.reason);
}
if (event.width > 0) {
detail.emplace("width", event.width);
}
if (event.height > 0) {
detail.emplace("height", event.height);
}
for (const auto& [key, value] : event.payload) {
detail[key] = value;
}
payload.emplace("payload", detail);
return common::SerializeJson(payload);
}
} // namespace rdp_worker::coordination
@@ -0,0 +1,264 @@
#include "rdp_worker/coordination/redis_client.hpp"
#include <cstring>
#include <stdexcept>
#include <string_view>
#if defined(_WIN32)
#include <winsock2.h>
#include <ws2tcpip.h>
#else
#include <arpa/inet.h>
#include <netdb.h>
#include <sys/socket.h>
#include <unistd.h>
#endif
namespace rdp_worker::coordination {
namespace {
void EnsureSuccess(bool condition, const std::string& message) {
if (!condition) {
throw std::runtime_error(message);
}
}
void CloseSocket(int socket_fd) {
if (socket_fd < 0) {
return;
}
#if defined(_WIN32)
closesocket(socket_fd);
#else
close(socket_fd);
#endif
}
} // namespace
bool RedisReply::IsNull() const { return std::holds_alternative<std::nullptr_t>(value); }
bool RedisReply::IsString() const { return std::holds_alternative<std::string>(value); }
bool RedisReply::IsInteger() const { return std::holds_alternative<int64_t>(value); }
bool RedisReply::IsArray() const { return std::holds_alternative<Array>(value); }
const std::string& RedisReply::AsString() const { return std::get<std::string>(value); }
int64_t RedisReply::AsInteger() const { return std::get<int64_t>(value); }
const RedisReply::Array& RedisReply::AsArray() const { return std::get<Array>(value); }
RedisClient::RedisClient(std::string host, int port, std::string password, int db)
: host_(std::move(host)),
port_(port),
password_(std::move(password)),
db_(db),
socket_fd_(-1) {}
RedisClient::~RedisClient() {
Close();
}
void RedisClient::Connect() {
#if defined(_WIN32)
WSADATA wsa_data{};
EnsureSuccess(WSAStartup(MAKEWORD(2, 2), &wsa_data) == 0, "WSAStartup failed");
#endif
addrinfo hints{};
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
addrinfo* result = nullptr;
EnsureSuccess(getaddrinfo(host_.c_str(), std::to_string(port_).c_str(), &hints, &result) == 0, "getaddrinfo failed");
for (addrinfo* node = result; node != nullptr; node = node->ai_next) {
socket_fd_ = static_cast<int>(socket(node->ai_family, node->ai_socktype, node->ai_protocol));
if (socket_fd_ < 0) {
continue;
}
if (connect(socket_fd_, node->ai_addr, static_cast<int>(node->ai_addrlen)) == 0) {
break;
}
CloseSocket(socket_fd_);
socket_fd_ = -1;
}
freeaddrinfo(result);
EnsureSuccess(socket_fd_ >= 0, "failed to connect to Redis");
if (!password_.empty()) {
Command({"AUTH", password_});
}
if (db_ != 0) {
Command({"SELECT", std::to_string(db_)});
}
}
void RedisClient::Close() {
CloseSocket(socket_fd_);
socket_fd_ = -1;
#if defined(_WIN32)
WSACleanup();
#endif
}
RedisReply RedisClient::Command(const std::vector<std::string>& parts) {
WriteAll(EncodeCommand(parts));
return ReadReply();
}
std::optional<std::string> RedisClient::Get(const std::string& key) {
RedisReply reply = Command({"GET", key});
if (reply.IsNull()) {
return std::nullopt;
}
return reply.AsString();
}
void RedisClient::Set(const std::string& key, const std::string& value, std::chrono::seconds ttl) {
Command({"SET", key, value, "EX", std::to_string(ttl.count())});
}
void RedisClient::SAdd(const std::string& key, const std::string& value) {
Command({"SADD", key, value});
}
void RedisClient::SRem(const std::string& key, const std::string& value) {
Command({"SREM", key, value});
}
std::vector<std::string> RedisClient::SMembers(const std::string& key) {
RedisReply reply = Command({"SMEMBERS", key});
std::vector<std::string> output;
if (!reply.IsArray()) {
return output;
}
for (const auto& item : reply.AsArray()) {
if (item.IsString()) {
output.push_back(item.AsString());
}
}
return output;
}
std::optional<std::vector<std::string>> RedisClient::BLPop(const std::string& key, std::chrono::seconds timeout) {
RedisReply reply = Command({"BLPOP", key, std::to_string(timeout.count())});
if (reply.IsNull()) {
return std::nullopt;
}
std::vector<std::string> output;
for (const auto& item : reply.AsArray()) {
if (item.IsString()) {
output.push_back(item.AsString());
}
}
return output;
}
std::optional<std::string> RedisClient::LPop(const std::string& key) {
RedisReply reply = Command({"LPOP", key});
if (reply.IsNull()) {
return std::nullopt;
}
return reply.AsString();
}
int64_t RedisClient::LLen(const std::string& key) {
RedisReply reply = Command({"LLEN", key});
if (!reply.IsInteger()) {
return 0;
}
return reply.AsInteger();
}
void RedisClient::RPush(const std::string& key, const std::string& value) {
Command({"RPUSH", key, value});
}
void RedisClient::Expire(const std::string& key, std::chrono::seconds ttl) {
Command({"EXPIRE", key, std::to_string(ttl.count())});
}
void RedisClient::Delete(const std::string& key) {
Command({"DEL", key});
}
std::string RedisClient::ReadLine() {
std::string output;
char ch = '\0';
while (true) {
const int received = recv(socket_fd_, &ch, 1, 0);
EnsureSuccess(received == 1, "failed to read from Redis");
if (ch == '\r') {
char lf = '\0';
EnsureSuccess(recv(socket_fd_, &lf, 1, 0) == 1 && lf == '\n', "invalid Redis line ending");
break;
}
output.push_back(ch);
}
return output;
}
std::string RedisClient::ReadBytes(std::size_t count) {
std::string output(count, '\0');
std::size_t offset = 0;
while (offset < count) {
const int received = recv(socket_fd_, output.data() + offset, static_cast<int>(count - offset), 0);
EnsureSuccess(received > 0, "failed to read bulk Redis payload");
offset += static_cast<std::size_t>(received);
}
char suffix[2];
EnsureSuccess(recv(socket_fd_, suffix, 2, 0) == 2 && suffix[0] == '\r' && suffix[1] == '\n', "invalid Redis bulk suffix");
return output;
}
RedisReply RedisClient::ReadReply() {
const std::string line = ReadLine();
EnsureSuccess(!line.empty(), "empty Redis reply");
const char prefix = line[0];
const std::string payload = line.substr(1);
switch (prefix) {
case '+':
return RedisReply{payload};
case ':':
return RedisReply{std::stoll(payload)};
case '$': {
const long long size = std::stoll(payload);
if (size < 0) {
return RedisReply{nullptr};
}
return RedisReply{ReadBytes(static_cast<std::size_t>(size))};
}
case '*': {
const long long size = std::stoll(payload);
if (size < 0) {
return RedisReply{nullptr};
}
RedisReply::Array values;
for (long long i = 0; i < size; ++i) {
values.push_back(ReadReply());
}
return RedisReply{values};
}
case '-':
throw std::runtime_error("Redis error: " + payload);
default:
throw std::runtime_error("unknown Redis reply type");
}
}
void RedisClient::WriteAll(const std::string& data) {
std::size_t offset = 0;
while (offset < data.size()) {
const int sent = send(socket_fd_, data.data() + offset, static_cast<int>(data.size() - offset), 0);
EnsureSuccess(sent > 0, "failed to send Redis command");
offset += static_cast<std::size_t>(sent);
}
}
std::string RedisClient::EncodeCommand(const std::vector<std::string>& parts) const {
std::string output = "*" + std::to_string(parts.size()) + "\r\n";
for (const auto& part : parts) {
output += "$" + std::to_string(part.size()) + "\r\n" + part + "\r\n";
}
return output;
}
} // namespace rdp_worker::coordination
@@ -0,0 +1,102 @@
#include "rdp_worker/cursor/cursor_adapter.hpp"
namespace rdp_worker::cursor {
CursorUpdate CursorAdapter::MakePosition(std::uint64_t sequence,
int desktop_width,
int desktop_height,
int x,
int y,
bool visible) const {
CursorUpdate update;
update.kind = CursorUpdateKind::kPosition;
update.sequence = sequence;
update.desktop_width = desktop_width;
update.desktop_height = desktop_height;
update.x = x;
update.y = y;
update.visible = visible;
return update;
}
CursorUpdate CursorAdapter::MakeSystem(std::uint64_t sequence,
int desktop_width,
int desktop_height,
int x,
int y,
std::uint32_t system_type) const {
CursorUpdate update = MakePosition(sequence, desktop_width, desktop_height, x, y, system_type != 0);
update.kind = CursorUpdateKind::kSystem;
update.shape_changed = true;
update.system_type = system_type;
return update;
}
CursorUpdate CursorAdapter::MakeColor(std::uint64_t sequence,
int desktop_width,
int desktop_height,
int x,
int y,
int width,
int height,
int cache_index,
std::uint64_t mask_bytes) const {
CursorUpdate update = MakePosition(sequence, desktop_width, desktop_height, x, y, true);
update.kind = CursorUpdateKind::kColor;
update.shape_changed = true;
update.width = width;
update.height = height;
update.cache_index = cache_index;
update.mask_bytes = mask_bytes;
return update;
}
CursorUpdate CursorAdapter::MakeNew(std::uint64_t sequence,
int desktop_width,
int desktop_height,
int x,
int y,
int width,
int height,
int cache_index,
int xor_bpp,
std::uint64_t mask_bytes) const {
CursorUpdate update = MakeColor(sequence, desktop_width, desktop_height, x, y, width, height, cache_index, mask_bytes);
update.kind = CursorUpdateKind::kNew;
update.xor_bpp = xor_bpp;
return update;
}
CursorUpdate CursorAdapter::MakeCached(std::uint64_t sequence,
int desktop_width,
int desktop_height,
int x,
int y,
int cache_index) const {
CursorUpdate update = MakePosition(sequence, desktop_width, desktop_height, x, y, true);
update.kind = CursorUpdateKind::kCached;
update.shape_changed = true;
update.cache_index = cache_index;
return update;
}
CursorUpdate CursorAdapter::MakeLarge(std::uint64_t sequence,
int desktop_width,
int desktop_height,
int x,
int y,
int width,
int height,
int cache_index,
int hot_spot_x,
int hot_spot_y,
int xor_bpp,
std::uint64_t mask_bytes) const {
CursorUpdate update = MakeNew(sequence, desktop_width, desktop_height, x, y, width, height, cache_index, xor_bpp, mask_bytes);
update.kind = CursorUpdateKind::kLarge;
update.hot_spot_x = hot_spot_x;
update.hot_spot_y = hot_spot_y;
return update;
}
} // namespace rdp_worker::cursor
@@ -0,0 +1,48 @@
#include "rdp_worker/cursor/cursor_update.hpp"
namespace rdp_worker::cursor {
const char* CursorUpdateKindName(CursorUpdateKind kind) {
switch (kind) {
case CursorUpdateKind::kPosition:
return "position";
case CursorUpdateKind::kSystem:
return "system";
case CursorUpdateKind::kColor:
return "color";
case CursorUpdateKind::kNew:
return "new";
case CursorUpdateKind::kCached:
return "cached";
case CursorUpdateKind::kLarge:
return "large";
}
return "unknown";
}
common::JsonObject CursorUpdateToPayload(const CursorUpdate& update,
const std::string& render_quality_profile) {
return common::JsonObject{
{"render_quality_profile", render_quality_profile},
{"render_state", "cursor"},
{"cursor_update_kind", CursorUpdateKindName(update.kind)},
{"cursor_sequence", static_cast<double>(update.sequence)},
{"width", update.desktop_width},
{"height", update.desktop_height},
{"cursor_x", update.x},
{"cursor_y", update.y},
{"cursor_visible", update.visible},
{"cursor_shape_changed", update.shape_changed},
{"cursor_cache_index", update.cache_index},
{"cursor_hot_spot_x", update.hot_spot_x},
{"cursor_hot_spot_y", update.hot_spot_y},
{"cursor_width", update.width},
{"cursor_height", update.height},
{"cursor_xor_bpp", update.xor_bpp},
{"cursor_mask_bytes", static_cast<double>(update.mask_bytes)},
{"cursor_system_type", static_cast<double>(update.system_type)},
{"dirty_rectangles", 0},
};
}
} // namespace rdp_worker::cursor
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,240 @@
#include "rdp_worker/dataplane/token_validator.hpp"
#include <algorithm>
#include <array>
#include <chrono>
#include <cstdint>
#include <memory>
#include <stdexcept>
#include <string_view>
#include <utility>
#include <openssl/bio.h>
#include <openssl/evp.h>
#include <openssl/pem.h>
#include "rdp_worker/common/json.hpp"
namespace rdp_worker::dataplane {
namespace {
std::vector<std::string> SplitJwt(const std::string& token) {
std::vector<std::string> parts;
std::size_t start = 0;
while (true) {
const std::size_t dot = token.find('.', start);
parts.push_back(token.substr(start, dot == std::string::npos ? std::string::npos : dot - start));
if (dot == std::string::npos) {
break;
}
start = dot + 1;
}
return parts;
}
std::vector<std::uint8_t> Base64UrlDecode(const std::string& input) {
static constexpr unsigned char kInvalid = 255;
std::array<unsigned char, 256> table{};
table.fill(kInvalid);
const std::string alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
for (std::size_t i = 0; i < alphabet.size(); ++i) {
table[static_cast<unsigned char>(alphabet[i])] = static_cast<unsigned char>(i);
}
std::vector<std::uint8_t> out;
int value = 0;
int value_bits = -8;
for (unsigned char ch : input) {
if (ch == '=') {
break;
}
if (table[ch] == kInvalid) {
throw std::runtime_error("invalid base64url token segment");
}
value = (value << 6) + table[ch];
value_bits += 6;
if (value_bits >= 0) {
out.push_back(static_cast<std::uint8_t>((value >> value_bits) & 0xFF));
value_bits -= 8;
}
}
return out;
}
std::string DecodeStringSegment(const std::string& input) {
const auto bytes = Base64UrlDecode(input);
return std::string(reinterpret_cast<const char*>(bytes.data()), bytes.size());
}
struct EvpKeyDeleter {
void operator()(EVP_PKEY* key) const {
EVP_PKEY_free(key);
}
};
struct BioDeleter {
void operator()(BIO* bio) const {
BIO_free(bio);
}
};
struct EvpMdCtxDeleter {
void operator()(EVP_MD_CTX* context) const {
EVP_MD_CTX_free(context);
}
};
using EvpKeyPtr = std::unique_ptr<EVP_PKEY, EvpKeyDeleter>;
using BioPtr = std::unique_ptr<BIO, BioDeleter>;
using EvpMdCtxPtr = std::unique_ptr<EVP_MD_CTX, EvpMdCtxDeleter>;
bool VerifyRs256(const std::string& public_key_pem, const std::string& signing_input, const std::vector<std::uint8_t>& signature) {
BioPtr bio(BIO_new_mem_buf(public_key_pem.data(), static_cast<int>(public_key_pem.size())));
if (!bio) {
throw std::runtime_error("public_key_bio_unavailable");
}
EvpKeyPtr key(PEM_read_bio_PUBKEY(bio.get(), nullptr, nullptr, nullptr));
if (!key) {
throw std::runtime_error("public_key_parse_failed");
}
EvpMdCtxPtr context(EVP_MD_CTX_new());
if (!context) {
throw std::runtime_error("signature_context_unavailable");
}
if (EVP_DigestVerifyInit(context.get(), nullptr, EVP_sha256(), nullptr, key.get()) != 1) {
throw std::runtime_error("signature_verify_init_failed");
}
if (EVP_DigestVerifyUpdate(context.get(), signing_input.data(), signing_input.size()) != 1) {
throw std::runtime_error("signature_verify_update_failed");
}
return EVP_DigestVerifyFinal(context.get(), signature.data(), signature.size()) == 1;
}
std::int64_t UnixNow() {
return std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
}
bool IsKnownChannel(const std::string& channel) {
return channel == "control" ||
channel == "input" ||
channel == "render" ||
channel == "clipboard" ||
channel == "file_upload" ||
channel == "file_download" ||
channel == "telemetry";
}
bool ArrayContainsString(const rdp_worker::common::JsonArray* array, const std::string& expected) {
if (array == nullptr) {
return false;
}
return std::any_of(array->begin(), array->end(), [&](const auto& item) {
return item.IsString() && item.AsString() == expected;
});
}
std::vector<std::string> ParseAllowedChannels(const rdp_worker::common::JsonObject& payload) {
const auto* array = rdp_worker::common::GetArray(payload, "allowed_channels");
if (array == nullptr || array->empty()) {
throw std::runtime_error("allowed_channels is required");
}
std::vector<std::string> channels;
for (const auto& item : *array) {
if (!item.IsString() || !IsKnownChannel(item.AsString())) {
throw std::runtime_error("allowed_channels contains unsupported channel");
}
channels.push_back(item.AsString());
}
return channels;
}
std::string RequiredString(const rdp_worker::common::JsonObject& object, const std::string& key) {
const auto value = rdp_worker::common::GetString(object, key);
if (!value.has_value() || value->empty()) {
throw std::runtime_error(key + " is required");
}
return *value;
}
} // namespace
DataPlaneTokenValidator::DataPlaneTokenValidator(std::string public_key_pem, std::string expected_worker_id)
: public_key_pem_(std::move(public_key_pem)),
expected_worker_id_(std::move(expected_worker_id)) {}
TokenValidationResult DataPlaneTokenValidator::Validate(const std::string& token) const {
TokenValidationResult result{};
try {
if (public_key_pem_.empty()) {
result.reason = "token_public_key_not_configured";
return result;
}
const auto parts = SplitJwt(token);
if (parts.size() != 3 || parts[0].empty() || parts[1].empty() || parts[2].empty()) {
result.reason = "malformed_token";
return result;
}
const auto header = rdp_worker::common::ParseJson(DecodeStringSegment(parts[0])).AsObject();
if (rdp_worker::common::GetString(header, "alg").value_or("") != "RS256" ||
rdp_worker::common::GetString(header, "typ").value_or("JWT") != "JWT") {
result.reason = "unsupported_token_header";
return result;
}
const std::string signing_input = parts[0] + "." + parts[1];
const auto actual_signature = Base64UrlDecode(parts[2]);
if (!VerifyRs256(public_key_pem_, signing_input, actual_signature)) {
result.reason = "invalid_signature";
return result;
}
const auto payload = rdp_worker::common::ParseJson(DecodeStringSegment(parts[1])).AsObject();
DataPlaneTokenClaims claims{};
claims.session_id = RequiredString(payload, "session_id");
claims.attachment_id = RequiredString(payload, "attachment_id");
claims.user_id = RequiredString(payload, "user_id");
claims.organization_id = RequiredString(payload, "organization_id");
claims.worker_id = RequiredString(payload, "worker_id");
claims.resource_id = RequiredString(payload, "resource_id");
claims.jti = RequiredString(payload, "jti");
claims.allowed_channels = ParseAllowedChannels(payload);
claims.expires_at_unix = static_cast<std::int64_t>(rdp_worker::common::GetNumber(payload, "exp").value_or(0));
const auto now = UnixNow();
if (claims.expires_at_unix <= now) {
result.reason = "token_expired";
return result;
}
const auto not_before = static_cast<std::int64_t>(rdp_worker::common::GetNumber(payload, "nbf").value_or(0));
if (not_before > 0 && not_before > now) {
result.reason = "token_not_yet_valid";
return result;
}
if (!expected_worker_id_.empty() && claims.worker_id != expected_worker_id_) {
result.reason = "wrong_worker";
return result;
}
if (!ArrayContainsString(rdp_worker::common::GetArray(payload, "aud"), "rap-data-plane") ||
!ArrayContainsString(rdp_worker::common::GetArray(payload, "aud"), "worker:" + claims.worker_id)) {
result.reason = "invalid_audience";
return result;
}
if (std::find(claims.allowed_channels.begin(), claims.allowed_channels.end(), "control") == claims.allowed_channels.end()) {
result.reason = "control_channel_not_allowed";
return result;
}
result.ok = true;
result.claims = std::move(claims);
return result;
} catch (const std::exception& error) {
result.reason = error.what();
return result;
}
}
} // namespace rdp_worker::dataplane
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,74 @@
#include "rdp_worker/graphics/graphics_adapter.hpp"
#include <utility>
namespace rdp_worker::graphics {
GraphicsAdapter::GraphicsAdapter(GraphicsAdapterPolicy policy)
: policy_(policy) {}
const GraphicsAdapterPolicy& GraphicsAdapter::Policy() const {
return policy_;
}
RenderUpdate GraphicsAdapter::MakeFullBgraFrame(std::uint64_t sequence,
int width,
int height,
int stride,
std::vector<std::uint8_t> pixels,
bool baseline) const {
RenderUpdate update;
update.kind = RenderUpdateKind::kFullBgraFrame;
update.sequence = sequence;
update.desktop_width = width;
update.desktop_height = height;
update.frame_width = width;
update.frame_height = height;
update.stride = stride;
update.region = Rect{0, 0, width, height};
update.payload = std::move(pixels);
update.baseline = baseline;
update.droppable = !baseline;
return update;
}
std::optional<RenderUpdate> GraphicsAdapter::TryMakeBgraRegion(std::uint64_t sequence,
int desktop_width,
int desktop_height,
int stride,
Rect region,
std::vector<std::uint8_t> pixels) const {
if (!RegionAllowed(desktop_width, desktop_height, region)) {
return std::nullopt;
}
RenderUpdate update;
update.kind = RenderUpdateKind::kBgraRegion;
update.sequence = sequence;
update.desktop_width = desktop_width;
update.desktop_height = desktop_height;
update.frame_width = region.width;
update.frame_height = region.height;
update.stride = stride;
update.region = region;
update.payload = std::move(pixels);
update.baseline = false;
update.droppable = true;
return update;
}
bool GraphicsAdapter::RegionAllowed(int desktop_width, int desktop_height, const Rect& region) const {
if (desktop_width <= 0 || desktop_height <= 0 ||
region.x < 0 || region.y < 0 ||
region.width <= 0 || region.height <= 0 ||
region.x + region.width > desktop_width ||
region.y + region.height > desktop_height) {
return false;
}
const long long desktop_area = static_cast<long long>(desktop_width) * static_cast<long long>(desktop_height);
const long long region_area = static_cast<long long>(region.width) * static_cast<long long>(region.height);
return region_area * 100 <= desktop_area * policy_.max_region_area_percent;
}
} // namespace rdp_worker::graphics
@@ -0,0 +1,38 @@
#include "rdp_worker/graphics/render_update.hpp"
namespace rdp_worker::graphics {
const char* RenderUpdateKindName(RenderUpdateKind kind) {
switch (kind) {
case RenderUpdateKind::kFullBgraFrame:
return "full_bgra_frame";
case RenderUpdateKind::kBgraRegion:
return "bgra_region";
case RenderUpdateKind::kSurfaceCreate:
return "surface_create";
case RenderUpdateKind::kSurfaceDelete:
return "surface_delete";
case RenderUpdateKind::kSurfaceBits:
return "surface_bits";
case RenderUpdateKind::kEncodedFrame:
return "encoded_frame";
case RenderUpdateKind::kCursorUpdate:
return "cursor_update";
}
return "unknown";
}
bool IsFullFrameUpdate(const RenderUpdate& update) {
return update.kind == RenderUpdateKind::kFullBgraFrame;
}
bool IsRegionUpdate(const RenderUpdate& update) {
return update.kind == RenderUpdateKind::kBgraRegion;
}
bool IsEncodedUpdate(const RenderUpdate& update) {
return update.kind == RenderUpdateKind::kEncodedFrame ||
update.kind == RenderUpdateKind::kSurfaceBits;
}
} // namespace rdp_worker::graphics
+56
View File
@@ -0,0 +1,56 @@
#include <atomic>
#include <csignal>
#include <memory>
#include <thread>
#include "rdp_worker/common/logger.hpp"
#include "rdp_worker/config/config.hpp"
#include "rdp_worker/coordination/control_plane.hpp"
#include "rdp_worker/dataplane/direct_wss_server.hpp"
#include "rdp_worker/runtime/session_manager.hpp"
namespace {
std::atomic<bool> g_stop{false};
void HandleSignal(int) {
g_stop.store(true);
}
} // namespace
int main() {
std::signal(SIGINT, HandleSignal);
std::signal(SIGTERM, HandleSignal);
auto logger = std::make_shared<rdp_worker::common::Logger>("rdp-worker");
const auto config = rdp_worker::config::LoadFromEnv();
auto control_plane = std::make_shared<rdp_worker::coordination::ControlPlane>(config, logger);
control_plane->Connect();
control_plane->ReleaseOwnedLeasesOnStartup();
control_plane->RegisterWorker();
auto session_manager = std::make_shared<rdp_worker::runtime::SessionManager>(control_plane, logger);
rdp_worker::dataplane::DirectWssServer direct_wss_server(config, session_manager, logger);
direct_wss_server.Start();
std::thread heartbeat_thread([&]() {
while (!g_stop.load()) {
control_plane->SendHeartbeat();
std::this_thread::sleep_for(config.worker_heartbeat_interval);
}
});
while (!g_stop.load()) {
if (auto assignment = control_plane->PollAssignment(config.assignment_poll_interval); assignment.has_value()) {
session_manager->ApplyAssignment(*assignment);
}
}
direct_wss_server.Stop();
session_manager->StopAll();
if (heartbeat_thread.joinable()) {
heartbeat_thread.join();
}
return 0;
}
@@ -0,0 +1,72 @@
#include "rdp_worker/runtime/direct_bind_policy.hpp"
#include <algorithm>
#include <vector>
namespace rdp_worker::runtime {
namespace {
bool ClipboardAllowsServerOrClient(const std::string& mode) {
return mode == "client_to_server" || mode == "server_to_client" || mode == "bidirectional";
}
bool FileTransferAllowsClientToServer(const std::string& mode) {
return mode == "client_to_server" || mode == "bidirectional";
}
bool FileTransferAllowsServerToClient(const std::string& mode) {
return mode == "server_to_client" || mode == "bidirectional";
}
std::vector<std::string> RuntimeAllowedChannels(const Assignment& assignment) {
std::vector<std::string> channels{"control", "input", "render", "telemetry"};
if (ClipboardAllowsServerOrClient(assignment.policy.clipboard_mode)) {
channels.push_back("clipboard");
}
if (FileTransferAllowsClientToServer(assignment.policy.file_transfer_mode)) {
channels.push_back("file_upload");
}
if (FileTransferAllowsServerToClient(assignment.policy.file_transfer_mode)) {
channels.push_back("file_download");
}
return channels;
}
bool RequestedChannelsAllowed(const std::vector<std::string>& requested, const std::vector<std::string>& allowed) {
return std::all_of(requested.begin(), requested.end(), [&](const auto& channel) {
return std::find(allowed.begin(), allowed.end(), channel) != allowed.end();
});
}
} // namespace
DirectBindValidationResult ValidateDirectDataPlaneBind(const Assignment& assignment,
const dataplane::DataPlaneTokenClaims& claims) {
if (assignment.state != SessionState::kStarting &&
assignment.state != SessionState::kActive &&
assignment.state != SessionState::kReconnecting) {
return {false, "session_not_attachable"};
}
if (assignment.worker_id != claims.worker_id) {
return {false, "worker_mismatch"};
}
if (assignment.attachment_id != claims.attachment_id) {
return {false, "attachment_mismatch"};
}
if (assignment.user_id != claims.user_id) {
return {false, "user_mismatch"};
}
if (assignment.organization_id != claims.organization_id) {
return {false, "organization_mismatch"};
}
if (assignment.connection.resource_id != claims.resource_id) {
return {false, "resource_mismatch"};
}
if (!RequestedChannelsAllowed(claims.allowed_channels, RuntimeAllowedChannels(assignment))) {
return {false, "channels_exceed_runtime_policy"};
}
return {true, ""};
}
} // namespace rdp_worker::runtime
@@ -0,0 +1,63 @@
#include "rdp_worker/runtime/session_manager.hpp"
#include "rdp_worker/runtime/direct_bind_policy.hpp"
namespace rdp_worker::runtime {
SessionManager::SessionManager(std::shared_ptr<coordination::ControlPlane> control_plane,
std::shared_ptr<common::Logger> logger)
: control_plane_(std::move(control_plane)),
logger_(std::move(logger)) {}
void SessionManager::ApplyAssignment(const Assignment& assignment) {
std::lock_guard<std::mutex> lock(mutex_);
const auto iterator = sessions_.find(assignment.session_id);
if (iterator != sessions_.end()) {
iterator->second->ApplyAssignment(assignment);
logger_->Info("updated assignment for existing session " + assignment.session_id);
return;
}
auto runtime = std::make_shared<SessionRuntime>(assignment, control_plane_, logger_);
runtime->Start();
sessions_.emplace(assignment.session_id, runtime);
logger_->Info("started new runtime for session " + assignment.session_id);
}
void SessionManager::StopAll() {
std::lock_guard<std::mutex> lock(mutex_);
for (auto& [_, runtime] : sessions_) {
runtime->Stop(true, "worker_shutdown");
}
sessions_.clear();
}
bool SessionManager::BindDirectDataPlaneAttachment(const dataplane::DataPlaneTokenClaims& claims, std::string& reason) {
return BindDirectDataPlaneRuntime(claims, reason) != nullptr;
}
std::shared_ptr<SessionRuntime> SessionManager::BindDirectDataPlaneRuntime(const dataplane::DataPlaneTokenClaims& claims, std::string& reason) {
std::lock_guard<std::mutex> lock(mutex_);
const auto iterator = sessions_.find(claims.session_id);
if (iterator == sessions_.end()) {
reason = "missing_runtime";
return nullptr;
}
const Assignment snapshot = iterator->second->Snapshot();
const auto validation = ValidateDirectDataPlaneBind(snapshot, claims);
if (!validation.ok) {
reason = validation.reason;
return nullptr;
}
reason.clear();
logger_->Info("event=data_plane_bind_success session=" + claims.session_id +
" attachment=" + claims.attachment_id +
" user=" + claims.user_id +
" organization=" + claims.organization_id +
" resource=" + claims.resource_id);
return iterator->second;
}
} // namespace rdp_worker::runtime
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,44 @@
#include <iostream>
#include "rdp_worker/cursor/cursor_adapter.hpp"
#include "rdp_worker/cursor/cursor_update.hpp"
int main() {
rdp_worker::cursor::CursorAdapter adapter;
auto position = adapter.MakePosition(1, 1280, 720, 10, 20, true);
if (position.kind != rdp_worker::cursor::CursorUpdateKind::kPosition ||
!position.visible ||
position.x != 10 ||
position.y != 20) {
std::cerr << "cursor position normalization failed\n";
return 1;
}
auto hidden = adapter.MakeSystem(2, 1280, 720, 10, 20, 0);
if (hidden.kind != rdp_worker::cursor::CursorUpdateKind::kSystem ||
hidden.visible ||
!hidden.shape_changed) {
std::cerr << "cursor system visibility normalization failed\n";
return 1;
}
auto large = adapter.MakeLarge(3, 1280, 720, 5, 6, 32, 32, 7, 1, 2, 32, 128);
if (large.kind != rdp_worker::cursor::CursorUpdateKind::kLarge ||
large.cache_index != 7 ||
large.hot_spot_x != 1 ||
large.hot_spot_y != 2 ||
large.mask_bytes != 128) {
std::cerr << "cursor shape normalization failed\n";
return 1;
}
auto payload = rdp_worker::cursor::CursorUpdateToPayload(large, "balanced");
if (payload.empty()) {
std::cerr << "cursor payload conversion failed\n";
return 1;
}
std::cout << "cursor_adapter_probe ok\n";
return 0;
}
@@ -0,0 +1,102 @@
#include <iostream>
#include <string>
#include "rdp_worker/runtime/direct_bind_policy.hpp"
namespace {
std::string ArgValue(int argc, char** argv, const std::string& name) {
for (int i = 1; i + 1 < argc; ++i) {
if (argv[i] == name) {
return argv[i + 1];
}
}
return {};
}
rdp_worker::runtime::Assignment BaseAssignment() {
rdp_worker::runtime::Assignment assignment{};
assignment.session_id = "session-1";
assignment.worker_id = "rdp-worker-1";
assignment.attachment_id = "attachment-current";
assignment.user_id = "user-1";
assignment.organization_id = "org-1";
assignment.state = rdp_worker::runtime::SessionState::kActive;
assignment.connection.resource_id = "resource-1";
assignment.policy.clipboard_mode = "disabled";
assignment.policy.file_transfer_mode = "disabled";
return assignment;
}
rdp_worker::dataplane::DataPlaneTokenClaims BaseClaims() {
rdp_worker::dataplane::DataPlaneTokenClaims claims{};
claims.session_id = "session-1";
claims.worker_id = "rdp-worker-1";
claims.attachment_id = "attachment-current";
claims.user_id = "user-1";
claims.organization_id = "org-1";
claims.resource_id = "resource-1";
claims.allowed_channels = {"control", "input", "render", "telemetry"};
claims.jti = "jti-1";
claims.expires_at_unix = 4102444800;
return claims;
}
} // namespace
int main(int argc, char** argv) {
const auto scenario = ArgValue(argc, argv, "--scenario");
if (scenario.empty()) {
std::cerr << "usage: rdp-worker-dataplane-bind-probe --scenario valid|starting|wrong-worker|wrong-attachment|wrong-user|wrong-organization|wrong-resource|channels-too-broad|failed-state|terminated-state\n";
return 2;
}
auto assignment = BaseAssignment();
auto claims = BaseClaims();
std::string expected_reason;
if (scenario == "valid") {
expected_reason = "";
} else if (scenario == "starting") {
assignment.state = rdp_worker::runtime::SessionState::kStarting;
expected_reason = "";
} else if (scenario == "wrong-attachment") {
claims.attachment_id = "attachment-old";
expected_reason = "attachment_mismatch";
} else if (scenario == "wrong-worker") {
claims.worker_id = "rdp-worker-other";
expected_reason = "worker_mismatch";
} else if (scenario == "wrong-user") {
claims.user_id = "user-other";
expected_reason = "user_mismatch";
} else if (scenario == "wrong-organization") {
claims.organization_id = "org-other";
expected_reason = "organization_mismatch";
} else if (scenario == "wrong-resource") {
claims.resource_id = "resource-other";
expected_reason = "resource_mismatch";
} else if (scenario == "channels-too-broad") {
claims.allowed_channels.push_back("clipboard");
expected_reason = "channels_exceed_runtime_policy";
} else if (scenario == "failed-state") {
assignment.state = rdp_worker::runtime::SessionState::kFailed;
expected_reason = "session_not_attachable";
} else if (scenario == "terminated-state") {
assignment.state = rdp_worker::runtime::SessionState::kTerminated;
expected_reason = "session_not_attachable";
} else {
std::cerr << "unknown scenario=" << scenario << "\n";
return 2;
}
const auto result = rdp_worker::runtime::ValidateDirectDataPlaneBind(assignment, claims);
if ((scenario == "valid" || scenario == "starting") && result.ok) {
std::cout << "PASS scenario=" << scenario << "\n";
return 0;
}
if (!expected_reason.empty() && !result.ok && result.reason == expected_reason) {
std::cout << "PASS scenario=" << scenario << " reason=" << result.reason << "\n";
return 0;
}
std::cerr << "FAIL scenario=" << scenario << " ok=" << (result.ok ? "true" : "false") << " reason=" << result.reason << "\n";
return 1;
}
@@ -0,0 +1,58 @@
#include <iostream>
#include <fstream>
#include <sstream>
#include <string>
#include "rdp_worker/dataplane/token_validator.hpp"
namespace {
std::string ArgValue(int argc, char** argv, const std::string& name) {
for (int i = 1; i + 1 < argc; ++i) {
if (argv[i] == name) {
return argv[i + 1];
}
}
return {};
}
std::string ReadFile(const std::string& path) {
std::ifstream input(path);
if (!input.good()) {
return {};
}
std::stringstream buffer;
buffer << input.rdbuf();
return buffer.str();
}
} // namespace
int main(int argc, char** argv) {
const auto token = ArgValue(argc, argv, "--token");
auto public_key = ArgValue(argc, argv, "--public-key-pem");
const auto public_key_file = ArgValue(argc, argv, "--public-key-file");
if (public_key.empty() && !public_key_file.empty()) {
public_key = ReadFile(public_key_file);
}
const auto worker_id = ArgValue(argc, argv, "--worker-id");
if (token.empty() || public_key.empty() || worker_id.empty()) {
std::cerr << "usage: rdp-worker-dataplane-token-probe --token <jwt> --public-key-file <public.pem> --worker-id <worker_id>\n";
return 2;
}
rdp_worker::dataplane::DataPlaneTokenValidator validator(public_key, worker_id);
const auto result = validator.Validate(token);
if (!result.ok) {
std::cerr << "FAIL reason=" << result.reason << "\n";
return 1;
}
std::cout << "PASS session_id=" << result.claims.session_id
<< " attachment_id=" << result.claims.attachment_id
<< " worker_id=" << result.claims.worker_id
<< " resource_id=" << result.claims.resource_id
<< " channels=" << result.claims.allowed_channels.size()
<< "\n";
return 0;
}
@@ -0,0 +1,41 @@
#include <iostream>
#include <vector>
#include "rdp_worker/graphics/graphics_adapter.hpp"
int main() {
rdp_worker::graphics::GraphicsAdapter adapter;
auto full = adapter.MakeFullBgraFrame(1, 1280, 720, 5120, std::vector<std::uint8_t>(1280 * 720 * 4), true);
if (!rdp_worker::graphics::IsFullFrameUpdate(full) || !full.baseline || full.droppable) {
std::cerr << "full frame baseline policy failed\n";
return 1;
}
auto region = adapter.TryMakeBgraRegion(
2,
1280,
720,
200 * 4,
rdp_worker::graphics::Rect{10, 20, 200, 100},
std::vector<std::uint8_t>(200 * 100 * 4));
if (!region.has_value() || !rdp_worker::graphics::IsRegionUpdate(*region) || !region->droppable) {
std::cerr << "region policy failed\n";
return 1;
}
auto too_large = adapter.TryMakeBgraRegion(
3,
1280,
720,
1280 * 4,
rdp_worker::graphics::Rect{0, 0, 1280, 720},
std::vector<std::uint8_t>(1280 * 720 * 4));
if (too_large.has_value()) {
std::cerr << "oversized region should be rejected\n";
return 1;
}
std::cout << "graphics_adapter_probe ok\n";
return 0;
}
@@ -0,0 +1,25 @@
#include <iostream>
#include "rdp_worker/adapter/service_adapter_protocol.hpp"
int main() {
using namespace rdp_worker::adapter;
for (const auto& spec : AllChannelSpecs()) {
std::cout << spec.name
<< " direction=" << DirectionName(spec.direction)
<< " reliability=" << ReliabilityName(spec.reliability)
<< " priority=" << PriorityValue(spec.priority)
<< " droppable=" << (spec.stale_updates_droppable ? "true" : "false")
<< " may_block_input=" << (spec.may_block_input ? "true" : "false")
<< '\n';
}
if (!ValidateAdapterChannelInvariants()) {
std::cerr << "adapter channel invariants failed\n";
return 1;
}
return 0;
}