diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java index 913a8f1211e..2b429174b74 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java @@ -87,6 +87,7 @@ import org.springframework.security.web.authentication.session.SessionAuthenticationException; import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter; +import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.savedrequest.RequestCache; import org.springframework.security.web.util.matcher.AndRequestMatcher; @@ -176,6 +177,8 @@ public final class OAuth2LoginConfigurer> private OAuth2AuthorizedClientRepository authorizedClientRepository; + private SecurityContextRepository securityContextRepository; + /** * Sets the repository of client registrations. * @param clientRegistrationRepository the repository of client registrations @@ -229,6 +232,12 @@ public OAuth2LoginConfigurer loginProcessingUrl(String loginProcessingUrl) { return this; } + @Override + public OAuth2LoginConfigurer securityContextRepository(SecurityContextRepository securityContextRepository) { + this.securityContextRepository = securityContextRepository; + return this; + } + /** * Sets the registry for managing the OIDC client-provider session link * @param oidcSessionRegistry the {@link OidcSessionRegistry} to use @@ -347,6 +356,9 @@ public void init(B http) throws Exception { OAuth2LoginAuthenticationFilter authenticationFilter = new OAuth2LoginAuthenticationFilter( this.getClientRegistrationRepository(), this.getAuthorizedClientRepository(), this.loginProcessingUrl); authenticationFilter.setSecurityContextHolderStrategy(getSecurityContextHolderStrategy()); + if (this.securityContextRepository != null) { + authenticationFilter.setSecurityContextRepository(this.securityContextRepository); + } this.setAuthenticationFilter(authenticationFilter); super.loginProcessingUrl(this.loginProcessingUrl); if (this.loginPage != null) { diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java index dfe6fea28fd..85de4dbd78a 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java @@ -101,6 +101,7 @@ import org.springframework.security.web.authentication.HttpStatusEntryPoint; import org.springframework.security.web.context.HttpRequestResponseHolder; import org.springframework.security.web.context.HttpSessionSecurityContextRepository; +import org.springframework.security.web.context.NullSecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.session.HttpSessionDestroyedEvent; import org.springframework.security.web.util.matcher.RequestHeaderRequestMatcher; @@ -110,6 +111,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatNoException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.BDDMockito.given; @@ -696,6 +698,12 @@ public void oidcLoginWhenOAuth2ClientBeansConfiguredThenNotShared() throws Excep verifyNoInteractions(clientRegistrationRepository, authorizedClientRepository); } + // gh-16623 + @Test + public void oauth2LoginConfigSecurityContextRepository() { + assertThatNoException().isThrownBy(() -> loadConfig(OAuth2LoginConfigSecurityContextRepository.class)); + } + private void loadConfig(Class... configs) { AnnotationConfigWebApplicationContext applicationContext = new AnnotationConfigWebApplicationContext(); applicationContext.register(configs); @@ -944,6 +952,24 @@ SecurityFilterChain filterChain(HttpSecurity http) throws Exception { } + @Configuration + @EnableWebSecurity + static class OAuth2LoginConfigSecurityContextRepository extends CommonSecurityFilterChainConfig { + + @Bean + SecurityFilterChain filterChain(HttpSecurity http) throws Exception { + // @formatter:off + http + .oauth2Login((login) -> login + .clientRegistrationRepository( + new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION)) + .securityContextRepository(new NullSecurityContextRepository())); + // @formatter:on + return super.configureFilterChain(http); + } + + } + @Configuration @EnableWebSecurity static class OAuth2LoginConfigCustomAuthorizationRequestResolver extends CommonSecurityFilterChainConfig {