Skip to content

Commit 7f5c319

Browse files
ugravejzheaux
authored andcommitted
Add relyingPartyRegistrationId to AbstractSaml2AuthenticationRequest
Closes gh-11195
1 parent 8e34b4c commit 7f5c319

File tree

10 files changed

+154
-28
lines changed

10 files changed

+154
-28
lines changed

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixin.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ class Saml2PostAuthenticationRequestMixin {
4747
@JsonCreator
4848
Saml2PostAuthenticationRequestMixin(@JsonProperty("samlRequest") String samlRequest,
4949
@JsonProperty("relayState") String relayState,
50-
@JsonProperty("authenticationRequestUri") String authenticationRequestUri) {
50+
@JsonProperty("authenticationRequestUri") String authenticationRequestUri,
51+
@JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId) {
5152
}
5253

5354
}

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2RedirectAuthenticationRequestMixin.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ class Saml2RedirectAuthenticationRequestMixin {
4848
Saml2RedirectAuthenticationRequestMixin(@JsonProperty("samlRequest") String samlRequest,
4949
@JsonProperty("sigAlg") String sigAlg, @JsonProperty("signature") String signature,
5050
@JsonProperty("relayState") String relayState,
51-
@JsonProperty("authenticationRequestUri") String authenticationRequestUri) {
51+
@JsonProperty("authenticationRequestUri") String authenticationRequestUri,
52+
@JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId) {
5253
}
5354

5455
}

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/AbstractSaml2AuthenticationRequest.java

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.nio.charset.Charset;
2121

2222
import org.springframework.security.core.SpringSecurityCoreVersion;
23+
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
2324
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
2425
import org.springframework.util.Assert;
2526

@@ -46,20 +47,26 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
4647

4748
private final String authenticationRequestUri;
4849

50+
private final String relyingPartyRegistrationId;
51+
4952
/**
5053
* Mandatory constructor for the {@link AbstractSaml2AuthenticationRequest}
5154
* @param samlRequest - the SAMLRequest XML data, SAML encoded, cannot be empty or
5255
* null
5356
* @param relayState - RelayState value that accompanies the request, may be null
5457
* @param authenticationRequestUri - The authenticationRequestUri, a URL, where to
5558
* send the XML message, cannot be empty or null
59+
* @param relyingPartyRegistrationId the registration id of the relying party, may be
60+
* null
5661
*/
57-
AbstractSaml2AuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri) {
62+
AbstractSaml2AuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri,
63+
String relyingPartyRegistrationId) {
5864
Assert.hasText(samlRequest, "samlRequest cannot be null or empty");
5965
Assert.hasText(authenticationRequestUri, "authenticationRequestUri cannot be null or empty");
6066
this.authenticationRequestUri = authenticationRequestUri;
6167
this.samlRequest = samlRequest;
6268
this.relayState = relayState;
69+
this.relyingPartyRegistrationId = relyingPartyRegistrationId;
6370
}
6471

6572
/**
@@ -89,6 +96,16 @@ public String getAuthenticationRequestUri() {
8996
return this.authenticationRequestUri;
9097
}
9198

99+
/**
100+
* The identifier for the {@link RelyingPartyRegistration} associated with this
101+
* request
102+
* @return the {@link RelyingPartyRegistration} id
103+
* @since 5.8
104+
*/
105+
public String getRelyingPartyRegistrationId() {
106+
return this.relyingPartyRegistrationId;
107+
}
108+
92109
/**
93110
* Returns the binding this AuthNRequest will be sent and encoded with. If
94111
* {@link Saml2MessageBinding#REDIRECT} is used, the DEFLATE encoding will be
@@ -108,9 +125,24 @@ public static class Builder<T extends Builder<T>> {
108125

109126
String relayState;
110127

128+
String relyingPartyRegistrationId;
129+
130+
/**
131+
* @deprecated Use {@link #Builder(RelyingPartyRegistration)} instead
132+
*/
133+
@Deprecated
111134
protected Builder() {
112135
}
113136

