8000 Improve endpoint handling. by ObiWahn · Pull Request #11236 · arangodb/arangodb · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

8000
Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions arangosh/Shell/V8ClientConnection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,9 +510,10 @@ static void ClientConnection_reconnect(v8::FunctionCallbackInfo<v8::Value> const
}

V8SecurityFeature& v8security = v8connection->server().getFeature<V8SecurityFeature>();
if (!v8security.isAllowedToConnectToEndpoint(isolate, endpoint)) {
if (!v8security.isAllowedToConnectToEndpoint(isolate, endpoint, endpoint)) {
TRI_V8_THROW_EXCEPTION_MESSAGE(TRI_ERROR_FORBIDDEN,
"not allowed to connect to this endpoint");
std::string("not allowed to connect to this endpoint") +
endpoint);
}

client->setEndpoint(endpoint);
Expand Down Expand Up @@ -1933,7 +1934,7 @@ v8::Local<v8::Value> V8ClientConnection::requestDataRaw(
TRI_V8_ASCII_STRING(isolate, "body"),
bufObj).FromMaybe(false);
}

for (auto const& it : response->header.meta()) {
headers->Set(context,
TRI_V8_STD_STRING(isolate, it.first),
Expand All @@ -1947,7 +1948,7 @@ v8::Local<v8::Value> V8ClientConnection::requestDataRaw(
headers->Set(context,
TRI_V8_STD_STRING(isolate, StaticStrings::ContentLength),
TRI_V8_STD_STRING(isolate, std::to_string(responseBody.size()))).FromMaybe(false);

}

result->Set(context,
Expand Down
41 changes: 28 additions & 13 deletions lib/ApplicationFeatures/V8SecurityFeature.cpp
10BC0
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,29 @@ void convertToSingleExpression(std::unordered_set<std::string> const& values,
targetRegex = ss.str();
}

bool checkBlackAndWhitelist(std::string const& value, bool hasWhitelist,
struct checkBlackWhiteResult {
bool result;
bool white;
bool black;
};

checkBlackWhiteResult checkBlackAndWhitelist(std::string const& value, bool hasWhitelist,
std::regex const& whitelist, bool hasBlacklist,
std::regex const& blacklist) {
if (!hasWhitelist && !hasBlacklist) {
return true;
return {true, false, false};
}

if (!hasBlacklist) {
// only have a whitelist
return std::regex_search(value, whitelist);
bool white = std::regex_search(value, whitelist);
return {white, white, false};
}

if (!hasWhitelist) {
// only have a blacklist
return !std::regex_search(value, blacklist);
bool black = std::regex_search(value, blacklist);
return {!black, false, black};
}

std::smatch white_result{};
Expand All @@ -145,17 +153,18 @@ bool checkBlackAndWhitelist(std::string const& value, bool hasWhitelist,

if (white && !black) {
// we only have a whitelist hit => allow
return true;
return {true, white, black};
} else if (!white && black) {
// we only have a blacklist hit => deny
return false;
return {false, white, black};
} else if (!white && !black) {
// we have neither a whitelist nor a blacklist hit => deny
return false;
return {false, white, black};
}

// longer match or blacklist wins
return white_result[0].length() > black_result[0].length();
bool white_longer_black = white_result[0].length() > black_result[0].length();
return {white_longer_black, white_longer_black, !white_longer_black};
}
} // namespace

Expand Down Expand Up @@ -410,19 +419,20 @@ bool V8SecurityFeature::shouldExposeStartupOption(v8::Isolate* isolate,
return checkBlackAndWhitelist(name, !_startupOptionsWhitelist.empty(),
_startupOptionsWhitelistRegex,
!_startupOptionsBlacklist.empty(),
_startupOptionsBlacklistRegex);
_startupOptionsBlacklistRegex).result;
}

bool V8SecurityFeature::shouldExposeEnvironmentVariable(v8::Isolate* isolate,
std::string const& name) const {
return checkBlackAndWhitelist(name, !_environmentVariablesWhitelist.empty(),
_environmentVariablesWhitelistRegex,
!_environmentVariablesBlacklist.empty(),
_environmentVariablesBlacklistRegex);
_environmentVariablesBlacklistRegex).result;
}

bool V8SecurityFeature::isAllowedToConnectToEndpoint(v8::Isolate* isolate,
std::string const& name) const {
std::string const& endpoint,
std::string const& url) const {
TRI_GET_GLOBALS();
TRI_ASSERT(v8g != nullptr);
if (v8g->_securityContext.isInternal()) {
Expand All @@ -431,8 +441,13 @@ bool V8SecurityFeature::isAllowedToConnectToEndpoint(v8::Isolate* isolate,
return true;
}

return checkBlackAndWhitelist(name, !_endpointsWhitelist.empty(), _endpointsWhitelistRegex,
auto endpointResult = checkBlackAndWhitelist(endpoint, !_endpointsWhitelist.empty(), _endpointsWhitelistRegex,
!_endpointsBlacklist.empty(), _endpointsBlacklistRegex);

auto urlResult = checkBlackAndWhitelist(url, !_endpointsWhitelist.empty(), _endpointsWhitelistRegex,
!_endpointsBlacklist.empty(), _endpointsBlacklistRegex);

return endpointResult.result || ( urlResult.result && !endpointResult.black);
}

bool V8SecurityFeature::isAllowedToAccessPath(v8::Isolate* isolate, std::string const& path,
Expand Down Expand Up @@ -478,5 +493,5 @@ bool V8SecurityFeature::isAllowedToAccessPath(v8::Isolate* isolate, char const*
}

return checkBlackAndWhitelist(path, !_filesWhitelist.empty(), _filesWhitelistRegex,
false, _filesWhitelistRegex /*passed to match the signature but not used*/);
false, _filesWhitelistRegex /*passed to match the signature but not used*/).result;
}
2 changes: 1 addition & 1 deletion lib/ApplicationFeatures/V8SecurityFeature.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class V8SecurityFeature final : public application_features::ApplicationFeature
/// accessible via the JS_Download (internal.download) function in JavaScript
/// actions the endpoint is passed in via protocol (e.g. tcp://, ssl://,
/// unix://) and port number (if applicable)
bool isAllowedToConnectToEndpoint(v8::Isolate* isolate, std::string const& endpoint) const;
bool isAllowedToConnectToEndpoint(v8::Isolate* isolate, std::string const& endpoint, std::string const& url) const;

/// @brief tests if the path (or path component) shall be accessible for the
/// calling JavaScript code
Expand Down
204 changes: 109 additions & 95 deletions lib/V8/v8-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,105 @@ static std::string GetEndpointFromUrl(std::string const& url) {
/// `process-utils.js` depends on simple http client error messages.
/// this needs to be adjusted if this is ever changed!
////////////////////////////////////////////////////////////////////////////////
namespace {

auto getEndpoint(v8::Isolate* isolate, std::vector<std::string> const& endpoints,
std::string& url, std::string& lastEndpoint)
-> std::tuple<std::string, std::string, std::string> {
// returns endpoint, relative, error
std::string relative;
std::string endpoint;
if (url.substr(0, 7) == "http://") {
endpoint = GetEndpointFromUrl(url).substr(7);
relative = url.substr(7 + endpoint.length());

if (relative.empty() || relative[0] != '/') {
relative = "/" + relative;
}
if (endpoint.find(':') == std::string::npos) {
endpoint.append(":80");
}
endpoint = "tcp://" + endpoint;
} else if (url.substr(0, 8) == "https://") {
endpoint = GetEndpointFromUrl(url).substr(8);
relative = url.substr(8 + endpoint.length());

if (relative.empty() || relative[0] != '/') {
relative = "/" + relative;
}
if (endpoint.find(':') == std::string::npos) {
endpoint.append(":443");
}
endpoint = "ssl://" + endpoint;
} else if (url.substr(0, 5) == "h2://") {
endpoint = GetEndpointFromUrl(url).substr(5);
relative = url.substr(5 + endpoint.length());

if (relative.empty() || relative[0] != '/') {
relative = "/" + relative;
}
if (endpoint.find(':') == std::string::npos) {
endpoint.append(":80");
}
endpoint = "tcp://" + endpoint;
} else if (url.substr(0, 6) == "h2s://") {
endpoint = GetEndpointFromUrl(url).substr(6);
relative = url.substr(6 + endpoint.length());

if (relative.empty() || relative[0] != '/') {
relative = "/" + relative;
}
if (endpoint.find(':') == std::string::npos) {
endpoint.append(":443");
}
endpoint = "ssl://" + endpoint;
} else if (url.substr(0, 6) == "srv://") {
size_t found = url.find('/', 6);

relative = "/";
if (found != std::string::npos) {
relative.append(url.substr(found + 1));
endpoint = url.substr(6, found - 6);
} else {
endpoint = url.substr(6);
}
endpoint = "srv://" + endpoint;
} else if (url.substr(0, 7) == "unix://") {
// Can only have arrived here if endpoints is non empty
if (endpoints.empty()) {
return {"", "", std::move("unsupported URL specified")};
}
endpoint = endpoints 292D .front();
relative = url.substr(endpoint.size());
} else if (!url.empty() && url[0] == '/') {
size_t found;
// relative URL. prefix it with last endpoint
relative = url;
url = lastEndpoint + url;
endpoint = lastEndpoint;
if (endpoint.substr(0, 5) == "http:") {
endpoint = endpoint.substr(5);
found = endpoint.find(":");
if (found == std::string::npos) {
endpoint = endpoint + ":80";
}
endpoint = "tcp:" + endpoint;
} else if (endpoint.substr(0, 6) == "https:") {
endpoint = endpoint.substr(6);
found = endpoint.find(":");
if (found == std::string::npos) {
endpoint = endpoint + ":443";
}
endpoint = "ssl:" + endpoint;
}
} else {
std::string msg("unsupported URL specified: '");
msg.append(url).append("'");
return {"", "", std::move(msg)};
}
return {std::move(endpoint), std::move(relative), ""};
}
} // namespace

void JS_Download(v8::FunctionCallbackInfo<v8::Value> const& args) {
TRI_V8_TRY_CATCH_BEGIN(isolate);
Expand All @@ -679,7 +778,8 @@ void JS_Download(v8::FunctionCallbackInfo<v8::Value> const& args) {
TRI_V8_THROW_EXCEPTION_USAGE(signature);
}

std::string url = TRI_ObjectToString(isolate, args[0]);
std::string const inputUrl = TRI_ObjectToString(isolate, args[0]);
std::string url = inputUrl;
std::vector<std::string> endpoints;

bool isLocalUrl = false;
Expand Down Expand Up @@ -892,104 +992,18 @@ void JS_Download(v8::FunctionCallbackInfo<v8::Value> const& args) {
int numRedirects = 0;

while (numRedirects < maxRedirects) {
std::string endpoint;
std::string relative;

if (url.substr(0, 7) == "http://") {
endpoint = GetEndpointFromUrl(url).substr(7);
relative = url.substr(7 + endpoint.length());

if (relative.empty() || relative[0] != '/') {
relative = "/" + relative;
}
if (endpoint.find(':') == std::string::npos) {
endpoint.append(":80");
}
endpoint = "tcp://" + endpoint;
} else if (url.substr(0, 8) == "https://") {
endpoint = GetEndpointFromUrl(url).substr(8);
relative = url.substr(8 + endpoint.length());

if (relative.empty() || relative[0] != '/') {
relative = "/" + relative;
}
if (endpoint.find(':') == std::string::npos) {
endpoint.append(":443");
}
endpoint = "ssl://" + endpoint;
} else if (url.substr(0, 5) == "h2://") {
endpoint = GetEndpointFromUrl(url).substr(5);
relative = url.substr(5 + endpoint.length());

if (relative.empty() || relative[0] != '/') {
relative = "/" + relative;
}
if (endpoint.find(':') == std::string::npos) {
endpoint.append(":80");
}
endpoint = "tcp://" + endpoint;
} else if (url.substr(0, 6) == "h2s://") {
endpoint = GetEndpointFromUrl(url).substr(6);
relative = url.substr(6 + endpoint.length());

if (relative.empty() || relative[0] != '/') {
relative = "/" + relative;
}
if (endpoint.find(':') == std::string::npos) {
endpoint.append(":80");
}
endpoint = "tcp://" + endpoint;
} else if (url.substr(0, 6) == "srv://") {
size_t found = url.find('/', 6);

relative = "/";
if (found != std::string::npos) {
relative.append(url.substr(found + 1));
endpoint = url.substr(6, found - 6);
} else {
endpoint = url.substr(6);
}
endpoint = "srv://" + endpoint;
} else if (url.substr(0, 7) == "unix://") {
// Can only have arrived here if endpoints is non empty
if (endpoints.empty()) {
TRI_V8_THROW_SYNTAX_ERROR("unsupported URL specified");
}
endpoint = endpoints.front();
relative = url.substr(endpoint.size());
} else if (!url.empty() && url[0] == '/') {
size_t found;
// relative URL. prefix it with last endpoint
relative = url;
url = lastEndpoint + url;
endpoint = lastEndpoint;
if (endpoint.substr(0, 5) == "http:") {
endpoint = endpoint.substr(5);
found = endpoint.find(":");
if (found == std::string::npos) {
endpoint = endpoint + ":80";
}
endpoint = "tcp:" + endpoint;
} else if (endpoint.substr(0, 6) == "https:") {
endpoint = endpoint.substr(6);
found = endpoint.find(":");
if (found == std::string::npos) {
endpoint = endpoint + ":443";
}
endpoint = "ssl:" + endpoint;
}
} else {
std::string msg("unsupported URL specified: '");
msg.append(url).append("'");
TRI_V8_THROW_ERROR(msg.c_str());
auto [endpoint, relative, error] = getEndpoint(isolate, endpoints, url, lastEndpoint);
if(!error.empty()) {
TRI_V8_THROW_SYNTAX_ERROR(error.c_str());
}

LOG_TOPIC("d6bdb", TRACE, arangodb::Logger::FIXME)
<< "downloading file. endpoint: " << endpoint << ", relative URL: " << url;
<< "downloading file. endpoint: " << endpoint << ", relative URL: " << url;

if (!isLocalUrl && !v8security.isAllowedToConnectToEndpoint(isolate, endpoint)) {
if (!isLocalUrl && !v8security.isAllowedToConnectToEndpoint(isolate, endpoint, inputUrl)) {
TRI_V8_THROW_EXCEPTION_MESSAGE(TRI_ERROR_FORBIDDEN,
"not allowed to connect to this endpoint");
"not allowed to connect to this URL: " + inputUrl);
}

std::unique_ptr<Endpoint> ep(Endpoint::clientFactory(endpoint));
Expand All @@ -1000,8 +1014,8 @@ void JS_Download(v8::FunctionCallbackInfo<v8::Value> const& args) {
}

std::unique_ptr<GeneralClientConnection> connection(
GeneralClientConnection::factory(v8g->_server, ep.get(), timeout,
timeout, 3, sslProtocol));
GeneralClientConnection::factory(v8g->_server, ep.get(), timeout,
timeout, 3, sslProtocol));

if (connection == nullptr) {
TRI_V8_THROW_EXCEPTION_MEMORY();
Expand Down
0