Skip to content

Commit a5391b6

Browse files
robotmrvjgrandja
authored andcommitted
Fix NPE in RequestContextSubscriber
RequestContextSubscriber could cause NPE if Mono/Flux.subscribe() was invoked outside of Web Context. In addition it replaced source Context with its own without respect to old data. Now Request Context Data is Propagated within holder class and it is added to existing reactor Context if Holder is not empty. Fixes gh-7228
1 parent 57f3c76 commit a5391b6

File tree

2 files changed

+167
-32
lines changed

2 files changed

+167
-32
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java

Lines changed: 77 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.springframework.http.HttpHeaders;
2323
import org.springframework.http.HttpMethod;
2424
import org.springframework.http.MediaType;
25+
import org.springframework.lang.Nullable;
2526
import org.springframework.security.core.Authentication;
2627
import org.springframework.security.core.GrantedAuthority;
2728
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
@@ -103,6 +104,7 @@
103104
* </ul>
104105
*
105106
* @author Rob Winch
107+
* @author Roman Matiushchenko
106108
* @since 5.1
107109
*/
108110
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
@@ -146,7 +148,7 @@ public ServletOAuth2AuthorizedClientExchangeFilterFunction(
146148

147149
@Override
148150
public void afterPropertiesSet() throws Exception {
149-
Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.lift((s, sub) -> createRequestContextSubscriber(sub)));
151+
Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.liftPublisher((s, sub) -> createRequestContextSubscriberIfNecessary(sub)));
150152
}
151153

152154
@Override
@@ -319,14 +321,22 @@ private Mono<ClientRequest> mergeRequestAttributesFromContext(ClientRequest requ
319321
}
320322

321323
private void populateRequestAttributes(Map<String, Object> attrs, Context ctx) {
322-
if (ctx.hasKey(HTTP_SERVLET_REQUEST_ATTR_NAME)) {
323-
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ctx.get(HTTP_SERVLET_REQUEST_ATTR_NAME));
324-
}
325-
if (ctx.hasKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
326-
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ctx.get(HTTP_SERVLET_RESPONSE_ATTR_NAME));
327-
}
328-
if (ctx.hasKey(AUTHENTICATION_ATTR_NAME)) {
329-
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, ctx.get(AUTHENTICATION_ATTR_NAME));
324+
RequestContextDataHolder holder = RequestContextSubscriber.getRequestContext(ctx);
325+
if (holder != null) {
326+
HttpServletRequest request = holder.getRequest();
327+
if (request != null) {
328+
attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, request);
329+
}
330+
331+
HttpServletResponse response = holder.getResponse();
332+
if (response != null) {
333+
attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, response);
334+
}
335+
336+
Authentication authentication = holder.getAuthentication();
337+
if (authentication != null) {
338+
attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication);
339+
}
330340
}
331341
populateDefaultOAuth2AuthorizedClient(attrs);
332342
}
@@ -488,7 +498,7 @@ private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient autho
488498
.build();
489499
}
490500

491-
private <T> CoreSubscriber<T> createRequestContextSubscriber(CoreSubscriber<T> delegate) {
501+
<T> CoreSubscriber<T> createRequestContextSubscriberIfNecessary(CoreSubscriber<T> delegate) {
492502
HttpServletRequest request = null;
493503
HttpServletResponse response = null;
494504
ServletRequestAttributes requestAttributes =
@@ -498,6 +508,10 @@ private <T> CoreSubscriber<T> createRequestContextSubscriber(CoreSubscriber<T> d
498508
response = requestAttributes.getResponse();
499509
}
500510
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
511+
if (authentication == null && request == null && response == null) {
512+
//do not need to create RequestContextSubscriber with empty data
513+
return delegate;
514+
}
501515
return new RequestContextSubscriber<>(delegate, request, response, authentication);
502516
}
503517

@@ -575,34 +589,37 @@ private UnsupportedOperationException unsupported() {
575589
}
576590
}
577591

578-
private static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
579-
private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME");
592+
static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
593+
static final String REQUEST_CONTEXT_DATA_HOLDER =
594+
RequestContextSubscriber.class.getName().concat(".REQUEST_CONTEXT_DATA_HOLDER");
580595
private final CoreSubscriber<T> delegate;
581-
private final HttpServletRequest request;
582-
private final HttpServletResponse response;
583-
private final Authentication authentication;
596+
private final Context context;
584597

