Skip to content

Commit ee3f2e1

Browse files
committed
Do not defer error emission during result consumption.
Consuming results through Result.getRowsUpdated() or Result.map() now emits error signals directly without accumulating these. This behavior prevents consuming results if an error token is present and arrangements that adopt one of the resulting publishers into a Mono (via Mono.from(…)) no longer causes a false impression of a successful call when consuming the first signal only without awaiting a completion signal. [resolves #180] Signed-off-by: Mark Paluch <mpaluch@vmware.com>
1 parent bb525d9 commit ee3f2e1

File tree

3 files changed

+55
-70
lines changed

3 files changed

+55
-70
lines changed

src/main/java/io/r2dbc/mssql/DefaultMssqlResult.java

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import io.r2dbc.mssql.message.token.ReturnValue;
2727
import io.r2dbc.mssql.message.token.RowToken;
2828
import io.r2dbc.mssql.util.Assert;
29-
import io.r2dbc.spi.R2dbcException;
3029
import io.r2dbc.spi.Readable;
3130
import io.r2dbc.spi.Result;
3231
import io.r2dbc.spi.Row;
@@ -66,8 +65,6 @@ final class DefaultMssqlResult implements MssqlResult {
6665

6766
private volatile MssqlRowMetadata rowMetadata;
6867

69-
private volatile RuntimeException throwable;
70-
7168
private DefaultMssqlResult(String sql, ConnectionContext context, Codecs codecs, Flux<io.r2dbc.mssql.message.Message> messages, boolean expectReturnValues) {
7269

7370
this.sql = sql;
@@ -124,24 +121,11 @@ public Mono<Integer> getRowsUpdated() {
124121

125122
if (message instanceof ErrorToken) {
126123

127-
R2dbcException mssqlException = ExceptionFactory.createException((ErrorToken) message, this.sql);
128-
129-
Throwable exception = this.throwable;
130-
if (exception != null) {
131-
exception.addSuppressed(mssqlException);
132-
} else {
133-
this.throwable = mssqlException;
134-
}
135-
124+
sink.error(ExceptionFactory.createException((ErrorToken) message, this.sql));
136125
return;
137126
}
138127

139128
ReferenceCountUtil.release(message);
140-
}).doOnComplete(() -> {
141-
RuntimeException exception = this.throwable;
142-
if (exception != null) {
143-
throw exception;
144-
}
145129
}).reduce(Long::sum).map(Long::intValue);
146130
}
147131

@@ -242,16 +226,7 @@ private <T> Flux<T> doMap(boolean rows, boolean outparameters, Function<? super
242226
}
243227

244228
if (message instanceof ErrorToken) {
245-
246-
R2dbcException mssqlException = ExceptionFactory.createException((ErrorToken) message, this.sql);
247-
248-
Throwable exception = this.throwable;
249-
if (exception != null) {
250-
exception.addSuppressed(mssqlException);
251-
} else {
252-
this.throwable = mssqlException;
253-
}
254-
229+
sink.error(ExceptionFactory.createException((ErrorToken) message, this.sql));
255230
return;
256231
}
257232

@@ -266,12 +241,7 @@ private <T> Flux<T> doMap(boolean rows, boolean outparameters, Function<? super
266241
mapped = mapped.concatWith(mappedReturnValues);
267242
}
268243

269-
return mapped.doOnComplete(() -> {
270-
RuntimeException exception = this.throwable;
271-
if (exception != null) {
272-
throw exception;
273-
}
274-
});
244+
return mapped;
275245
}
276246

277247
@Override

src/main/java/io/r2dbc/mssql/MssqlSegmentResult.java

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,6 @@ final class MssqlSegmentResult implements MssqlResult {
6868

6969
private final Flux<Segment> segments;
7070

71-
private volatile RuntimeException throwable;
72-
7371
private MssqlSegmentResult(String sql, ConnectionContext context, Codecs codecs, Flux<Segment> segments) {
7472

7573
this.sql = sql;
@@ -200,16 +198,11 @@ public Mono<Integer> getRowsUpdated() {
200198
}
201199

202200
if (isError(segment)) {
203-
handleError((Message) segment);
201+
sink.error(((Message) segment).exception());
204202
return;
205203
}
206204

207205
ReferenceCountUtil.release(segment);
208-
}).doOnComplete(() -> {
209-
RuntimeException exception = this.throwable;
210-
if (exception != null) {
211-
throw exception;
212-
}
213206
}).reduce(Long::sum).map(Long::intValue);
214207
}
215208

