|
1 | 1 | /*
|
2 |
| - * Copyright 2019 the original author or authors. |
| 2 | + * Copyright 2019-2021 the original author or authors. |
3 | 3 | *
|
4 | 4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | 5 | * you may not use this file except in compliance with the License.
|
|
19 | 19 | import io.rsocket.Payload;
|
20 | 20 | import io.rsocket.RSocket;
|
21 | 21 | import io.rsocket.metadata.WellKnownMimeType;
|
| 22 | +import io.rsocket.util.ByteBufPayload; |
| 23 | +import io.rsocket.util.DefaultPayload; |
22 | 24 | import io.rsocket.util.RSocketProxy;
|
23 | 25 | import org.junit.Test;
|
24 | 26 | import org.junit.runner.RunWith;
|
|
28 | 30 | import org.mockito.runners.MockitoJUnitRunner;
|
29 | 31 | import org.mockito.stubbing.Answer;
|
30 | 32 | import org.reactivestreams.Publisher;
|
| 33 | +import org.reactivestreams.Subscription; |
31 | 34 | import org.springframework.http.MediaType;
|
| 35 | +import org.springframework.security.access.AccessDeniedException; |
32 | 36 | import org.springframework.security.authentication.TestingAuthenticationToken;
|
33 | 37 | import org.springframework.security.core.Authentication;
|
34 | 38 | import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
|
41 | 45 | import org.springframework.security.rsocket.core.PayloadInterceptorRSocket;
|
42 | 46 | import org.springframework.util.MimeType;
|
43 | 47 | import org.springframework.util.MimeTypeUtils;
|
| 48 | +import reactor.util.context.Context; |
| 49 | +import reactor.core.CoreSubscriber; |
44 | 50 | import reactor.core.publisher.Flux;
|
45 | 51 | import reactor.core.publisher.Mono;
|
46 | 52 | import reactor.test.StepVerifier;
|
|
50 | 56 | import java.util.Arrays;
|
51 | 57 | import java.util.Collections;
|
52 | 58 | import java.util.List;
|
| 59 | +import java.util.concurrent.Executors; |
| 60 | +import java.util.concurrent.ExecutorService; |
53 | 61 |
|
54 | 62 | import static org.assertj.core.api.Assertions.*;
|
55 | 63 | import static org.mockito.ArgumentMatchers.any;
|
56 | 64 | import static org.mockito.ArgumentMatchers.eq;
|
| 65 | +import static org.mockito.Mockito.times; |
57 | 66 | import static org.mockito.Mockito.verify;
|
58 | 67 | import static org.mockito.Mockito.verifyZeroInteractions;
|
59 | 68 | import static org.mockito.Mockito.when;
|
@@ -315,6 +324,57 @@ public void requestChannelWhenInterceptorCompletesThenDelegateSubscribed() {
|
315 | 324 | verify(this.delegate).requestChannel(any());
|
316 | 325 | }
|
317 | 326 |
|
| 327 | + // gh-9345 |
| 328 | + @Test |
| 329 | + public void requestChannelWhenInterceptorCompletesThenAllPayloadsRetained() { |
| 330 | + ExecutorService executors = Executors.newSingleThreadExecutor(); |
| 331 | + Payload payload = ByteBufPayload.create("data"); |
| 332 | + Payload payloadTwo = ByteBufPayload.create("moredata"); |
| 333 | + Payload payloadThree = ByteBufPayload.create("stillmoredata"); |
| 334 | + Context ctx = Context.empty(); |
| 335 | + Flux<Payload> payloads = this.payloadResult.flux(); |
| 336 | + when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty()) |
| 337 | + .thenReturn(Mono.error(() -> new AccessDeniedException("Access Denied"))); |
| 338 | + when(this.delegate.requestChannel(any())).thenAnswer((invocation) -> { |
| 339 | + Flux<Payload> input = invocation.getArgument(0); |
| 340 | + return Flux.from(input).switchOnFirst((signal, innerFlux) -> innerFlux.map(Payload::getDataUtf8) |
| 341 | + .transform((data) -> Flux.<String>create((emitter) -> { |
| 342 | + Runnable run = () -> data.subscribe(new CoreSubscriber<String>() { |
| 343 | + @Override |
| 344 | + public void onSubscribe(Subscription s) { |
| 345 | + s.request(3); |
| 346 | + } |
| 347 | + |
| 348 | + @Override |
| 349 | + public void onNext(String s) { |
| 350 | + emitter.next(s); |
| 351 | + } |
| 352 | + |
| 353 | + @Override |
| 354 | + public void onError(Throwable t) { |
| 355 | + emitter.error(t); |
| 356 | + } |
| 357 | + |
| 358 | + @Override |
| 359 | + public void onComplete() { |
| 360 | + emitter.complete(); |
| 361 | + } |
| 362 | + }); |
| 363 | + executors.execute(run); |
| 364 | + })).map(DefaultPayload::create)); |
| 365 | + }); |
| 366 | + PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate, |
| 367 | + Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType, ctx); |
| 368 | + StepVerifier.create(interceptor.requestChannel(payloads).doOnDiscard(Payload.class, Payload::release)) |
| 369 | + .then(() -> this.payloadResult.assertSubscribers()) |
| 370 | + .then(() -> this.payloadResult.emit(payload, payloadTwo, payloadThree)) |
| 371 | + .assertNext((next) -> assertThat(next.getDataUtf8()).isEqualTo(payload.getDataUtf8())) |
| 372 | + .verifyError(AccessDeniedException.class); |
| 373 | + verify(this.interceptor, times(2)).intercept(this.exchange.capture(), any()); |
| 374 | + assertThat(this.exchange.getValue().getPayload()).isEqualTo(payloadTwo); |
| 375 | + verify(this.delegate).requestChannel(any()); |
| 376 | + } |
| 377 | + |
318 | 378 | @Test
|
319 | 379 | public void requestChannelWhenInterceptorErrorsThenDelegateNotSubscribed() {
|
320 | 380 | RuntimeException expected = new RuntimeException("Oops");
|
|
0 commit comments