Initial project snapshot

This commit is contained in:
2026-04-28 22:29:50 +03:00
commit 8ba0561f4f
365 changed files with 91832 additions and 0 deletions
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,240 @@
#include "rdp_worker/dataplane/token_validator.hpp"
#include <algorithm>
#include <array>
#include <chrono>
#include <cstdint>
#include <memory>
#include <stdexcept>
#include <string_view>
#include <utility>
#include <openssl/bio.h>
#include <openssl/evp.h>
#include <openssl/pem.h>
#include "rdp_worker/common/json.hpp"
namespace rdp_worker::dataplane {
namespace {
std::vector<std::string> SplitJwt(const std::string& token) {
std::vector<std::string> parts;
std::size_t start = 0;
while (true) {
const std::size_t dot = token.find('.', start);
parts.push_back(token.substr(start, dot == std::string::npos ? std::string::npos : dot - start));
if (dot == std::string::npos) {
break;
}
start = dot + 1;
}
return parts;
}
std::vector<std::uint8_t> Base64UrlDecode(const std::string& input) {
static constexpr unsigned char kInvalid = 255;
std::array<unsigned char, 256> table{};
table.fill(kInvalid);
const std::string alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
for (std::size_t i = 0; i < alphabet.size(); ++i) {
table[static_cast<unsigned char>(alphabet[i])] = static_cast<unsigned char>(i);
}
std::vector<std::uint8_t> out;
int value = 0;
int value_bits = -8;
for (unsigned char ch : input) {
if (ch == '=') {
break;
}
if (table[ch] == kInvalid) {
throw std::runtime_error("invalid base64url token segment");
}
value = (value << 6) + table[ch];
value_bits += 6;
if (value_bits >= 0) {
out.push_back(static_cast<std::uint8_t>((value >> value_bits) & 0xFF));
value_bits -= 8;
}
}
return out;
}
std::string DecodeStringSegment(const std::string& input) {
const auto bytes = Base64UrlDecode(input);
return std::string(reinterpret_cast<const char*>(bytes.data()), bytes.size());
}
struct EvpKeyDeleter {
void operator()(EVP_PKEY* key) const {
EVP_PKEY_free(key);
}
};
struct BioDeleter {
void operator()(BIO* bio) const {
BIO_free(bio);
}
};
struct EvpMdCtxDeleter {
void operator()(EVP_MD_CTX* context) const {
EVP_MD_CTX_free(context);
}
};
using EvpKeyPtr = std::unique_ptr<EVP_PKEY, EvpKeyDeleter>;
using BioPtr = std::unique_ptr<BIO, BioDeleter>;
using EvpMdCtxPtr = std::unique_ptr<EVP_MD_CTX, EvpMdCtxDeleter>;
bool VerifyRs256(const std::string& public_key_pem, const std::string& signing_input, const std::vector<std::uint8_t>& signature) {
BioPtr bio(BIO_new_mem_buf(public_key_pem.data(), static_cast<int>(public_key_pem.size())));
if (!bio) {
throw std::runtime_error("public_key_bio_unavailable");
}
EvpKeyPtr key(PEM_read_bio_PUBKEY(bio.get(), nullptr, nullptr, nullptr));
if (!key) {
throw std::runtime_error("public_key_parse_failed");
}
EvpMdCtxPtr context(EVP_MD_CTX_new());
if (!context) {
throw std::runtime_error("signature_context_unavailable");
}
if (EVP_DigestVerifyInit(context.get(), nullptr, EVP_sha256(), nullptr, key.get()) != 1) {
throw std::runtime_error("signature_verify_init_failed");
}
if (EVP_DigestVerifyUpdate(context.get(), signing_input.data(), signing_input.size()) != 1) {
throw std::runtime_error("signature_verify_update_failed");
}
return EVP_DigestVerifyFinal(context.get(), signature.data(), signature.size()) == 1;
}
std::int64_t UnixNow() {
return std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
}
bool IsKnownChannel(const std::string& channel) {
return channel == "control" ||
channel == "input" ||
channel == "render" ||
channel == "clipboard" ||
channel == "file_upload" ||
channel == "file_download" ||
channel == "telemetry";
}
bool ArrayContainsString(const rdp_worker::common::JsonArray* array, const std::string& expected) {
if (array == nullptr) {
return false;
}
return std::any_of(array->begin(), array->end(), [&](const auto& item) {
return item.IsString() && item.AsString() == expected;
});
}
std::vector<std::string> ParseAllowedChannels(const rdp_worker::common::JsonObject& payload) {
const auto* array = rdp_worker::common::GetArray(payload, "allowed_channels");
if (array == nullptr || array->empty()) {
throw std::runtime_error("allowed_channels is required");
}
std::vector<std::string> channels;
for (const auto& item : *array) {
if (!item.IsString() || !IsKnownChannel(item.AsString())) {
throw std::runtime_error("allowed_channels contains unsupported channel");
}
channels.push_back(item.AsString());
}
return channels;
}
std::string RequiredString(const rdp_worker::common::JsonObject& object, const std::string& key) {
const auto value = rdp_worker::common::GetString(object, key);
if (!value.has_value() || value->empty()) {
throw std::runtime_error(key + " is required");
}
return *value;
}
} // namespace
DataPlaneTokenValidator::DataPlaneTokenValidator(std::string public_key_pem, std::string expected_worker_id)
: public_key_pem_(std::move(public_key_pem)),
expected_worker_id_(std::move(expected_worker_id)) {}
TokenValidationResult DataPlaneTokenValidator::Validate(const std::string& token) const {
TokenValidationResult result{};
try {
if (public_key_pem_.empty()) {
result.reason = "token_public_key_not_configured";
return result;
}
const auto parts = SplitJwt(token);
if (parts.size() != 3 || parts[0].empty() || parts[1].empty() || parts[2].empty()) {
result.reason = "malformed_token";
return result;
}
const auto header = rdp_worker::common::ParseJson(DecodeStringSegment(parts[0])).AsObject();
if (rdp_worker::common::GetString(header, "alg").value_or("") != "RS256" ||
rdp_worker::common::GetString(header, "typ").value_or("JWT") != "JWT") {
result.reason = "unsupported_token_header";
return result;
}
const std::string signing_input = parts[0] + "." + parts[1];
const auto actual_signature = Base64UrlDecode(parts[2]);
if (!VerifyRs256(public_key_pem_, signing_input, actual_signature)) {
result.reason = "invalid_signature";
return result;
}
const auto payload = rdp_worker::common::ParseJson(DecodeStringSegment(parts[1])).AsObject();
DataPlaneTokenClaims claims{};
claims.session_id = RequiredString(payload, "session_id");
claims.attachment_id = RequiredString(payload, "attachment_id");
claims.user_id = RequiredString(payload, "user_id");
claims.organization_id = RequiredString(payload, "organization_id");
claims.worker_id = RequiredString(payload, "worker_id");
claims.resource_id = RequiredString(payload, "resource_id");
claims.jti = RequiredString(payload, "jti");
claims.allowed_channels = ParseAllowedChannels(payload);
claims.expires_at_unix = static_cast<std::int64_t>(rdp_worker::common::GetNumber(payload, "exp").value_or(0));
const auto now = UnixNow();
if (claims.expires_at_unix <= now) {
result.reason = "token_expired";
return result;
}
const auto not_before = static_cast<std::int64_t>(rdp_worker::common::GetNumber(payload, "nbf").value_or(0));
if (not_before > 0 && not_before > now) {
result.reason = "token_not_yet_valid";
return result;
}
if (!expected_worker_id_.empty() && claims.worker_id != expected_worker_id_) {
result.reason = "wrong_worker";
return result;
}
if (!ArrayContainsString(rdp_worker::common::GetArray(payload, "aud"), "rap-data-plane") ||
!ArrayContainsString(rdp_worker::common::GetArray(payload, "aud"), "worker:" + claims.worker_id)) {
result.reason = "invalid_audience";
return result;
}
if (std::find(claims.allowed_channels.begin(), claims.allowed_channels.end(), "control") == claims.allowed_channels.end()) {
result.reason = "control_channel_not_allowed";
return result;
}
result.ok = true;
result.claims = std::move(claims);
return result;
} catch (const std::exception& error) {
result.reason = error.what();
return result;
}
}
} // namespace rdp_worker::dataplane