Skip to content

Commit b189e03

Browse files
committed
PayloadInterceptorRSocket retains all payloads
Flux#skip discards its corresponding elements, meaning that they aren't intended for reuse. When using RSocket's ByteBufPayloads, this means that the bytes are releaseed back into RSocket's pool. Since the downstream request may still need the skipped payload, we should construct the publisher in a different way so as to avoid the preemptive release. Deferring Spring JavaFormat to clarify what changed. Closes gh-9345
1 parent 6cafa48 commit b189e03

File tree

2 files changed

+68
-5
lines changed

2 files changed

+68
-5
lines changed

rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocket.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2019 the original author or authors.
2+
* Copyright 2019-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -104,15 +104,18 @@ public Flux<Payload> requestChannel(Publisher<Payload> payloads) {
104104
return intercept(PayloadExchangeType.REQUEST_CHANNEL, firstPayload)
105105
.flatMapMany(context ->
106106
innerFlux
107-
.skip(1)
108-
.flatMap(p -> intercept(PayloadExchangeType.PAYLOAD, p).thenReturn(p))
109-
.transform(securedPayloads -> Flux.concat(Flux.just(firstPayload), securedPayloads))
107+
.index()
108+
.concatMap(tuple -> justOrIntercept(tuple.getT1(), tuple.getT2()))
110109
.transform(securedPayloads -> this.source.requestChannel(securedPayloads))
111110
.subscriberContext(context)
112111
);
113112
});
114113
}
115114

115+
private Mono<Payload> justOrIntercept(Long index, Payload payload) {
116+
return (index == 0) ? Mono.just(payload) : intercept(PayloadExchangeType.PAYLOAD, payload).thenReturn(payload);
117+
}
118+
116119
@Override
117120
public Mono<Void> metadataPush(Payload payload) {
118121
return intercept(PayloadExchangeType.METADATA_PUSH, payload)

rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadInterceptorRSocketTests.java

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2019 the original author or authors.
2+
* Copyright 2019-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -19,6 +19,8 @@
1919
import io.rsocket.Payload;
2020
import io.rsocket.RSocket;
2121
import io.rsocket.metadata.WellKnownMimeType;
22+
import io.rsocket.util.ByteBufPayload;
23+
import io.rsocket.util.DefaultPayload;
2224
import io.rsocket.util.RSocketProxy;
2325
import org.junit.Test;
2426
import org.junit.runner.RunWith;
@@ -28,7 +30,9 @@
2830
import org.mockito.runners.MockitoJUnitRunner;
2931
import org.mockito.stubbing.Answer;
3032
import org.reactivestreams.Publisher;
33+
import org.reactivestreams.Subscription;
3134
import org.springframework.http.MediaType;
35+
import org.springframework.security.access.AccessDeniedException;
3236
import org.springframework.security.authentication.TestingAuthenticationToken;
3337
import org.springframework.security.core.Authentication;
3438
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
@@ -41,6 +45,8 @@
4145
import org.springframework.security.rsocket.core.PayloadInterceptorRSocket;
4246
import org.springframework.util.MimeType;
4347
import org.springframework.util.MimeTypeUtils;
48+
import reactor.util.context.Context;
49+
import reactor.core.CoreSubscriber;
4450
import reactor.core.publisher.Flux;
4551
import reactor.core.publisher.Mono;
4652
import reactor.test.StepVerifier;
@@ -50,10 +56,13 @@
5056
import java.util.Arrays;
5157
import java.util.Collections;
5258
import java.util.List;
59+
import java.util.concurrent.Executors;
60+
import java.util.concurrent.ExecutorService;
5361

5462
import static org.assertj.core.api.Assertions.*;
5563
import static org.mockito.ArgumentMatchers.any;
5664
import static org.mockito.ArgumentMatchers.eq;
65+
import static org.mockito.Mockito.times;
5766
import static org.mockito.Mockito.verify;
5867
import static org.mockito.Mockito.verifyZeroInteractions;
5968
import static org.mockito.Mockito.when;
@@ -315,6 +324,57 @@ public void requestChannelWhenInterceptorCompletesThenDelegateSubscribed() {
315324
verify(this.delegate).requestChannel(any());
316325
}
317326

327+
// gh-9345
328+
@Test
329+
public void requestChannelWhenInterceptorCompletesThenAllPayloadsRetained() {
330+
ExecutorService executors = Executors.newSingleThreadExecutor();
331+
Payload payload = ByteBufPayload.create("data");
332+
Payload payloadTwo = ByteBufPayload.create("moredata");
333+
Payload payloadThree = ByteBufPayload.create("stillmoredata");
334+
Context ctx = Context.empty();
335+
Flux<Payload> payloads = this.payloadResult.flux();
336+
when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty())
337+
.thenReturn(Mono.error(() -> new AccessDeniedException("Access Denied")));
338+
when(this.delegate.requestChannel(any())).thenAnswer((invocation) -> {
339+
Flux<Payload> input = invocation.getArgument(0);
340+
return Flux.from(input).switchOnFirst((signal, innerFlux) -> innerFlux.map(Payload::getDataUtf8)
341+
.transform((data) -> Flux.<String>create((emitter) -> {
342+
Runnable run = () -> data.subscribe(new CoreSubscriber<String>() {
343+
@Override
344+
public void onSubscribe(Subscription s) {
345+
s.request(3);
346+
}
347+
348+
@Override
349+
public void onNext(String s) {
350+
emitter.next(s);
351+
}
352+
353+
@Override
354+
public void onError(Throwable t) {
355+
emitter.error(t);
356+
}
357+
358+
@Override
359+
public void onComplete() {
360+
emitter.complete();
361+
}
362+
});
363+
executors.execute(run);
364+
})).map(DefaultPayload::create));
365+
});
366+
PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate,
367+
Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType, ctx);
368+
StepVerifier.create(interceptor.requestChannel(payloads).doOnDiscard(Payload.class, Payload::release))
369+
.then(() -> this.payloadResult.assertSubscribers())
370+
.then(() -> this.payloadResult.emit(payload, payloadTwo, payloadThree))
371+
.assertNext((next) -> assertThat(next.getDataUtf8()).isEqualTo(payload.getDataUtf8()))
372+
.verifyError(AccessDeniedException.class);
373+
verify(this.interceptor, times(2)).intercept(this.exchange.capture(), any());
374+
assertThat(this.exchange.getValue().getPayload()).isEqualTo(payloadTwo);
375+
verify(this.delegate).requestChannel(any());
376+
}
377+
318378
@Test
319379
public void requestChannelWhenInterceptorErrorsThenDelegateNotSubscribed() {
320380
RuntimeException expected = new RuntimeException("Oops");

0 commit comments

Comments
 (0)