Initial project snapshot

This commit is contained in:
2026-04-28 22:29:50 +03:00
commit 8ba0561f4f
365 changed files with 91832 additions and 0 deletions
@@ -0,0 +1,300 @@
#include "rdp_worker/coordination/control_plane.hpp"
#include <sstream>
#include <stdexcept>
#include "rdp_worker/common/logger.hpp"
#include "rdp_worker/common/time.hpp"
namespace rdp_worker::coordination {
using common::GetArray;
using common::GetBool;
using common::GetNumber;
using common::GetObject;
using common::GetString;
using common::JsonArray;
using common::JsonObject;
using common::JsonValue;
ControlPlane::ControlPlane(config::Config config, std::shared_ptr<common::Logger> logger)
: config_(std::move(config)),
logger_(std::move(logger)),
redis_(std::make_unique<RedisClient>(config_.redis_host, config_.redis_port, config_.redis_password, config_.redis_db)) {}
void ControlPlane::Connect() {
std::lock_guard<std::mutex> lock(mutex_);
redis_->Connect();
}
void ControlPlane::RegisterWorker() {
std::lock_guard<std::mutex> lock(mutex_);
redis_->Set("worker:registration:" + config_.worker_id, WorkerRegistrationPayload(), config_.worker_heartbeat_interval * 3);
redis_->SAdd("worker:registrations", config_.worker_id);
}
void ControlPlane::ReleaseOwnedLeasesOnStartup() {
std::lock_guard<std::mutex> lock(mutex_);
int released = 0;
for (const auto& lease_id : redis_->SMembers("worker:leases")) {
auto encoded = redis_->Get("worker:lease:" + lease_id);
if (!encoded.has_value()) {
redis_->SRem("worker:leases", lease_id);
continue;
}
auto lease = ParseLease(common::ParseJson(*encoded).AsObject());
if (lease.worker_id != config_.worker_id) {
continue;
}
redis_->Delete("worker:lease:" + lease_id);
redis_->SRem("worker:leases", lease_id);
if (!lease.session_id.empty()) {
redis_->Delete("worker:session-lease:" + lease.session_id);
redis_->Delete("worker:queue:" + lease.session_id);
}
++released;
}
if (released > 0) {
logger_->Warn("released stale owned worker leases on startup worker=" + config_.worker_id +
" released_count=" + std::to_string(released));
}
}
void ControlPlane::SendHeartbeat() {
RegisterWorker();
}
std::optional<runtime::Assignment> ControlPlane::PollAssignment(std::chrono::seconds timeout) {
RedisClient stream_client(config_.redis_host, config_.redis_port, config_.redis_password, config_.redis_db);
stream_client.Connect();
auto entry = stream_client.BLPop("worker:control:" + config_.worker_id, timeout);
if (!entry.has_value() || entry->size() != 2) {
return std::nullopt;
}
const JsonObject object = common::ParseJson((*entry)[1]).AsObject();
return ParseAssignment(object);
}
std::optional<common::JsonObject> ControlPlane::PollSessionEnvelope(const std::string& session_id, std::chrono::seconds timeout) {
RedisClient stream_client(config_.redis_host, config_.redis_port, config_.redis_password, config_.redis_db);
stream_client.Connect();
auto entry = stream_client.BLPop("worker:queue:" + session_id, timeout);
if (!entry.has_value() || entry->size() != 2) {
return std::nullopt;
}
auto object = common::ParseJson((*entry)[1]).AsObject();
const JsonObject* payload = GetObject(object, "payload");
if (payload != nullptr) {
const std::string type = GetString(object, "type").value_or("");
const std::string correlation_id = GetString(*payload, "correlation_id").value_or("");
if (type == "input" && !correlation_id.empty()) {
logger_->Info("input.trace worker_queue_pop session=" + session_id +
" correlation_id=" + correlation_id +
" trace_stage=worker_queue_pop");
}
}
return object;
}
std::vector<common::JsonObject> ControlPlane::DrainSessionEnvelopes(const std::string& session_id, std::size_t max_count) {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<common::JsonObject> output;
output.reserve(max_count);
const std::string key = "worker:queue:" + session_id;
for (std::size_t i = 0; i < max_count; ++i) {
auto encoded = redis_->LPop(key);
if (!encoded.has_value()) {
break;
}
auto object = common::ParseJson(*encoded).AsObject();
const JsonObject* payload = GetObject(object, "payload");
if (payload != nullptr) {
const std::string type = GetString(object, "type").value_or("");
const std::string correlation_id = GetString(*payload, "correlation_id").value_or("");
if (type == "input" && !correlation_id.empty()) {
logger_->Info("input.trace worker_queue_pop session=" + session_id +
" correlation_id=" + correlation_id +
" trace_stage=worker_queue_pop");
}
}
output.push_back(std::move(object));
}
return output;
}
int64_t ControlPlane::SessionEnvelopeQueueLength(const std::string& session_id) {
std::lock_guard<std::mutex> lock(mutex_);
return redis_->LLen("worker:queue:" + session_id);
}
std::optional<runtime::WorkerLease> ControlPlane::GetLeaseBySession(const std::string& session_id) {
std::lock_guard<std::mutex> lock(mutex_);
auto lease_id = redis_->Get("worker:session-lease:" + session_id);
if (!lease_id.has_value()) {
return std::nullopt;
}
auto encoded = redis_->Get("worker:lease:" + *lease_id);
if (!encoded.has_value()) {
return std::nullopt;
}
return ParseLease(common::ParseJson(*encoded).AsObject());
}
void ControlPlane::RenewLease(const runtime::WorkerLease& lease) {
std::lock_guard<std::mutex> lock(mutex_);
redis_->Set("worker:lease:" + lease.lease_id, LeasePayload(lease), std::chrono::seconds(45));
redis_->Set("worker:session-lease:" + lease.session_id, lease.lease_id, std::chrono::seconds(45));
}
void ControlPlane::ReleaseLease(const runtime::WorkerLease& lease) {
std::lock_guard<std::mutex> lock(mutex_);
redis_->Delete("worker:lease:" + lease.lease_id);
redis_->Delete("worker:session-lease:" + lease.session_id);
}
void ControlPlane::PublishEvent(const runtime::WorkerEvent& event) {
std::lock_guard<std::mutex> lock(mutex_);
const std::string encoded = EventPayload(event);
redis_->RPush("worker:events", encoded);
redis_->Expire("worker:events", std::chrono::minutes(10));
}
runtime::Assignment ControlPlane::ParseAssignment(const JsonObject& object) const {
runtime::Assignment assignment{};
assignment.session_id = GetString(object, "session_id").value_or("");
assignment.worker_id = GetString(object, "worker_id").value_or("");
assignment.attachment_id = GetString(object, "attachment_id").value_or("");
assignment.user_id = GetString(object, "user_id").value_or("");
assignment.device_id = GetString(object, "device_id").value_or("");
assignment.takeover_of = GetString(object, "takeover_of");
const std::string state = GetString(object, "state").value_or("starting");
if (state == "active") {
assignment.state = runtime::SessionState::kActive;
} else if (state == "detached") {
assignment.state = runtime::SessionState::kDetached;
} else if (state == "reconnecting") {
assignment.state = runtime::SessionState::kReconnecting;
} else {
assignment.state = runtime::SessionState::kStarting;
}
const JsonObject* metadata = GetObject(object, "metadata");
if (metadata == nullptr) {
throw std::runtime_error("assignment metadata is required");
}
const JsonObject* resource = GetObject(*metadata, "resource");
if (resource == nullptr) {
throw std::runtime_error("assignment resource metadata is required");
}
assignment.organization_id = GetString(*resource, "organization_id").value_or("");
assignment.connection.resource_id = GetString(*resource, "id").value_or("");
assignment.connection.resource_name = GetString(*resource, "name").value_or("");
assignment.connection.host = GetString(*resource, "address").value_or("");
assignment.connection.port = 3389;
assignment.connection.username = "";
assignment.connection.password = "";
assignment.connection.domain = "";
assignment.connection.certificate_verification_mode = GetString(*resource, "certificate_verification_mode").value_or("strict");
assignment.connection.render_quality_profile = GetString(*resource, "render_quality_profile").value_or("balanced");
assignment.connection.insecure_skip_verify = config_.insecure_skip_verify;
const JsonObject* resource_meta = GetObject(*resource, "metadata");
if (resource_meta != nullptr) {
assignment.connection.host = GetString(*resource_meta, "rdp_host").value_or(assignment.connection.host);
assignment.connection.port = static_cast<uint16_t>(GetNumber(*resource_meta, "rdp_port").value_or(3389));
assignment.connection.username = GetString(*resource_meta, "username").value_or("");
assignment.connection.password = GetString(*resource_meta, "password").value_or("");
assignment.connection.domain = GetString(*resource_meta, "domain").value_or("");
assignment.connection.certificate_verification_mode =
GetString(*resource_meta, "certificate_verification_mode").value_or(assignment.connection.certificate_verification_mode);
assignment.connection.render_quality_profile =
GetString(*resource_meta, "render_quality_profile").value_or(assignment.connection.render_quality_profile);
}
const JsonObject* policy = GetObject(*metadata, "policy");
if (policy != nullptr) {
assignment.policy.detach_grace_period = std::chrono::seconds(static_cast<int>(GetNumber(*policy, "detach_grace_period_seconds").value_or(1800)));
assignment.policy.clipboard_mode = GetString(*policy, "clipboard_mode").value_or("disabled");
if (assignment.policy.clipboard_mode.empty()) {
assignment.policy.clipboard_mode = GetBool(*policy, "clipboard_enabled").value_or(false) ? "bidirectional" : "disabled";
}
assignment.policy.file_transfer_mode = GetString(*policy, "file_transfer_mode").value_or("disabled");
if (assignment.policy.file_transfer_mode.empty()) {
assignment.policy.file_transfer_mode = GetBool(*policy, "file_transfer_enabled").value_or(false) ? "client_to_server" : "disabled";
}
}
return assignment;
}
runtime::WorkerLease ControlPlane::ParseLease(const JsonObject& object) const {
runtime::WorkerLease lease{};
lease.lease_id = GetString(object, "lease_id").value_or("");
lease.worker_id = GetString(object, "worker_id").value_or("");
lease.session_id = GetString(object, "session_id").value_or("");
lease.resource_id = GetString(object, "resource_id").value_or("");
lease.control_stream = GetString(object, "control_stream").value_or("");
lease.expires_at = GetString(object, "expires_at").value_or("");
if (const JsonArray* capabilities = GetArray(object, "capabilities"); capabilities != nullptr) {
for (const auto& item : *capabilities) {
if (item.IsString()) {
lease.capabilities.push_back(item.AsString());
}
}
}
return lease;
}
std::string ControlPlane::WorkerRegistrationPayload() const {
JsonArray capabilities;
for (const auto& item : config_.capabilities) {
capabilities.emplace_back(item);
}
return common::SerializeJson(JsonObject{
{"worker_id", config_.worker_id},
{"protocol", "rdp"},
{"status", "online"},
{"capabilities", capabilities},
{"control_stream", "worker://control/" + config_.worker_id},
{"last_heartbeat_at", common::ToRfc3339(common::NowUtc())},
});
}
std::string ControlPlane::LeasePayload(const runtime::WorkerLease& lease) const {
JsonArray capabilities;
for (const auto& capability : lease.capabilities) {
capabilities.emplace_back(capability);
}
return common::SerializeJson(JsonObject{
{"lease_id", lease.lease_id},
{"worker_id", lease.worker_id},
{"protocol", "rdp"},
{"resource_id", lease.resource_id},
{"session_id", lease.session_id},
{"capabilities", capabilities},
{"control_stream", lease.control_stream},
{"expires_at", common::ToRfc3339(common::NowUtc() + std::chrono::seconds(45))},
});
}
std::string ControlPlane::EventPayload(const runtime::WorkerEvent& event) const {
JsonObject payload{
{"type", event.type},
{"session_id", event.session_id},
{"worker_id", event.worker_id},
};
JsonObject detail;
if (!event.reason.empty()) {
detail.emplace("reason", event.reason);
}
if (event.width > 0) {
detail.emplace("width", event.width);
}
if (event.height > 0) {
detail.emplace("height", event.height);
}
for (const auto& [key, value] : event.payload) {
detail[key] = value;
}
payload.emplace("payload", detail);
return common::SerializeJson(payload);
}
} // namespace rdp_worker::coordination
@@ -0,0 +1,264 @@
#include "rdp_worker/coordination/redis_client.hpp"
#include <cstring>
#include <stdexcept>
#include <string_view>
#if defined(_WIN32)
#include <winsock2.h>
#include <ws2tcpip.h>
#else
#include <arpa/inet.h>
#include <netdb.h>
#include <sys/socket.h>
#include <unistd.h>
#endif
namespace rdp_worker::coordination {
namespace {
void EnsureSuccess(bool condition, const std::string& message) {
if (!condition) {
throw std::runtime_error(message);
}
}
void CloseSocket(int socket_fd) {
if (socket_fd < 0) {
return;
}
#if defined(_WIN32)
closesocket(socket_fd);
#else
close(socket_fd);
#endif
}
} // namespace
bool RedisReply::IsNull() const { return std::holds_alternative<std::nullptr_t>(value); }
bool RedisReply::IsString() const { return std::holds_alternative<std::string>(value); }
bool RedisReply::IsInteger() const { return std::holds_alternative<int64_t>(value); }
bool RedisReply::IsArray() const { return std::holds_alternative<Array>(value); }
const std::string& RedisReply::AsString() const { return std::get<std::string>(value); }
int64_t RedisReply::AsInteger() const { return std::get<int64_t>(value); }
const RedisReply::Array& RedisReply::AsArray() const { return std::get<Array>(value); }
RedisClient::RedisClient(std::string host, int port, std::string password, int db)
: host_(std::move(host)),
port_(port),
password_(std::move(password)),
db_(db),
socket_fd_(-1) {}
RedisClient::~RedisClient() {
Close();
}
void RedisClient::Connect() {
#if defined(_WIN32)
WSADATA wsa_data{};
EnsureSuccess(WSAStartup(MAKEWORD(2, 2), &wsa_data) == 0, "WSAStartup failed");
#endif
addrinfo hints{};
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
addrinfo* result = nullptr;
EnsureSuccess(getaddrinfo(host_.c_str(), std::to_string(port_).c_str(), &hints, &result) == 0, "getaddrinfo failed");
for (addrinfo* node = result; node != nullptr; node = node->ai_next) {
socket_fd_ = static_cast<int>(socket(node->ai_family, node->ai_socktype, node->ai_protocol));
if (socket_fd_ < 0) {
continue;
}
if (connect(socket_fd_, node->ai_addr, static_cast<int>(node->ai_addrlen)) == 0) {
break;
}
CloseSocket(socket_fd_);
socket_fd_ = -1;
}
freeaddrinfo(result);
EnsureSuccess(socket_fd_ >= 0, "failed to connect to Redis");
if (!password_.empty()) {
Command({"AUTH", password_});
}
if (db_ != 0) {
Command({"SELECT", std::to_string(db_)});
}
}
void RedisClient::Close() {
CloseSocket(socket_fd_);
socket_fd_ = -1;
#if defined(_WIN32)
WSACleanup();
#endif
}
RedisReply RedisClient::Command(const std::vector<std::string>& parts) {
WriteAll(EncodeCommand(parts));
return ReadReply();
}
std::optional<std::string> RedisClient::Get(const std::string& key) {
RedisReply reply = Command({"GET", key});
if (reply.IsNull()) {
return std::nullopt;
}
return reply.AsString();
}
void RedisClient::Set(const std::string& key, const std::string& value, std::chrono::seconds ttl) {
Command({"SET", key, value, "EX", std::to_string(ttl.count())});
}
void RedisClient::SAdd(const std::string& key, const std::string& value) {
Command({"SADD", key, value});
}
void RedisClient::SRem(const std::string& key, const std::string& value) {
Command({"SREM", key, value});
}
std::vector<std::string> RedisClient::SMembers(const std::string& key) {
RedisReply reply = Command({"SMEMBERS", key});
std::vector<std::string> output;
if (!reply.IsArray()) {
return output;
}
for (const auto& item : reply.AsArray()) {
if (item.IsString()) {
output.push_back(item.AsString());
}
}
return output;
}
std::optional<std::vector<std::string>> RedisClient::BLPop(const std::string& key, std::chrono::seconds timeout) {
RedisReply reply = Command({"BLPOP", key, std::to_string(timeout.count())});
if (reply.IsNull()) {
return std::nullopt;
}
std::vector<std::string> output;
for (const auto& item : reply.AsArray()) {
if (item.IsString()) {
output.push_back(item.AsString());
}
}
return output;
}
std::optional<std::string> RedisClient::LPop(const std::string& key) {
RedisReply reply = Command({"LPOP", key});
if (reply.IsNull()) {
return std::nullopt;
}
return reply.AsString();
}
int64_t RedisClient::LLen(const std::string& key) {
RedisReply reply = Command({"LLEN", key});
if (!reply.IsInteger()) {
return 0;
}
return reply.AsInteger();
}
void RedisClient::RPush(const std::string& key, const std::string& value) {
Command({"RPUSH", key, value});
}
void RedisClient::Expire(const std::string& key, std::chrono::seconds ttl) {
Command({"EXPIRE", key, std::to_string(ttl.count())});
}
void RedisClient::Delete(const std::string& key) {
Command({"DEL", key});
}
std::string RedisClient::ReadLine() {
std::string output;
char ch = '\0';
while (true) {
const int received = recv(socket_fd_, &ch, 1, 0);
EnsureSuccess(received == 1, "failed to read from Redis");
if (ch == '\r') {
char lf = '\0';
EnsureSuccess(recv(socket_fd_, &lf, 1, 0) == 1 && lf == '\n', "invalid Redis line ending");
break;
}
output.push_back(ch);
}
return output;
}
std::string RedisClient::ReadBytes(std::size_t count) {
std::string output(count, '\0');
std::size_t offset = 0;
while (offset < count) {
const int received = recv(socket_fd_, output.data() + offset, static_cast<int>(count - offset), 0);
EnsureSuccess(received > 0, "failed to read bulk Redis payload");
offset += static_cast<std::size_t>(received);
}
char suffix[2];
EnsureSuccess(recv(socket_fd_, suffix, 2, 0) == 2 && suffix[0] == '\r' && suffix[1] == '\n', "invalid Redis bulk suffix");
return output;
}
RedisReply RedisClient::ReadReply() {
const std::string line = ReadLine();
EnsureSuccess(!line.empty(), "empty Redis reply");
const char prefix = line[0];
const std::string payload = line.substr(1);
switch (prefix) {
case '+':
return RedisReply{payload};
case ':':
return RedisReply{std::stoll(payload)};
case '$': {
const long long size = std::stoll(payload);
if (size < 0) {
return RedisReply{nullptr};
}
return RedisReply{ReadBytes(static_cast<std::size_t>(size))};
}
case '*': {
const long long size = std::stoll(payload);
if (size < 0) {
return RedisReply{nullptr};
}
RedisReply::Array values;
for (long long i = 0; i < size; ++i) {
values.push_back(ReadReply());
}
return RedisReply{values};
}
case '-':
throw std::runtime_error("Redis error: " + payload);
default:
throw std::runtime_error("unknown Redis reply type");
}
}
void RedisClient::WriteAll(const std::string& data) {
std::size_t offset = 0;
while (offset < data.size()) {
const int sent = send(socket_fd_, data.data() + offset, static_cast<int>(data.size() - offset), 0);
EnsureSuccess(sent > 0, "failed to send Redis command");
offset += static_cast<std::size_t>(sent);
}
}
std::string RedisClient::EncodeCommand(const std::vector<std::string>& parts) const {
std::string output = "*" + std::to_string(parts.size()) + "\r\n";
for (const auto& part : parts) {
output += "$" + std::to_string(part.size()) + "\r\n" + part + "\r\n";
}
return output;
}
} // namespace rdp_worker::coordination