18
18
19
19
import java .lang .reflect .Field ;
20
20
import java .lang .reflect .Modifier ;
21
+ import java .util .LinkedHashMap ;
21
22
import java .util .List ;
23
+ import java .util .Map ;
22
24
23
25
import javax .servlet .DispatcherType ;
26
+ import javax .servlet .Servlet ;
27
+ import javax .servlet .ServletContext ;
28
+ import javax .servlet .ServletRegistration ;
24
29
30
+ import org .jetbrains .annotations .NotNull ;
25
31
import org .junit .jupiter .api .BeforeEach ;
26
32
import org .junit .jupiter .api .Test ;
27
33
34
40
import org .springframework .security .web .util .matcher .DispatcherTypeRequestMatcher ;
35
41
import org .springframework .security .web .util .matcher .RegexRequestMatcher ;
36
42
import org .springframework .security .web .util .matcher .RequestMatcher ;
43
+ import org .springframework .web .context .WebApplicationContext ;
44
+ import org .springframework .web .servlet .DispatcherServlet ;
37
45
38
46
import static org .assertj .core .api .Assertions .assertThat ;
39
47
import static org .assertj .core .api .Assertions .assertThatExceptionOfType ;
@@ -56,12 +64,17 @@ public <O> O postProcess(O object) {
56
64
57
65
private TestRequestMatcherRegistry matcherRegistry ;
58
66
67
+ private WebApplicationContext context ;
68
+
59
69
@ BeforeEach
60
70
public void setUp () {
61
71
this .matcherRegistry = new TestRequestMatcherRegistry ();
62
- ApplicationContext context = mock (ApplicationContext .class );
63
- given (context .getBean (ObjectPostProcessor .class )).willReturn (NO_OP_OBJECT_POST_PROCESSOR );
64
- this .matcherRegistry .setApplicationContext (context );
72
+ this .context = mock (WebApplicationContext .class );
73
+ ServletContext servletContext = new MockServletContext ();
74
+ servletContext .addServlet ("dispatcherServlet" , DispatcherServlet .class );
75
+ given (this .context .getBean (ObjectPostProcessor .class )).willReturn (NO_OP_OBJECT_POST_PROCESSOR );
76
+ given (this .context .getServletContext ()).willReturn (servletContext );
77
+ this .matcherRegistry .setApplicationContext (this .context );
65
78
}
66
79
67
80
@ Test
@@ -184,6 +197,32 @@ public void requestMatchersWhenMvcPresentInClassPathAndMvcIntrospectorBeanNotAva
184
197
"Please ensure Spring Security & Spring MVC are configured in a shared ApplicationContext" );
185
198
}
186
199
200
+ @ Test
201
+ public void requestMatchersWhenNoDispatcherServletThenAntPathRequestMatcherType () {
202
+ MockServletContext servletContext = new MockServletContext ();
203
+ given (this .context .getServletContext ()).willReturn (servletContext );
204
+ List <RequestMatcher > requestMatchers = this .matcherRegistry .requestMatchers ("/**" );
205
+ assertThat (requestMatchers ).isNotEmpty ();
206
+ assertThat (requestMatchers ).hasSize (1 );
207
+ assertThat (requestMatchers .get (0 )).isExactlyInstanceOf (AntPathRequestMatcher .class );
208
+ servletContext .addServlet ("servletOne" , Servlet .class );
209
+ servletContext .addServlet ("servletTwo" , Servlet .class );
210
+ requestMatchers = this .matcherRegistry .requestMatchers ("/**" );
211
+ assertThat (requestMatchers ).isNotEmpty ();
212
+ assertThat (requestMatchers ).hasSize (1 );
213
+ assertThat (requestMatchers .get (0 )).isExactlyInstanceOf (AntPathRequestMatcher .class );
214
+ }
215
+
216
+ @ Test
217
+ public void requestMatchersWhenAmbiguousServletsThenException () {
218
+ MockServletContext servletContext = new MockServletContext ();
219
+ given (this .context .getServletContext ()).willReturn (servletContext );
220
+ servletContext .addServlet ("dispatcherServlet" , DispatcherServlet .class );
221
+ servletContext .addServlet ("servletTwo" , Servlet .class );
222
+ assertThatExceptionOfType (IllegalArgumentException .class )
223
+ .isThrownBy (() -> this .matcherRegistry .requestMatchers ("/**" ));
224
+ }
225
+
187
226
private void mockMvcIntrospector (boolean isPresent ) {
188
227
ApplicationContext context = this .matcherRegistry .getApplicationContext ();
189
228
given (context .containsBean ("mvcHandlerMappingIntrospector" )).willReturn (isPresent );
@@ -217,4 +256,25 @@ protected List<RequestMatcher> chainRequestMatchers(List<RequestMatcher> request
217
256
218
257
}
219
258
259
+ private static class MockServletContext extends org .springframework .mock .web .MockServletContext {
260
+
261
+ private final Map <String , ServletRegistration > registrations = new LinkedHashMap <>();
262
+
263
+ @ NotNull
264
+ @ Override
265
+ public ServletRegistration .Dynamic addServlet (@ NotNull String servletName , Class <? extends Servlet > clazz ) {
266
+ ServletRegistration .Dynamic dynamic = mock (ServletRegistration .Dynamic .class );
267
+ given (dynamic .getClassName ()).willReturn (clazz .getName ());
268
+ this .registrations .put (servletName , dynamic );
269
+ return dynamic ;
270
+ }
271
+
272
+ @ NotNull
273
+ @ Override
274
+ public Map <String , ? extends ServletRegistration > getServletRegistrations () {
275
+ return this .registrations ;
276
+ }
277
+
278
+ }
279
+
220
280
}
0 commit comments