1
1
/*
2
- * Copyright 2012-2017 the original author or authors.
2
+ * Copyright 2012-2020 the original author or authors.
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
16
16
17
17
package org .springframework .security .web .firewall ;
18
18
19
- import javax .servlet .http .HttpServletRequest ;
20
- import javax .servlet .http .HttpServletResponse ;
21
19
import java .util .Arrays ;
22
20
import java .util .Collection ;
23
21
import java .util .Collections ;
24
22
import java .util .HashSet ;
25
23
import java .util .List ;
26
24
import java .util .Set ;
25
+ import javax .servlet .http .HttpServletRequest ;
26
+ import javax .servlet .http .HttpServletResponse ;
27
27
28
28
/**
29
29
* <p>
59
59
* Rejects URLs that contain a URL encoded percent. See
60
60
* {@link #setAllowUrlEncodedPercent(boolean)}
61
61
* </li>
62
+ * <li>
63
+ * Rejects hosts that are not allowed. See
64
+ * {@link #setAllowedHostnames(Collection)}
65
+ * </li>
62
66
* </ul>
63
67
*
64
68
* @see DefaultHttpFirewall
65
69
* @author Rob Winch
70
+ * @author Eddú Meléndez
66
71
* @since 4.2.4
67
72
*/
68
73
public class StrictHttpFirewall implements HttpFirewall {
@@ -82,6 +87,8 @@ public class StrictHttpFirewall implements HttpFirewall {
82
87
83
88
private Set <String > decodedUrlBlacklist = new HashSet <String >();
84
89
90
+ private Collection <String > allowedHostnames ;
91
+
85
92
public StrictHttpFirewall () {
86
93
urlBlacklistsAddAll (FORBIDDEN_SEMICOLON );
87
94
urlBlacklistsAddAll (FORBIDDEN_FORWARDSLASH );
@@ -230,6 +237,13 @@ public void setAllowUrlEncodedPercent(boolean allowUrlEncodedPercent) {
230
237
}
231
238
}
232
239
240
+ public void setAllowedHostnames (Collection <String > allowedHostnames ) {
241
+ if (allowedHostnames == null ) {
242
+ throw new IllegalArgumentException ("allowedHostnames cannot be null" );
243
+ }
244
+ this .allowedHostnames = allowedHostnames ;
245
+ }
246
+
233
247
private void urlBlacklistsAddAll (Collection <String > values ) {
234
248
this .encodedUrlBlacklist .addAll (values );
235
249
this .decodedUrlBlacklist .addAll (values );
@@ -243,6 +257,7 @@ private void urlBlacklistsRemoveAll(Collection<String> values) {
243
257
@ Override
244
258
public FirewalledRequest getFirewalledRequest (HttpServletRequest request ) throws RequestRejectedException {
245
259
rejectedBlacklistedUrls (request );
260
+ rejectedUntrustedHosts (request );
246
261
247
262
if (!isNormalized (request )) {
248
263
throw new RequestRejectedException ("The request was rejected because the URL was not normalized." );
@@ -272,6 +287,19 @@ private void rejectedBlacklistedUrls(HttpServletRequest request) {
272
287
}
273
288
}
274
289
290
+ private void rejectedUntrustedHosts (HttpServletRequest request ) {
291
+ String serverName = request .getServerName ();
292
+ if (serverName == null ) {
293
+ return ;
294
+ }
295
+ if (this .allowedHostnames == null ) {
296
+ return ;
297
+ }
298
+ if (!this .allowedHostnames .contains (serverName )) {
299
+ throw new RequestRejectedException ("The request was rejected because the domain " + serverName + " is untrusted." );
300
+ }
301
+ }
302
+
275
303
@ Override
276
304
public HttpServletResponse getFirewalledResponse (HttpServletResponse response ) {
277
305
return new FirewalledResponse (response );
0 commit comments