137+
/**
138+
* Creates a new Builder with relying party registration
139+
* @param registration the registration of the relying party.
140+
* @sine 5.8
141+
*/
142+
protected Builder(RelyingPartyRegistration registration) {
143+
this.relyingPartyRegistrationId = registration.getRegistrationId();
144+
}
145+
114146
/**
115147
* Casting the return as the generic subtype, when returning itself
116148
* @return this object

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@
3030
*/
3131
public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationRequest {
3232

33-
Saml2PostAuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri) {
34-
super(samlRequest, relayState, authenticationRequestUri);
33+
Saml2PostAuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri,
34+
String relyingPartyRegistrationId) {
35+
super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId);
3536
}
3637

3738
/**
@@ -50,23 +51,25 @@ public Saml2MessageBinding getBinding() {
5051
*/
5152
public static Builder withRelyingPartyRegistration(RelyingPartyRegistration registration) {
5253
String location = registration.getAssertingPartyDetails().getSingleSignOnServiceLocation();
53-
return new Builder().authenticationRequestUri(location);
54+
return new Builder(registration).authenticationRequestUri(location);
5455
}
5556

5657
/**
5758
* Builder class for a {@link Saml2PostAuthenticationRequest} object.
5859
*/
5960
public static final class Builder extends AbstractSaml2AuthenticationRequest.Builder<Builder> {
6061

61-
private Builder() {
62+
private Builder(RelyingPartyRegistration registration) {
63+
super(registration);
6264
}
6365

6466
/**
6567
* Constructs an immutable {@link Saml2PostAuthenticationRequest} object.
6668
* @return an immutable {@link Saml2PostAuthenticationRequest} object.
6769
*/
6870
public Saml2PostAuthenticationRequest build() {
69-
return new Saml2PostAuthenticationRequest(this.samlRequest, this.relayState, this.authenticationRequestUri);
71+
return new Saml2PostAuthenticationRequest(this.samlRequest, this.relayState, this.authenticationRequestUri,
72+
this.relyingPartyRegistrationId);
7073
}
7174

7275
}

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
3535
private final String signature;
3636

3737
private Saml2RedirectAuthenticationRequest(String samlRequest, String sigAlg, String signature, String relayState,
38-
String authenticationRequestUri) {
39-
super(samlRequest, relayState, authenticationRequestUri);
38+
String authenticationRequestUri, String relyingPartyRegistrationId) {
39+
super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId);
4040
this.sigAlg = sigAlg;
4141
this.signature = signature;
4242
}
@@ -74,7 +74,7 @@ public Saml2MessageBinding getBinding() {
7474
*/
7575
public static Builder withRelyingPartyRegistration(RelyingPartyRegistration registration) {
7676
String location = registration.getAssertingPartyDetails().getSingleSignOnServiceLocation();
77-
return new Builder().authenticationRequestUri(location);
77+
return new Builder(registration).authenticationRequestUri(location);
7878
}
7979

8080
/**
@@ -86,7 +86,8 @@ public static final class Builder extends AbstractSaml2AuthenticationRequest.Bui
8686

8787
private String signature;
8888

89-
private Builder() {
89+
private Builder(RelyingPartyRegistration registration) {
90+
super(registration);
9091
}
9192

9293
/**
@@ -115,7 +116,7 @@ public Builder signature(String signature) {
115116
*/
116117
public Saml2RedirectAuthenticationRequest build() {
117118
return new Saml2RedirectAuthenticationRequest(this.samlRequest, this.sigAlg, this.signature,
118-
this.relayState, this.authenticationRequestUri);
119+
this.relayState, this.authenticationRequestUri, this.relyingPartyRegistrationId);
119120
}
120121

121122
}

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import org.apache.commons.codec.CodecPolicy;
2727
import org.apache.commons.codec.binary.Base64;
2828

29-
import org.springframework.core.convert.converter.Converter;
3029
import org.springframework.http.HttpMethod;
3130
import org.springframework.security.saml2.core.Saml2Error;
3231
import org.springframework.security.saml2.core.Saml2ErrorCodes;
@@ -50,25 +49,29 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
5049

5150
private static Base64 BASE64 = new Base64(0, new byte[] { '\n' }, false, CodecPolicy.STRICT);
5251

53-
private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver;
52+
private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
5453