@@ -235,7 +228,7 @@ public <T> Flux<T> map(Function<? super Readable, ? extends T> mappingFunction)
235228
private <T> Flux<T> doMap(boolean rows, boolean outparameters, Function<? super Readable, ? extends T> mappingFunction) {
236229

237230
return this.segments
238-
.<T>handle((segment, sink) -> {
231+
.handle((segment, sink) -> {
239232

240233
if (rows && segment instanceof RowSegment) {
241234

@@ -264,16 +257,11 @@ private <T> Flux<T> doMap(boolean rows, boolean outparameters, Function<? super
264257
}
265258

266259
if (isError(segment)) {
267-
handleError((Message) segment);
260+
sink.error(((Message) segment).exception());
268261
return;
269262
}
270263

271264
ReferenceCountUtil.release(segment);
272-
}).doOnComplete(() -> {
273-
RuntimeException exception = this.throwable;
274-
if (exception != null) {
275-
throw exception;
276-
}
277265
});
278266
}
279267

@@ -319,17 +307,6 @@ public <T> Flux<T> flatMap(Function<Segment, ? extends Publisher<? extends T>> m
319307
});
320308
}
321309

322-
private void handleError(Message segment) {
323-
R2dbcException mssqlException = segment.exception();
324-
325-
Throwable exception = this.throwable;
326-
if (exception != null) {
327-
exception.addSuppressed(mssqlException);
328-
} else {
329-
this.throwable = mssqlException;
330-
}
331-
}
332-
333310
private boolean isError(Segment segment) {
334311
return segment instanceof MssqlMessage && ((MssqlMessage) segment).isError();
335312
}

src/test/java/io/r2dbc/mssql/MssqlResultUnitTests.java

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@
2121
import io.r2dbc.mssql.message.Message;
2222
import io.r2dbc.mssql.message.token.DoneToken;
2323
import io.r2dbc.mssql.message.token.ErrorToken;
24-
import org.junit.jupiter.api.Test;
24+
import org.junit.jupiter.params.ParameterizedTest;
25+
import org.junit.jupiter.params.provider.MethodSource;
2526
import reactor.core.publisher.Flux;
2627
import reactor.test.StepVerifier;
2728

28-
import java.util.Iterator;
29-
import java.util.stream.Stream;
30-
31-
import static org.assertj.core.api.Assertions.assertThat;
29+
import java.util.Arrays;
30+
import java.util.List;
3231

3332
/**
3433
* Unit tests for {@link DefaultMssqlResult}.
@@ -37,20 +36,59 @@
3736
*/
3837
class MssqlResultUnitTests {
3938

40-
@Test
41-
void shouldDeferErrorSignal() {
39+
@ParameterizedTest
40+
@MethodSource("factories")
41+
void shouldEmitErrorSignalInOrder(ResultFactory factory) {
4242

4343
ErrorToken error = new ErrorToken(0, 0, Byte.MIN_VALUE, Byte.MIN_VALUE, "foo", "", "", 0);
4444
DoneToken done = DoneToken.create(0);
45-
Iterator<Message> iterator = Stream.of(error, done).map(Message.class::cast).iterator();
4645

47-
MssqlResult result = MssqlSegmentResult.toResult("", new ConnectionContext(), new DefaultCodecs(), Flux.fromIterable(() -> iterator), false);
46+
MssqlResult countThenError = factory.create(Flux.just(done, error));
4847

49-
result.getRowsUpdated()
48+
countThenError.getRowsUpdated()
5049
.as(StepVerifier::create)
5150
.expectError()
5251
.verify();
5352

54-
assertThat(iterator.hasNext()).isFalse();
53+
MssqlResult errorThenCount = factory.create(Flux.just(error, done));
54+
55+
errorThenCount.getRowsUpdated()
56+
.as(StepVerifier::create)
57+
.expectError()
58+
.verify();
5559
}
60+
61+
static List<ResultFactory> factories() {
62+
63+
return Arrays.asList(new ResultFactory() {
64+
65+
@Override
66+
MssqlResult create(Flux<Message> messages) {
67+
return DefaultMssqlResult.toResult("", new ConnectionContext(), new DefaultCodecs(), messages, false);
68+
}
69+
70+
@Override
71+
public String toString() {
72+
return "DefaultMssqlResult";
73+
}
74+
}, new ResultFactory() {
75+
76+
@Override
77+
MssqlResult create(Flux<Message> messages) {
78+
return MssqlSegmentResult.toResult("", new ConnectionContext(), new DefaultCodecs(), messages, false);
79+
}
80+
81+
@Override
82+
public String toString() {
83+
return "MssqlSegmentResult";
84+
}
85+
});
86+
}
87+
88+
static abstract class ResultFactory {
89+
90+
abstract MssqlResult create(Flux<Message> messages);
91+
92+
}
93+
5694
}

0 commit comments

Comments
 (0)