585-
private RequestContextSubscriber(CoreSubscriber<T> delegate,
586-
HttpServletRequest request,
587-
HttpServletResponse response,
588-
Authentication authentication) {
598+
RequestContextSubscriber(CoreSubscriber<T> delegate,
599+
HttpServletRequest request,
600+
HttpServletResponse response,
601+
Authentication authentication) {
589602
this.delegate = delegate;
590-
this.request = request;
591-
this.response = response;
592-
this.authentication = authentication;
603+
604+
Context parentContext = this.delegate.currentContext();
605+
Context context;
606+
if (parentContext.hasKey(REQUEST_CONTEXT_DATA_HOLDER)) {
607+
context = parentContext;
608+
} else {
609+
context = parentContext.put(REQUEST_CONTEXT_DATA_HOLDER, new RequestContextDataHolder(request, response, authentication));
610+
}
611+
612+
this.context = context;
613+
}
614+
615+
@Nullable
616+
private static RequestContextDataHolder getRequestContext(Context ctx) {
617+
return ctx.getOrDefault(REQUEST_CONTEXT_DATA_HOLDER, null);
593618
}
594619

595620
@Override
596621
public Context currentContext() {
597-
Context context = this.delegate.currentContext();
598-
if (context.hasKey(CONTEXT_DEFAULTED_ATTR_NAME)) {
599-
return context;
600-
}
601-
return Context.of(
602-
CONTEXT_DEFAULTED_ATTR_NAME, Boolean.TRUE,
603-
HTTP_SERVLET_REQUEST_ATTR_NAME, this.request,
604-
HTTP_SERVLET_RESPONSE_ATTR_NAME, this.response,
605-
AUTHENTICATION_ATTR_NAME, this.authentication);
622+
return this.context;
606623
}
607624

608625
@Override
@@ -625,4 +642,33 @@ public void onComplete() {
625642
this.delegate.onComplete();
626643
}
627644
}
645+
646+
static class RequestContextDataHolder {
647+
private final HttpServletRequest request;
648+
private final HttpServletResponse response;
649+
private final Authentication authentication;
650+
651+
RequestContextDataHolder(@Nullable HttpServletRequest request,
652+
@Nullable HttpServletResponse response,
653+
@Nullable Authentication authentication) {
654+
this.request = request;
655+
this.response = response;
656+
this.authentication = authentication;
657+
}
658+
659+
@Nullable
660+
private HttpServletRequest getRequest() {
661+
return this.request;
662+
}
663+
664+
@Nullable
665+
private HttpServletResponse getResponse() {
666+
return this.response;
667+
}
668+
669+
@Nullable
670+
private Authentication getAuthentication() {
671+
return this.authentication;
672+
}
673+
}
628674
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@
6161
import org.springframework.web.reactive.function.BodyInserter;
6262
import org.springframework.web.reactive.function.client.ClientRequest;
6363
import org.springframework.web.reactive.function.client.WebClient;
64+
import reactor.core.CoreSubscriber;
65+
import reactor.core.publisher.BaseSubscriber;
6466
import reactor.core.publisher.Mono;
67+
import reactor.util.context.Context;
6568

6669
import java.net.URI;
6770
import java.time.Duration;
@@ -74,6 +77,7 @@
7477
import java.util.function.Consumer;
7578

7679
import static org.assertj.core.api.Assertions.assertThat;
80+
import static org.assertj.core.api.Assertions.assertThatCode;
7781
import static org.mockito.ArgumentMatchers.any;
7882
import static org.mockito.ArgumentMatchers.eq;
7983
import static org.mockito.Mockito.*;
@@ -124,9 +128,10 @@ public void setup() {
124128
}
125129

126130
@After
127-
public void cleanup() {
131+
public void cleanup() throws Exception {
128132
SecurityContextHolder.clearContext();
129133
RequestContextHolder.resetRequestAttributes();
134+
this.function.destroy();
130135
}
131136

132137
@Test
@@ -636,6 +641,90 @@ public void filterWhenRequestAttributesNotSetAndHooksInitHooksResetThenDefaultsN
636641
assertThat(getBody(request)).isEmpty();
637642
}
638643

