Skip to content

Commit b0b7a28

Browse files
authored
OIDC proxy: Handle response with bad request from protected resource (#8105)
1 parent fd24b0a commit b0b7a28

File tree

4 files changed

+58
-14
lines changed

4 files changed

+58
-14
lines changed

ydb/mvp/oidc_proxy/oidc_protected_page.h

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,9 @@ class THandlerSessionServiceCheck : public NActors::TActorBootstrapped<THandlerS
7373
if (event->Get()->Response != nullptr) {
7474
NHttp::THttpIncomingResponsePtr response = event->Get()->Response;
7575
LOG_DEBUG_S(ctx, EService::MVP, "Incoming response for protected resource: " << response->Status);
76-
if ((response->Status == "400" || response->Status.empty()) && RequestedPageScheme.empty()) {
77-
NHttp::THttpOutgoingRequestPtr request = response->GetRequest();
78-
if (!request->Secure) {
79-
LOG_DEBUG_S(ctx, EService::MVP, "Try to send request to HTTPS port");
80-
NHttp::THeadersBuilder headers {request->Headers};
81-
ForwardUserRequest(headers.Get(AUTH_HEADER_NAME), ctx, true);
82-
return;
83-
}
76+
if (NeedSendSecureHttpRequest(response)) {
77+
SendSecureHttpRequest(response, ctx);
78+
return;
8479
}
8580
NHttp::THeadersBuilder headers = GetResponseHeaders(response);
8681
TStringBuf contentType = headers.Get("Content-Type").NextTok(';');
@@ -117,7 +112,7 @@ class THandlerSessionServiceCheck : public NActors::TActorBootstrapped<THandlerS
117112
return it != Settings.AllowedProxyHosts.cend();
118113
}
119114

120-
bool IsAuthorizedRequest(TStringBuf authHeader) {
115+
static bool IsAuthorizedRequest(TStringBuf authHeader) {
121116
if (authHeader.empty()) {
122117
return false;
123118
}
@@ -142,7 +137,9 @@ class THandlerSessionServiceCheck : public NActors::TActorBootstrapped<THandlerS
142137
ctx.Send(HttpProxyId, new NHttp::TEvHttpProxy::TEvHttpOutgoingRequest(httpRequest));
143138
}
144139

145-
TString FixReferenceInHtml(TStringBuf html, TStringBuf host, TStringBuf findStr) {
140+
virtual bool NeedSendSecureHttpRequest(const NHttp::THttpIncomingResponsePtr& response) const = 0;
141+
142+
static TString FixReferenceInHtml(TStringBuf html, TStringBuf host, TStringBuf findStr) {
146143
TStringBuilder result;
147144
size_t n = html.find(findStr);
148145
if (n == TStringBuf::npos) {
@@ -166,14 +163,14 @@ class THandlerSessionServiceCheck : public NActors::TActorBootstrapped<THandlerS
166163
return result;
167164
}
168165

169-
TString FixReferenceInHtml(TStringBuf html, TStringBuf host) {
166+
static TString FixReferenceInHtml(TStringBuf html, TStringBuf host) {
170167
TStringBuf findString = "href=";
171168
auto result = FixReferenceInHtml(html, host, findString);
172169
findString = "src=";
173170
return FixReferenceInHtml(result, host, findString);
174171
}
175172

176-
void ForwardRequestHeaders(NHttp::THttpOutgoingRequestPtr& request) {
173+
void ForwardRequestHeaders(NHttp::THttpOutgoingRequestPtr& request) const {
177174
static const TVector<TStringBuf> HEADERS_WHITE_LIST = {
178175
"Connection",
179176
"Accept-Language",
@@ -195,7 +192,7 @@ class THandlerSessionServiceCheck : public NActors::TActorBootstrapped<THandlerS
195192
request->Set("Accept-Encoding", "deflate");
196193
}
197194

198-
NHttp::THeadersBuilder GetResponseHeaders(const NHttp::THttpIncomingResponsePtr& response) {
195+
static NHttp::THeadersBuilder GetResponseHeaders(const NHttp::THttpIncomingResponsePtr& response) {
199196
static const TVector<TStringBuf> HEADERS_WHITE_LIST = {
200197
"Content-Type",
201198
"Connection",
@@ -216,6 +213,14 @@ class THandlerSessionServiceCheck : public NActors::TActorBootstrapped<THandlerS
216213
}
217214
return resultHeaders;
218215
}
216+
217+
private:
218+
void SendSecureHttpRequest(const NHttp::THttpIncomingResponsePtr& response, const NActors::TActorContext& ctx) {
219+
NHttp::THttpOutgoingRequestPtr request = response->GetRequest();
220+
LOG_DEBUG_S(ctx, EService::MVP, "Try to send request to HTTPS port");
221+
NHttp::THeadersBuilder headers {request->Headers};
222+
ForwardUserRequest(headers.Get(AUTH_HEADER_NAME), ctx, true);
223+
}
219224
};
220225

221226
} // NMVP

ydb/mvp/oidc_proxy/oidc_protected_page_nebius.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,13 @@ class THandlerSessionServiceCheckNebius : public THandlerSessionServiceCheck {
121121
THandlerSessionServiceCheck::ForwardUserRequest(authHeader, ctx, secure);
122122
Become(&THandlerSessionServiceCheckNebius::StateWork);
123123
}
124+
125+
bool NeedSendSecureHttpRequest(const NHttp::THttpIncomingResponsePtr& response) const override {
126+
if ((response->Status == "400" || response->Status.empty()) && RequestedPageScheme.empty()) {
127+
return !response->GetRequest()->Secure;
128+
}
129+
return false;
130+
}
124131
};
125132

126133
} // NMVP

ydb/mvp/oidc_proxy/oidc_protected_page_yandex.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,20 @@ class THandlerSessionServiceCheckYandex : public THandlerSessionServiceCheck {
8080
meta.Timeout = TDuration::Seconds(10);
8181
connection->DoRequest(request, std::move(responseCb), &yandex::cloud::priv::oauth::v1::SessionService::Stub::AsyncCheck, meta);
8282
}
83+
84+
bool NeedSendSecureHttpRequest(const NHttp::THttpIncomingResponsePtr& response) const override {
85+
if ((response->Status == "400" || response->Status.empty()) && RequestedPageScheme.empty()) {
86+
NHttp::THttpOutgoingRequestPtr request = response->GetRequest();
87+
if (!request->Secure) {
88+
static const TStringBuf bodyContent = "The plain HTTP request was sent to HTTPS port";
89+
NHttp::THeadersBuilder headers(response->Headers);
90+
TStringBuf contentType = headers.Get("Content-Type").NextTok(';');
91+
TStringBuf body = response->Body;
92+
return contentType == "text/html" && body.find(bodyContent) != TStringBuf::npos;
93+
}
94+
}
95+
return false;
96+
}
8397
};
8498

8599
} // NMVP

ydb/mvp/oidc_proxy/oidc_proxy_ut.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,25 @@ Y_UNIT_TEST_SUITE(Mvp) {
231231

232232
auto outgoingResponseEv = runtime.GrabEdgeEvent<NHttp::TEvHttpProxy::TEvHttpOutgoingResponse>(handle);
233233
UNIT_ASSERT_STRINGS_EQUAL(outgoingResponseEv->Response->Status, "200");
234-
UNIT_ASSERT_STRINGS_EQUAL(outgoingResponseEv->Response->Body, "this is test");
234+
UNIT_ASSERT_STRINGS_EQUAL(outgoingResponseEv->Response->Body, okResponseBody);
235+
236+
runtime.Send(new IEventHandle(target, edge, new NHttp::TEvHttpProxy::TEvHttpIncomingRequest(incomingRequest)));
237+
outgoingRequestEv = runtime.GrabEdgeEvent<NHttp::TEvHttpProxy::TEvHttpOutgoingRequest>(handle);
238+
UNIT_ASSERT_STRINGS_EQUAL(outgoingRequestEv->Request->Host, allowedProxyHost);
239+
UNIT_ASSERT_STRINGS_EQUAL(outgoingRequestEv->Request->URL, "/counters");
240+
UNIT_ASSERT_STRING_CONTAINS(outgoingRequestEv->Request->Headers, "Authorization: Bearer protected_page_iam_token");
241+
UNIT_ASSERT_EQUAL(outgoingRequestEv->Request->Secure, false);
242+
incomingResponse = new NHttp::THttpIncomingResponse(outgoingRequestEv->Request);
243+
const TString errorJsonResponseBody {"{\"status\":\"400\", \"message\":\"Table does not exist\"}"};
244+
EatWholeString(incomingResponse, "HTTP/1.1 400 Bad Request\r\n"
245+
"Connection: close\r\n"
246+
"Content-Type: application/json; charset=utf-8\r\n"
247+
"Content-Length: " + ToString(errorJsonResponseBody.size()) + "\r\n\r\n" + errorJsonResponseBody);
248+
runtime.Send(new IEventHandle(handle->Sender, edge, new NHttp::TEvHttpProxy::TEvHttpIncomingResponse(outgoingRequestEv->Request, incomingResponse)));
249+
250+
outgoingResponseEv = runtime.GrabEdgeEvent<NHttp::TEvHttpProxy::TEvHttpOutgoingResponse>(handle);
251+
UNIT_ASSERT_STRINGS_EQUAL(outgoingResponseEv->Response->Status, "400");
252+
UNIT_ASSERT_STRINGS_EQUAL(outgoingResponseEv->Response->Body, errorJsonResponseBody);
235253
}
236254

237255

0 commit comments

Comments
 (0)