Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 588aa95

Browse files
fix: handle preflight requests (#2175)
* fix: handle options preflight * fix: better handler * fix: better comparision for two url paths --------- Co-authored-by: sangjanai <sang@jan.ai>
1 parent b0bf02b commit 588aa95

File tree

3 files changed

+158
-32
lines changed

3 files changed

+158
-32
lines changed

engine/main.cc

Lines changed: 83 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ void RunServer(std::optional<std::string> host, std::optional<int> port,
6464
bool ignore_cout) {
6565
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__))
6666
auto signal_handler = +[](int sig) -> void {
67-
std::cout << "\rCaught interrupt signal:" << sig << ", shutting down\n";;
67+
std::cout << "\rCaught interrupt signal:" << sig << ", shutting down\n";
6868
shutdown_signal = true;
6969
};
7070
signal(SIGINT, signal_handler);
@@ -288,54 +288,105 @@ void RunServer(std::optional<std::string> host, std::optional<int> port,
288288
return false;
289289
};
290290

291+
auto handle_cors = [config_service](const drogon::HttpRequestPtr& req,
292+
const drogon::HttpResponsePtr& resp) {
293+
const std::string& origin = req->getHeader("Origin");
294+
CTL_INF("Origin: " << origin);
295+
296+
auto allowed_origins =
297+
config_service->GetApiServerConfiguration()->allowed_origins;
298+
299+
auto is_contains_asterisk =
300+
std::find(allowed_origins.begin(), allowed_origins.end(), "*");
301+
if (is_contains_asterisk != allowed_origins.end()) {
302+
resp->addHeader("Access-Control-Allow-Origin", "*");
303+
resp->addHeader("Access-Control-Allow-Methods", "*");
304+
return;
305+
}
306+
307+
// Check if the origin is in our allowed list
308+
auto it = std::find(allowed_origins.begin(), allowed_origins.end(), origin);
309+
if (it != allowed_origins.end()) {
310+
resp->addHeader("Access-Control-Allow-Origin", origin);
311+
} else if (allowed_origins.empty()) {
312+
resp->addHeader("Access-Control-Allow-Origin", "*");
313+
}
314+
resp->addHeader("Access-Control-Allow-Methods", "*");
315+
};
316+
291317
drogon::app().registerPreRoutingAdvice(
292-
[&validate_api_key](
318+
[&validate_api_key, &handle_cors](
293319
const drogon::HttpRequestPtr& req,
294-
std::function<void(const drogon::HttpResponsePtr&)>&& cb,
295-
drogon::AdviceChainCallback&& ccb) {
320+
std::function<void(const drogon::HttpResponsePtr&)>&& stop,
321+
drogon::AdviceChainCallback&& pass) {
322+
// Handle OPTIONS preflight requests
323+
if (req->method() == drogon::HttpMethod::Options) {
324+
auto resp = HttpResponse::newHttpResponse();
325+
auto handlers = drogon::app().getHandlersInfo();
326+
bool has_ep = [req, &handlers]() {
327+
for (auto const& h : handlers) {
328+
if (string_utils::AreUrlPathsEqual(req->path(), std::get<0>(h)))
329+
return true;
330+
}
331+
return false;
332+
}();
333+
if (!has_ep) {
334+
resp->setStatusCode(drogon::HttpStatusCode::k404NotFound);
335+
stop(resp);
336+
return;
337+
}
338+
339+
handle_cors(req, resp);
340+
std::string supported_methods = [req, &handlers]() {
341+
std::string methods;
342+
for (auto const& h : handlers) {
343+
if (string_utils::AreUrlPathsEqual(req->path(), std::get<0>(h))) {
344+
auto m = drogon::to_string_view(std::get<1>(h));
345+
if (methods.find(m) == std::string::npos) {
346+
methods += drogon::to_string_view(std::get<1>(h));
347+
methods += ", ";
348+
}
349+
}
350+
}
351+
if (methods.size() < 2)
352+
return std::string();
353+
return methods.substr(0, methods.size() - 2);
354+
}();
355+
356+
// Add more info to header
357+
resp->addHeader("Access-Control-Allow-Methods", supported_methods);
358+
{
359+
const auto& val = req->getHeader("Access-Control-Request-Headers");
360+
if (!val.empty())
361+
resp->addHeader("Access-Control-Allow-Headers", val);
362+
}
363+
// Set Access-Control-Max-Age
364+
resp->addHeader("Access-Control-Max-Age",
365+
"600"); // Cache for 10 minutes
366+
stop(resp);
367+
return;
368+
}
369+
296370
if (!validate_api_key(req)) {
297371
Json::Value ret;
298372
ret["message"] = "Invalid API Key";
299373
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
300374
resp->setStatusCode(drogon::k401Unauthorized);
301-
cb(resp);
375+
stop(resp);
302376
return;
303377
}
304-
ccb();
378+
pass();
305379
});
306380

307381
// CORS
308382
drogon::app().registerPostHandlingAdvice(
309-
[config_service](const drogon::HttpRequestPtr& req,
310-
const drogon::HttpResponsePtr& resp) {
383+
[config_service, &handle_cors](const drogon::HttpRequestPtr& req,
384+
const drogon::HttpResponsePtr& resp) {
311385
if (!config_service->GetApiServerConfiguration()->cors) {
312386
CTL_INF("CORS is disabled!");
313387
return;
314388
}
315-
316-
const std::string& origin = req->getHeader("Origin");
317-
CTL_INF("Origin: " << origin);
318-
319-
auto allowed_origins =
320-
config_service->GetApiServerConfiguration()->allowed_origins;
321-
322-
auto is_contains_asterisk =
323-
std::find(allowed_origins.begin(), allowed_origins.end(), "*");
324-
if (is_contains_asterisk != allowed_origins.end()) {
325-
resp->addHeader("Access-Control-Allow-Origin", "*");
326-
resp->addHeader("Access-Control-Allow-Methods", "*");
327-
return;
328-
}
329-
330-
// Check if the origin is in our allowed list
331-
auto it =
332-
std::find(allowed_origins.begin(), allowed_origins.end(), origin);
333-
if (it != allowed_origins.end()) {
334-
resp->addHeader("Access-Control-Allow-Origin", origin);
335-
} else if (allowed_origins.empty()) {
336-
resp->addHeader("Access-Control-Allow-Origin", "*");
337-
}
338-
resp->addHeader("Access-Control-Allow-Methods", "*");
389+
handle_cors(req, resp);
339390
});
340391

341392
// ssl

engine/test/components/test_string_utils.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,3 +289,50 @@ TEST_F(StringUtilsTestSuite, LargeInputPerformance) {
289289
}
290290

291291

292+
TEST_F(StringUtilsTestSuite, UrlPaths_SimilarStrings) {
293+
std::string str1 = "/v1/threads/{1}/messages/{2}";
294+
std::string str2 = "/v1/threads/xxx/messages/yyy";
295+
EXPECT_TRUE( AreUrlPathsEqual(str1, str2));
296+
}
297+
298+
TEST_F(StringUtilsTestSuite, UrlPaths_DifferentPaths) {
299+
std::string str1 = "/v1/threads/{1}/messages/{2}";
300+
std::string str2 = "/v1/threads/xxx/messages/yyy/extra";
301+
EXPECT_FALSE(AreUrlPathsEqual(str1, str2));
302+
}
303+
304+
TEST_F(StringUtilsTestSuite, UrlPaths_DifferentPlaceholderCounts) {
305+
std::string str1 = "/v1/threads/{1}/messages/{2}";
306+
std::string str2 = "/v1/threads/{1}/messages/{2}/{3}";
307+
EXPECT_FALSE(AreUrlPathsEqual(str1, str2));
308+
}
309+
310+
TEST_F(StringUtilsTestSuite, UrlPaths_NoPlaceholders) {
311+
std::string str1 = "/v1/threads/1/messages/2";
312+
std::string str2 = "/v1/threads/xxx/messages/yyy";
313+
EXPECT_FALSE(AreUrlPathsEqual(str1, str2));
314+
}
315+
316+
TEST_F(StringUtilsTestSuite, UrlPaths_EmptyStrings) {
317+
std::string str1 = "";
318+
std::string str2 = "";
319+
EXPECT_TRUE(AreUrlPathsEqual(str1, str2));
320+
}
321+
322+
TEST_F(StringUtilsTestSuite, UrlPaths_SinglePlaceholder) {
323+
std::string str1 = "/v1/threads/{1}";
324+
std::string str2 = "/v1/threads/xxx";
325+
EXPECT_TRUE(AreUrlPathsEqual(str1, str2));
326+
}
327+
328+
TEST_F(StringUtilsTestSuite, UrlPaths_MultiplePlaceholdersSameFormat) {
329+
std::string str1 = "/v1/threads/{1}/messages/{2}/comments/{3}";
330+
std::string str2 = "/v1/threads/xxx/messages/yyy/comments/zzz";
331+
EXPECT_TRUE(AreUrlPathsEqual(str1, str2));
332+
}
333+
334+
TEST_F(StringUtilsTestSuite, UrlPaths_NonPlaceholderDifferences) {
335+
std::string str1 = "/v1/threads/{1}/messages/{2}";
336+
std::string str2 = "/v2/threads/xxx/messages/yyy";
337+
EXPECT_FALSE(AreUrlPathsEqual(str1, str2));
338+
}

engine/utils/string_utils.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <cctype>
55
#include <chrono>
66
#include <iomanip>
7+
#include <regex>
78
#include <sstream>
89
#include <string>
910
#include <vector>
@@ -200,4 +201,31 @@ inline std::string EscapeJson(const std::string& s) {
200201
}
201202
return o.str();
202203
}
204+
205+
// Add a method to compares two url paths
206+
inline bool AreUrlPathsEqual(const std::string& path1,
207+
const std::string& path2) {
208+
auto has_placeholder = [](const std::string& s) {
209+
if (s.empty())
210+
return false;
211+
return s.find_first_of('{') < s.find_last_of('}');
212+
};
213+
std::vector<std::string> parts1 = SplitBy(path1, "/");
214+
std::vector<std::string> parts2 = SplitBy(path2, "/");
215+
216+
// Check if both strings have the same number of parts
217+
if (parts1.size() != parts2.size()) {
218+
return false;
219+
}
220+
221+
for (size_t i = 0; i < parts1.size(); ++i) {
222+
if (has_placeholder(parts1[i]) || has_placeholder(parts2[i]))
223+
continue;
224+
if (parts1[i] != parts2[i]) {
225+
return false;
226+
}
227+
}
228+
229+
return true;
230+
}
203231
} // namespace string_utils

0 commit comments

Comments
 (0)