Initial project snapshot
This commit is contained in:
@@ -0,0 +1,104 @@
|
||||
#include "rdp_worker/adapter/adapter_event_router.hpp"
|
||||
|
||||
#include "rdp_worker/common/json.hpp"
|
||||
|
||||
namespace rdp_worker::adapter {
|
||||
|
||||
namespace {
|
||||
|
||||
AdapterEventDescriptor MakeDescriptor(AdapterChannel channel,
|
||||
std::string_view type,
|
||||
bool adapter_origin) {
|
||||
return AdapterEventDescriptor{
|
||||
channel,
|
||||
type,
|
||||
adapter_origin,
|
||||
IsReliable(channel),
|
||||
IsDroppable(channel),
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
AdapterEventDescriptor AdapterEventRouter::DescribeRenderNotification(
|
||||
const runtime::RenderNotification& notification) const {
|
||||
if (notification.type == "session_frame") {
|
||||
const auto update_kind = common::GetString(notification.payload, "frame_update_kind").value_or("full");
|
||||
return MakeDescriptor(
|
||||
AdapterChannel::kDisplay,
|
||||
update_kind == "region" ? "display.region_bgra" : "display.baseline_full_bgra",
|
||||
true);
|
||||
}
|
||||
|
||||
if (notification.type == "session_cursor_updated") {
|
||||
return MakeDescriptor(AdapterChannel::kCursor, "cursor.update", true);
|
||||
}
|
||||
|
||||
if (notification.type == "session_render_resized") {
|
||||
return MakeDescriptor(AdapterChannel::kDisplay, "display.resize", true);
|
||||
}
|
||||
|
||||
if (notification.type == "session_render_dirty") {
|
||||
return MakeDescriptor(AdapterChannel::kDisplay, "display.dirty", true);
|
||||
}
|
||||
|
||||
return MakeDescriptor(AdapterChannel::kTelemetry, notification.type, true);
|
||||
}
|
||||
|
||||
AdapterEventDescriptor AdapterEventRouter::DescribeClipboardNotification(
|
||||
const runtime::ClipboardNotification&) const {
|
||||
return MakeDescriptor(AdapterChannel::kClipboard, "clipboard.server_text", true);
|
||||
}
|
||||
|
||||
AdapterEventDescriptor AdapterEventRouter::DescribeClientEnvelope(std::string_view envelope_type,
|
||||
std::string_view payload_kind,
|
||||
std::string_view payload_action) const {
|
||||
if (envelope_type == "input") {
|
||||
if (payload_kind == "mouse" && payload_action == "move") {
|
||||
return MakeDescriptor(AdapterChannel::kInput, "input.pointer_move", false);
|
||||
}
|
||||
if (payload_kind == "mouse") {
|
||||
return MakeDescriptor(AdapterChannel::kInput, "input.pointer", false);
|
||||
}
|
||||
if (payload_kind == "keyboard") {
|
||||
return MakeDescriptor(AdapterChannel::kInput, "input.keyboard", false);
|
||||
}
|
||||
if (payload_kind == "focus") {
|
||||
return MakeDescriptor(AdapterChannel::kInput, "input.focus", false);
|
||||
}
|
||||
return MakeDescriptor(AdapterChannel::kInput, "input.unknown", false);
|
||||
}
|
||||
|
||||
if (envelope_type == "clipboard") {
|
||||
return MakeDescriptor(AdapterChannel::kClipboard, "clipboard.client_text", false);
|
||||
}
|
||||
|
||||
if (envelope_type == "file_upload") {
|
||||
return MakeDescriptor(AdapterChannel::kFileTransfer, "file_upload.client_to_server", false);
|
||||
}
|
||||
|
||||
if (envelope_type == "control") {
|
||||
return MakeDescriptor(AdapterChannel::kControl, "control.command", false);
|
||||
}
|
||||
|
||||
return MakeDescriptor(AdapterChannel::kTelemetry, envelope_type, false);
|
||||
}
|
||||
|
||||
std::string AdapterEventDescriptorLogLine(const AdapterEventDescriptor& descriptor) {
|
||||
std::string line;
|
||||
line.reserve(160);
|
||||
line += "adapter_event channel=";
|
||||
line += ChannelName(descriptor.channel);
|
||||
line += " type=";
|
||||
line += descriptor.normalized_type;
|
||||
line += " origin=";
|
||||
line += descriptor.adapter_origin ? "adapter" : "client";
|
||||
line += " reliable=";
|
||||
line += descriptor.reliable ? "true" : "false";
|
||||
line += " droppable=";
|
||||
line += descriptor.droppable ? "true" : "false";
|
||||
return line;
|
||||
}
|
||||
|
||||
} // namespace rdp_worker::adapter
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
#include "rdp_worker/adapter/rdp_adapter_runtime.hpp"
|
||||
|
||||
#include <utility>
|
||||
|
||||
namespace rdp_worker::adapter {
|
||||
|
||||
RdpAdapterRuntime::RdpAdapterRuntime(std::shared_ptr<common::Logger> logger)
|
||||
: logger_(std::move(logger)),
|
||||
freerdp_(logger_) {}
|
||||
|
||||
bool RdpAdapterRuntime::Start(const runtime::ConnectionSpec& spec) {
|
||||
logger_->Info("rdp_adapter.runtime_start substrate=freerdp resource_id=" + spec.resource_id +
|
||||
" host=" + spec.host +
|
||||
" render_quality_profile=" + spec.render_quality_profile);
|
||||
lifecycle_logged_ = true;
|
||||
return freerdp_.Start(spec);
|
||||
}
|
||||
|
||||
void RdpAdapterRuntime::Disconnect(bool terminate) {
|
||||
logger_->Info("rdp_adapter.runtime_disconnect terminate=" + (terminate ? std::string("true") : std::string("false")));
|
||||
freerdp_.Disconnect(terminate);
|
||||
}
|
||||
|
||||
bool RdpAdapterRuntime::IsConnected() const {
|
||||
return freerdp_.IsConnected();
|
||||
}
|
||||
|
||||
bool RdpAdapterRuntime::PumpEvents(std::chrono::milliseconds timeout) {
|
||||
if (!lifecycle_logged_) {
|
||||
logger_->Info("rdp_adapter.event_pump_start substrate=freerdp");
|
||||
lifecycle_logged_ = true;
|
||||
}
|
||||
return freerdp_.PumpEvents(timeout);
|
||||
}
|
||||
|
||||
int RdpAdapterRuntime::DesktopWidth() const {
|
||||
return freerdp_.DesktopWidth();
|
||||
}
|
||||
|
||||
int RdpAdapterRuntime::DesktopHeight() const {
|
||||
return freerdp_.DesktopHeight();
|
||||
}
|
||||
|
||||
bool RdpAdapterRuntime::SendFocusEvent(bool focused) {
|
||||
TraceClientEnvelope("input", "focus", focused ? "focus_in" : "focus_out");
|
||||
return freerdp_.SendFocusEvent(focused);
|
||||
}
|
||||
|
||||
bool RdpAdapterRuntime::SendKeyboardInput(uint16_t scan_code, bool key_down, bool extended) {
|
||||
TraceClientEnvelope("input", "keyboard", key_down ? "key_down" : "key_up");
|
||||
return freerdp_.SendKeyboardInput(scan_code, key_down, extended);
|
||||
}
|
||||
|
||||
bool RdpAdapterRuntime::SendMouseMove(double normalized_x, double normalized_y) {
|
||||
TraceClientEnvelope("input", "mouse", "move");
|
||||
return freerdp_.SendMouseMove(normalized_x, normalized_y);
|
||||
}
|
||||
|
||||
bool RdpAdapterRuntime::SendMouseButton(const std::string& button,
|
||||
bool pressed,
|
||||
double normalized_x,
|
||||
double normalized_y) {
|
||||
TraceClientEnvelope("input", "mouse", pressed ? "button_down" : "button_up");
|
||||
return freerdp_.SendMouseButton(button, pressed, normalized_x, normalized_y);
|
||||
}
|
||||
|
||||
bool RdpAdapterRuntime::SendMouseWheel(int wheel_delta, bool horizontal, double normalized_x, double normalized_y) {
|
||||
TraceClientEnvelope("input", "mouse", "wheel");
|
||||
return freerdp_.SendMouseWheel(wheel_delta, horizontal, normalized_x, normalized_y);
|
||||
}
|
||||
|
||||
bool RdpAdapterRuntime::SetClipboardText(const std::string& text) {
|
||||
TraceClientEnvelope("clipboard", "text", "client_to_server");
|
||||
return freerdp_.SetClipboardText(text);
|
||||
}
|
||||
|
||||
void RdpAdapterRuntime::MarkInputAppliedForGraphicsTrace(const std::string& correlation_id) {
|
||||
freerdp_.MarkInputAppliedForGraphicsTrace(correlation_id);
|
||||
}
|
||||
|
||||
std::optional<runtime::RenderNotification> RdpAdapterRuntime::CaptureFullFrameNotification(
|
||||
const std::string& state,
|
||||
const std::string& capture_source) {
|
||||
auto notification = freerdp_.CaptureFullFrameNotification(state, capture_source);
|
||||
if (notification.has_value()) {
|
||||
TraceAdapterEvent(event_router_.DescribeRenderNotification(*notification));
|
||||
}
|
||||
return notification;
|
||||
}
|
||||
|
||||
std::optional<runtime::RenderNotification> RdpAdapterRuntime::PopRenderNotification() {
|
||||
auto notification = freerdp_.PopRenderNotification();
|
||||
if (notification.has_value()) {
|
||||
TraceAdapterEvent(event_router_.DescribeRenderNotification(*notification));
|
||||
}
|
||||
return notification;
|
||||
}
|
||||
|
||||
std::optional<runtime::ClipboardNotification> RdpAdapterRuntime::PopClipboardNotification() {
|
||||
auto notification = freerdp_.PopClipboardNotification();
|
||||
if (notification.has_value()) {
|
||||
TraceAdapterEvent(event_router_.DescribeClipboardNotification(*notification));
|
||||
}
|
||||
return notification;
|
||||
}
|
||||
|
||||
const std::string& RdpAdapterRuntime::RenderQualityProfile() const {
|
||||
return freerdp_.RenderQualityProfile();
|
||||
}
|
||||
|
||||
const AdapterEventRouter& RdpAdapterRuntime::EventRouter() const {
|
||||
return event_router_;
|
||||
}
|
||||
|
||||
void RdpAdapterRuntime::TraceClientEnvelope(std::string_view envelope_type,
|
||||
std::string_view payload_kind,
|
||||
std::string_view payload_action) {
|
||||
const auto descriptor = event_router_.DescribeClientEnvelope(envelope_type, payload_kind, payload_action);
|
||||
if (descriptor.channel == AdapterChannel::kInput && descriptor.normalized_type == "input.pointer_move") {
|
||||
return;
|
||||
}
|
||||
TraceAdapterEvent(descriptor);
|
||||
}
|
||||
|
||||
void RdpAdapterRuntime::TraceAdapterEvent(const AdapterEventDescriptor& descriptor) {
|
||||
if (descriptor.channel == AdapterChannel::kDisplay &&
|
||||
descriptor.normalized_type != "display.resize" &&
|
||||
descriptor.normalized_type != "display.baseline_full_bgra") {
|
||||
return;
|
||||
}
|
||||
logger_->Info(AdapterEventDescriptorLogLine(descriptor));
|
||||
}
|
||||
|
||||
} // namespace rdp_worker::adapter
|
||||
@@ -0,0 +1,97 @@
|
||||
#include "rdp_worker/adapter/service_adapter_protocol.hpp"
|
||||
|
||||
namespace rdp_worker::adapter {
|
||||
|
||||
std::optional<ChannelSpec> FindChannelSpec(std::string_view name) {
|
||||
for (const auto& spec : AllChannelSpecs()) {
|
||||
if (spec.name == name) {
|
||||
return spec;
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::string_view ChannelName(AdapterChannel channel) {
|
||||
for (const auto& spec : AllChannelSpecs()) {
|
||||
if (spec.channel == channel) {
|
||||
return spec.name;
|
||||
}
|
||||
}
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
std::string_view DirectionName(ChannelDirection direction) {
|
||||
switch (direction) {
|
||||
case ChannelDirection::kClientToAdapter:
|
||||
return "client_to_adapter";
|
||||
case ChannelDirection::kAdapterToClient:
|
||||
return "adapter_to_client";
|
||||
case ChannelDirection::kBidirectional:
|
||||
return "bidirectional";
|
||||
}
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
std::string_view ReliabilityName(ChannelReliability reliability) {
|
||||
switch (reliability) {
|
||||
case ChannelReliability::kReliableOrdered:
|
||||
return "reliable_ordered";
|
||||
case ChannelReliability::kReliableChunked:
|
||||
return "reliable_chunked";
|
||||
case ChannelReliability::kDroppableLatest:
|
||||
return "droppable_latest";
|
||||
case ChannelReliability::kAdaptiveDroppable:
|
||||
return "adaptive_droppable";
|
||||
case ChannelReliability::kSampledDroppable:
|
||||
return "sampled_droppable";
|
||||
}
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
int PriorityValue(ChannelPriority priority) {
|
||||
return static_cast<int>(priority);
|
||||
}
|
||||
|
||||
bool IsDroppable(AdapterChannel channel) {
|
||||
for (const auto& spec : AllChannelSpecs()) {
|
||||
if (spec.channel == channel) {
|
||||
return spec.stale_updates_droppable;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsReliable(AdapterChannel channel) {
|
||||
for (const auto& spec : AllChannelSpecs()) {
|
||||
if (spec.channel == channel) {
|
||||
return spec.reliability == ChannelReliability::kReliableOrdered ||
|
||||
spec.reliability == ChannelReliability::kReliableChunked;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ValidateAdapterChannelInvariants() {
|
||||
const auto input = FindChannelSpec("input");
|
||||
if (!input.has_value() || input->priority != ChannelPriority::kCritical || input->may_block_input) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto& spec : AllChannelSpecs()) {
|
||||
if (spec.channel != AdapterChannel::kInput &&
|
||||
PriorityValue(spec.priority) <= PriorityValue(ChannelPriority::kCritical)) {
|
||||
return false;
|
||||
}
|
||||
if (spec.may_block_input) {
|
||||
return false;
|
||||
}
|
||||
if ((spec.channel == AdapterChannel::kDisplay || spec.channel == AdapterChannel::kCursor) &&
|
||||
!spec.stale_updates_droppable) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace rdp_worker::adapter
|
||||
|
||||
@@ -0,0 +1,401 @@
|
||||
#include "rdp_worker/common/json.hpp"
|
||||
|
||||
#include <cctype>
|
||||
#include <cstdint>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace rdp_worker::common {
|
||||
|
||||
JsonValue::JsonValue() : value(nullptr) {}
|
||||
JsonValue::JsonValue(std::nullptr_t) : value(nullptr) {}
|
||||
JsonValue::JsonValue(bool input) : value(input) {}
|
||||
JsonValue::JsonValue(double input) : value(input) {}
|
||||
JsonValue::JsonValue(int input) : value(static_cast<double>(input)) {}
|
||||
JsonValue::JsonValue(const char* input) : value(std::string(input)) {}
|
||||
JsonValue::JsonValue(std::string input) : value(std::move(input)) {}
|
||||
JsonValue::JsonValue(JsonArray input) : value(std::move(input)) {}
|
||||
JsonValue::JsonValue(JsonObject input) : value(std::move(input)) {}
|
||||
|
||||
bool JsonValue::IsObject() const { return std::holds_alternative<JsonObject>(value); }
|
||||
bool JsonValue::IsArray() const { return std::holds_alternative<JsonArray>(value); }
|
||||
bool JsonValue::IsString() const { return std::holds_alternative<std::string>(value); }
|
||||
bool JsonValue::IsBool() const { return std::holds_alternative<bool>(value); }
|
||||
bool JsonValue::IsNumber() const { return std::holds_alternative<double>(value); }
|
||||
const JsonObject& JsonValue::AsObject() const { return std::get<JsonObject>(value); }
|
||||
const JsonArray& JsonValue::AsArray() const { return std::get<JsonArray>(value); }
|
||||
const std::string& JsonValue::AsString() const { return std::get<std::string>(value); }
|
||||
bool JsonValue::AsBool() const { return std::get<bool>(value); }
|
||||
double JsonValue::AsNumber() const { return std::get<double>(value); }
|
||||
|
||||
namespace {
|
||||
|
||||
void AppendUtf8(std::string& output, std::uint32_t codepoint) {
|
||||
if (codepoint <= 0x7F) {
|
||||
output.push_back(static_cast<char>(codepoint));
|
||||
} else if (codepoint <= 0x7FF) {
|
||||
output.push_back(static_cast<char>(0xC0 | (codepoint >> 6)));
|
||||
output.push_back(static_cast<char>(0x80 | (codepoint & 0x3F)));
|
||||
} else if (codepoint <= 0xFFFF) {
|
||||
output.push_back(static_cast<char>(0xE0 | (codepoint >> 12)));
|
||||
output.push_back(static_cast<char>(0x80 | ((codepoint >> 6) & 0x3F)));
|
||||
output.push_back(static_cast<char>(0x80 | (codepoint & 0x3F)));
|
||||
} else if (codepoint <= 0x10FFFF) {
|
||||
output.push_back(static_cast<char>(0xF0 | (codepoint >> 18)));
|
||||
output.push_back(static_cast<char>(0x80 | ((codepoint >> 12) & 0x3F)));
|
||||
output.push_back(static_cast<char>(0x80 | ((codepoint >> 6) & 0x3F)));
|
||||
output.push_back(static_cast<char>(0x80 | (codepoint & 0x3F)));
|
||||
} else {
|
||||
throw std::runtime_error("invalid JSON unicode escape");
|
||||
}
|
||||
}
|
||||
|
||||
class Parser {
|
||||
public:
|
||||
explicit Parser(const std::string& input) : input_(input), index_(0) {}
|
||||
|
||||
JsonValue Parse() {
|
||||
SkipWhitespace();
|
||||
JsonValue value = ParseValue();
|
||||
SkipWhitespace();
|
||||
if (index_ != input_.size()) {
|
||||
throw std::runtime_error("unexpected trailing JSON data");
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
private:
|
||||
JsonValue ParseValue() {
|
||||
SkipWhitespace();
|
||||
if (Match("null")) {
|
||||
return JsonValue(nullptr);
|
||||
}
|
||||
if (Match("true")) {
|
||||
return JsonValue(true);
|
||||
}
|
||||
if (Match("false")) {
|
||||
return JsonValue(false);
|
||||
}
|
||||
if (Peek() == '"') {
|
||||
return JsonValue(ParseString());
|
||||
}
|
||||
if (Peek() == '{') {
|
||||
return JsonValue(ParseObject());
|
||||
}
|
||||
if (Peek() == '[') {
|
||||
return JsonValue(ParseArray());
|
||||
}
|
||||
if (Peek() == '-' || std::isdigit(static_cast<unsigned char>(Peek()))) {
|
||||
return JsonValue(ParseNumber());
|
||||
}
|
||||
throw std::runtime_error("unexpected JSON token");
|
||||
}
|
||||
|
||||
JsonObject ParseObject() {
|
||||
Expect('{');
|
||||
JsonObject object;
|
||||
SkipWhitespace();
|
||||
if (Peek() == '}') {
|
||||
Advance();
|
||||
return object;
|
||||
}
|
||||
while (true) {
|
||||
const std::string key = ParseString();
|
||||
SkipWhitespace();
|
||||
Expect(':');
|
||||
object.emplace(key, ParseValue());
|
||||
SkipWhitespace();
|
||||
if (Peek() == '}') {
|
||||
Advance();
|
||||
break;
|
||||
}
|
||||
Expect(',');
|
||||
}
|
||||
return object;
|
||||
}
|
||||
|
||||
JsonArray ParseArray() {
|
||||
Expect('[');
|
||||
JsonArray array;
|
||||
SkipWhitespace();
|
||||
if (Peek() == ']') {
|
||||
Advance();
|
||||
return array;
|
||||
}
|
||||
while (true) {
|
||||
array.emplace_back(ParseValue());
|
||||
SkipWhitespace();
|
||||
if (Peek() == ']') {
|
||||
Advance();
|
||||
break;
|
||||
}
|
||||
Expect(',');
|
||||
}
|
||||
return array;
|
||||
}
|
||||
|
||||
std::string ParseString() {
|
||||
Expect('"');
|
||||
std::string output;
|
||||
while (index_ < input_.size()) {
|
||||
const char current = input_[index_++];
|
||||
if (current == '"') {
|
||||
return output;
|
||||
}
|
||||
if (current == '\\') {
|
||||
if (index_ >= input_.size()) {
|
||||
throw std::runtime_error("invalid JSON escape");
|
||||
}
|
||||
const char escaped = input_[index_++];
|
||||
switch (escaped) {
|
||||
case '"':
|
||||
case '\\':
|
||||
case '/':
|
||||
output.push_back(escaped);
|
||||
break;
|
||||
case 'b':
|
||||
output.push_back('\b');
|
||||
break;
|
||||
case 'f':
|
||||
output.push_back('\f');
|
||||
break;
|
||||
case 'n':
|
||||
output.push_back('\n');
|
||||
break;
|
||||
case 'r':
|
||||
output.push_back('\r');
|
||||
break;
|
||||
case 't':
|
||||
output.push_back('\t');
|
||||
break;
|
||||
case 'u':
|
||||
AppendUtf8(output, ParseUnicodeEscape());
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("unsupported JSON escape");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
output.push_back(current);
|
||||
}
|
||||
throw std::runtime_error("unterminated JSON string");
|
||||
}
|
||||
|
||||
std::uint32_t ParseUnicodeEscape() {
|
||||
const std::uint32_t first = ParseHexQuad();
|
||||
if (first >= 0xD800 && first <= 0xDBFF) {
|
||||
if (index_ + 1 >= input_.size() || input_[index_] != '\\' || input_[index_ + 1] != 'u') {
|
||||
throw std::runtime_error("invalid JSON unicode surrogate");
|
||||
}
|
||||
index_ += 2;
|
||||
const std::uint32_t second = ParseHexQuad();
|
||||
if (second < 0xDC00 || second > 0xDFFF) {
|
||||
throw std::runtime_error("invalid JSON unicode surrogate");
|
||||
}
|
||||
return 0x10000 + (((first - 0xD800) << 10) | (second - 0xDC00));
|
||||
}
|
||||
if (first >= 0xDC00 && first <= 0xDFFF) {
|
||||
throw std::runtime_error("invalid JSON unicode surrogate");
|
||||
}
|
||||
return first;
|
||||
}
|
||||
|
||||
std::uint32_t ParseHexQuad() {
|
||||
if (index_ + 4 > input_.size()) {
|
||||
throw std::runtime_error("invalid JSON unicode escape");
|
||||
}
|
||||
std::uint32_t value = 0;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
const char ch = input_[index_++];
|
||||
value <<= 4;
|
||||
if (ch >= '0' && ch <= '9') {
|
||||
value |= static_cast<std::uint32_t>(ch - '0');
|
||||
} else if (ch >= 'a' && ch <= 'f') {
|
||||
value |= static_cast<std::uint32_t>(ch - 'a' + 10);
|
||||
} else if (ch >= 'A' && ch <= 'F') {
|
||||
value |= static_cast<std::uint32_t>(ch - 'A' + 10);
|
||||
} else {
|
||||
throw std::runtime_error("invalid JSON unicode escape");
|
||||
}
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
double ParseNumber() {
|
||||
const std::size_t start = index_;
|
||||
if (Peek() == '-') {
|
||||
Advance();
|
||||
}
|
||||
while (std::isdigit(static_cast<unsigned char>(Peek()))) {
|
||||
Advance();
|
||||
}
|
||||
if (Peek() == '.') {
|
||||
Advance();
|
||||
while (std::isdigit(static_cast<unsigned char>(Peek()))) {
|
||||
Advance();
|
||||
}
|
||||
}
|
||||
return std::stod(input_.substr(start, index_ - start));
|
||||
}
|
||||
|
||||
void SkipWhitespace() {
|
||||
while (index_ < input_.size() && std::isspace(static_cast<unsigned char>(input_[index_]))) {
|
||||
++index_;
|
||||
}
|
||||
}
|
||||
|
||||
char Peek() const {
|
||||
if (index_ >= input_.size()) {
|
||||
return '\0';
|
||||
}
|
||||
return input_[index_];
|
||||
}
|
||||
|
||||
void Advance() {
|
||||
if (index_ < input_.size()) {
|
||||
++index_;
|
||||
}
|
||||
}
|
||||
|
||||
void Expect(char expected) {
|
||||
SkipWhitespace();
|
||||
if (Peek() != expected) {
|
||||
throw std::runtime_error("unexpected JSON character");
|
||||
}
|
||||
Advance();
|
||||
}
|
||||
|
||||
bool Match(const char* keyword) {
|
||||
const std::size_t length = std::char_traits<char>::length(keyword);
|
||||
if (input_.substr(index_, length) == keyword) {
|
||||
index_ += length;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const std::string& input_;
|
||||
std::size_t index_;
|
||||
};
|
||||
|
||||
std::string Escape(const std::string& value) {
|
||||
std::ostringstream output;
|
||||
for (const char ch : value) {
|
||||
switch (ch) {
|
||||
case '"':
|
||||
output << "\\\"";
|
||||
break;
|
||||
case '\\':
|
||||
output << "\\\\";
|
||||
break;
|
||||
case '\n':
|
||||
output << "\\n";
|
||||
break;
|
||||
case '\r':
|
||||
output << "\\r";
|
||||
break;
|
||||
case '\t':
|
||||
output << "\\t";
|
||||
break;
|
||||
default:
|
||||
output << ch;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return output.str();
|
||||
}
|
||||
|
||||
std::string SerializeInternal(const JsonValue& value) {
|
||||
if (std::holds_alternative<std::nullptr_t>(value.value)) {
|
||||
return "null";
|
||||
}
|
||||
if (std::holds_alternative<bool>(value.value)) {
|
||||
return std::get<bool>(value.value) ? "true" : "false";
|
||||
}
|
||||
if (std::holds_alternative<double>(value.value)) {
|
||||
std::ostringstream output;
|
||||
output << std::get<double>(value.value);
|
||||
return output.str();
|
||||
}
|
||||
if (std::holds_alternative<std::string>(value.value)) {
|
||||
return "\"" + Escape(std::get<std::string>(value.value)) + "\"";
|
||||
}
|
||||
if (std::holds_alternative<JsonArray>(value.value)) {
|
||||
std::ostringstream output;
|
||||
output << "[";
|
||||
const auto& array = std::get<JsonArray>(value.value);
|
||||
for (std::size_t i = 0; i < array.size(); ++i) {
|
||||
if (i > 0) {
|
||||
output << ",";
|
||||
}
|
||||
output << SerializeInternal(array[i]);
|
||||
}
|
||||
output << "]";
|
||||
return output.str();
|
||||
}
|
||||
std::ostringstream output;
|
||||
output << "{";
|
||||
const auto& object = std::get<JsonObject>(value.value);
|
||||
bool first = true;
|
||||
for (const auto& [key, child] : object) {
|
||||
if (!first) {
|
||||
output << ",";
|
||||
}
|
||||
first = false;
|
||||
output << "\"" << Escape(key) << "\":" << SerializeInternal(child);
|
||||
}
|
||||
output << "}";
|
||||
return output.str();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
JsonValue ParseJson(const std::string& input) {
|
||||
return Parser(input).Parse();
|
||||
}
|
||||
|
||||
std::string SerializeJson(const JsonValue& value) {
|
||||
return SerializeInternal(value);
|
||||
}
|
||||
|
||||
std::optional<std::string> GetString(const JsonObject& object, const std::string& key) {
|
||||
auto iterator = object.find(key);
|
||||
if (iterator == object.end() || !iterator->second.IsString()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return iterator->second.AsString();
|
||||
}
|
||||
|
||||
std::optional<bool> GetBool(const JsonObject& object, const std::string& key) {
|
||||
auto iterator = object.find(key);
|
||||
if (iterator == object.end() || !iterator->second.IsBool()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return iterator->second.AsBool();
|
||||
}
|
||||
|
||||
std::optional<double> GetNumber(const JsonObject& object, const std::string& key) {
|
||||
auto iterator = object.find(key);
|
||||
if (iterator == object.end() || !iterator->second.IsNumber()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return iterator->second.AsNumber();
|
||||
}
|
||||
|
||||
const JsonObject* GetObject(const JsonObject& object, const std::string& key) {
|
||||
auto iterator = object.find(key);
|
||||
if (iterator == object.end() || !iterator->second.IsObject()) {
|
||||
return nullptr;
|
||||
}
|
||||
return &iterator->second.AsObject();
|
||||
}
|
||||
|
||||
const JsonArray* GetArray(const JsonObject& object, const std::string& key) {
|
||||
auto iterator = object.find(key);
|
||||
if (iterator == object.end() || !iterator->second.IsArray()) {
|
||||
return nullptr;
|
||||
}
|
||||
return &iterator->second.AsArray();
|
||||
}
|
||||
|
||||
} // namespace rdp_worker::common
|
||||
@@ -0,0 +1,52 @@
|
||||
#include "rdp_worker/common/logger.hpp"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "rdp_worker/common/time.hpp"
|
||||
|
||||
namespace rdp_worker::common {
|
||||
|
||||
namespace {
|
||||
|
||||
std::string LevelToString(LogLevel level) {
|
||||
switch (level) {
|
||||
case LogLevel::kDebug:
|
||||
return "DEBUG";
|
||||
case LogLevel::kInfo:
|
||||
return "INFO";
|
||||
case LogLevel::kWarn:
|
||||
return "WARN";
|
||||
case LogLevel::kError:
|
||||
return "ERROR";
|
||||
}
|
||||
return "INFO";
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Logger::Logger(std::string service_name) : service_name_(std::move(service_name)) {}
|
||||
|
||||
void Logger::Debug(const std::string& message) {
|
||||
Write(LogLevel::kDebug, message);
|
||||
}
|
||||
|
||||
void Logger::Info(const std::string& message) {
|
||||
Write(LogLevel::kInfo, message);
|
||||
}
|
||||
|
||||
void Logger::Warn(const std::string& message) {
|
||||
Write(LogLevel::kWarn, message);
|
||||
}
|
||||
|
||||
void Logger::Error(const std::string& message) {
|
||||
Write(LogLevel::kError, message);
|
||||
}
|
||||
|
||||
void Logger::Write(LogLevel level, const std::string& message) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
std::cout << "{\"ts\":\"" << ToRfc3339(NowUtc()) << "\",\"service\":\"" << service_name_
|
||||
<< "\",\"level\":\"" << LevelToString(level) << "\",\"message\":\"" << message
|
||||
<< "\"}" << std::endl;
|
||||
}
|
||||
|
||||
} // namespace rdp_worker::common
|
||||
@@ -0,0 +1,42 @@
|
||||
#include "rdp_worker/common/time.hpp"
|
||||
|
||||
#include <ctime>
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace rdp_worker::common {
|
||||
|
||||
Clock::time_point NowUtc() {
|
||||
return Clock::now();
|
||||
}
|
||||
|
||||
std::string ToRfc3339(Clock::time_point time_point) {
|
||||
const std::time_t raw = Clock::to_time_t(time_point);
|
||||
std::tm utc_tm{};
|
||||
#if defined(_WIN32)
|
||||
gmtime_s(&utc_tm, &raw);
|
||||
#else
|
||||
gmtime_r(&raw, &utc_tm);
|
||||
#endif
|
||||
std::ostringstream output;
|
||||
output << std::put_time(&utc_tm, "%Y-%m-%dT%H:%M:%SZ");
|
||||
return output.str();
|
||||
}
|
||||
|
||||
Clock::time_point ParseRfc3339(const std::string& value) {
|
||||
std::tm utc_tm{};
|
||||
std::istringstream input(value);
|
||||
input >> std::get_time(&utc_tm, "%Y-%m-%dT%H:%M:%SZ");
|
||||
if (input.fail()) {
|
||||
throw std::runtime_error("failed to parse RFC3339 timestamp: " + value);
|
||||
}
|
||||
#if defined(_WIN32)
|
||||
const std::time_t raw = _mkgmtime(&utc_tm);
|
||||
#else
|
||||
const std::time_t raw = timegm(&utc_tm);
|
||||
#endif
|
||||
return Clock::from_time_t(raw);
|
||||
}
|
||||
|
||||
} // namespace rdp_worker::common
|
||||
@@ -0,0 +1,89 @@
|
||||
#include "rdp_worker/config/config.hpp"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace rdp_worker::config {
|
||||
|
||||
namespace {
|
||||
|
||||
std::string GetEnvOrDefault(const char* key, const char* fallback) {
|
||||
const char* value = std::getenv(key);
|
||||
if (value == nullptr || std::string(value).empty()) {
|
||||
return fallback;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
int GetInt(const char* key, int fallback) {
|
||||
const char* value = std::getenv(key);
|
||||
if (value == nullptr || std::string(value).empty()) {
|
||||
return fallback;
|
||||
}
|
||||
return std::stoi(value);
|
||||
}
|
||||
|
||||
bool GetBool(const char* key, bool fallback) {
|
||||
const char* value = std::getenv(key);
|
||||
if (value == nullptr || std::string(value).empty()) {
|
||||
return fallback;
|
||||
}
|
||||
const std::string raw(value);
|
||||
return raw == "1" || raw == "true" || raw == "TRUE" || raw == "yes" || raw == "on";
|
||||
}
|
||||
|
||||
std::string ReadFileIfConfigured(const std::string& path) {
|
||||
if (path.empty()) {
|
||||
return "";
|
||||
}
|
||||
std::ifstream input(path);
|
||||
if (!input.good()) {
|
||||
throw std::runtime_error("failed to read file " + path);
|
||||
}
|
||||
std::stringstream buffer;
|
||||
buffer << input.rdbuf();
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
std::vector<std::string> Split(const std::string& value) {
|
||||
std::vector<std::string> parts;
|
||||
std::stringstream stream(value);
|
||||
std::string part;
|
||||
while (std::getline(stream, part, ',')) {
|
||||
if (!part.empty()) {
|
||||
parts.push_back(part);
|
||||
}
|
||||
}
|
||||
return parts;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Config LoadFromEnv() {
|
||||
Config config{};
|
||||
config.worker_id = GetEnvOrDefault("RDP_WORKER_ID", "rdp-worker-1");
|
||||
config.redis_host = GetEnvOrDefault("RDP_WORKER_REDIS_HOST", "127.0.0.1");
|
||||
config.redis_port = GetInt("RDP_WORKER_REDIS_PORT", 6379);
|
||||
config.redis_password = GetEnvOrDefault("RDP_WORKER_REDIS_PASSWORD", "");
|
||||
config.redis_db = GetInt("RDP_WORKER_REDIS_DB", 0);
|
||||
config.worker_heartbeat_interval = std::chrono::seconds(GetInt("RDP_WORKER_HEARTBEAT_INTERVAL_SECONDS", 5));
|
||||
config.lease_renew_interval = std::chrono::seconds(GetInt("RDP_WORKER_LEASE_RENEW_INTERVAL_SECONDS", 10));
|
||||
config.assignment_poll_interval = std::chrono::seconds(GetInt("RDP_WORKER_ASSIGNMENT_POLL_INTERVAL_SECONDS", 2));
|
||||
config.insecure_skip_verify = GetBool("RDP_WORKER_INSECURE_SKIP_VERIFY", false);
|
||||
config.capabilities = Split(GetEnvOrDefault("RDP_WORKER_CAPABILITIES", "adaptive-quality,dirty-rects,clipboard,file-transfer"));
|
||||
config.data_plane_enabled = GetBool("RDP_WORKER_DATA_PLANE_ENABLED", false);
|
||||
config.data_plane_listen_host = GetEnvOrDefault("RDP_WORKER_DATA_PLANE_LISTEN_HOST", "0.0.0.0");
|
||||
config.data_plane_listen_port = GetInt("RDP_WORKER_DATA_PLANE_LISTEN_PORT", 8443);
|
||||
config.data_plane_public_key_pem = GetEnvOrDefault("RDP_WORKER_DATA_PLANE_PUBLIC_KEY_PEM", "");
|
||||
config.data_plane_public_key_file = GetEnvOrDefault("RDP_WORKER_DATA_PLANE_PUBLIC_KEY_FILE", "");
|
||||
if (config.data_plane_public_key_pem.empty()) {
|
||||
config.data_plane_public_key_pem = ReadFileIfConfigured(config.data_plane_public_key_file);
|
||||
}
|
||||
config.data_plane_tls_cert_file = GetEnvOrDefault("RDP_WORKER_DATA_PLANE_TLS_CERT_FILE", "");
|
||||
config.data_plane_tls_key_file = GetEnvOrDefault("RDP_WORKER_DATA_PLANE_TLS_KEY_FILE", "");
|
||||
return config;
|
||||
}
|
||||
|
||||
} // namespace rdp_worker::config
|
||||
@@ -0,0 +1,300 @@
|
||||
#include "rdp_worker/coordination/control_plane.hpp"
|
||||
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "rdp_worker/common/logger.hpp"
|
||||
#include "rdp_worker/common/time.hpp"
|
||||
|
||||
namespace rdp_worker::coordination {
|
||||
|
||||
using common::GetArray;
|
||||
using common::GetBool;
|
||||
using common::GetNumber;
|
||||
using common::GetObject;
|
||||
using common::GetString;
|
||||
using common::JsonArray;
|
||||
using common::JsonObject;
|
||||
using common::JsonValue;
|
||||
|
||||
ControlPlane::ControlPlane(config::Config config, std::shared_ptr<common::Logger> logger)
|
||||
: config_(std::move(config)),
|
||||
logger_(std::move(logger)),
|
||||
redis_(std::make_unique<RedisClient>(config_.redis_host, config_.redis_port, config_.redis_password, config_.redis_db)) {}
|
||||
|
||||
void ControlPlane::Connect() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
redis_->Connect();
|
||||
}
|
||||
|
||||
void ControlPlane::RegisterWorker() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
redis_->Set("worker:registration:" + config_.worker_id, WorkerRegistrationPayload(), config_.worker_heartbeat_interval * 3);
|
||||
redis_->SAdd("worker:registrations", config_.worker_id);
|
||||
}
|
||||
|
||||
void ControlPlane::ReleaseOwnedLeasesOnStartup() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
int released = 0;
|
||||
for (const auto& lease_id : redis_->SMembers("worker:leases")) {
|
||||
auto encoded = redis_->Get("worker:lease:" + lease_id);
|
||||
if (!encoded.has_value()) {
|
||||
redis_->SRem("worker:leases", lease_id);
|
||||
continue;
|
||||
}
|
||||
auto lease = ParseLease(common::ParseJson(*encoded).AsObject());
|
||||
if (lease.worker_id != config_.worker_id) {
|
||||
continue;
|
||||
}
|
||||
redis_->Delete("worker:lease:" + lease_id);
|
||||
redis_->SRem("worker:leases", lease_id);
|
||||
if (!lease.session_id.empty()) {
|
||||
redis_->Delete("worker:session-lease:" + lease.session_id);
|
||||
redis_->Delete("worker:queue:" + lease.session_id);
|
||||
}
|
||||
++released;
|
||||
}
|
||||
if (released > 0) {
|
||||
logger_->Warn("released stale owned worker leases on startup worker=" + config_.worker_id +
|
||||
" released_count=" + std::to_string(released));
|
||||
}
|
||||
}
|
||||
|
||||
void ControlPlane::SendHeartbeat() {
|
||||
RegisterWorker();
|
||||
}
|
||||
|
||||
std::optional<runtime::Assignment> ControlPlane::PollAssignment(std::chrono::seconds timeout) {
|
||||
RedisClient stream_client(config_.redis_host, config_.redis_port, config_.redis_password, config_.redis_db);
|
||||
stream_client.Connect();
|
||||
auto entry = stream_client.BLPop("worker:control:" + config_.worker_id, timeout);
|
||||
if (!entry.has_value() || entry->size() != 2) {
|
||||
return std::nullopt;
|
||||
}
|
||||
const JsonObject object = common::ParseJson((*entry)[1]).AsObject();
|
||||
return ParseAssignment(object);
|
||||
}
|
||||
|
||||
std::optional<common::JsonObject> ControlPlane::PollSessionEnvelope(const std::string& session_id, std::chrono::seconds timeout) {
|
||||
RedisClient stream_client(config_.redis_host, config_.redis_port, config_.redis_password, config_.redis_db);
|
||||
stream_client.Connect();
|
||||
auto entry = stream_client.BLPop("worker:queue:" + session_id, timeout);
|
||||
if (!entry.has_value() || entry->size() != 2) {
|
||||
return std::nullopt;
|
||||
}
|
||||
auto object = common::ParseJson((*entry)[1]).AsObject();
|
||||
const JsonObject* payload = GetObject(object, "payload");
|
||||
if (payload != nullptr) {
|
||||
const std::string type = GetString(object, "type").value_or("");
|
||||
const std::string correlation_id = GetString(*payload, "correlation_id").value_or("");
|
||||
if (type == "input" && !correlation_id.empty()) {
|
||||
logger_->Info("input.trace worker_queue_pop session=" + session_id +
|
||||
" correlation_id=" + correlation_id +
|
||||
" trace_stage=worker_queue_pop");
|
||||
}
|
||||
}
|
||||
return object;
|
||||
}
|
||||
|
||||
std::vector<common::JsonObject> ControlPlane::DrainSessionEnvelopes(const std::string& session_id, std::size_t max_count) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
std::vector<common::JsonObject> output;
|
||||
output.reserve(max_count);
|
||||
const std::string key = "worker:queue:" + session_id;
|
||||
for (std::size_t i = 0; i < max_count; ++i) {
|
||||
auto encoded = redis_->LPop(key);
|
||||
if (!encoded.has_value()) {
|
||||
break;
|
||||
}
|
||||
auto object = common::ParseJson(*encoded).AsObject();
|
||||
const JsonObject* payload = GetObject(object, "payload");
|
||||
if (payload != nullptr) {
|
||||
const std::string type = GetString(object, "type").value_or("");
|
||||
const std::string correlation_id = GetString(*payload, "correlation_id").value_or("");
|
||||
if (type == "input" && !correlation_id.empty()) {
|
||||
logger_->Info("input.trace worker_queue_pop session=" + session_id +
|
||||
" correlation_id=" + correlation_id +
|
||||
" trace_stage=worker_queue_pop");
|
||||
}
|
||||
}
|
||||
output.push_back(std::move(object));
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
int64_t ControlPlane::SessionEnvelopeQueueLength(const std::string& session_id) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return redis_->LLen("worker:queue:" + session_id);
|
||||
}
|
||||
|
||||
std::optional<runtime::WorkerLease> ControlPlane::GetLeaseBySession(const std::string& session_id) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
auto lease_id = redis_->Get("worker:session-lease:" + session_id);
|
||||
if (!lease_id.has_value()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
auto encoded = redis_->Get("worker:lease:" + *lease_id);
|
||||
if (!encoded.has_value()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return ParseLease(common::ParseJson(*encoded).AsObject());
|
||||
}
|
||||
|
||||
void ControlPlane::RenewLease(const runtime::WorkerLease& lease) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
redis_->Set("worker:lease:" + lease.lease_id, LeasePayload(lease), std::chrono::seconds(45));
|
||||
redis_->Set("worker:session-lease:" + lease.session_id, lease.lease_id, std::chrono::seconds(45));
|
||||
}
|
||||
|
||||
void ControlPlane::ReleaseLease(const runtime::WorkerLease& lease) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
redis_->Delete("worker:lease:" + lease.lease_id);
|
||||
redis_->Delete("worker:session-lease:" + lease.session_id);
|
||||
}
|
||||
|
||||
void ControlPlane::PublishEvent(const runtime::WorkerEvent& event) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
const std::string encoded = EventPayload(event);
|
||||
redis_->RPush("worker:events", encoded);
|
||||
redis_->Expire("worker:events", std::chrono::minutes(10));
|
||||
}
|
||||
|
||||
runtime::Assignment ControlPlane::ParseAssignment(const JsonObject& object) const {
|
||||
runtime::Assignment assignment{};
|
||||
assignment.session_id = GetString(object, "session_id").value_or("");
|
||||
assignment.worker_id = GetString(object, "worker_id").value_or("");
|
||||
assignment.attachment_id = GetString(object, "attachment_id").value_or("");
|
||||
assignment.user_id = GetString(object, "user_id").value_or("");
|
||||
assignment.device_id = GetString(object, "device_id").value_or("");
|
||||
assignment.takeover_of = GetString(object, "takeover_of");
|
||||
const std::string state = GetString(object, "state").value_or("starting");
|
||||
if (state == "active") {
|
||||
assignment.state = runtime::SessionState::kActive;
|
||||
} else if (state == "detached") {
|
||||
assignment.state = runtime::SessionState::kDetached;
|
||||
} else if (state == "reconnecting") {
|
||||
assignment.state = runtime::SessionState::kReconnecting;
|
||||
} else {
|
||||
assignment.state = runtime::SessionState::kStarting;
|
||||
}
|
||||
|
||||
const JsonObject* metadata = GetObject(object, "metadata");
|
||||
if (metadata == nullptr) {
|
||||
throw std::runtime_error("assignment metadata is required");
|
||||
}
|
||||
const JsonObject* resource = GetObject(*metadata, "resource");
|
||||
if (resource == nullptr) {
|
||||
throw std::runtime_error("assignment resource metadata is required");
|
||||
}
|
||||
assignment.organization_id = GetString(*resource, "organization_id").value_or("");
|
||||
assignment.connection.resource_id = GetString(*resource, "id").value_or("");
|
||||
assignment.connection.resource_name = GetString(*resource, "name").value_or("");
|
||||
assignment.connection.host = GetString(*resource, "address").value_or("");
|
||||
assignment.connection.port = 3389;
|
||||
assignment.connection.username = "";
|
||||
assignment.connection.password = "";
|
||||
assignment.connection.domain = "";
|
||||
assignment.connection.certificate_verification_mode = GetString(*resource, "certificate_verification_mode").value_or("strict");
|
||||
assignment.connection.render_quality_profile = GetString(*resource, "render_quality_profile").value_or("balanced");
|
||||
assignment.connection.insecure_skip_verify = config_.insecure_skip_verify;
|
||||
const JsonObject* resource_meta = GetObject(*resource, "metadata");
|
||||
if (resource_meta != nullptr) {
|
||||
assignment.connection.host = GetString(*resource_meta, "rdp_host").value_or(assignment.connection.host);
|
||||
assignment.connection.port = static_cast<uint16_t>(GetNumber(*resource_meta, "rdp_port").value_or(3389));
|
||||
assignment.connection.username = GetString(*resource_meta, "username").value_or("");
|
||||
assignment.connection.password = GetString(*resource_meta, "password").value_or("");
|
||||
assignment.connection.domain = GetString(*resource_meta, "domain").value_or("");
|
||||
assignment.connection.certificate_verification_mode =
|
||||
GetString(*resource_meta, "certificate_verification_mode").value_or(assignment.connection.certificate_verification_mode);
|
||||
assignment.connection.render_quality_profile =
|
||||
GetString(*resource_meta, "render_quality_profile").value_or(assignment.connection.render_quality_profile);
|
||||
}
|
||||
const JsonObject* policy = GetObject(*metadata, "policy");
|
||||
if (policy != nullptr) {
|
||||
assignment.policy.detach_grace_period = std::chrono::seconds(static_cast<int>(GetNumber(*policy, "detach_grace_period_seconds").value_or(1800)));
|
||||
assignment.policy.clipboard_mode = GetString(*policy, "clipboard_mode").value_or("disabled");
|
||||
if (assignment.policy.clipboard_mode.empty()) {
|
||||
assignment.policy.clipboard_mode = GetBool(*policy, "clipboard_enabled").value_or(false) ? "bidirectional" : "disabled";
|
||||
}
|
||||
assignment.policy.file_transfer_mode = GetString(*policy, "file_transfer_mode").value_or("disabled");
|
||||
if (assignment.policy.file_transfer_mode.empty()) {
|
||||
assignment.policy.file_transfer_mode = GetBool(*policy, "file_transfer_enabled").value_or(false) ? "client_to_server" : "disabled";
|
||||
}
|
||||
}
|
||||
return assignment;
|
||||
}
|
||||
|
||||
runtime::WorkerLease ControlPlane::ParseLease(const JsonObject& object) const {
|
||||
runtime::WorkerLease lease{};
|
||||
lease.lease_id = GetString(object, "lease_id").value_or("");
|
||||
lease.worker_id = GetString(object, "worker_id").value_or("");
|
||||
lease.session_id = GetString(object, "session_id").value_or("");
|
||||
lease.resource_id = GetString(object, "resource_id").value_or("");
|
||||
lease.control_stream = GetString(object, "control_stream").value_or("");
|
||||
lease.expires_at = GetString(object, "expires_at").value_or("");
|
||||
if (const JsonArray* capabilities = GetArray(object, "capabilities"); capabilities != nullptr) {
|
||||
for (const auto& item : *capabilities) {
|
||||
if (item.IsString()) {
|
||||
lease.capabilities.push_back(item.AsString());
|
||||
}
|
||||
}
|
||||
}
|
||||
return lease;
|
||||
}
|
||||
|
||||
std::string ControlPlane::WorkerRegistrationPayload() const {
|
||||
JsonArray capabilities;
|
||||
for (const auto& item : config_.capabilities) {
|
||||
capabilities.emplace_back(item);
|
||||
}
|
||||
return common::SerializeJson(JsonObject{
|
||||
{"worker_id", config_.worker_id},
|
||||
{"protocol", "rdp"},
|
||||
{"status", "online"},
|
||||
{"capabilities", capabilities},
|
||||
{"control_stream", "worker://control/" + config_.worker_id},
|
||||
{"last_heartbeat_at", common::ToRfc3339(common::NowUtc())},
|
||||
});
|
||||
}
|
||||
|
||||
std::string ControlPlane::LeasePayload(const runtime::WorkerLease& lease) const {
|
||||
JsonArray capabilities;
|
||||
for (const auto& capability : lease.capabilities) {
|
||||
capabilities.emplace_back(capability);
|
||||
}
|
||||
return common::SerializeJson(JsonObject{
|
||||
{"lease_id", lease.lease_id},
|
||||
{"worker_id", lease.worker_id},
|
||||
{"protocol", "rdp"},
|
||||
{"resource_id", lease.resource_id},
|
||||
{"session_id", lease.session_id},
|
||||
{"capabilities", capabilities},
|
||||
{"control_stream", lease.control_stream},
|
||||
{"expires_at", common::ToRfc3339(common::NowUtc() + std::chrono::seconds(45))},
|
||||
});
|
||||
}
|
||||
|
||||
std::string ControlPlane::EventPayload(const runtime::WorkerEvent& event) const {
|
||||
JsonObject payload{
|
||||
{"type", event.type},
|
||||
{"session_id", event.session_id},
|
||||
{"worker_id", event.worker_id},
|
||||
};
|
||||
JsonObject detail;
|
||||
if (!event.reason.empty()) {
|
||||
detail.emplace("reason", event.reason);
|
||||
}
|
||||
if (event.width > 0) {
|
||||
detail.emplace("width", event.width);
|
||||
}
|
||||
if (event.height > 0) {
|
||||
detail.emplace("height", event.height);
|
||||
}
|
||||
for (const auto& [key, value] : event.payload) {
|
||||
detail[key] = value;
|
||||
}
|
||||
payload.emplace("payload", detail);
|
||||
return common::SerializeJson(payload);
|
||||
}
|
||||
|
||||
} // namespace rdp_worker::coordination
|
||||
@@ -0,0 +1,264 @@
|
||||
#include "rdp_worker/coordination/redis_client.hpp"
|
||||
|
||||
#include <cstring>
|
||||
#include <stdexcept>
|
||||
#include <string_view>
|
||||
|
||||
#if defined(_WIN32)
|
||||
#include <winsock2.h>
|
||||
#include <ws2tcpip.h>
|
||||
#else
|
||||
#include <arpa/inet.h>
|
||||
#include <netdb.h>
|
||||
#include <sys/socket.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
namespace rdp_worker::coordination {
|
||||
|
||||
namespace {
|
||||
|
||||
void EnsureSuccess(bool condition, const std::string& message) {
|
||||
if (!condition) {
|
||||
throw std::runtime_error(message);
|
||||
}
|
||||
}
|
||||
|
||||
void CloseSocket(int socket_fd) {
|
||||
if (socket_fd < 0) {
|
||||
return;
|
||||
}
|
||||
#if defined(_WIN32)
|
||||
closesocket(socket_fd);
|
||||
#else
|
||||
close(socket_fd);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RedisReply::IsNull() const { return std::holds_alternative<std::nullptr_t>(value); }
|
||||
bool RedisReply::IsString() const { return std::holds_alternative<std::string>(value); }
|
||||
bool RedisReply::IsInteger() const { return std::holds_alternative<int64_t>(value); }
|
||||
bool RedisReply::IsArray() const { return std::holds_alternative<Array>(value); }
|
||||
const std::string& RedisReply::AsString() const { return std::get<std::string>(value); }
|
||||
int64_t RedisReply::AsInteger() const { return std::get<int64_t>(value); }
|
||||
const RedisReply::Array& RedisReply::AsArray() const { return std::get<Array>(value); }
|
||||
|
||||
RedisClient::RedisClient(std::string host, int port, std::string password, int db)
|
||||
: host_(std::move(host)),
|
||||
port_(port),
|
||||
password_(std::move(password)),
|
||||
db_(db),
|
||||
socket_fd_(-1) {}
|
||||
|
||||
RedisClient::~RedisClient() {
|
||||
Close();
|
||||
}
|
||||
|
||||
void RedisClient::Connect() {
|
||||
#if defined(_WIN32)
|
||||
WSADATA wsa_data{};
|
||||
EnsureSuccess(WSAStartup(MAKEWORD(2, 2), &wsa_data) == 0, "WSAStartup failed");
|
||||
#endif
|
||||
addrinfo hints{};
|
||||
hints.ai_family = AF_UNSPEC;
|
||||
hints.ai_socktype = SOCK_STREAM;
|
||||
|
||||
addrinfo* result = nullptr;
|
||||
EnsureSuccess(getaddrinfo(host_.c_str(), std::to_string(port_).c_str(), &hints, &result) == 0, "getaddrinfo failed");
|
||||
|
||||
for (addrinfo* node = result; node != nullptr; node = node->ai_next) {
|
||||
socket_fd_ = static_cast<int>(socket(node->ai_family, node->ai_socktype, node->ai_protocol));
|
||||
if (socket_fd_ < 0) {
|
||||
continue;
|
||||
}
|
||||
if (connect(socket_fd_, node->ai_addr, static_cast<int>(node->ai_addrlen)) == 0) {
|
||||
break;
|
||||
}
|
||||
CloseSocket(socket_fd_);
|
||||
socket_fd_ = -1;
|
||||
}
|
||||
freeaddrinfo(result);
|
||||
EnsureSuccess(socket_fd_ >= 0, "failed to connect to Redis");
|
||||
|
||||
if (!password_.empty()) {
|
||||
Command({"AUTH", password_});
|
||||
}
|
||||
if (db_ != 0) {
|
||||
Command({"SELECT", std::to_string(db_)});
|
||||
}
|
||||
}
|
||||
|
||||
void RedisClient::Close() {
|
||||
CloseSocket(socket_fd_);
|
||||
socket_fd_ = -1;
|
||||
#if defined(_WIN32)
|
||||
WSACleanup();
|
||||
#endif
|
||||
}
|
||||
|
||||
RedisReply RedisClient::Command(const std::vector<std::string>& parts) {
|
||||
WriteAll(EncodeCommand(parts));
|
||||
return ReadReply();
|
||||
}
|
||||
|
||||
std::optional<std::string> RedisClient::Get(const std::string& key) {
|
||||
RedisReply reply = Command({"GET", key});
|
||||
if (reply.IsNull()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return reply.AsString();
|
||||
}
|
||||
|
||||
void RedisClient::Set(const std::string& key, const std::string& value, std::chrono::seconds ttl) {
|
||||
Command({"SET", key, value, "EX", std::to_string(ttl.count())});
|
||||
}
|
||||
|
||||
void RedisClient::SAdd(const std::string& key, const std::string& value) {
|
||||
Command({"SADD", key, value});
|
||||
}
|
||||
|
||||
void RedisClient::SRem(const std::string& key, const std::string& value) {
|
||||
Command({"SREM", key, value});
|
||||
}
|
||||
|
||||
std::vector<std::string> RedisClient::SMembers(const std::string& key) {
|
||||
RedisReply reply = Command({"SMEMBERS", key});
|
||||
std::vector<std::string> output;
|
||||
if (!reply.IsArray()) {
|
||||
return output;
|
||||
}
|
||||
for (const auto& item : reply.AsArray()) {
|
||||
if (item.IsString()) {
|
||||
output.push_back(item.AsString());
|
||||
}
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
std::optional<std::vector<std::string>> RedisClient::BLPop(const std::string& key, std::chrono::seconds timeout) {
|
||||
RedisReply reply = Command({"BLPOP", key, std::to_string(timeout.count())});
|
||||
if (reply.IsNull()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
std::vector<std::string> output;
|
||||
for (const auto& item : reply.AsArray()) {
|
||||
if (item.IsString()) {
|
||||
output.push_back(item.AsString());
|
||||
}
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
std::optional<std::string> RedisClient::LPop(const std::string& key) {
|
||||
RedisReply reply = Command({"LPOP", key});
|
||||
if (reply.IsNull()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return reply.AsString();
|
||||
}
|
||||
|
||||
int64_t RedisClient::LLen(const std::string& key) {
|
||||
RedisReply reply = Command({"LLEN", key});
|
||||
if (!reply.IsInteger()) {
|
||||
return 0;
|
||||
}
|
||||
return reply.AsInteger();
|
||||
}
|
||||
|
||||
void RedisClient::RPush(const std::string& key, const std::string& value) {
|
||||
Command({"RPUSH", key, value});
|
||||
}
|
||||
|
||||
void RedisClient::Expire(const std::string& key, std::chrono::seconds ttl) {
|
||||
Command({"EXPIRE", key, std::to_string(ttl.count())});
|
||||
}
|
||||
|
||||
void RedisClient::Delete(const std::string& key) {
|
||||
Command({"DEL", key});
|
||||
}
|
||||
|
||||
std::string RedisClient::ReadLine() {
|
||||
std::string output;
|
||||
char ch = '\0';
|
||||
while (true) {
|
||||
const int received = recv(socket_fd_, &ch, 1, 0);
|
||||
EnsureSuccess(received == 1, "failed to read from Redis");
|
||||
if (ch == '\r') {
|
||||
char lf = '\0';
|
||||
EnsureSuccess(recv(socket_fd_, &lf, 1, 0) == 1 && lf == '\n', "invalid Redis line ending");
|
||||
break;
|
||||
}
|
||||
output.push_back(ch);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
std::string RedisClient::ReadBytes(std::size_t count) {
|
||||
std::string output(count, '\0');
|
||||
std::size_t offset = 0;
|
||||
while (offset < count) {
|
||||
const int received = recv(socket_fd_, output.data() + offset, static_cast<int>(count - offset), 0);
|
||||
EnsureSuccess(received > 0, "failed to read bulk Redis payload");
|
||||
offset += static_cast<std::size_t>(received);
|
||||
}
|
||||
char suffix[2];
|
||||
EnsureSuccess(recv(socket_fd_, suffix, 2, 0) == 2 && suffix[0] == '\r' && suffix[1] == '\n', "invalid Redis bulk suffix");
|
||||
return output;
|
||||
}
|
||||
|
||||
RedisReply RedisClient::ReadReply() {
|
||||
const std::string line = ReadLine();
|
||||
EnsureSuccess(!line.empty(), "empty Redis reply");
|
||||
|
||||
const char prefix = line[0];
|
||||
const std::string payload = line.substr(1);
|
||||
switch (prefix) {
|
||||
case '+':
|
||||
return RedisReply{payload};
|
||||
case ':':
|
||||
return RedisReply{std::stoll(payload)};
|
||||
case '$': {
|
||||
const long long size = std::stoll(payload);
|
||||
if (size < 0) {
|
||||
return RedisReply{nullptr};
|
||||
}
|
||||
return RedisReply{ReadBytes(static_cast<std::size_t>(size))};
|
||||
}
|
||||
case '*': {
|
||||
const long long size = std::stoll(payload);
|
||||
if (size < 0) {
|
||||
return RedisReply{nullptr};
|
||||
}
|
||||
RedisReply::Array values;
|
||||
for (long long i = 0; i < size; ++i) {
|
||||
values.push_back(ReadReply());
|
||||
}
|
||||
return RedisReply{values};
|
||||
}
|
||||
case '-':
|
||||
throw std::runtime_error("Redis error: " + payload);
|
||||
default:
|
||||
throw std::runtime_error("unknown Redis reply type");
|
||||
}
|
||||
}
|
||||
|
||||
void RedisClient::WriteAll(const std::string& data) {
|
||||
std::size_t offset = 0;
|
||||
while (offset < data.size()) {
|
||||
const int sent = send(socket_fd_, data.data() + offset, static_cast<int>(data.size() - offset), 0);
|
||||
EnsureSuccess(sent > 0, "failed to send Redis command");
|
||||
offset += static_cast<std::size_t>(sent);
|
||||
}
|
||||
}
|
||||
|
||||
std::string RedisClient::EncodeCommand(const std::vector<std::string>& parts) const {
|
||||
std::string output = "*" + std::to_string(parts.size()) + "\r\n";
|
||||
for (const auto& part : parts) {
|
||||
output += "$" + std::to_string(part.size()) + "\r\n" + part + "\r\n";
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace rdp_worker::coordination
|
||||
@@ -0,0 +1,102 @@
|
||||
#include "rdp_worker/cursor/cursor_adapter.hpp"
|
||||
|
||||
namespace rdp_worker::cursor {
|
||||
|
||||
CursorUpdate CursorAdapter::MakePosition(std::uint64_t sequence,
|
||||
int desktop_width,
|
||||
int desktop_height,
|
||||
int x,
|
||||
int y,
|
||||
bool visible) const {
|
||||
CursorUpdate update;
|
||||
update.kind = CursorUpdateKind::kPosition;
|
||||
update.sequence = sequence;
|
||||
update.desktop_width = desktop_width;
|
||||
update.desktop_height = desktop_height;
|
||||
update.x = x;
|
||||
update.y = y;
|
||||
update.visible = visible;
|
||||
return update;
|
||||
}
|
||||
|
||||
CursorUpdate CursorAdapter::MakeSystem(std::uint64_t sequence,
|
||||
int desktop_width,
|
||||
int desktop_height,
|
||||
int x,
|
||||
int y,
|
||||
std::uint32_t system_type) const {
|
||||
CursorUpdate update = MakePosition(sequence, desktop_width, desktop_height, x, y, system_type != 0);
|
||||
update.kind = CursorUpdateKind::kSystem;
|
||||
update.shape_changed = true;
|
||||
update.system_type = system_type;
|
||||
return update;
|
||||
}
|
||||
|
||||
CursorUpdate CursorAdapter::MakeColor(std::uint64_t sequence,
|
||||
int desktop_width,
|
||||
int desktop_height,
|
||||
int x,
|
||||
int y,
|
||||
int width,
|
||||
int height,
|
||||
int cache_index,
|
||||
std::uint64_t mask_bytes) const {
|
||||
CursorUpdate update = MakePosition(sequence, desktop_width, desktop_height, x, y, true);
|
||||
update.kind = CursorUpdateKind::kColor;
|
||||
update.shape_changed = true;
|
||||
update.width = width;
|
||||
update.height = height;
|
||||
update.cache_index = cache_index;
|
||||
update.mask_bytes = mask_bytes;
|
||||
return update;
|
||||
}
|
||||
|
||||
CursorUpdate CursorAdapter::MakeNew(std::uint64_t sequence,
|
||||
int desktop_width,
|
||||
int desktop_height,
|
||||
int x,
|
||||
int y,
|
||||
int width,
|
||||
int height,
|
||||
int cache_index,
|
||||
int xor_bpp,
|
||||
std::uint64_t mask_bytes) const {
|
||||
CursorUpdate update = MakeColor(sequence, desktop_width, desktop_height, x, y, width, height, cache_index, mask_bytes);
|
||||
update.kind = CursorUpdateKind::kNew;
|
||||
update.xor_bpp = xor_bpp;
|
||||
return update;
|
||||
}
|
||||
|
||||
CursorUpdate CursorAdapter::MakeCached(std::uint64_t sequence,
|
||||
int desktop_width,
|
||||
int desktop_height,
|
||||
int x,
|
||||
int y,
|
||||
int cache_index) const {
|
||||
CursorUpdate update = MakePosition(sequence, desktop_width, desktop_height, x, y, true);
|
||||
update.kind = CursorUpdateKind::kCached;
|
||||
update.shape_changed = true;
|
||||
update.cache_index = cache_index;
|
||||
return update;
|
||||
}
|
||||
|
||||
CursorUpdate CursorAdapter::MakeLarge(std::uint64_t sequence,
|
||||
int desktop_width,
|
||||
int desktop_height,
|
||||
int x,
|
||||
int y,
|
||||
int width,
|
||||
int height,
|
||||
int cache_index,
|
||||
int hot_spot_x,
|
||||
int hot_spot_y,
|
||||
int xor_bpp,
|
||||
std::uint64_t mask_bytes) const {
|
||||
CursorUpdate update = MakeNew(sequence, desktop_width, desktop_height, x, y, width, height, cache_index, xor_bpp, mask_bytes);
|
||||
update.kind = CursorUpdateKind::kLarge;
|
||||
update.hot_spot_x = hot_spot_x;
|
||||
update.hot_spot_y = hot_spot_y;
|
||||
return update;
|
||||
}
|
||||
|
||||
} // namespace rdp_worker::cursor
|
||||
@@ -0,0 +1,48 @@
|
||||
#include "rdp_worker/cursor/cursor_update.hpp"
|
||||
|
||||
namespace rdp_worker::cursor {
|
||||
|
||||
const char* CursorUpdateKindName(CursorUpdateKind kind) {
|
||||
switch (kind) {
|
||||
case CursorUpdateKind::kPosition:
|
||||
return "position";
|
||||
case CursorUpdateKind::kSystem:
|
||||
return "system";
|
||||
case CursorUpdateKind::kColor:
|
||||
return "color";
|
||||
case CursorUpdateKind::kNew:
|
||||
return "new";
|
||||
case CursorUpdateKind::kCached:
|
||||
return "cached";
|
||||
case CursorUpdateKind::kLarge:
|
||||
return "large";
|
||||
}
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
common::JsonObject CursorUpdateToPayload(const CursorUpdate& update,
|
||||
const std::string& render_quality_profile) {
|
||||
return common::JsonObject{
|
||||
{"render_quality_profile", render_quality_profile},
|
||||
{"render_state", "cursor"},
|
||||
{"cursor_update_kind", CursorUpdateKindName(update.kind)},
|
||||
{"cursor_sequence", static_cast<double>(update.sequence)},
|
||||
{"width", update.desktop_width},
|
||||
{"height", update.desktop_height},
|
||||
{"cursor_x", update.x},
|
||||
{"cursor_y", update.y},
|
||||
{"cursor_visible", update.visible},
|
||||
{"cursor_shape_changed", update.shape_changed},
|
||||
{"cursor_cache_index", update.cache_index},
|
||||
{"cursor_hot_spot_x", update.hot_spot_x},
|
||||
{"cursor_hot_spot_y", update.hot_spot_y},
|
||||
{"cursor_width", update.width},
|
||||
{"cursor_height", update.height},
|
||||
{"cursor_xor_bpp", update.xor_bpp},
|
||||
{"cursor_mask_bytes", static_cast<double>(update.mask_bytes)},
|
||||
{"cursor_system_type", static_cast<double>(update.system_type)},
|
||||
{"dirty_rectangles", 0},
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace rdp_worker::cursor
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,240 @@
|
||||
#include "rdp_worker/dataplane/token_validator.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <chrono>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <string_view>
|
||||
#include <utility>
|
||||
|
||||
#include <openssl/bio.h>
|
||||
#include <openssl/evp.h>
|
||||
#include <openssl/pem.h>
|
||||
|
||||
#include "rdp_worker/common/json.hpp"
|
||||
|
||||
namespace rdp_worker::dataplane {
|
||||
|
||||
namespace {
|
||||
|
||||
std::vector<std::string> SplitJwt(const std::string& token) {
|
||||
std::vector<std::string> parts;
|
||||
std::size_t start = 0;
|
||||
while (true) {
|
||||
const std::size_t dot = token.find('.', start);
|
||||
parts.push_back(token.substr(start, dot == std::string::npos ? std::string::npos : dot - start));
|
||||
if (dot == std::string::npos) {
|
||||
break;
|
||||
}
|
||||
start = dot + 1;
|
||||
}
|
||||
return parts;
|
||||
}
|
||||
|
||||
std::vector<std::uint8_t> Base64UrlDecode(const std::string& input) {
|
||||
static constexpr unsigned char kInvalid = 255;
|
||||
std::array<unsigned char, 256> table{};
|
||||
table.fill(kInvalid);
|
||||
const std::string alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
|
||||
for (std::size_t i = 0; i < alphabet.size(); ++i) {
|
||||
table[static_cast<unsigned char>(alphabet[i])] = static_cast<unsigned char>(i);
|
||||
}
|
||||
|
||||
std::vector<std::uint8_t> out;
|
||||
int value = 0;
|
||||
int value_bits = -8;
|
||||
for (unsigned char ch : input) {
|
||||
if (ch == '=') {
|
||||
break;
|
||||
}
|
||||
if (table[ch] == kInvalid) {
|
||||
throw std::runtime_error("invalid base64url token segment");
|
||||
}
|
||||
value = (value << 6) + table[ch];
|
||||
value_bits += 6;
|
||||
if (value_bits >= 0) {
|
||||
out.push_back(static_cast<std::uint8_t>((value >> value_bits) & 0xFF));
|
||||
value_bits -= 8;
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
std::string DecodeStringSegment(const std::string& input) {
|
||||
const auto bytes = Base64UrlDecode(input);
|
||||
return std::string(reinterpret_cast<const char*>(bytes.data()), bytes.size());
|
||||
}
|
||||
|
||||
struct EvpKeyDeleter {
|
||||
void operator()(EVP_PKEY* key) const {
|
||||
EVP_PKEY_free(key);
|
||||
}
|
||||
};
|
||||
|
||||
struct BioDeleter {
|
||||
void operator()(BIO* bio) const {
|
||||
BIO_free(bio);
|
||||
}
|
||||
};
|
||||
|
||||
struct EvpMdCtxDeleter {
|
||||
void operator()(EVP_MD_CTX* context) const {
|
||||
EVP_MD_CTX_free(context);
|
||||
}
|
||||
};
|
||||
|
||||
using EvpKeyPtr = std::unique_ptr<EVP_PKEY, EvpKeyDeleter>;
|
||||
using BioPtr = std::unique_ptr<BIO, BioDeleter>;
|
||||
using EvpMdCtxPtr = std::unique_ptr<EVP_MD_CTX, EvpMdCtxDeleter>;
|
||||
|
||||
bool VerifyRs256(const std::string& public_key_pem, const std::string& signing_input, const std::vector<std::uint8_t>& signature) {
|
||||
BioPtr bio(BIO_new_mem_buf(public_key_pem.data(), static_cast<int>(public_key_pem.size())));
|
||||
if (!bio) {
|
||||
throw std::runtime_error("public_key_bio_unavailable");
|
||||
}
|
||||
EvpKeyPtr key(PEM_read_bio_PUBKEY(bio.get(), nullptr, nullptr, nullptr));
|
||||
if (!key) {
|
||||
throw std::runtime_error("public_key_parse_failed");
|
||||
}
|
||||
EvpMdCtxPtr context(EVP_MD_CTX_new());
|
||||
if (!context) {
|
||||
throw std::runtime_error("signature_context_unavailable");
|
||||
}
|
||||
if (EVP_DigestVerifyInit(context.get(), nullptr, EVP_sha256(), nullptr, key.get()) != 1) {
|
||||
throw std::runtime_error("signature_verify_init_failed");
|
||||
}
|
||||
if (EVP_DigestVerifyUpdate(context.get(), signing_input.data(), signing_input.size()) != 1) {
|
||||
throw std::runtime_error("signature_verify_update_failed");
|
||||
}
|
||||
return EVP_DigestVerifyFinal(context.get(), signature.data(), signature.size()) == 1;
|
||||
}
|
||||
|
||||
std::int64_t UnixNow() {
|
||||
return std::chrono::duration_cast<std::chrono::seconds>(
|
||||
std::chrono::system_clock::now().time_since_epoch())
|
||||
.count();
|
||||
}
|
||||
|
||||
bool IsKnownChannel(const std::string& channel) {
|
||||
return channel == "control" ||
|
||||
channel == "input" ||
|
||||
channel == "render" ||
|
||||
channel == "clipboard" ||
|
||||
channel == "file_upload" ||
|
||||
channel == "file_download" ||
|
||||
channel == "telemetry";
|
||||
}
|
||||
|
||||
bool ArrayContainsString(const rdp_worker::common::JsonArray* array, const std::string& expected) {
|
||||
if (array == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return std::any_of(array->begin(), array->end(), [&](const auto& item) {
|
||||
return item.IsString() && item.AsString() == expected;
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<std::string> ParseAllowedChannels(const rdp_worker::common::JsonObject& payload) {
|
||||
const auto* array = rdp_worker::common::GetArray(payload, "allowed_channels");
|
||||
if (array == nullptr || array->empty()) {
|
||||
throw std::runtime_error("allowed_channels is required");
|
||||
}
|
||||
std::vector<std::string> channels;
|
||||
for (const auto& item : *array) {
|
||||
if (!item.IsString() || !IsKnownChannel(item.AsString())) {
|
||||
throw std::runtime_error("allowed_channels contains unsupported channel");
|
||||
}
|
||||
channels.push_back(item.AsString());
|
||||
}
|
||||
return channels;
|
||||
}
|
||||
|
||||
std::string RequiredString(const rdp_worker::common::JsonObject& object, const std::string& key) {
|
||||
const auto value = rdp_worker::common::GetString(object, key);
|
||||
if (!value.has_value() || value->empty()) {
|
||||
throw std::runtime_error(key + " is required");
|
||||
}
|
||||
return *value;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
DataPlaneTokenValidator::DataPlaneTokenValidator(std::string public_key_pem, std::string expected_worker_id)
|
||||
: public_key_pem_(std::move(public_key_pem)),
|
||||
expected_worker_id_(std::move(expected_worker_id)) {}
|
||||
|
||||
TokenValidationResult DataPlaneTokenValidator::Validate(const std::string& token) const {
|
||||
TokenValidationResult result{};
|
||||
try {
|
||||
if (public_key_pem_.empty()) {
|
||||
result.reason = "token_public_key_not_configured";
|
||||
return result;
|
||||
}
|
||||
const auto parts = SplitJwt(token);
|
||||
if (parts.size() != 3 || parts[0].empty() || parts[1].empty() || parts[2].empty()) {
|
||||
result.reason = "malformed_token";
|
||||
return result;
|
||||
}
|
||||
|
||||
const auto header = rdp_worker::common::ParseJson(DecodeStringSegment(parts[0])).AsObject();
|
||||
if (rdp_worker::common::GetString(header, "alg").value_or("") != "RS256" ||
|
||||
rdp_worker::common::GetString(header, "typ").value_or("JWT") != "JWT") {
|
||||
result.reason = "unsupported_token_header";
|
||||
return result;
|
||||
}
|
||||
|
||||
const std::string signing_input = parts[0] + "." + parts[1];
|
||||
const auto actual_signature = Base64UrlDecode(parts[2]);
|
||||
if (!VerifyRs256(public_key_pem_, signing_input, actual_signature)) {
|
||||
result.reason = "invalid_signature";
|
||||
return result;
|
||||
}
|
||||
|
||||
const auto payload = rdp_worker::common::ParseJson(DecodeStringSegment(parts[1])).AsObject();
|
||||
DataPlaneTokenClaims claims{};
|
||||
claims.session_id = RequiredString(payload, "session_id");
|
||||
claims.attachment_id = RequiredString(payload, "attachment_id");
|
||||
claims.user_id = RequiredString(payload, "user_id");
|
||||
claims.organization_id = RequiredString(payload, "organization_id");
|
||||
claims.worker_id = RequiredString(payload, "worker_id");
|
||||
claims.resource_id = RequiredString(payload, "resource_id");
|
||||
claims.jti = RequiredString(payload, "jti");
|
||||
claims.allowed_channels = ParseAllowedChannels(payload);
|
||||
claims.expires_at_unix = static_cast<std::int64_t>(rdp_worker::common::GetNumber(payload, "exp").value_or(0));
|
||||
|
||||
const auto now = UnixNow();
|
||||
if (claims.expires_at_unix <= now) {
|
||||
result.reason = "token_expired";
|
||||
return result;
|
||||
}
|
||||
const auto not_before = static_cast<std::int64_t>(rdp_worker::common::GetNumber(payload, "nbf").value_or(0));
|
||||
if (not_before > 0 && not_before > now) {
|
||||
result.reason = "token_not_yet_valid";
|
||||
return result;
|
||||
}
|
||||
if (!expected_worker_id_.empty() && claims.worker_id != expected_worker_id_) {
|
||||
result.reason = "wrong_worker";
|
||||
return result;
|
||||
}
|
||||
if (!ArrayContainsString(rdp_worker::common::GetArray(payload, "aud"), "rap-data-plane") ||
|
||||
!ArrayContainsString(rdp_worker::common::GetArray(payload, "aud"), "worker:" + claims.worker_id)) {
|
||||
result.reason = "invalid_audience";
|
||||
return result;
|
||||
}
|
||||
if (std::find(claims.allowed_channels.begin(), claims.allowed_channels.end(), "control") == claims.allowed_channels.end()) {
|
||||
result.reason = "control_channel_not_allowed";
|
||||
return result;
|
||||
}
|
||||
|
||||
result.ok = true;
|
||||
result.claims = std::move(claims);
|
||||
return result;
|
||||
} catch (const std::exception& error) {
|
||||
result.reason = error.what();
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace rdp_worker::dataplane
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,74 @@
|
||||
#include "rdp_worker/graphics/graphics_adapter.hpp"
|
||||
|
||||
#include <utility>
|
||||
|
||||
namespace rdp_worker::graphics {
|
||||
|
||||
GraphicsAdapter::GraphicsAdapter(GraphicsAdapterPolicy policy)
|
||||
: policy_(policy) {}
|
||||
|
||||
const GraphicsAdapterPolicy& GraphicsAdapter::Policy() const {
|
||||
return policy_;
|
||||
}
|
||||
|
||||
RenderUpdate GraphicsAdapter::MakeFullBgraFrame(std::uint64_t sequence,
|
||||
int width,
|
||||
int height,
|
||||
int stride,
|
||||
std::vector<std::uint8_t> pixels,
|
||||
bool baseline) const {
|
||||
RenderUpdate update;
|
||||
update.kind = RenderUpdateKind::kFullBgraFrame;
|
||||
update.sequence = sequence;
|
||||
update.desktop_width = width;
|
||||
update.desktop_height = height;
|
||||
update.frame_width = width;
|
||||
update.frame_height = height;
|
||||
update.stride = stride;
|
||||
update.region = Rect{0, 0, width, height};
|
||||
update.payload = std::move(pixels);
|
||||
update.baseline = baseline;
|
||||
update.droppable = !baseline;
|
||||
return update;
|
||||
}
|
||||
|
||||
std::optional<RenderUpdate> GraphicsAdapter::TryMakeBgraRegion(std::uint64_t sequence,
|
||||
int desktop_width,
|
||||
int desktop_height,
|
||||
int stride,
|
||||
Rect region,
|
||||
std::vector<std::uint8_t> pixels) const {
|
||||
if (!RegionAllowed(desktop_width, desktop_height, region)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
RenderUpdate update;
|
||||
update.kind = RenderUpdateKind::kBgraRegion;
|
||||
update.sequence = sequence;
|
||||
update.desktop_width = desktop_width;
|
||||
update.desktop_height = desktop_height;
|
||||
update.frame_width = region.width;
|
||||
update.frame_height = region.height;
|
||||
update.stride = stride;
|
||||
update.region = region;
|
||||
update.payload = std::move(pixels);
|
||||
update.baseline = false;
|
||||
update.droppable = true;
|
||||
return update;
|
||||
}
|
||||
|
||||
bool GraphicsAdapter::RegionAllowed(int desktop_width, int desktop_height, const Rect& region) const {
|
||||
if (desktop_width <= 0 || desktop_height <= 0 ||
|
||||
region.x < 0 || region.y < 0 ||
|
||||
region.width <= 0 || region.height <= 0 ||
|
||||
region.x + region.width > desktop_width ||
|
||||
region.y + region.height > desktop_height) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const long long desktop_area = static_cast<long long>(desktop_width) * static_cast<long long>(desktop_height);
|
||||
const long long region_area = static_cast<long long>(region.width) * static_cast<long long>(region.height);
|
||||
return region_area * 100 <= desktop_area * policy_.max_region_area_percent;
|
||||
}
|
||||
|
||||
} // namespace rdp_worker::graphics
|
||||
@@ -0,0 +1,38 @@
|
||||
#include "rdp_worker/graphics/render_update.hpp"
|
||||
|
||||
namespace rdp_worker::graphics {
|
||||
|
||||
const char* RenderUpdateKindName(RenderUpdateKind kind) {
|
||||
switch (kind) {
|
||||
case RenderUpdateKind::kFullBgraFrame:
|
||||
return "full_bgra_frame";
|
||||
case RenderUpdateKind::kBgraRegion:
|
||||
return "bgra_region";
|
||||
case RenderUpdateKind::kSurfaceCreate:
|
||||
return "surface_create";
|
||||
case RenderUpdateKind::kSurfaceDelete:
|
||||
return "surface_delete";
|
||||
case RenderUpdateKind::kSurfaceBits:
|
||||
return "surface_bits";
|
||||
case RenderUpdateKind::kEncodedFrame:
|
||||
return "encoded_frame";
|
||||
case RenderUpdateKind::kCursorUpdate:
|
||||
return "cursor_update";
|
||||
}
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
bool IsFullFrameUpdate(const RenderUpdate& update) {
|
||||
return update.kind == RenderUpdateKind::kFullBgraFrame;
|
||||
}
|
||||
|
||||
bool IsRegionUpdate(const RenderUpdate& update) {
|
||||
return update.kind == RenderUpdateKind::kBgraRegion;
|
||||
}
|
||||
|
||||
bool IsEncodedUpdate(const RenderUpdate& update) {
|
||||
return update.kind == RenderUpdateKind::kEncodedFrame ||
|
||||
update.kind == RenderUpdateKind::kSurfaceBits;
|
||||
}
|
||||
|
||||
} // namespace rdp_worker::graphics
|
||||
@@ -0,0 +1,56 @@
|
||||
#include <atomic>
|
||||
#include <csignal>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
|
||||
#include "rdp_worker/common/logger.hpp"
|
||||
#include "rdp_worker/config/config.hpp"
|
||||
#include "rdp_worker/coordination/control_plane.hpp"
|
||||
#include "rdp_worker/dataplane/direct_wss_server.hpp"
|
||||
#include "rdp_worker/runtime/session_manager.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
std::atomic<bool> g_stop{false};
|
||||
|
||||
void HandleSignal(int) {
|
||||
g_stop.store(true);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main() {
|
||||
std::signal(SIGINT, HandleSignal);
|
||||
std::signal(SIGTERM, HandleSignal);
|
||||
|
||||
auto logger = std::make_shared<rdp_worker::common::Logger>("rdp-worker");
|
||||
const auto config = rdp_worker::config::LoadFromEnv();
|
||||
auto control_plane = std::make_shared<rdp_worker::coordination::ControlPlane>(config, logger);
|
||||
control_plane->Connect();
|
||||
control_plane->ReleaseOwnedLeasesOnStartup();
|
||||
control_plane->RegisterWorker();
|
||||
|
||||
auto session_manager = std::make_shared<rdp_worker::runtime::SessionManager>(control_plane, logger);
|
||||
rdp_worker::dataplane::DirectWssServer direct_wss_server(config, session_manager, logger);
|
||||
direct_wss_server.Start();
|
||||
|
||||
std::thread heartbeat_thread([&]() {
|
||||
while (!g_stop.load()) {
|
||||
control_plane->SendHeartbeat();
|
||||
std::this_thread::sleep_for(config.worker_heartbeat_interval);
|
||||
}
|
||||
});
|
||||
|
||||
while (!g_stop.load()) {
|
||||
if (auto assignment = control_plane->PollAssignment(config.assignment_poll_interval); assignment.has_value()) {
|
||||
session_manager->ApplyAssignment(*assignment);
|
||||
}
|
||||
}
|
||||
|
||||
direct_wss_server.Stop();
|
||||
session_manager->StopAll();
|
||||
if (heartbeat_thread.joinable()) {
|
||||
heartbeat_thread.join();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
#include "rdp_worker/runtime/direct_bind_policy.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
namespace rdp_worker::runtime {
|
||||
|
||||
namespace {
|
||||
|
||||
bool ClipboardAllowsServerOrClient(const std::string& mode) {
|
||||
return mode == "client_to_server" || mode == "server_to_client" || mode == "bidirectional";
|
||||
}
|
||||
|
||||
bool FileTransferAllowsClientToServer(const std::string& mode) {
|
||||
return mode == "client_to_server" || mode == "bidirectional";
|
||||
}
|
||||
|
||||
bool FileTransferAllowsServerToClient(const std::string& mode) {
|
||||
return mode == "server_to_client" || mode == "bidirectional";
|
||||
}
|
||||
|
||||
std::vector<std::string> RuntimeAllowedChannels(const Assignment& assignment) {
|
||||
std::vector<std::string> channels{"control", "input", "render", "telemetry"};
|
||||
if (ClipboardAllowsServerOrClient(assignment.policy.clipboard_mode)) {
|
||||
channels.push_back("clipboard");
|
||||
}
|
||||
if (FileTransferAllowsClientToServer(assignment.policy.file_transfer_mode)) {
|
||||
channels.push_back("file_upload");
|
||||
}
|
||||
if (FileTransferAllowsServerToClient(assignment.policy.file_transfer_mode)) {
|
||||
channels.push_back("file_download");
|
||||
}
|
||||
return channels;
|
||||
}
|
||||
|
||||
bool RequestedChannelsAllowed(const std::vector<std::string>& requested, const std::vector<std::string>& allowed) {
|
||||
return std::all_of(requested.begin(), requested.end(), [&](const auto& channel) {
|
||||
return std::find(allowed.begin(), allowed.end(), channel) != allowed.end();
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
DirectBindValidationResult ValidateDirectDataPlaneBind(const Assignment& assignment,
|
||||
const dataplane::DataPlaneTokenClaims& claims) {
|
||||
if (assignment.state != SessionState::kStarting &&
|
||||
assignment.state != SessionState::kActive &&
|
||||
assignment.state != SessionState::kReconnecting) {
|
||||
return {false, "session_not_attachable"};
|
||||
}
|
||||
if (assignment.worker_id != claims.worker_id) {
|
||||
return {false, "worker_mismatch"};
|
||||
}
|
||||
if (assignment.attachment_id != claims.attachment_id) {
|
||||
return {false, "attachment_mismatch"};
|
||||
}
|
||||
if (assignment.user_id != claims.user_id) {
|
||||
return {false, "user_mismatch"};
|
||||
}
|
||||
if (assignment.organization_id != claims.organization_id) {
|
||||
return {false, "organization_mismatch"};
|
||||
}
|
||||
if (assignment.connection.resource_id != claims.resource_id) {
|
||||
return {false, "resource_mismatch"};
|
||||
}
|
||||
if (!RequestedChannelsAllowed(claims.allowed_channels, RuntimeAllowedChannels(assignment))) {
|
||||
return {false, "channels_exceed_runtime_policy"};
|
||||
}
|
||||
return {true, ""};
|
||||
}
|
||||
|
||||
} // namespace rdp_worker::runtime
|
||||
@@ -0,0 +1,63 @@
|
||||
#include "rdp_worker/runtime/session_manager.hpp"
|
||||
|
||||
#include "rdp_worker/runtime/direct_bind_policy.hpp"
|
||||
|
||||
namespace rdp_worker::runtime {
|
||||
|
||||
SessionManager::SessionManager(std::shared_ptr<coordination::ControlPlane> control_plane,
|
||||
std::shared_ptr<common::Logger> logger)
|
||||
: control_plane_(std::move(control_plane)),
|
||||
logger_(std::move(logger)) {}
|
||||
|
||||
void SessionManager::ApplyAssignment(const Assignment& assignment) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
const auto iterator = sessions_.find(assignment.session_id);
|
||||
if (iterator != sessions_.end()) {
|
||||
iterator->second->ApplyAssignment(assignment);
|
||||
logger_->Info("updated assignment for existing session " + assignment.session_id);
|
||||
return;
|
||||
}
|
||||
|
||||
auto runtime = std::make_shared<SessionRuntime>(assignment, control_plane_, logger_);
|
||||
runtime->Start();
|
||||
sessions_.emplace(assignment.session_id, runtime);
|
||||
logger_->Info("started new runtime for session " + assignment.session_id);
|
||||
}
|
||||
|
||||
void SessionManager::StopAll() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
for (auto& [_, runtime] : sessions_) {
|
||||
runtime->Stop(true, "worker_shutdown");
|
||||
}
|
||||
sessions_.clear();
|
||||
}
|
||||
|
||||
bool SessionManager::BindDirectDataPlaneAttachment(const dataplane::DataPlaneTokenClaims& claims, std::string& reason) {
|
||||
return BindDirectDataPlaneRuntime(claims, reason) != nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<SessionRuntime> SessionManager::BindDirectDataPlaneRuntime(const dataplane::DataPlaneTokenClaims& claims, std::string& reason) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
const auto iterator = sessions_.find(claims.session_id);
|
||||
if (iterator == sessions_.end()) {
|
||||
reason = "missing_runtime";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const Assignment snapshot = iterator->second->Snapshot();
|
||||
const auto validation = ValidateDirectDataPlaneBind(snapshot, claims);
|
||||
if (!validation.ok) {
|
||||
reason = validation.reason;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
reason.clear();
|
||||
logger_->Info("event=data_plane_bind_success session=" + claims.session_id +
|
||||
" attachment=" + claims.attachment_id +
|
||||
" user=" + claims.user_id +
|
||||
" organization=" + claims.organization_id +
|
||||
" resource=" + claims.resource_id);
|
||||
return iterator->second;
|
||||
}
|
||||
|
||||
} // namespace rdp_worker::runtime
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,44 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "rdp_worker/cursor/cursor_adapter.hpp"
|
||||
#include "rdp_worker/cursor/cursor_update.hpp"
|
||||
|
||||
int main() {
|
||||
rdp_worker::cursor::CursorAdapter adapter;
|
||||
|
||||
auto position = adapter.MakePosition(1, 1280, 720, 10, 20, true);
|
||||
if (position.kind != rdp_worker::cursor::CursorUpdateKind::kPosition ||
|
||||
!position.visible ||
|
||||
position.x != 10 ||
|
||||
position.y != 20) {
|
||||
std::cerr << "cursor position normalization failed\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto hidden = adapter.MakeSystem(2, 1280, 720, 10, 20, 0);
|
||||
if (hidden.kind != rdp_worker::cursor::CursorUpdateKind::kSystem ||
|
||||
hidden.visible ||
|
||||
!hidden.shape_changed) {
|
||||
std::cerr << "cursor system visibility normalization failed\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto large = adapter.MakeLarge(3, 1280, 720, 5, 6, 32, 32, 7, 1, 2, 32, 128);
|
||||
if (large.kind != rdp_worker::cursor::CursorUpdateKind::kLarge ||
|
||||
large.cache_index != 7 ||
|
||||
large.hot_spot_x != 1 ||
|
||||
large.hot_spot_y != 2 ||
|
||||
large.mask_bytes != 128) {
|
||||
std::cerr << "cursor shape normalization failed\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto payload = rdp_worker::cursor::CursorUpdateToPayload(large, "balanced");
|
||||
if (payload.empty()) {
|
||||
std::cerr << "cursor payload conversion failed\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::cout << "cursor_adapter_probe ok\n";
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "rdp_worker/runtime/direct_bind_policy.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
std::string ArgValue(int argc, char** argv, const std::string& name) {
|
||||
for (int i = 1; i + 1 < argc; ++i) {
|
||||
if (argv[i] == name) {
|
||||
return argv[i + 1];
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
rdp_worker::runtime::Assignment BaseAssignment() {
|
||||
rdp_worker::runtime::Assignment assignment{};
|
||||
assignment.session_id = "session-1";
|
||||
assignment.worker_id = "rdp-worker-1";
|
||||
assignment.attachment_id = "attachment-current";
|
||||
assignment.user_id = "user-1";
|
||||
assignment.organization_id = "org-1";
|
||||
assignment.state = rdp_worker::runtime::SessionState::kActive;
|
||||
assignment.connection.resource_id = "resource-1";
|
||||
assignment.policy.clipboard_mode = "disabled";
|
||||
assignment.policy.file_transfer_mode = "disabled";
|
||||
return assignment;
|
||||
}
|
||||
|
||||
rdp_worker::dataplane::DataPlaneTokenClaims BaseClaims() {
|
||||
rdp_worker::dataplane::DataPlaneTokenClaims claims{};
|
||||
claims.session_id = "session-1";
|
||||
claims.worker_id = "rdp-worker-1";
|
||||
claims.attachment_id = "attachment-current";
|
||||
claims.user_id = "user-1";
|
||||
claims.organization_id = "org-1";
|
||||
claims.resource_id = "resource-1";
|
||||
claims.allowed_channels = {"control", "input", "render", "telemetry"};
|
||||
claims.jti = "jti-1";
|
||||
claims.expires_at_unix = 4102444800;
|
||||
return claims;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
const auto scenario = ArgValue(argc, argv, "--scenario");
|
||||
if (scenario.empty()) {
|
||||
std::cerr << "usage: rdp-worker-dataplane-bind-probe --scenario valid|starting|wrong-worker|wrong-attachment|wrong-user|wrong-organization|wrong-resource|channels-too-broad|failed-state|terminated-state\n";
|
||||
return 2;
|
||||
}
|
||||
|
||||
auto assignment = BaseAssignment();
|
||||
auto claims = BaseClaims();
|
||||
std::string expected_reason;
|
||||
if (scenario == "valid") {
|
||||
expected_reason = "";
|
||||
} else if (scenario == "starting") {
|
||||
assignment.state = rdp_worker::runtime::SessionState::kStarting;
|
||||
expected_reason = "";
|
||||
} else if (scenario == "wrong-attachment") {
|
||||
claims.attachment_id = "attachment-old";
|
||||
expected_reason = "attachment_mismatch";
|
||||
} else if (scenario == "wrong-worker") {
|
||||
claims.worker_id = "rdp-worker-other";
|
||||
expected_reason = "worker_mismatch";
|
||||
} else if (scenario == "wrong-user") {
|
||||
claims.user_id = "user-other";
|
||||
expected_reason = "user_mismatch";
|
||||
} else if (scenario == "wrong-organization") {
|
||||
claims.organization_id = "org-other";
|
||||
expected_reason = "organization_mismatch";
|
||||
} else if (scenario == "wrong-resource") {
|
||||
claims.resource_id = "resource-other";
|
||||
expected_reason = "resource_mismatch";
|
||||
} else if (scenario == "channels-too-broad") {
|
||||
claims.allowed_channels.push_back("clipboard");
|
||||
expected_reason = "channels_exceed_runtime_policy";
|
||||
} else if (scenario == "failed-state") {
|
||||
assignment.state = rdp_worker::runtime::SessionState::kFailed;
|
||||
expected_reason = "session_not_attachable";
|
||||
} else if (scenario == "terminated-state") {
|
||||
assignment.state = rdp_worker::runtime::SessionState::kTerminated;
|
||||
expected_reason = "session_not_attachable";
|
||||
} else {
|
||||
std::cerr << "unknown scenario=" << scenario << "\n";
|
||||
return 2;
|
||||
}
|
||||
|
||||
const auto result = rdp_worker::runtime::ValidateDirectDataPlaneBind(assignment, claims);
|
||||
if ((scenario == "valid" || scenario == "starting") && result.ok) {
|
||||
std::cout << "PASS scenario=" << scenario << "\n";
|
||||
return 0;
|
||||
}
|
||||
if (!expected_reason.empty() && !result.ok && result.reason == expected_reason) {
|
||||
std::cout << "PASS scenario=" << scenario << " reason=" << result.reason << "\n";
|
||||
return 0;
|
||||
}
|
||||
std::cerr << "FAIL scenario=" << scenario << " ok=" << (result.ok ? "true" : "false") << " reason=" << result.reason << "\n";
|
||||
return 1;
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "rdp_worker/dataplane/token_validator.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
std::string ArgValue(int argc, char** argv, const std::string& name) {
|
||||
for (int i = 1; i + 1 < argc; ++i) {
|
||||
if (argv[i] == name) {
|
||||
return argv[i + 1];
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
std::string ReadFile(const std::string& path) {
|
||||
std::ifstream input(path);
|
||||
if (!input.good()) {
|
||||
return {};
|
||||
}
|
||||
std::stringstream buffer;
|
||||
buffer << input.rdbuf();
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
const auto token = ArgValue(argc, argv, "--token");
|
||||
auto public_key = ArgValue(argc, argv, "--public-key-pem");
|
||||
const auto public_key_file = ArgValue(argc, argv, "--public-key-file");
|
||||
if (public_key.empty() && !public_key_file.empty()) {
|
||||
public_key = ReadFile(public_key_file);
|
||||
}
|
||||
const auto worker_id = ArgValue(argc, argv, "--worker-id");
|
||||
if (token.empty() || public_key.empty() || worker_id.empty()) {
|
||||
std::cerr << "usage: rdp-worker-dataplane-token-probe --token <jwt> --public-key-file <public.pem> --worker-id <worker_id>\n";
|
||||
return 2;
|
||||
}
|
||||
|
||||
rdp_worker::dataplane::DataPlaneTokenValidator validator(public_key, worker_id);
|
||||
const auto result = validator.Validate(token);
|
||||
if (!result.ok) {
|
||||
std::cerr << "FAIL reason=" << result.reason << "\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::cout << "PASS session_id=" << result.claims.session_id
|
||||
<< " attachment_id=" << result.claims.attachment_id
|
||||
<< " worker_id=" << result.claims.worker_id
|
||||
<< " resource_id=" << result.claims.resource_id
|
||||
<< " channels=" << result.claims.allowed_channels.size()
|
||||
<< "\n";
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "rdp_worker/graphics/graphics_adapter.hpp"
|
||||
|
||||
int main() {
|
||||
rdp_worker::graphics::GraphicsAdapter adapter;
|
||||
|
||||
auto full = adapter.MakeFullBgraFrame(1, 1280, 720, 5120, std::vector<std::uint8_t>(1280 * 720 * 4), true);
|
||||
if (!rdp_worker::graphics::IsFullFrameUpdate(full) || !full.baseline || full.droppable) {
|
||||
std::cerr << "full frame baseline policy failed\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto region = adapter.TryMakeBgraRegion(
|
||||
2,
|
||||
1280,
|
||||
720,
|
||||
200 * 4,
|
||||
rdp_worker::graphics::Rect{10, 20, 200, 100},
|
||||
std::vector<std::uint8_t>(200 * 100 * 4));
|
||||
if (!region.has_value() || !rdp_worker::graphics::IsRegionUpdate(*region) || !region->droppable) {
|
||||
std::cerr << "region policy failed\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto too_large = adapter.TryMakeBgraRegion(
|
||||
3,
|
||||
1280,
|
||||
720,
|
||||
1280 * 4,
|
||||
rdp_worker::graphics::Rect{0, 0, 1280, 720},
|
||||
std::vector<std::uint8_t>(1280 * 720 * 4));
|
||||
if (too_large.has_value()) {
|
||||
std::cerr << "oversized region should be rejected\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::cout << "graphics_adapter_probe ok\n";
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "rdp_worker/adapter/service_adapter_protocol.hpp"
|
||||
|
||||
int main() {
|
||||
using namespace rdp_worker::adapter;
|
||||
|
||||
for (const auto& spec : AllChannelSpecs()) {
|
||||
std::cout << spec.name
|
||||
<< " direction=" << DirectionName(spec.direction)
|
||||
<< " reliability=" << ReliabilityName(spec.reliability)
|
||||
<< " priority=" << PriorityValue(spec.priority)
|
||||
<< " droppable=" << (spec.stale_updates_droppable ? "true" : "false")
|
||||
<< " may_block_input=" << (spec.may_block_input ? "true" : "false")
|
||||
<< '\n';
|
||||
}
|
||||
|
||||
if (!ValidateAdapterChannelInvariants()) {
|
||||
std::cerr << "adapter channel invariants failed\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user