644+
// gh-7228
645+
@Test
646+
public void afterPropertiesSetWhenHooksInitAndOutsideWebSecurityContextThenShouldNotThrowException() throws Exception {
647+
this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized
648+
assertThatCode(() -> Mono.subscriberContext().block())
649+
.as("RequestContext Hook brakes application outside of web/security context")
650+
.doesNotThrowAnyException();
651+
}
652+
653+
@Test
654+
public void createRequestContextSubscriberIfNecessaryWhenOutsideWebSecurityContextThenReturnOriginalSubscriber() throws Exception {
655+
BaseSubscriber<Object> originalSubscriber = new BaseSubscriber<Object>() {};
656+
CoreSubscriber<Object> resultSubscriber = this.function.createRequestContextSubscriberIfNecessary(originalSubscriber);
657+
assertThat(resultSubscriber).isSameAs(originalSubscriber);
658+
}
659+
660+
// gh-7228
661+
@Test
662+
public void createRequestContextSubscriberWhenRequestResponseProvidedThenCreateWithParentContext() throws Exception {
663+
testRequestContextSubscriber(new MockHttpServletRequest(), new MockHttpServletResponse(), null);
664+
}
665+
666+
// gh-7228
667+
@Test
668+
public void createRequestContextSubscriberWhenAuthenticationProvidedThenCreateWithParentContext() throws Exception {
669+
testRequestContextSubscriber(null, null, this.authentication);
670+
}
671+
672+
@Test
673+
public void createRequestContextSubscriberWhenParentContextHasDataHolderThenShouldReuseParentContext() throws Exception {
674+
RequestContextDataHolder testValue = new RequestContextDataHolder(null, null, null);
675+
final Context parentContext = Context.of(RequestContextSubscriber.REQUEST_CONTEXT_DATA_HOLDER, testValue);
676+
BaseSubscriber<Object> parent = new BaseSubscriber<Object>() {
677+
@Override
678+
public Context currentContext() {
679+
return parentContext;
680+
}
681+
};
682+
683+
RequestContextSubscriber<Object> requestContextSubscriber =
684+
new RequestContextSubscriber<>(parent, null, null, authentication);
685+
686+
Context resultContext = requestContextSubscriber.currentContext();
687+
688+
assertThat(resultContext)
689+
.describedAs("parent context was replaced")
690+
.isSameAs(parentContext);
691+
}
692+
693+
private void testRequestContextSubscriber(MockHttpServletRequest servletRequest,
694+
MockHttpServletResponse servletResponse,
695+
Authentication authentication) {
696+
String testKey = "test_key";
697+
String testValue = "test_value";
698+
699+
BaseSubscriber<Object> parent = new BaseSubscriber<Object>() {
700+
@Override
701+
public Context currentContext() {
702+
return Context.of(testKey, testValue);
703+
}
704+
};
705+
706+
RequestContextSubscriber<Object> requestContextSubscriber =
707+
new RequestContextSubscriber<>(parent, servletRequest, servletResponse, authentication);
708+
709+
Context resultContext = requestContextSubscriber.currentContext();
710+
711+
assertThat(resultContext)
712+
.describedAs("result context is null")
713+
.isNotNull();
714+
715+
assertThat(resultContext.getOrEmpty(testKey))
716+
.describedAs("context is replaced")
717+
.hasValue(testValue);
718+
719+
Object dataHolder = resultContext.getOrDefault(RequestContextSubscriber.REQUEST_CONTEXT_DATA_HOLDER, null);
720+
assertThat(dataHolder)
721+
.describedAs("context is not populated with REQUEST_CONTEXT_DATA_HOLDER")
722+
.isNotNull()
723+
.hasFieldOrPropertyWithValue("request", servletRequest)
724+
.hasFieldOrPropertyWithValue("response", servletResponse)
725+
.hasFieldOrPropertyWithValue("authentication", authentication);
726+
}
727+
639728
private static String getBody(ClientRequest request) {
640729
final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
641730
messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));

0 commit comments

Comments
 (0)