Skip to content

Commit 142a9a9

Browse files
authored
[oidc proxy] Move definitions of functions to source files (#8724)
1 parent 3047938 commit 142a9a9

27 files changed

+965
-747
lines changed

ydb/mvp/oidc_proxy/bin/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
int main(int argc, char **argv) {
44
try {
5-
return NMVP::TMVP(argc, argv).Run();
5+
return NMVP::NOIDC::TMVP(argc, argv).Run();
66
} catch (const yexception& e) {
77
Cerr << "Caught exception: " << e.what() << Endl;
88
return 1;

ydb/mvp/oidc_proxy/mvp.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626
#include "mvp.h"
2727
#include "oidc_client.h"
2828

29-
using namespace NMVP;
29+
NActors::IActor* CreateMemProfiler();
30+
31+
namespace NMVP {
32+
namespace NOIDC {
3033

3134
namespace {
3235

@@ -42,7 +45,7 @@ TString AddSchemeToUserToken(const TString& token, const TString& scheme) {
4245
const ui16 TMVP::DefaultHttpPort = 8788;
4346
const ui16 TMVP::DefaultHttpsPort = 8789;
4447

45-
const TString& NMVP::GetEServiceName(NActors::NLog::EComponent component) {
48+
const TString& GetEServiceName(NActors::NLog::EComponent component) {
4649
static const TString loggerName("LOGGER");
4750
static const TString mvpName("MVP");
4851
static const TString grpcName("GRPC");
@@ -66,8 +69,6 @@ void TMVP::OnTerminate(int) {
6669
AtomicSet(Quit, true);
6770
}
6871

69-
NActors::IActor* CreateMemProfiler();
70-
7172
int TMVP::Init() {
7273
ActorSystem.Start();
7374

@@ -415,3 +416,6 @@ THolder<NActors::TActorSystemSetup> TMVP::BuildActorSystemSetup(int argc, char**
415416
}
416417

417418
TAtomic TMVP::Quit = false;
419+
420+
} // NOIDC
421+
} // NMVP

ydb/mvp/oidc_proxy/mvp.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
#include <library/cpp/deprecated/atomic/atomic.h>
1111
#include <util/system/rwlock.h>
1212
#include <contrib/libs/yaml-cpp/include/yaml-cpp/yaml.h>
13-
#include "openid_connect.h"
13+
#include "oidc_settings.h"
1414

1515
namespace NMVP {
16+
namespace NOIDC {
1617

1718
const TString& GetEServiceName(NActors::NLog::EComponent component);
1819

@@ -71,4 +72,5 @@ class TMVP {
7172
int Shutdown();
7273
};
7374

75+
} // namespace NOIDC
7476
} // namespace NMVP

ydb/mvp/oidc_proxy/oidc_client.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,24 @@
22
#include "oidc_protected_page_handler.h"
33
#include "oidc_session_create_handler.h"
44

5+
namespace NMVP {
6+
namespace NOIDC {
7+
58
void InitOIDC(NActors::TActorSystem& actorSystem,
69
const NActors::TActorId& httpProxyId,
710
const TOpenIdConnectSettings& settings) {
811
actorSystem.Send(httpProxyId, new NHttp::TEvHttpProxy::TEvRegisterHandler(
912
"/auth/callback",
10-
actorSystem.Register(new NMVP::TSessionCreateHandler(httpProxyId, settings))
13+
actorSystem.Register(new TSessionCreateHandler(httpProxyId, settings))
1114
)
1215
);
1316

1417
actorSystem.Send(httpProxyId, new NHttp::TEvHttpProxy::TEvRegisterHandler(
1518
"/",
16-
actorSystem.Register(new NMVP::TProtectedPageHandler(httpProxyId, settings))
19+
actorSystem.Register(new TProtectedPageHandler(httpProxyId, settings))
1720
)
1821
);
1922
}
23+
24+
} // NOIDC
25+
} // NMVP

ydb/mvp/oidc_proxy/oidc_client.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
#pragma once
2+
namespace NActors {
23

3-
#include <ydb/mvp/core/core_ydb.h>
4-
#include "openid_connect.h"
4+
class TActorSystem;
5+
struct TActorId;
6+
7+
} // NActors
8+
namespace NMVP {
9+
namespace NOIDC {
10+
11+
struct TOpenIdConnectSettings;
512

613
void InitOIDC(NActors::TActorSystem& actorSystem, const NActors::TActorId& httpProxyId, const TOpenIdConnectSettings& settings);
14+
15+
} // NOIDC
16+
} // NMVP
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
#include <ydb/library/actors/core/actor.h>
2+
#include <ydb/library/actors/http/http.h>
3+
#include <ydb/mvp/core/mvp_log.h>
4+
#include <ydb/core/util/wildcard.h>
5+
#include "openid_connect.h"
6+
#include "oidc_protected_page.h"
7+
8+
namespace NMVP {
9+
namespace NOIDC {
10+
11+
THandlerSessionServiceCheck::THandlerSessionServiceCheck(const NActors::TActorId& sender,
12+
const NHttp::THttpIncomingRequestPtr& request,
13+
const NActors::TActorId& httpProxyId,
14+
const TOpenIdConnectSettings& settings)
15+
: Sender(sender)
16+
, Request(request)
17+
, HttpProxyId(httpProxyId)
18+
, Settings(settings)
19+
, ProtectedPageUrl(Request->URL.SubStr(1))
20+
{}
21+
22+
void THandlerSessionServiceCheck::Bootstrap(const NActors::TActorContext& ctx) {
23+
if (!CheckRequestedHost()) {
24+
ctx.Send(Sender, new NHttp::TEvHttpProxy::TEvHttpOutgoingResponse(CreateResponseForbiddenHost()));
25+
Die(ctx);
26+
return;
27+
}
28+
NHttp::THeaders headers(Request->Headers);
29+
IsAjaxRequest = DetectAjaxRequest(headers);
30+
TStringBuf authHeader = headers.Get(AUTH_HEADER_NAME);
31+
if (Request->Method == "OPTIONS" || IsAuthorizedRequest(authHeader)) {
32+
ForwardUserRequest(TString(authHeader), ctx);
33+
} else {
34+
StartOidcProcess(ctx);
35+
}
36+
}
37+
38+
void THandlerSessionServiceCheck::HandleProxy(NHttp::TEvHttpProxy::TEvHttpIncomingResponse::TPtr event, const NActors::TActorContext& ctx) {
39+
NHttp::THttpOutgoingResponsePtr httpResponse;
40+
if (event->Get()->Response != nullptr) {
41+
NHttp::THttpIncomingResponsePtr response = event->Get()->Response;
42+
LOG_DEBUG_S(ctx, EService::MVP, "Incoming response for protected resource: " << response->Status);
43+
if (NeedSendSecureHttpRequest(response)) {
44+
SendSecureHttpRequest(response, ctx);
45+
return;
46+
}
47+
NHttp::THeadersBuilder headers = GetResponseHeaders(response);
48+
TStringBuf contentType = headers.Get("Content-Type").NextTok(';');
49+
if (contentType == "text/html") {
50+
TString newBody = FixReferenceInHtml(response->Body, response->GetRequest()->Host);
51+
httpResponse = Request->CreateResponse( response->Status, response->Message, headers, newBody);
52+
} else {
53+
httpResponse = Request->CreateResponse( response->Status, response->Message, headers, response->Body);
54+
}
55+
} else {
56+
static constexpr size_t MAX_LOGGED_SIZE = 1024;
57+
LOG_DEBUG_S(ctx, EService::MVP, "Can not process request to protected resource:\n" << event->Get()->Request->GetRawData().substr(0, MAX_LOGGED_SIZE));
58+
httpResponse = CreateResponseForNotExistingResponseFromProtectedResource(event->Get()->GetError());
59+
}
60+
ctx.Send(Sender, new NHttp::TEvHttpProxy::TEvHttpOutgoingResponse(httpResponse));
61+
Die(ctx);
62+
}
63+
64+
bool THandlerSessionServiceCheck::CheckRequestedHost() {
65+
size_t pos = ProtectedPageUrl.find('/');
66+
if (pos == TString::npos) {
67+
return false;
68+
}
69+
TStringBuf scheme, host, uri;
70+
if (!NHttp::CrackURL(ProtectedPageUrl, scheme, host, uri)) {
71+
return false;
72+
}
73+
if (!scheme.empty() && (scheme != "http" && scheme != "https")) {
74+
return false;
75+
}
76+
RequestedPageScheme = scheme;
77+
auto it = std::find_if(Settings.AllowedProxyHosts.cbegin(), Settings.AllowedProxyHosts.cend(), [&host] (const TString& wildcard) {
78+
return NKikimr::IsMatchesWildcard(host, wildcard);
79+
});
80+
return it != Settings.AllowedProxyHosts.cend();
81+
}
82+
83+
bool THandlerSessionServiceCheck::IsAuthorizedRequest(TStringBuf authHeader) {
84+
if (authHeader.empty()) {
85+
return false;
86+
}
87+
return to_lower(ToString(authHeader)).StartsWith(IAM_TOKEN_SCHEME_LOWER);
88+
}
89+
90+
void THandlerSessionServiceCheck::ForwardUserRequest(TStringBuf authHeader, const NActors::TActorContext& ctx, bool secure) {
91+
LOG_DEBUG_S(ctx, EService::MVP, "Forward user request bypass OIDC");
92+
NHttp::THttpOutgoingRequestPtr httpRequest = NHttp::THttpOutgoingRequest::CreateRequest(Request->Method, ProtectedPageUrl);
93+
ForwardRequestHeaders(httpRequest);
94+
if (!authHeader.empty()) {
95+
httpRequest->Set(AUTH_HEADER_NAME, authHeader);
96+
}
97+
if (Request->HaveBody()) {
98+
httpRequest->SetBody(Request->Body);
99+
}
100+
if (RequestedPageScheme.empty()) {
101+
httpRequest->Secure = secure;
102+
}
103+
ctx.Send(HttpProxyId, new NHttp::TEvHttpProxy::TEvHttpOutgoingRequest(httpRequest));
104+
}
105+
106+
TString THandlerSessionServiceCheck::FixReferenceInHtml(TStringBuf html, TStringBuf host, TStringBuf findStr) {
107+
TStringBuilder result;
108+
size_t n = html.find(findStr);
109+
if (n == TStringBuf::npos) {
110+
return TString(html);
111+
}
112+
size_t len = findStr.length() + 1;
113+
size_t pos = 0;
114+
while (n != TStringBuf::npos) {
115+
result << html.SubStr(pos, n + len - pos);
116+
if (html[n + len] == '/') {
117+
result << "/" << host;
118+
if (html[n + len + 1] == '\'' || html[n + len + 1] == '\"') {
119+
result << "/internal";
120+
n++;
121+
}
122+
}
123+
pos = n + len;
124+
n = html.find(findStr, pos);
125+
}
126+
result << html.SubStr(pos);
127+
return result;
128+
}
129+
130+
TString THandlerSessionServiceCheck::FixReferenceInHtml(TStringBuf html, TStringBuf host) {
131+
TStringBuf findString = "href=";
132+
auto result = FixReferenceInHtml(html, host, findString);
133+
findString = "src=";
134+
return FixReferenceInHtml(result, host, findString);
135+
}
136+
137+
void THandlerSessionServiceCheck::ForwardRequestHeaders(NHttp::THttpOutgoingRequestPtr& request) const {
138+
static const TVector<TStringBuf> HEADERS_WHITE_LIST = {
139+
"Connection",
140+
"Accept-Language",
141+
"Cache-Control",
142+
"Sec-Fetch-Dest",
143+
"Sec-Fetch-Mode",
144+
"Sec-Fetch-Site",
145+
"Sec-Fetch-User",
146+
"Upgrade-Insecure-Requests",
147+
"Content-Type",
148+
"Origin"
149+
};
150+
NHttp::THeadersBuilder headers(Request->Headers);
151+
for (const auto& header : HEADERS_WHITE_LIST) {
152+
if (headers.Has(header)) {
153+
request->Set(header, headers.Get(header));
154+
}
155+
}
156+
request->Set("Accept-Encoding", "deflate");
157+
}
158+
159+
NHttp::THeadersBuilder THandlerSessionServiceCheck::GetResponseHeaders(const NHttp::THttpIncomingResponsePtr& response) {
160+
static const TVector<TStringBuf> HEADERS_WHITE_LIST = {
161+
"Content-Type",
162+
"Connection",
163+
"X-Worker-Name",
164+
"Set-Cookie",
165+
"Access-Control-Allow-Origin",
166+
"Access-Control-Allow-Credentials",
167+
"Access-Control-Allow-Headers",
168+
"Access-Control-Allow-Methods"
169+
};
170+
NHttp::THeadersBuilder headers(response->Headers);
171+
NHttp::THeadersBuilder resultHeaders;
172+
for (const auto& header : HEADERS_WHITE_LIST) {
173+
if (headers.Has(header)) {
174+
resultHeaders.Set(header, headers.Get(header));
175+
}
176+
}
177+
static const TString LOCATION_HEADER_NAME = "Location";
178+
if (headers.Has(LOCATION_HEADER_NAME)) {
179+
resultHeaders.Set(LOCATION_HEADER_NAME, GetFixedLocationHeader(headers.Get(LOCATION_HEADER_NAME)));
180+
}
181+
return resultHeaders;
182+
}
183+
184+
void THandlerSessionServiceCheck::SendSecureHttpRequest(const NHttp::THttpIncomingResponsePtr& response, const NActors::TActorContext& ctx) {
185+
NHttp::THttpOutgoingRequestPtr request = response->GetRequest();
186+
LOG_DEBUG_S(ctx, EService::MVP, "Try to send request to HTTPS port");
187+
NHttp::THeadersBuilder headers {request->Headers};
188+
ForwardUserRequest(headers.Get(AUTH_HEADER_NAME), ctx, true);
189+
}
190+
191+
TString THandlerSessionServiceCheck::GetFixedLocationHeader(TStringBuf location) {
192+
TStringBuf scheme, host, uri;
193+
NHttp::CrackURL(ProtectedPageUrl, scheme, host, uri);
194+
if (location.StartsWith("//")) {
195+
return TStringBuilder() << '/' << (scheme.empty() ? "" : TString(scheme) + "://") << location.SubStr(2);
196+
} else if (location.StartsWith('/')) {
197+
return TStringBuilder() << '/'
198+
<< (scheme.empty() ? "" : TString(scheme) + "://")
199+
<< host << location;
200+
} else {
201+
TStringBuf locScheme, locHost, locUri;
202+
NHttp::CrackURL(location, locScheme, locHost, locUri);
203+
if (!locScheme.empty()) {
204+
return TStringBuilder() << '/' << location;
205+
}
206+
}
207+
return TString(location);
208+
}
209+
210+
NHttp::THttpOutgoingResponsePtr THandlerSessionServiceCheck::CreateResponseForbiddenHost() {
211+
NHttp::THeadersBuilder headers;
212+
headers.Set("Content-Type", "text/html");
213+
SetCORS(Request, &headers);
214+
215+
TStringBuf scheme, host, uri;
216+
NHttp::CrackURL(ProtectedPageUrl, scheme, host, uri);
217+
TStringBuilder html;
218+
html << "<html><head><title>403 Forbidden</title></head><body bgcolor=\"white\"><center><h1>";
219+
html << "403 Forbidden host: " << host;
220+
html << "</h1></center></body></html>";
221+
222+
return Request->CreateResponse("403", "Forbidden", headers, html);
223+
}
224+
225+
NHttp::THttpOutgoingResponsePtr THandlerSessionServiceCheck::CreateResponseForNotExistingResponseFromProtectedResource(const TString& errorMessage) {
226+
NHttp::THeadersBuilder headers;
227+
headers.Set("Content-Type", "text/html");
228+
SetCORS(Request, &headers);
229+
230+
TStringBuilder html;
231+
html << "<html><head><title>400 Bad Request</title></head><body bgcolor=\"white\"><center><h1>";
232+
html << "400 Bad Request. Can not process request to protected resource: " << errorMessage;
233+
html << "</h1></center></body></html>";
234+
return Request->CreateResponse("400", "Bad Request", headers, html);
235+
}
236+
237+
} // NOIDC
238+
} // NMVP

0 commit comments

Comments
 (0)