Skip to content

Commit e06cc00

Browse files
committed
This commit updates the Java SDK to allow all completable journal entries to have a failure variant. As part of this change, we also added tests to ensure the correct behavior. This fixes #187.
1 parent 24dd269 commit e06cc00

File tree

7 files changed

+125
-19
lines changed

7 files changed

+125
-19
lines changed

sdk-core/src/main/java/dev/restate/sdk/core/Entries.java

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,26 @@ public void trace(PollInputStreamEntryMessage expected, Span span) {
6464

6565
@Override
6666
public boolean hasResult(PollInputStreamEntryMessage actual) {
67-
return true;
67+
return actual.getResultCase() != PollInputStreamEntryMessage.ResultCase.RESULT_NOT_SET;
6868
}
6969

7070
@Override
7171
public ReadyResultInternal<R> parseEntryResult(PollInputStreamEntryMessage actual) {
72-
return valueParser.apply(actual.getValue());
72+
if (actual.getResultCase() == PollInputStreamEntryMessage.ResultCase.VALUE) {
73+
return valueParser.apply(actual.getValue());
74+
} else if (actual.getResultCase() == PollInputStreamEntryMessage.ResultCase.FAILURE) {
75+
return ReadyResults.failure(Util.toRestateException(actual.getFailure()));
76+
} else {
77+
throw new IllegalStateException("PollInputEntry has not been completed.");
78+
}
7379
}
7480

7581
@Override
7682
public ReadyResultInternal<R> parseCompletionResult(CompletionMessage actual) {
7783
if (actual.getResultCase() == CompletionMessage.ResultCase.VALUE) {
7884
return valueParser.apply(actual.getValue());
85+
} else if (actual.getResultCase() == CompletionMessage.ResultCase.FAILURE) {
86+
return ReadyResults.failure(Util.toRestateException(actual.getFailure()));
7987
}
8088
return super.parseCompletionResult(actual);
8189
}
@@ -126,17 +134,23 @@ void checkEntryHeader(GetStateEntryMessage expected, MessageLite actual)
126134
public ReadyResultInternal<ByteString> parseEntryResult(GetStateEntryMessage actual) {
127135
if (actual.getResultCase() == GetStateEntryMessage.ResultCase.VALUE) {
128136
return ReadyResults.success(actual.getValue());
137+
} else if (actual.getResultCase() == GetStateEntryMessage.ResultCase.FAILURE) {
138+
return ReadyResults.failure(Util.toRestateException(actual.getFailure()));
139+
} else if (actual.getResultCase() == GetStateEntryMessage.ResultCase.EMPTY) {
140+
return ReadyResults.empty();
141+
} else {
142+
throw new IllegalStateException("GetStateEntry has not been completed.");
129143
}
130-
return ReadyResults.empty();
131144
}
132145

133146
@Override
134147
public ReadyResultInternal<ByteString> parseCompletionResult(CompletionMessage actual) {
135148
if (actual.getResultCase() == CompletionMessage.ResultCase.VALUE) {
136149
return ReadyResults.success(actual.getValue());
137-
}
138-
if (actual.getResultCase() == CompletionMessage.ResultCase.EMPTY) {
150+
} else if (actual.getResultCase() == CompletionMessage.ResultCase.EMPTY) {
139151
return ReadyResults.empty();
152+
} else if (actual.getResultCase() == CompletionMessage.ResultCase.FAILURE) {
153+
return ReadyResults.failure(Util.toRestateException(actual.getFailure()));
140154
}
141155
return super.parseCompletionResult(actual);
142156
}
@@ -239,18 +253,26 @@ void trace(SleepEntryMessage expected, Span span) {
239253

240254
@Override
241255
public boolean hasResult(SleepEntryMessage actual) {
242-
return actual.hasResult();
256+
return actual.getResultCase() != Protocol.SleepEntryMessage.ResultCase.RESULT_NOT_SET;
243257
}
244258

245259
@Override
246260
public ReadyResultInternal<Void> parseEntryResult(SleepEntryMessage actual) {
247-
return ReadyResults.empty();
261+
if (actual.getResultCase() == SleepEntryMessage.ResultCase.FAILURE) {
262+
return ReadyResults.failure(Util.toRestateException(actual.getFailure()));
263+
} else if (actual.getResultCase() == SleepEntryMessage.ResultCase.EMPTY) {
264+
return ReadyResults.empty();
265+
} else {
266+
throw new IllegalStateException("SleepEntry has not been completed.");
267+
}
248268
}
249269

250270
@Override
251271
public ReadyResultInternal<Void> parseCompletionResult(CompletionMessage actual) {
252272
if (actual.getResultCase() == CompletionMessage.ResultCase.EMPTY) {
253273
return ReadyResults.empty();
274+
} else if (actual.getResultCase() == CompletionMessage.ResultCase.FAILURE) {
275+
return ReadyResults.failure(Util.toRestateException(actual.getFailure()));
254276
}
255277
return super.parseCompletionResult(actual);
256278
}

sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,13 @@ public static MessageHeader fromMessage(MessageLite msg) {
6363
} else if (msg instanceof Protocol.EntryAckMessage) {
6464
return new MessageHeader(MessageType.EntryAckMessage, 0, msg.getSerializedSize());
6565
} else if (msg instanceof Protocol.PollInputStreamEntryMessage) {
66-
return new MessageHeader(MessageType.PollInputStreamEntryMessage, 0, msg.getSerializedSize());
66+
return new MessageHeader(
67+
MessageType.PollInputStreamEntryMessage,
68+
((Protocol.PollInputStreamEntryMessage) msg).getResultCase()
69+
!= Protocol.PollInputStreamEntryMessage.ResultCase.RESULT_NOT_SET
70+
? DONE_FLAG
71+
: 0,
72+
msg.getSerializedSize());
6773
} else if (msg instanceof Protocol.OutputStreamEntryMessage) {
6874
return new MessageHeader(MessageType.OutputStreamEntryMessage, 0, msg.getSerializedSize());
6975
} else if (msg instanceof Protocol.GetStateEntryMessage) {
@@ -81,7 +87,10 @@ public static MessageHeader fromMessage(MessageLite msg) {
8187
} else if (msg instanceof Protocol.SleepEntryMessage) {
8288
return new MessageHeader(
8389
MessageType.SleepEntryMessage,
84-
((Protocol.SleepEntryMessage) msg).hasResult() ? DONE_FLAG : 0,
90+
((Protocol.SleepEntryMessage) msg).getResultCase()
91+
!= Protocol.SleepEntryMessage.ResultCase.RESULT_NOT_SET
92+
? DONE_FLAG
93+
: 0,
8594
msg.getSerializedSize());
8695
} else if (msg instanceof Protocol.InvokeEntryMessage) {
8796
return new MessageHeader(

sdk-core/src/main/java/dev/restate/sdk/core/RestateServerCall.java

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import com.google.protobuf.MessageLite;
1212
import dev.restate.sdk.common.TerminalException;
13+
import dev.restate.sdk.common.syscalls.ReadyResult;
1314
import dev.restate.sdk.common.syscalls.SyscallCallback;
1415
import io.grpc.Metadata;
1516
import io.grpc.MethodDescriptor;
@@ -160,11 +161,21 @@ private void pollInput() {
160161
() -> {
161162
Objects.requireNonNull(listener);
162163

163-
// PollInput can only be result
164-
MessageLite message = deferredValue.toReadyResult().getResult();
165-
166-
LOG.trace("Read input message:\n{}", message);
167-
listener.invoke(message);
164+
final ReadyResult<MessageLite> pollInputReadyResult =
165+
deferredValue.toReadyResult();
166+
167+
if (pollInputReadyResult.isSuccess()) {
168+
final MessageLite message = pollInputReadyResult.getResult();
169+
LOG.trace("Read input message:\n{}", message);
170+
listener.invoke(message);
171+
} else {
172+
final TerminalException failure = pollInputReadyResult.getFailure();
173+
this.close(
174+
Status.UNKNOWN
175+
.withDescription(failure.getMessage())
176+
.withCause(failure),
177+
new Metadata());
178+
}
168179
},
169180
this::onError)),
170181
this::onError));

sdk-core/src/test/java/dev/restate/sdk/core/OnlyInputAndOutputTestSuite.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import static dev.restate.sdk.core.TestDefinitions.TestDefinition;
1313
import static dev.restate.sdk.core.TestDefinitions.testInvocation;
1414

15+
import dev.restate.sdk.common.TerminalException;
1516
import dev.restate.sdk.core.TestDefinitions.TestSuite;
1617
import dev.restate.sdk.core.testservices.GreeterGrpc;
1718
import dev.restate.sdk.core.testservices.GreetingRequest;
@@ -30,7 +31,12 @@ public Stream<TestDefinition> definitions() {
3031
.withInput(
3132
startMessage(1), inputMessage(GreetingRequest.newBuilder().setName("Francesco")))
3233
.expectingOutput(
33-
outputMessage(
34-
GreetingResponse.newBuilder().setMessage("Hello Francesco").build())));
34+
outputMessage(GreetingResponse.newBuilder().setMessage("Hello Francesco").build())),
35+
testInvocation(this::noSyscallsGreeter, GreeterGrpc.getGreetMethod())
36+
.withInput(
37+
startMessage(1),
38+
inputMessage(new TerminalException(TerminalException.Code.CANCELLED)))
39+
.expectingOutput(
40+
outputMessage(new TerminalException(TerminalException.Code.CANCELLED))));
3541
}
3642
}