5554
private Function<HttpServletRequest, AbstractSaml2AuthenticationRequest> loader;
5655

56+
/**
57+
* Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for
58+
* resolving {@link RelyingPartyRegistration}s
59+
* @param relyingPartyRegistrationResolver the strategy for resolving
60+
* {@link RelyingPartyRegistration}s
61+
*/
5762
public Saml2AuthenticationTokenConverter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
5863
Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
59-
this.relyingPartyRegistrationResolver = adaptToConverter(relyingPartyRegistrationResolver);
64+
this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
6065
this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest;
6166
}
6267

63-
private static Converter<HttpServletRequest, RelyingPartyRegistration> adaptToConverter(
64-
RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
65-
Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
66-
return (request) -> relyingPartyRegistrationResolver.resolve(request, null);
67-
}
68-
6968
@Override
7069
public Saml2AuthenticationToken convert(HttpServletRequest request) {
71-
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.convert(request);
70+
AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request);
71+
String relyingPartyRegistrationId = (authenticationRequest != null)
72+
? authenticationRequest.getRelyingPartyRegistrationId() : null;
73+
RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.resolve(request,
74+
relyingPartyRegistrationId);
7275
if (relyingPartyRegistration == null) {
7376
return null;
7477
}
@@ -78,7 +81,6 @@ public Saml2AuthenticationToken convert(HttpServletRequest request) {
7881
}
7982
byte[] b = samlDecode(saml2Response);
8083
saml2Response = inflateIfRequired(request, b);
81-
AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request);
8284
return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response, authenticationRequest);
8385
}
8486

saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixinTests.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,23 @@ void shouldDeserialize() throws Exception {
5656
assertThat(authRequest.getRelayState()).isEqualTo(TestSaml2JsonPayloads.RELAY_STATE);
5757
assertThat(authRequest.getAuthenticationRequestUri())
5858
.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
59+
assertThat(authRequest.getRelyingPartyRegistrationId())
60+
.isEqualTo(TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID);
61+
}
62+
63+
@Test
64+
void shouldDeserializeWithNoRegistrationId() throws Exception {
65+
String json = TestSaml2JsonPayloads.DEFAULT_POST_AUTH_REQUEST_JSON.replace(
66+
"\"relyingPartyRegistrationId\": \"" + TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID + "\",", "");
67+
68+
Saml2PostAuthenticationRequest authRequest = this.mapper.readValue(json, Saml2PostAuthenticationRequest.class);
69+
70+
assertThat(authRequest).isNotNull();
71+
assertThat(authRequest.getSamlRequest()).isEqualTo(TestSaml2JsonPayloads.SAML_REQUEST);
72+
assertThat(authRequest.getRelayState()).isEqualTo(TestSaml2JsonPayloads.RELAY_STATE);
73+
assertThat(authRequest.getAuthenticationRequestUri())
74+
.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
75+
assertThat(authRequest.getRelyingPartyRegistrationId()).isNull();
5976
}
6077

6178
}

saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/Saml2RedirectAuthenticationRequestMixinTests.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,26 @@ void shouldDeserialize() throws Exception {
5959
.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
6060
assertThat(authRequest.getSigAlg()).isEqualTo(TestSaml2JsonPayloads.SIG_ALG);
6161
assertThat(authRequest.getSignature()).isEqualTo(TestSaml2JsonPayloads.SIGNATURE);
62+
assertThat(authRequest.getRelyingPartyRegistrationId())
63+
.isEqualTo(TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID);
64+
}
65+
66+
@Test
67+
void shouldDeserializeWithNoRegistrationId() throws Exception {
68+
String json = TestSaml2JsonPayloads.DEFAULT_REDIRECT_AUTH_REQUEST_JSON.replace(
69+
"\"relyingPartyRegistrationId\": \"" + TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID + "\",", "");
70+
71+
Saml2RedirectAuthenticationRequest authRequest = this.mapper.readValue(json,
72+
Saml2RedirectAuthenticationRequest.class);
73+
74+
assertThat(authRequest).isNotNull();
75+
assertThat(authRequest.getSamlRequest()).isEqualTo(TestSaml2JsonPayloads.SAML_REQUEST);
76+
assertThat(authRequest.getRelayState()).isEqualTo(TestSaml2JsonPayloads.RELAY_STATE);
77+
assertThat(authRequest.getAuthenticationRequestUri())
78+
.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
79+
assertThat(authRequest.getSigAlg()).isEqualTo(TestSaml2JsonPayloads.SIG_ALG);
80+
assertThat(authRequest.getSignature()).isEqualTo(TestSaml2JsonPayloads.SIGNATURE);
81+
assertThat(authRequest.getRelyingPartyRegistrationId()).isNull();
6282
}
6383

