Initial project snapshot
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,240 @@
|
||||
#include "rdp_worker/dataplane/token_validator.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <chrono>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <string_view>
|
||||
#include <utility>
|
||||
|
||||
#include <openssl/bio.h>
|
||||
#include <openssl/evp.h>
|
||||
#include <openssl/pem.h>
|
||||
|
||||
#include "rdp_worker/common/json.hpp"
|
||||
|
||||
namespace rdp_worker::dataplane {
|
||||
|
||||
namespace {
|
||||
|
||||
std::vector<std::string> SplitJwt(const std::string& token) {
|
||||
std::vector<std::string> parts;
|
||||
std::size_t start = 0;
|
||||
while (true) {
|
||||
const std::size_t dot = token.find('.', start);
|
||||
parts.push_back(token.substr(start, dot == std::string::npos ? std::string::npos : dot - start));
|
||||
if (dot == std::string::npos) {
|
||||
break;
|
||||
}
|
||||
start = dot + 1;
|
||||
}
|
||||
return parts;
|
||||
}
|
||||
|
||||
std::vector<std::uint8_t> Base64UrlDecode(const std::string& input) {
|
||||
static constexpr unsigned char kInvalid = 255;
|
||||
std::array<unsigned char, 256> table{};
|
||||
table.fill(kInvalid);
|
||||
const std::string alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
|
||||
for (std::size_t i = 0; i < alphabet.size(); ++i) {
|
||||
table[static_cast<unsigned char>(alphabet[i])] = static_cast<unsigned char>(i);
|
||||
}
|
||||
|
||||
std::vector<std::uint8_t> out;
|
||||
int value = 0;
|
||||
int value_bits = -8;
|
||||
for (unsigned char ch : input) {
|
||||
if (ch == '=') {
|
||||
break;
|
||||
}
|
||||
if (table[ch] == kInvalid) {
|
||||
throw std::runtime_error("invalid base64url token segment");
|
||||
}
|
||||
value = (value << 6) + table[ch];
|
||||
value_bits += 6;
|
||||
if (value_bits >= 0) {
|
||||
out.push_back(static_cast<std::uint8_t>((value >> value_bits) & 0xFF));
|
||||
value_bits -= 8;
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
std::string DecodeStringSegment(const std::string& input) {
|
||||
const auto bytes = Base64UrlDecode(input);
|
||||
return std::string(reinterpret_cast<const char*>(bytes.data()), bytes.size());
|
||||
}
|
||||
|
||||
struct EvpKeyDeleter {
|
||||
void operator()(EVP_PKEY* key) const {
|
||||
EVP_PKEY_free(key);
|
||||
}
|
||||
};
|
||||
|
||||
struct BioDeleter {
|
||||
void operator()(BIO* bio) const {
|
||||
BIO_free(bio);
|
||||
}
|
||||
};
|
||||
|
||||
struct EvpMdCtxDeleter {
|
||||
void operator()(EVP_MD_CTX* context) const {
|
||||
EVP_MD_CTX_free(context);
|
||||
}
|
||||
};
|
||||
|
||||
using EvpKeyPtr = std::unique_ptr<EVP_PKEY, EvpKeyDeleter>;
|
||||
using BioPtr = std::unique_ptr<BIO, BioDeleter>;
|
||||
using EvpMdCtxPtr = std::unique_ptr<EVP_MD_CTX, EvpMdCtxDeleter>;
|
||||
|
||||
bool VerifyRs256(const std::string& public_key_pem, const std::string& signing_input, const std::vector<std::uint8_t>& signature) {
|
||||
BioPtr bio(BIO_new_mem_buf(public_key_pem.data(), static_cast<int>(public_key_pem.size())));
|
||||
if (!bio) {
|
||||
throw std::runtime_error("public_key_bio_unavailable");
|
||||
}
|
||||
EvpKeyPtr key(PEM_read_bio_PUBKEY(bio.get(), nullptr, nullptr, nullptr));
|
||||
if (!key) {
|
||||
throw std::runtime_error("public_key_parse_failed");
|
||||
}
|
||||
EvpMdCtxPtr context(EVP_MD_CTX_new());
|
||||
if (!context) {
|
||||
throw std::runtime_error("signature_context_unavailable");
|
||||
}
|
||||
if (EVP_DigestVerifyInit(context.get(), nullptr, EVP_sha256(), nullptr, key.get()) != 1) {
|
||||
throw std::runtime_error("signature_verify_init_failed");
|
||||
}
|
||||
if (EVP_DigestVerifyUpdate(context.get(), signing_input.data(), signing_input.size()) != 1) {
|
||||
throw std::runtime_error("signature_verify_update_failed");
|
||||
}
|
||||
return EVP_DigestVerifyFinal(context.get(), signature.data(), signature.size()) == 1;
|
||||
}
|
||||
|
||||
std::int64_t UnixNow() {
|
||||
return std::chrono::duration_cast<std::chrono::seconds>(
|
||||
std::chrono::system_clock::now().time_since_epoch())
|
||||
.count();
|
||||
}
|
||||
|
||||
bool IsKnownChannel(const std::string& channel) {
|
||||
return channel == "control" ||
|
||||
channel == "input" ||
|
||||
channel == "render" ||
|
||||
channel == "clipboard" ||
|
||||
channel == "file_upload" ||
|
||||
channel == "file_download" ||
|
||||
channel == "telemetry";
|
||||
}
|
||||
|
||||
bool ArrayContainsString(const rdp_worker::common::JsonArray* array, const std::string& expected) {
|
||||
if (array == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return std::any_of(array->begin(), array->end(), [&](const auto& item) {
|
||||
return item.IsString() && item.AsString() == expected;
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<std::string> ParseAllowedChannels(const rdp_worker::common::JsonObject& payload) {
|
||||
const auto* array = rdp_worker::common::GetArray(payload, "allowed_channels");
|
||||
if (array == nullptr || array->empty()) {
|
||||
throw std::runtime_error("allowed_channels is required");
|
||||
}
|
||||
std::vector<std::string> channels;
|
||||
for (const auto& item : *array) {
|
||||
if (!item.IsString() || !IsKnownChannel(item.AsString())) {
|
||||
throw std::runtime_error("allowed_channels contains unsupported channel");
|
||||
}
|
||||
channels.push_back(item.AsString());
|
||||
}
|
||||
return channels;
|
||||
}
|
||||
|
||||
std::string RequiredString(const rdp_worker::common::JsonObject& object, const std::string& key) {
|
||||
const auto value = rdp_worker::common::GetString(object, key);
|
||||
if (!value.has_value() || value->empty()) {
|
||||
throw std::runtime_error(key + " is required");
|
||||
}
|
||||
return *value;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
DataPlaneTokenValidator::DataPlaneTokenValidator(std::string public_key_pem, std::string expected_worker_id)
|
||||
: public_key_pem_(std::move(public_key_pem)),
|
||||
expected_worker_id_(std::move(expected_worker_id)) {}
|
||||
|
||||
TokenValidationResult DataPlaneTokenValidator::Validate(const std::string& token) const {
|
||||
TokenValidationResult result{};
|
||||
try {
|
||||
if (public_key_pem_.empty()) {
|
||||
result.reason = "token_public_key_not_configured";
|
||||
return result;
|
||||
}
|
||||
const auto parts = SplitJwt(token);
|
||||
if (parts.size() != 3 || parts[0].empty() || parts[1].empty() || parts[2].empty()) {
|
||||
result.reason = "malformed_token";
|
||||
return result;
|
||||
}
|
||||
|
||||
const auto header = rdp_worker::common::ParseJson(DecodeStringSegment(parts[0])).AsObject();
|
||||
if (rdp_worker::common::GetString(header, "alg").value_or("") != "RS256" ||
|
||||
rdp_worker::common::GetString(header, "typ").value_or("JWT") != "JWT") {
|
||||
result.reason = "unsupported_token_header";
|
||||
return result;
|
||||
}
|
||||
|
||||
const std::string signing_input = parts[0] + "." + parts[1];
|
||||
const auto actual_signature = Base64UrlDecode(parts[2]);
|
||||
if (!VerifyRs256(public_key_pem_, signing_input, actual_signature)) {
|
||||
result.reason = "invalid_signature";
|
||||
return result;
|
||||
}
|
||||
|
||||
const auto payload = rdp_worker::common::ParseJson(DecodeStringSegment(parts[1])).AsObject();
|
||||
DataPlaneTokenClaims claims{};
|
||||
claims.session_id = RequiredString(payload, "session_id");
|
||||
claims.attachment_id = RequiredString(payload, "attachment_id");
|
||||
claims.user_id = RequiredString(payload, "user_id");
|
||||
claims.organization_id = RequiredString(payload, "organization_id");
|
||||
claims.worker_id = RequiredString(payload, "worker_id");
|
||||
claims.resource_id = RequiredString(payload, "resource_id");
|
||||
claims.jti = RequiredString(payload, "jti");
|
||||
claims.allowed_channels = ParseAllowedChannels(payload);
|
||||
claims.expires_at_unix = static_cast<std::int64_t>(rdp_worker::common::GetNumber(payload, "exp").value_or(0));
|
||||
|
||||
const auto now = UnixNow();
|
||||
if (claims.expires_at_unix <= now) {
|
||||
result.reason = "token_expired";
|
||||
return result;
|
||||
}
|
||||
const auto not_before = static_cast<std::int64_t>(rdp_worker::common::GetNumber(payload, "nbf").value_or(0));
|
||||
if (not_before > 0 && not_before > now) {
|
||||
result.reason = "token_not_yet_valid";
|
||||
return result;
|
||||
}
|
||||
if (!expected_worker_id_.empty() && claims.worker_id != expected_worker_id_) {
|
||||
result.reason = "wrong_worker";
|
||||
return result;
|
||||
}
|
||||
if (!ArrayContainsString(rdp_worker::common::GetArray(payload, "aud"), "rap-data-plane") ||
|
||||
!ArrayContainsString(rdp_worker::common::GetArray(payload, "aud"), "worker:" + claims.worker_id)) {
|
||||
result.reason = "invalid_audience";
|
||||
return result;
|
||||
}
|
||||
if (std::find(claims.allowed_channels.begin(), claims.allowed_channels.end(), "control") == claims.allowed_channels.end()) {
|
||||
result.reason = "control_channel_not_allowed";
|
||||
return result;
|
||||
}
|
||||
|
||||
result.ok = true;
|
||||
result.claims = std::move(claims);
|
||||
return result;
|
||||
} catch (const std::exception& error) {
|
||||
result.reason = error.what();
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace rdp_worker::dataplane
|
||||
Reference in New Issue
Block a user