sdk-core/src/test/java/dev/restate/sdk/core/ProtoUtils.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ public static Protocol.PollInputStreamEntryMessage inputMessage(MessageLiteOrBui
9898
.build();
9999
}
100100

101+
public static Protocol.PollInputStreamEntryMessage inputMessage(Throwable error) {
102+
return Protocol.PollInputStreamEntryMessage.newBuilder()
103+
.setFailure(Util.toProtocolFailure(error))
104+
.build();
105+
}
106+
101107
public static Protocol.OutputStreamEntryMessage outputMessage(MessageLiteOrBuilder value) {
102108
return Protocol.OutputStreamEntryMessage.newBuilder()
103109
.setValue(build(value).toByteString())
@@ -121,6 +127,10 @@ public static Protocol.GetStateEntryMessage.Builder getStateMessage(String key)
121127
return Protocol.GetStateEntryMessage.newBuilder().setKey(ByteString.copyFromUtf8(key));
122128
}
123129

130+
public static Protocol.GetStateEntryMessage.Builder getStateMessage(String key, Throwable error) {
131+
return getStateMessage(key).setFailure(Util.toProtocolFailure(error));
132+
}
133+
124134
public static Protocol.GetStateEntryMessage getStateEmptyMessage(String key) {
125135
return Protocol.GetStateEntryMessage.newBuilder()
126136
.setKey(ByteString.copyFromUtf8(key))

sdk-core/src/test/java/dev/restate/sdk/core/SleepTestSuite.java

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import com.google.protobuf.Empty;
1616
import com.google.protobuf.MessageLiteOrBuilder;
1717
import dev.restate.generated.service.protocol.Protocol;
18+
import dev.restate.sdk.common.TerminalException;
1819
import dev.restate.sdk.core.testservices.GreeterGrpc;
1920
import dev.restate.sdk.core.testservices.GreetingRequest;
2021
import dev.restate.sdk.core.testservices.GreetingResponse;
@@ -55,7 +56,7 @@ public Stream<TestDefinitions.TestDefinition> definitions() {
5556
inputMessage(GreetingRequest.newBuilder().setName("Till")),
5657
Protocol.SleepEntryMessage.newBuilder()
5758
.setWakeUpTime(Instant.now().toEpochMilli())
58-
.setResult(Empty.getDefaultInstance())
59+
.setEmpty(Empty.getDefaultInstance())
5960
.build())
6061
.expectingOutput(
6162
outputMessage(GreetingResponse.newBuilder().setMessage("Hello").build()))
@@ -81,13 +82,37 @@ public Stream<TestDefinitions.TestDefinition> definitions() {
8182
(i % 3 == 0)
8283
? Protocol.SleepEntryMessage.newBuilder()
8384
.setWakeUpTime(Instant.now().toEpochMilli())
84-
.setResult(Empty.getDefaultInstance())
85+
.setEmpty(Empty.getDefaultInstance())
8586
.build()
8687
: Protocol.SleepEntryMessage.newBuilder()
8788
.setWakeUpTime(Instant.now().toEpochMilli())
8889
.build()))
8990
.toArray(MessageLiteOrBuilder[]::new))
9091
.expectingOutput(suspensionMessage(1, 2, 4, 5, 7, 8, 10))
91-
.named("Sleep 1000 ms sleep completed"));
92+
.named("Sleep 1000 ms sleep completed"),
93+
testInvocation(this::sleepGreeter, GreeterGrpc.getGreetMethod())
94+
.withInput(
95+
startMessage(2),
96+
inputMessage(GreetingRequest.newBuilder().setName("Till")),
97+
Protocol.SleepEntryMessage.newBuilder()
98+
.setWakeUpTime(Instant.now().toEpochMilli())
99+
.setFailure(
100+
Util.toProtocolFailure(TerminalException.Code.CANCELLED, "canceled"))
101+
.build())
102+
.expectingOutput(outputMessage(TerminalException.Code.CANCELLED, "canceled"))
103+
.named("Failed sleep"),
104+
testInvocation(this::sleepGreeter, GreeterGrpc.getGreetMethod())
105+
.withInput(
106+
startMessage(1),
107+
inputMessage(GreetingRequest.newBuilder().setName("Till")),
108+
completionMessage(
109+
1, new TerminalException(TerminalException.Code.CANCELLED, "canceled")))
110+
.assertingOutput(
111+
messageLites -> {
112+
assertThat(messageLites.get(0)).isInstanceOf(Protocol.SleepEntryMessage.class);
113+
assertThat(messageLites.get(1))
114+
.isEqualTo(outputMessage(TerminalException.Code.CANCELLED, "canceled"));
115+
})
116+
.named("Failing sleep"));
92117
}
93118
}

