22
22
import org .springframework .http .HttpHeaders ;
23
23
import org .springframework .http .HttpMethod ;
24
24
import org .springframework .http .MediaType ;
25
+ import org .springframework .lang .Nullable ;
25
26
import org .springframework .security .core .Authentication ;
26
27
import org .springframework .security .core .GrantedAuthority ;
27
28
import org .springframework .security .core .context .ReactiveSecurityContextHolder ;
103
104
* </ul>
104
105
*
105
106
* @author Rob Winch
107
+ * @author Roman Matiushchenko
106
108
* @since 5.1
107
109
*/
108
110
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
@@ -146,7 +148,7 @@ public ServletOAuth2AuthorizedClientExchangeFilterFunction(
146
148
147
149
@ Override
148
150
public void afterPropertiesSet () throws Exception {
149
- Hooks .onLastOperator (REQUEST_CONTEXT_OPERATOR_KEY , Operators .lift ((s , sub ) -> createRequestContextSubscriber (sub )));
151
+ Hooks .onLastOperator (REQUEST_CONTEXT_OPERATOR_KEY , Operators .liftPublisher ((s , sub ) -> createRequestContextSubscriberIfNecessary (sub )));
150
152
}
151
153
152
154
@ Override
@@ -319,14 +321,22 @@ private Mono<ClientRequest> mergeRequestAttributesFromContext(ClientRequest requ
319
321
}
320
322
321
323
private void populateRequestAttributes (Map <String , Object > attrs , Context ctx ) {
322
- if (ctx .hasKey (HTTP_SERVLET_REQUEST_ATTR_NAME )) {
323
- attrs .putIfAbsent (HTTP_SERVLET_REQUEST_ATTR_NAME , ctx .get (HTTP_SERVLET_REQUEST_ATTR_NAME ));
324
- }
325
- if (ctx .hasKey (HTTP_SERVLET_RESPONSE_ATTR_NAME )) {
326
- attrs .putIfAbsent (HTTP_SERVLET_RESPONSE_ATTR_NAME , ctx .get (HTTP_SERVLET_RESPONSE_ATTR_NAME ));
327
- }
328
- if (ctx .hasKey (AUTHENTICATION_ATTR_NAME )) {
329
- attrs .putIfAbsent (AUTHENTICATION_ATTR_NAME , ctx .get (AUTHENTICATION_ATTR_NAME ));
324
+ RequestContextDataHolder holder = RequestContextSubscriber .getRequestContext (ctx );
325
+ if (holder != null ) {
326
+ HttpServletRequest request = holder .getRequest ();
327
+ if (request != null ) {
328
+ attrs .putIfAbsent (HTTP_SERVLET_REQUEST_ATTR_NAME , request );
329
+ }
330
+
331
+ HttpServletResponse response = holder .getResponse ();
332
+ if (response != null ) {
333
+ attrs .putIfAbsent (HTTP_SERVLET_RESPONSE_ATTR_NAME , response );
334
+ }
335
+
336
+ Authentication authentication = holder .getAuthentication ();
337
+ if (authentication != null ) {
338
+ attrs .putIfAbsent (AUTHENTICATION_ATTR_NAME , authentication );
339
+ }
330
340
}
331
341
populateDefaultOAuth2AuthorizedClient (attrs );
332
342
}
@@ -488,7 +498,7 @@ private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient autho
488
498
.build ();
489
499
}
490
500
491
- private <T > CoreSubscriber <T > createRequestContextSubscriber (CoreSubscriber <T > delegate ) {
501
+ <T > CoreSubscriber <T > createRequestContextSubscriberIfNecessary (CoreSubscriber <T > delegate ) {
492
502
HttpServletRequest request = null ;
493
503
HttpServletResponse response = null ;
494
504
ServletRequestAttributes requestAttributes =
@@ -498,6 +508,10 @@ private <T> CoreSubscriber<T> createRequestContextSubscriber(CoreSubscriber<T> d
498
508
response = requestAttributes .getResponse ();
499
509
}
500
510
Authentication authentication = SecurityContextHolder .getContext ().getAuthentication ();
511
+ if (authentication == null && request == null && response == null ) {
512
+ //do not need to create RequestContextSubscriber with empty data
513
+ return delegate ;
514
+ }
501
515
return new RequestContextSubscriber <>(delegate , request , response , authentication );
502
516
}
503
517
@@ -575,34 +589,37 @@ private UnsupportedOperationException unsupported() {
575
589
}
576
590
}
577
591
578
- private static class RequestContextSubscriber <T > implements CoreSubscriber <T > {
579
- private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber .class .getName ().concat (".CONTEXT_DEFAULTED_ATTR_NAME" );
592
+ static class RequestContextSubscriber <T > implements CoreSubscriber <T > {
593
+ static final String REQUEST_CONTEXT_DATA_HOLDER =
594
+ RequestContextSubscriber .class .getName ().concat (".REQUEST_CONTEXT_DATA_HOLDER" );
580
595
private final CoreSubscriber <T > delegate ;
581
- private final HttpServletRequest request ;
582
- private final HttpServletResponse response ;
583
- private final Authentication authentication ;
596
+ private final Context context ;
584
597
585
- private RequestContextSubscriber (CoreSubscriber <T > delegate ,
586
- HttpServletRequest request ,
587
- HttpServletResponse response ,
588
- Authentication authentication ) {
598
+ RequestContextSubscriber (CoreSubscriber <T > delegate ,
599
+ HttpServletRequest request ,
600
+ HttpServletResponse response ,
601
+ Authentication authentication ) {
589
602
this .delegate = delegate ;
590
- this .request = request ;
591
- this .response = response ;
592
- this .authentication = authentication ;
603
+
604
+ Context parentContext = this .delegate .currentContext ();
605
+ Context context ;
606
+ if (parentContext .hasKey (REQUEST_CONTEXT_DATA_HOLDER )) {
607
+ context = parentContext ;
608
+ } else {
609
+ context = parentContext .put (REQUEST_CONTEXT_DATA_HOLDER , new RequestContextDataHolder (request , response , authentication ));
610
+ }
611
+
612
+ this .context = context ;
613
+ }
614
+
615
+ @ Nullable
616
+ private static RequestContextDataHolder getRequestContext (Context ctx ) {
617
+ return ctx .getOrDefault (REQUEST_CONTEXT_DATA_HOLDER , null );
593
618
}
594
619
595
620
@ Override
596
621
public Context currentContext () {
597
- Context context = this .delegate .currentContext ();
598
- if (context .hasKey (CONTEXT_DEFAULTED_ATTR_NAME )) {
599
- return context ;
600
- }
601
- return Context .of (
602
- CONTEXT_DEFAULTED_ATTR_NAME , Boolean .TRUE ,
603
- HTTP_SERVLET_REQUEST_ATTR_NAME , this .request ,
604
- HTTP_SERVLET_RESPONSE_ATTR_NAME , this .response ,
605
- AUTHENTICATION_ATTR_NAME , this .authentication );
622
+ return this .context ;
606
623
}
607
624
608
625
@ Override
@@ -625,4 +642,33 @@ public void onComplete() {
625
642
this .delegate .onComplete ();
626
643
}
627
644
}
645
+
646
+ static class RequestContextDataHolder {
647
+ private final HttpServletRequest request ;
648
+ private final HttpServletResponse response ;
649
+ private final Authentication authentication ;
650
+
651
+ RequestContextDataHolder (@ Nullable HttpServletRequest request ,
652
+ @ Nullable HttpServletResponse response ,
653
+ @ Nullable Authentication authentication ) {
654
+ this .request = request ;
655
+ this .response = response ;
656
+ this .authentication = authentication ;
657
+ }
658
+
659
+ @ Nullable
660
+ private HttpServletRequest getRequest () {
661
+ return this .request ;
662
+ }
663
+
664
+ @ Nullable
665
+ private HttpServletResponse getResponse () {
666
+ return this .response ;
667
+ }
668
+
669
+ @ Nullable
670
+ private Authentication getAuthentication () {
671
+ return this .authentication ;
672
+ }
673
+ }
628
674
}
0 commit comments