6484
}

saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/TestSaml2JsonPayloads.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ static DefaultSaml2AuthenticatedPrincipal createDefaultPrincipal() {
9494
static final String SAML_REQUEST = "samlRequestValue";
9595
static final String RELAY_STATE = "relayStateValue";
9696
static final String AUTHENTICATION_REQUEST_URI = "authenticationRequestUriValue";
97+
static final String RELYINGPARTY_REGISTRATION_ID = "registrationIdValue";
9798
static final String SIG_ALG = "sigAlgValue";
9899
static final String SIGNATURE = "signatureValue";
99100

@@ -103,6 +104,7 @@ static DefaultSaml2AuthenticatedPrincipal createDefaultPrincipal() {
103104
+ " \"samlRequest\": \"" + SAML_REQUEST + "\","
104105
+ " \"relayState\": \"" + RELAY_STATE + "\","
105106
+ " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\","
107+
+ " \"relyingPartyRegistrationId\": \"" + RELYINGPARTY_REGISTRATION_ID + "\","
106108
+ " \"sigAlg\": \"" + SIG_ALG + "\","
107109
+ " \"signature\": \"" + SIGNATURE + "\""
108110
+ "}";
@@ -113,14 +115,14 @@ static DefaultSaml2AuthenticatedPrincipal createDefaultPrincipal() {
113115
+ " \"@class\": \"org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest\","
114116
+ " \"samlRequest\": \"" + SAML_REQUEST + "\","
115117
+ " \"relayState\": \"" + RELAY_STATE + "\","
118+
+ " \"relyingPartyRegistrationId\": \"" + RELYINGPARTY_REGISTRATION_ID + "\","
116119
+ " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\""
117120
+ "}";
118121
// @formatter:on
119122

120123
static final String ID = "idValue";
121124
static final String LOCATION = "locationValue";
122125
static final String BINDNG = "REDIRECT";
123-
static final String RELYINGPARTY_REGISTRATION_ID = "registrationIdValue";
124126
static final String ADDITIONAL_PARAM = "additionalParamValue";
125127

126128
// @formatter:off
@@ -140,14 +142,17 @@ static DefaultSaml2AuthenticatedPrincipal createDefaultPrincipal() {
140142
// @formatter:on
141143

142144
static Saml2PostAuthenticationRequest createDefaultSaml2PostAuthenticationRequest() {
143-
return Saml2PostAuthenticationRequest.withRelyingPartyRegistration(TestRelyingPartyRegistrations.full()
144-
.assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI))
145-
.build()).samlRequest(SAML_REQUEST).relayState(RELAY_STATE).build();
145+
return Saml2PostAuthenticationRequest.withRelyingPartyRegistration(
146+
TestRelyingPartyRegistrations.full().registrationId(RELYINGPARTY_REGISTRATION_ID)
147+
.assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI))
148+
.build())
149+
.samlRequest(SAML_REQUEST).relayState(RELAY_STATE).build();
146150
}
147151

148152
static Saml2RedirectAuthenticationRequest createDefaultSaml2RedirectAuthenticationRequest() {
149153
return Saml2RedirectAuthenticationRequest
150154
.withRelyingPartyRegistration(TestRelyingPartyRegistrations.full()
155+
.registrationId(RELYINGPARTY_REGISTRATION_ID)
151156
.assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI))
152157
.build())
153158
.samlRequest(SAML_REQUEST).relayState(RELAY_STATE).sigAlg(SIG_ALG).signature(SIGNATURE).build();

0 commit comments

Comments
 (0)