sdk-core/src/test/java/dev/restate/sdk/core/StateTestSuite.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010

1111
import static dev.restate.sdk.core.ProtoUtils.*;
1212
import static dev.restate.sdk.core.TestDefinitions.testInvocation;
13+
import static org.assertj.core.api.Assertions.assertThat;
1314

1415
import com.google.protobuf.Empty;
16+
import dev.restate.generated.service.protocol.Protocol;
17+
import dev.restate.sdk.common.TerminalException;
1518
import dev.restate.sdk.core.testservices.GreeterGrpc;
1619
import dev.restate.sdk.core.testservices.GreetingRequest;
1720
import dev.restate.sdk.core.testservices.GreetingResponse;
@@ -76,6 +79,26 @@ public Stream<TestDefinitions.TestDefinition> definitions() {
7679
getStateMessage("STATE"),
7780
outputMessage(GreetingResponse.newBuilder().setMessage("Hello Francesco")))
7881
.named("Without GetStateEntry and completed with later CompletionFrame"),
82+
testInvocation(this::getState, GreeterGrpc.getGreetMethod())
83+
.withInput(
84+
startMessage(2),
85+
inputMessage(GreetingRequest.newBuilder().setName("Till")),
86+
getStateMessage("STATE", new TerminalException(TerminalException.Code.CANCELLED)))
87+
.expectingOutput(outputMessage(new TerminalException(TerminalException.Code.CANCELLED)))
88+
.named("Failed GetStateEntry"),
89+
testInvocation(this::getState, GreeterGrpc.getGreetMethod())
90+
.withInput(
91+
startMessage(1),
92+
inputMessage(GreetingRequest.newBuilder().setName("Till")),
93+
completionMessage(1, new TerminalException(TerminalException.Code.CANCELLED)))
94+
.assertingOutput(
95+
messageLites -> {
96+
assertThat(messageLites.get(0)).isInstanceOf(Protocol.GetStateEntryMessage.class);
97+
assertThat(messageLites.get(1))
98+
.isEqualTo(
99+
outputMessage(new TerminalException(TerminalException.Code.CANCELLED)));
100+
})
101+
.named("Failing GetStateEntry"),
79102
testInvocation(this::getAndSetState, GreeterGrpc.getGreetMethod())
80103
.withInput(
81104
startMessage(3),

0 commit comments

Comments
 (0)