Skip to content

Commit d68648c

Browse files
authored
Remove Origin header when forwarding (#3357)
This prevents forwarded requests, such as those from circuit breaker fallbacks, from failing in CORS checks, which require a fully populated scheme and host. Fixes gh-3350
1 parent 66aa480 commit d68648c

File tree

4 files changed

+44
-3
lines changed

4 files changed

+44
-3
lines changed

spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/support/ServerWebExchangeUtils.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,11 @@ public static Mono<Void> handle(DispatcherHandler handler, ServerWebExchange exc
431431
// remove attributes that may disrupt the forwarded request
432432
exchange.getAttributes().remove(GATEWAY_PREDICATE_PATH_CONTAINER_ATTR);
433433

434+
// CORS check is applied to the original request, but should not be applied to
435+
// internally forwarded requests.
436+
// See https://github.com/spring-cloud/spring-cloud-gateway/issues/3350.
437+
exchange = exchange.mutate().request(request -> request.headers(headers -> headers.setOrigin(null))).build();
438+
434439
return handler.handle(exchange);
435440
}
436441

spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/ForwardRoutingFilterTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.springframework.web.util.UriComponentsBuilder;
3838

3939
import static org.assertj.core.api.Assertions.assertThat;
40+
import static org.mockito.ArgumentMatchers.assertArg;
4041
import static org.mockito.Mockito.verify;
4142
import static org.mockito.Mockito.verifyNoInteractions;
4243
import static org.mockito.Mockito.verifyNoMoreInteractions;
@@ -93,9 +94,8 @@ public void shouldFilterWhenGatewayRequestUrlSchemeIsForward() {
9394
forwardRoutingFilter.filter(exchange, chain);
9495

9596
verifyNoMoreInteractions(chain);
96-
verify(dispatcherHandler).handle(exchange);
97-
98-
assertThat(exchange.getAttributes().get(GATEWAY_ALREADY_ROUTED_ATTR)).isNull();
97+
verify(dispatcherHandler).handle(
98+
assertArg(exchange -> assertThat(exchange.getAttributes().get(GATEWAY_ALREADY_ROUTED_ATTR)).isNull()));
9999
}
100100

101101
@Test

spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/factory/SpringCloudCircuitBreakerFilterFactoryTests.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,13 @@ public void filterFallbackForward() {
106106
.isOk().expectBody().json("{\"from\":\"circuitbreakerfallbackcontroller3\"}");
107107
}
108108

109+
@Test
110+
public void filterFallbackForwardWithCORS() {
111+
testClient.get().uri("/delay/3?a=b").header("Host", "www.circuitbreakerforward.org")
112+
.header("Origin", "https://cors.withcircuitbreaker.org").exchange().expectStatus().isOk().expectBody()
113+
.json("{\"from\":\"circuitbreakerfallbackcontroller3\"}");
114+
}
115+
109116
@Test
110117
public void filterStatusCodeFallback() {
111118
testClient.get().uri("/status/500").header("Host", "www.circuitbreakerstatuscode.org").exchange().expectStatus()

spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/support/ServerWebExchangeUtilsTests.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,27 @@
2323

2424
import org.assertj.core.api.Assertions;
2525
import org.junit.jupiter.api.Test;
26+
import org.mockito.Mockito;
27+
import reactor.core.publisher.Mono;
2628

2729
import org.springframework.core.io.buffer.DataBuffer;
2830
import org.springframework.core.io.buffer.DefaultDataBuffer;
31+
import org.springframework.http.HttpHeaders;
2932
import org.springframework.http.HttpMethod;
3033
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
3134
import org.springframework.mock.web.server.MockServerWebExchange;
35+
import org.springframework.web.reactive.DispatcherHandler;
3236
import org.springframework.web.reactive.function.server.HandlerStrategies;
3337
import org.springframework.web.reactive.function.server.ServerRequest;
38+
import org.springframework.web.server.ServerWebExchange;
3439

3540
import static org.assertj.core.api.Assertions.assertThat;
41+
import static org.mockito.ArgumentMatchers.any;
42+
import static org.mockito.ArgumentMatchers.assertArg;
3643
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.CACHED_REQUEST_BODY_ATTR;
44+
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_PREDICATE_PATH_CONTAINER_ATTR;
3745
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.expand;
46+
import static org.springframework.http.server.PathContainer.parsePath;
3847

3948
public class ServerWebExchangeUtilsTests {
4049

@@ -94,6 +103,26 @@ public void duplicatedCachingDataBufferHandling() {
94103
Assertions.assertThat(dataBufferBeforeCaching).isEqualTo(dataBufferAfterCached);
95104
}
96105

106+
@Test
107+
public void forwardedRequestsHaveDisruptiveAttributesAndHeadersRemoved() {
108+
DispatcherHandler handler = Mockito.mock(DispatcherHandler.class);
109+
Mockito.when(handler.handle(any(ServerWebExchange.class))).thenReturn(Mono.empty());
110+
111+
ServerWebExchange originalExchange = mockExchange(Map.of()).mutate()
112+
.request(request -> request.headers(headers -> headers.setOrigin("https://example.com"))).build();
113+
originalExchange.getAttributes().put(GATEWAY_PREDICATE_PATH_CONTAINER_ATTR, parsePath("/example/path"));
114+
115+
ServerWebExchangeUtils.handle(handler, originalExchange).block();
116+
117+
Mockito.verify(handler).handle(assertArg(exchange -> {
118+
Assertions.assertThat(exchange.getAttributes()).as("exchange attributes")
119+
.doesNotContainKey(GATEWAY_PREDICATE_PATH_CONTAINER_ATTR);
120+
121+
Assertions.assertThat(exchange.getRequest().getHeaders()).as("request headers")
122+
.doesNotContainKey(HttpHeaders.ORIGIN);
123+
}));
124+
}
125+
97126
private MockServerWebExchange mockExchange(Map<String, String> vars) {
98127
return mockExchange(HttpMethod.GET, vars);
99128
}

0 commit comments

Comments
 (0)