Skip to content

Commit ffdb397

Browse files
Save the SecurityContext when switching user
Closes gh-12504
1 parent 225dc59 commit ffdb397

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

web/src/main/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilter.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@
5858
import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler;
5959
import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler;
6060
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
61+
import org.springframework.security.web.context.RequestAttributeSecurityContextRepository;
62+
import org.springframework.security.web.context.SecurityContextRepository;
6163
import org.springframework.security.web.util.UrlUtils;
6264
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
6365
import org.springframework.security.web.util.matcher.RequestMatcher;
@@ -142,6 +144,8 @@ public class SwitchUserFilter extends GenericFilterBean implements ApplicationEv
142144

143145
private AuthenticationFailureHandler failureHandler;
144146

147+
private SecurityContextRepository securityContextRepository = new RequestAttributeSecurityContextRepository();
148+
145149
@Override
146150
public void afterPropertiesSet() {
147151
Assert.notNull(this.userDetailsService, "userDetailsService must be specified");
@@ -179,6 +183,7 @@ private void doFilter(HttpServletRequest request, HttpServletResponse response,
179183
context.setAuthentication(targetUser);
180184
SecurityContextHolder.setContext(context);
181185
this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", targetUser));
186+
this.securityContextRepository.saveContext(context, request, response);
182187
// redirect to target url
183188
this.successHandler.onAuthenticationSuccess(request, response, targetUser);
184189
}
@@ -196,6 +201,7 @@ private void doFilter(HttpServletRequest request, HttpServletResponse response,
196201
context.setAuthentication(originalUser);
197202
SecurityContextHolder.setContext(context);
198203
this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", originalUser));
204+
this.securityContextRepository.saveContext(context, request, response);
199205
// redirect to target url
200206
this.successHandler.onAuthenticationSuccess(request, response, originalUser);
201207
return;
@@ -510,6 +516,19 @@ public void setSwitchAuthorityRole(String switchAuthorityRole) {
510516
this.switchAuthorityRole = switchAuthorityRole;
511517
}
512518

519+
/**
520+
* Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
521+
* switch user success. The default is
522+
* {@link RequestAttributeSecurityContextRepository}.
523+
* @param securityContextRepository the {@link SecurityContextRepository} to use.
524+
* Cannot be null.
525+
* @since 5.7.7
526+
*/
527+
public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
528+
Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
529+
this.securityContextRepository = securityContextRepository;
530+
}
531+
513532
private static RequestMatcher createMatcher(String pattern) {
514533
return new AntPathRequestMatcher(pattern, "POST", true, new UrlPathHelper());
515534
}

web/src/test/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilterTests.java

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,18 @@
1616

1717
package org.springframework.security.web.authentication.switchuser;
1818

19+
import java.io.IOException;
1920
import java.util.ArrayList;
2021
import java.util.List;
2122

2223
import javax.servlet.FilterChain;
24+
import javax.servlet.ServletException;
2325

2426
import org.junit.jupiter.api.AfterEach;
2527
import org.junit.jupiter.api.BeforeEach;
2628
import org.junit.jupiter.api.Test;
2729

30+
import org.springframework.mock.web.MockFilterChain;
2831
import org.springframework.mock.web.MockHttpServletRequest;
2932
import org.springframework.mock.web.MockHttpServletResponse;
3033
import org.springframework.security.authentication.AccountExpiredException;
@@ -44,11 +47,15 @@
4447
import org.springframework.security.util.FieldUtils;
4548
import org.springframework.security.web.DefaultRedirectStrategy;
4649
import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler;
50+
import org.springframework.security.web.context.RequestAttributeSecurityContextRepository;
51+
import org.springframework.security.web.context.SecurityContextRepository;
4752
import org.springframework.security.web.util.matcher.AnyRequestMatcher;
53+
import org.springframework.test.util.ReflectionTestUtils;
4854

4955
import static org.assertj.core.api.Assertions.assertThat;
5056
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
5157
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
58+
import static org.mockito.ArgumentMatchers.any;
5259
import static org.mockito.Mockito.mock;
5360
import static org.mockito.Mockito.never;
5461
import static org.mockito.Mockito.verify;
@@ -483,6 +490,59 @@ public void setSwitchFailureUrlWhenValidThenNoException() {
483490
filter.setSwitchFailureUrl("/foo");
484491
}
485492

493+
@Test
494+
void filterWhenDefaultSecurityContextRepositoryThenRequestAttributeRepository() {
495+
SwitchUserFilter switchUserFilter = new SwitchUserFilter();
496+
assertThat(ReflectionTestUtils.getField(switchUserFilter, "securityContextRepository"))
497+
.isInstanceOf(RequestAttributeSecurityContextRepository.class);
498+
}
499+
500+
@Test
501+
void doFilterWhenSwitchUserThenSaveSecurityContext() throws ServletException, IOException {
502+
SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
503+
MockHttpServletRequest request = new MockHttpServletRequest();
504+
MockHttpServletResponse response = new MockHttpServletResponse();
505+
MockFilterChain filterChain = new MockFilterChain();
506+
request.setParameter(SwitchUserFilter.SPRING_SECURITY_SWITCH_USERNAME_KEY, "jacklord");
507+
request.setRequestURI("/login/impersonate");
508+
SwitchUserFilter filter = new SwitchUserFilter();
509+
filter.setSecurityContextRepository(securityContextRepository);
510+
filter.setUserDetailsService(new MockUserDetailsService());
511+
filter.setTargetUrl("/target");
512+
filter.afterPropertiesSet();
513+
514+
filter.doFilter(request, response, filterChain);
515+
516+
verify(securityContextRepository).saveContext(any(), any(), any());
517+
}
518+
519+
@Test
520+
void doFilterWhenExitUserThenSaveSecurityContext() throws ServletException, IOException {
521+
UsernamePasswordAuthenticationToken source = UsernamePasswordAuthenticationToken.authenticated("dano",
522+
"hawaii50", ROLES_12);
523+
// set current user (Admin)
524+
List<GrantedAuthority> adminAuths = new ArrayList<>(ROLES_12);
525+
adminAuths.add(new SwitchUserGrantedAuthority("PREVIOUS_ADMINISTRATOR", source));
526+
UsernamePasswordAuthenticationToken admin = UsernamePasswordAuthenticationToken.authenticated("jacklord",
527+
"hawaii50", adminAuths);
528+
SecurityContextHolder.getContext().setAuthentication(admin);
529+
SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
530+
MockHttpServletRequest request = new MockHttpServletRequest();
531+
MockHttpServletResponse response = new MockHttpServletResponse();
532+
MockFilterChain filterChain = new MockFilterChain();
533+
request.setParameter(SwitchUserFilter.SPRING_SECURITY_SWITCH_USERNAME_KEY, "jacklord");
534+
request.setRequestURI("/logout/impersonate");
535+
SwitchUserFilter filter = new SwitchUserFilter();
536+
filter.setSecurityContextRepository(securityContextRepository);
537+
filter.setUserDetailsService(new MockUserDetailsService());
538+
filter.setTargetUrl("/target");
539+
filter.afterPropertiesSet();
540+
541+
filter.doFilter(request, response, filterChain);
542+
543+
verify(securityContextRepository).saveContext(any(), any(), any());
544+
}
545+
486546
private class MockUserDetailsService implements UserDetailsService {
487547

488548
private String password = "hawaii50";

0 commit comments

Comments
 (0)