Skip to content

Commit 5073dc0

Browse files
authored
Better support for meta annotations by using AnnotatedElementUtils.findMergedAnnotation (#255)
*AnnotatedElementUtils.findMergedAnnotation instead of AnnotationUtils.findAnnotation
1 parent 1d67679 commit 5073dc0

File tree

4 files changed

+151
-93
lines changed

4 files changed

+151
-93
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,34 @@
11
package no.nav.security.token.support.spring.validation.interceptor;
22

3+
import static org.springframework.core.annotation.AnnotatedElementUtils.findMergedAnnotation;
4+
35
import java.lang.annotation.Annotation;
6+
import java.lang.reflect.AnnotatedElement;
47
import java.lang.reflect.Method;
58
import java.util.List;
69
import java.util.Objects;
710
import java.util.Optional;
811

9-
import org.springframework.core.annotation.AnnotationUtils;
10-
1112
import no.nav.security.token.support.core.context.TokenValidationContextHolder;
1213
import no.nav.security.token.support.core.validation.JwtTokenAnnotationHandler;
1314

14-
public class SpringJwtTokenAnnotationHandler extends JwtTokenAnnotationHandler {
15-
15+
public final class SpringJwtTokenAnnotationHandler extends JwtTokenAnnotationHandler {
1616

17-
public SpringJwtTokenAnnotationHandler(TokenValidationContextHolder tokenValidationContextHolder) {
18-
super(tokenValidationContextHolder);
17+
public SpringJwtTokenAnnotationHandler(TokenValidationContextHolder holder) {
18+
super(holder);
1919
}
2020

2121
@Override
22-
protected Annotation getAnnotation(Method method, List<Class<? extends Annotation>> types) {
23-
return Optional.ofNullable(scanAnnotation(method, types))
24-
.orElseGet(() -> scanAnnotation(method.getDeclaringClass(), types));
25-
}
26-
27-
private static Annotation scanAnnotation(Method m, List<Class<? extends Annotation>> types) {
28-
return types.stream()
29-
.map(t -> AnnotationUtils.findAnnotation(m, t))
30-
.filter(Objects::nonNull)
31-
.findFirst()
32-
.orElse(null);
22+
protected Annotation getAnnotation(Method m, List<Class<? extends Annotation>> types) {
23+
return Optional.ofNullable(findAnnotation(m, types))
24+
.orElseGet(() -> findAnnotation(m.getDeclaringClass(), types));
3325
}
3426

35-
private static Annotation scanAnnotation(Class<?> clazz, List<Class<? extends Annotation>> types) {
27+
private static Annotation findAnnotation(AnnotatedElement e, List<Class<? extends Annotation>> types) {
3628
return types.stream()
37-
.map(t -> AnnotationUtils.findAnnotation(clazz, t))
29+
.map(t -> findMergedAnnotation(e, t))
3830
.filter(Objects::nonNull)
3931
.findFirst()
4032
.orElse(null);
4133
}
42-
4334
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package no.nav.security.token.support.spring.integrationtest;
2+
3+
import static java.lang.annotation.ElementType.TYPE;
4+
import static java.lang.annotation.RetentionPolicy.RUNTIME;
5+
import static org.springframework.http.MediaType.APPLICATION_JSON_VALUE;
6+
7+
import java.lang.annotation.Documented;
8+
import java.lang.annotation.Retention;
9+
import java.lang.annotation.Target;
10+
11+
import org.springframework.core.annotation.AliasFor;
12+
import org.springframework.web.bind.annotation.RequestMapping;
13+
import org.springframework.web.bind.annotation.RestController;
14+
15+
import no.nav.security.token.support.core.api.ProtectedWithClaims;
16+
17+
@RestController
18+
@Documented
19+
@ProtectedWithClaims(issuer = "knownissuer")
20+
@Target(TYPE)
21+
@Retention(RUNTIME)
22+
@RequestMapping
23+
public @interface MetaProtected {
24+
@AliasFor(annotation = RequestMapping.class, attribute = "value")
25+
String[] value() default {};
26+
27+
@AliasFor(annotation = ProtectedWithClaims.class, attribute = "claimMap")
28+
String[] claimMap() default "acr=Level4";
29+
30+
@AliasFor(annotation = RequestMapping.class, attribute = "produces")
31+
String[] produces() default APPLICATION_JSON_VALUE;
32+
33+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package no.nav.security.token.support.spring.integrationtest;
2+
3+
import org.springframework.web.bind.annotation.GetMapping;
4+
5+
@MetaProtected(MetaProtectedRestController.METAPROTECTED)
6+
public class MetaProtectedRestController {
7+
static final String METAPROTECTED = "/metaprotected";
8+
9+
@GetMapping
10+
public String metaProtectedWithClaimsMethod() {
11+
return "protected with some required claims";
12+
}
13+
14+
}

token-validation-spring/src/test/java/no/nav/security/token/support/spring/integrationtest/ProtectedRestControllerIntegrationTest.java

Lines changed: 93 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
11
package no.nav.security.token.support.spring.integrationtest;
22

3-
import com.nimbusds.jwt.JWT;
4-
import com.nimbusds.jwt.JWTClaimsSet;
5-
import com.nimbusds.jwt.PlainJWT;
6-
import com.nimbusds.jwt.SignedJWT;
7-
import com.nimbusds.oauth2.sdk.TokenRequest;
8-
import io.restassured.module.mockmvc.RestAssuredMockMvc;
9-
import no.nav.security.mock.oauth2.MockOAuth2Server;
10-
import no.nav.security.mock.oauth2.token.OAuth2TokenCallback;
11-
import no.nav.security.token.support.test.JwkGenerator;
3+
import static io.restassured.module.mockmvc.RestAssuredMockMvc.given;
4+
import static no.nav.security.token.support.spring.integrationtest.MetaProtectedRestController.METAPROTECTED;
5+
import static no.nav.security.token.support.spring.integrationtest.ProtectedRestController.PROTECTED;
6+
import static no.nav.security.token.support.spring.integrationtest.ProtectedRestController.PROTECTED_WITH_CLAIMS;
7+
import static no.nav.security.token.support.spring.integrationtest.ProtectedRestController.PROTECTED_WITH_CLAIMS2;
8+
import static no.nav.security.token.support.spring.integrationtest.ProtectedRestController.PROTECTED_WITH_CLAIMS_ANY_CLAIMS;
9+
import static no.nav.security.token.support.spring.integrationtest.ProtectedRestController.UNPROTECTED;
10+
import static no.nav.security.token.support.test.JwtTokenGenerator.ACR;
11+
import static no.nav.security.token.support.test.JwtTokenGenerator.AUD;
12+
import static no.nav.security.token.support.test.JwtTokenGenerator.createSignedJWT;
13+
14+
import java.util.Collection;
15+
import java.util.Date;
16+
import java.util.Map;
17+
import java.util.Optional;
18+
import java.util.UUID;
19+
import java.util.concurrent.TimeUnit;
20+
21+
import javax.servlet.Filter;
22+
1223
import org.jetbrains.annotations.NotNull;
1324
import org.junit.jupiter.api.BeforeEach;
1425
import org.junit.jupiter.api.Test;
@@ -21,16 +32,19 @@
2132
import org.springframework.test.web.servlet.setup.MockMvcConfigurer;
2233
import org.springframework.web.context.WebApplicationContext;
2334

24-
import javax.servlet.Filter;
25-
import java.util.*;
26-
import java.util.concurrent.TimeUnit;
35+
import com.nimbusds.jwt.JWT;
36+
import com.nimbusds.jwt.JWTClaimsSet;
37+
import com.nimbusds.jwt.PlainJWT;
38+
import com.nimbusds.jwt.SignedJWT;
39+
import com.nimbusds.oauth2.sdk.TokenRequest;
2740

28-
import static io.restassured.module.mockmvc.RestAssuredMockMvc.given;
29-
import static no.nav.security.token.support.spring.integrationtest.ProtectedRestController.*;
30-
import static no.nav.security.token.support.test.JwtTokenGenerator.*;
41+
import io.restassured.module.mockmvc.RestAssuredMockMvc;
42+
import no.nav.security.mock.oauth2.MockOAuth2Server;
43+
import no.nav.security.mock.oauth2.token.OAuth2TokenCallback;
44+
import no.nav.security.token.support.test.JwkGenerator;
3145

3246
@SpringBootTest
33-
@ContextConfiguration(classes = {ProtectedApplication.class, ProtectedApplicationConfig.class})
47+
@ContextConfiguration(classes = { ProtectedApplication.class, ProtectedApplicationConfig.class })
3448
@ActiveProfiles("test")
3549
class ProtectedRestControllerIntegrationTest {
3650

@@ -56,21 +70,21 @@ public void afterConfigurerAdded(ConfigurableMockMvcBuilder<?> builder) {
5670
@Test
5771
void unprotectedMethod() {
5872
given()
59-
.when()
60-
.get(UNPROTECTED)
61-
.then()
62-
.log().ifValidationFails()
63-
.statusCode(HttpStatus.OK.value());
73+
.when()
74+
.get(UNPROTECTED)
75+
.then()
76+
.log().ifValidationFails()
77+
.statusCode(HttpStatus.OK.value());
6478
}
6579

6680
@Test
6781
void noTokenInRequest() {
6882
given()
69-
.when()
70-
.get(PROTECTED)
71-
.then()
72-
.log().ifValidationFails()
73-
.statusCode(HttpStatus.UNAUTHORIZED.value());
83+
.when()
84+
.get(PROTECTED)
85+
.then()
86+
.log().ifValidationFails()
87+
.statusCode(HttpStatus.UNAUTHORIZED.value());
7488

7589
}
7690

@@ -100,8 +114,8 @@ void signedTokenInRequestUnknownAudience() {
100114
@Test
101115
void signedTokenInRequestProtectedWithClaimsMethodMissingRequiredClaims() {
102116
JWTClaimsSet jwtClaimsSet = defaultJwtClaimsSetBuilder()
103-
.claim("importantclaim", "vip")
104-
.build();
117+
.claim("importantclaim", "vip")
118+
.build();
105119
expectStatusCode(PROTECTED_WITH_CLAIMS, issueToken("knownissuer", jwtClaimsSet).serialize(), HttpStatus.UNAUTHORIZED);
106120
}
107121

@@ -118,18 +132,24 @@ void signedTokenInRequestProtectedMethodShouldBeOk() {
118132
expectStatusCode(PROTECTED, jwt.serialize(), HttpStatus.OK);
119133
}
120134

135+
@Test
136+
void signedTokenInRequestProtectedMetaMethodShouldBeOk() {
137+
JWT jwt = issueToken("knownissuer", jwtClaimsSetKnownIssuer());
138+
expectStatusCode(METAPROTECTED, jwt.serialize(), HttpStatus.OK);
139+
}
140+
121141
@Test
122142
void signedTokenInRequestProtectedWithClaimsMethodShouldBeOk() {
123143
JWTClaimsSet jwtClaimsSet = defaultJwtClaimsSetBuilder()
124-
.claim("importantclaim", "vip")
125-
.claim("acr", "Level4")
126-
.build();
144+
.claim("importantclaim", "vip")
145+
.claim("acr", "Level4")
146+
.build();
127147

128148
expectStatusCode(PROTECTED_WITH_CLAIMS, issueToken("knownissuer", jwtClaimsSet).serialize(), HttpStatus.OK);
129149

130150
JWTClaimsSet jwtClaimsSet2 = defaultJwtClaimsSetBuilder()
131-
.claim("claim1", "1")
132-
.build();
151+
.claim("claim1", "1")
152+
.build();
133153

134154
expectStatusCode(PROTECTED_WITH_CLAIMS_ANY_CLAIMS, issueToken("knownissuer", jwtClaimsSet2).serialize(), HttpStatus.OK);
135155
}
@@ -138,12 +158,12 @@ void signedTokenInRequestProtectedWithClaimsMethodShouldBeOk() {
138158
void signedTokenInRequestWithoutSubAndAudClaimsShouldBeOk() {
139159
Date now = new Date();
140160
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder()
141-
.jwtID(UUID.randomUUID().toString())
142-
.claim("auth_time", now)
143-
.notBeforeTime(now)
144-
.issueTime(now)
145-
.expirationTime(new Date(now.getTime() + TimeUnit.MINUTES.toMillis(1)))
146-
.build();
161+
.jwtID(UUID.randomUUID().toString())
162+
.claim("auth_time", now)
163+
.notBeforeTime(now)
164+
.issueTime(now)
165+
.expirationTime(new Date(now.getTime() + TimeUnit.MINUTES.toMillis(1)))
166+
.build();
147167

148168
expectStatusCode(PROTECTED_WITH_CLAIMS2, issueToken("knownissuer2", jwtClaimsSet).serialize(), HttpStatus.OK);
149169
}
@@ -152,36 +172,36 @@ void signedTokenInRequestWithoutSubAndAudClaimsShouldBeOk() {
152172
void signedTokenInRequestWithoutSubAndAudClaimsShouldBeNotBeOk() {
153173
Date now = new Date();
154174
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder()
155-
.jwtID(UUID.randomUUID().toString())
156-
.claim("auth_time", now)
157-
.notBeforeTime(now)
158-
.issueTime(now)
159-
.expirationTime(new Date(now.getTime() + TimeUnit.MINUTES.toMillis(1)))
160-
.build();
175+
.jwtID(UUID.randomUUID().toString())
176+
.claim("auth_time", now)
177+
.notBeforeTime(now)
178+
.issueTime(now)
179+
.expirationTime(new Date(now.getTime() + TimeUnit.MINUTES.toMillis(1)))
180+
.build();
161181

162182
expectStatusCode(PROTECTED_WITH_CLAIMS, issueToken("knownissuer", jwtClaimsSet).serialize(), HttpStatus.UNAUTHORIZED);
163183
}
164184

165185
private static void expectStatusCode(String uri, String token, HttpStatus httpStatus) {
166186
given()
167-
.header("Authorization", "Bearer " + token)
168-
.when()
169-
.get(uri)
170-
.then()
171-
.log().ifValidationFails()
172-
.statusCode(httpStatus.value());
187+
.header("Authorization", "Bearer " + token)
188+
.when()
189+
.get(uri)
190+
.then()
191+
.log().ifValidationFails()
192+
.statusCode(httpStatus.value());
173193
}
174194

175195
private static JWTClaimsSet.Builder defaultJwtClaimsSetBuilder() {
176196
Date now = new Date();
177197
return new JWTClaimsSet.Builder()
178-
.subject("testsub")
179-
.audience(AUD)
180-
.jwtID(UUID.randomUUID().toString())
181-
.claim("auth_time", now)
182-
.notBeforeTime(now)
183-
.issueTime(now)
184-
.expirationTime(new Date(now.getTime() + TimeUnit.MINUTES.toMillis(1)));
198+
.subject("testsub")
199+
.audience(AUD)
200+
.jwtID(UUID.randomUUID().toString())
201+
.claim("auth_time", now)
202+
.notBeforeTime(now)
203+
.issueTime(now)
204+
.expirationTime(new Date(now.getTime() + TimeUnit.MINUTES.toMillis(1)));
185205
}
186206

187207
private static JWTClaimsSet jwtClaimsSetKnownIssuer() {
@@ -193,19 +213,19 @@ private static JWTClaimsSet jwtClaimsSet(String audience) {
193213
}
194214

195215
public static JWTClaimsSet buildClaimSet(String subject, String audience, String authLevel,
196-
long expiry) {
216+
long expiry) {
197217
Date now = new Date();
198218
return new JWTClaimsSet.Builder()
199-
.subject(subject)
200-
.audience(audience)
201-
.jwtID(UUID.randomUUID().toString())
202-
.claim("acr", authLevel)
203-
.claim("ver", "1.0")
204-
.claim("nonce", "myNonce")
205-
.claim("auth_time", now)
206-
.notBeforeTime(now)
207-
.issueTime(now)
208-
.expirationTime(new Date(now.getTime() + expiry)).build();
219+
.subject(subject)
220+
.audience(audience)
221+
.jwtID(UUID.randomUUID().toString())
222+
.claim("acr", authLevel)
223+
.claim("ver", "1.0")
224+
.claim("nonce", "myNonce")
225+
.claim("auth_time", now)
226+
.notBeforeTime(now)
227+
.issueTime(now)
228+
.expirationTime(new Date(now.getTime() + expiry)).build();
209229
}
210230

211231
private SignedJWT issueToken(String issuerId, JWTClaimsSet jwtClaimsSet) {
@@ -229,10 +249,10 @@ public String issuerId() {
229249
@Override
230250
public String audience(@NotNull TokenRequest tokenRequest) {
231251
return Optional.ofNullable(jwtClaimsSet.getAudience())
232-
.stream()
233-
.flatMap(a -> a.stream())
234-
.findFirst()
235-
.orElse(null);
252+
.stream()
253+
.flatMap(a -> a.stream())
254+
.findFirst()
255+
.orElse(null);
236256
}
237257

238258
@NotNull

0 commit comments

Comments
 (0)