Skip to content

Commit 5ccb859

Browse files
Send original update request back in accept/reject response (#2074)
1 parent 82d5a88 commit 5ccb859

File tree

2 files changed

+48
-20
lines changed

2 files changed

+48
-20
lines changed

temporal-sdk/src/main/java/io/temporal/internal/statemachines/UpdateProtocolStateMachine.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ enum State {
7979
private String requestMsgId;
8080
private long requestSeqID;
8181
private Meta meta;
82+
private Optional<Request> originalRequest = Optional.empty();
8283
private String messageId;
8384

8485
public static final StateMachineDefinition<State, ExplicitEvent, UpdateProtocolStateMachine>
@@ -175,7 +176,8 @@ void triggerUpdate() {
175176
requestMsgId = this.currentMessage.getId();
176177
requestSeqID = this.currentMessage.getEventId();
177178
try {
178-
meta = this.currentMessage.getBody().unpack(Request.class).getMeta();
179+
originalRequest = Optional.of(this.currentMessage.getBody().unpack(Request.class));
180+
meta = originalRequest.get().getMeta();
179181
} catch (InvalidProtocolBufferException e) {
180182
throw new IllegalArgumentException("Current message not an update:" + this.currentMessage);
181183
}
@@ -199,8 +201,10 @@ public void accept() {
199201
Acceptance.newBuilder()
200202
.setAcceptedRequestMessageId(requestMsgId)
201203
.setAcceptedRequestSequencingEventId(requestSeqID)
204+
.setAcceptedRequest(originalRequest.get())
202205
.build();
203-
206+
// Clear the original request to allow GC to reclaim the memory.
207+
originalRequest = Optional.empty();
204208
messageId = requestMsgId + "/accept";
205209
sendHandle.apply(
206210
Message.newBuilder()
@@ -217,6 +221,7 @@ public void reject(Failure failure) {
217221
.setRejectedRequestMessageId(requestMsgId)
218222
.setRejectedRequestSequencingEventId(requestSeqID)
219223
.setFailure(failure)
224+
.setRejectedRequest(originalRequest.get())
220225
.build();
221226

222227
String messageId = requestMsgId + "/reject";

temporal-sdk/src/test/java/io/temporal/internal/statemachines/UpdateProtocolStateMachineTest.java

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import static io.temporal.internal.statemachines.MutableSideEffectStateMachine.*;
2424
import static io.temporal.internal.statemachines.SideEffectStateMachine.SIDE_EFFECT_MARKER_NAME;
2525
import static org.junit.Assert.*;
26-
import static org.junit.Assert.assertEquals;
2726

2827
import com.google.protobuf.Any;
2928
import com.google.protobuf.InvalidProtocolBufferException;
@@ -34,10 +33,7 @@
3433
import io.temporal.api.enums.v1.EventType;
3534
import io.temporal.api.history.v1.*;
3635
import io.temporal.api.protocol.v1.Message;
37-
import io.temporal.api.update.v1.Input;
38-
import io.temporal.api.update.v1.Meta;
39-
import io.temporal.api.update.v1.Outcome;
40-
import io.temporal.api.update.v1.Request;
36+
import io.temporal.api.update.v1.*;
4137
import io.temporal.common.converter.DataConverter;
4238
import io.temporal.common.converter.DefaultDataConverter;
4339
import io.temporal.internal.common.ProtobufTimeUtils;
@@ -84,7 +80,7 @@ public static void generateCoverage() {
8480
}
8581

8682
@Test
87-
public void testUpdateAccept() {
83+
public void testUpdateAccept() throws InvalidProtocolBufferException {
8884
class TestUpdateListener extends TestEntityManagerListenerBase {
8985

9086
@Override
@@ -173,8 +169,31 @@ protected void update(UpdateMessage message, AsyncWorkflowBuilder<Void> builder)
173169
{
174170
TestEntityManagerListenerBase listener = new TestUpdateListener();
175171
stateMachines = newStateMachines(listener);
176-
List<Command> commands = h.handleWorkflowTaskTakeCommands(stateMachines, 0);
177-
assertEquals(0, commands.size());
172+
Request request =
173+
Request.newBuilder()
174+
.setInput(
175+
Input.newBuilder()
176+
.setName("updateName")
177+
.setArgs(converter.toPayloads("arg").get()))
178+
.build();
179+
stateMachines.setMessages(
180+
Collections.unmodifiableList(
181+
Arrays.asList(
182+
new Message[] {
183+
Message.newBuilder()
184+
.setProtocolInstanceId("protocol_id")
185+
.setId("id")
186+
.setEventId(0)
187+
.setBody(Any.pack(request))
188+
.build(),
189+
})));
190+
List<Command> commands = h.handleWorkflowTaskTakeCommands(stateMachines, 1);
191+
assertEquals(2, commands.size());
192+
List<Message> messages = stateMachines.takeMessages();
193+
assertEquals(1, messages.size());
194+
Acceptance acceptance = messages.get(0).getBody().unpack(Acceptance.class);
195+
assertNotNull(acceptance);
196+
assertEquals(request, acceptance.getAcceptedRequest());
178197
}
179198
{
180199
TestEntityManagerListenerBase listener = new TestUpdateListener();
@@ -369,7 +388,7 @@ protected void update(UpdateMessage message, AsyncWorkflowBuilder<Void> builder)
369388
}
370389

371390
@Test
372-
public void testUpdateRejected() {
391+
public void testUpdateRejected() throws InvalidProtocolBufferException {
373392
class TestUpdateListener extends TestEntityManagerListenerBase {
374393

375394
@Override
@@ -404,14 +423,13 @@ protected void update(UpdateMessage message, AsyncWorkflowBuilder<Void> builder)
404423
// Full replay
405424
TestEntityManagerListenerBase listener = new TestUpdateListener();
406425
stateMachines = newStateMachines(listener);
407-
Any messageBody =
408-
Any.pack(
409-
Request.newBuilder()
410-
.setInput(
411-
Input.newBuilder()
412-
.setName("updateName")
413-
.setArgs(converter.toPayloads("arg").get()))
414-
.build());
426+
Request request =
427+
Request.newBuilder()
428+
.setInput(
429+
Input.newBuilder()
430+
.setName("updateName")
431+
.setArgs(converter.toPayloads("arg").get()))
432+
.build();
415433
stateMachines.setMessages(
416434
Collections.unmodifiableList(
417435
Arrays.asList(
@@ -420,11 +438,16 @@ protected void update(UpdateMessage message, AsyncWorkflowBuilder<Void> builder)
420438
.setProtocolInstanceId("protocol_id")
421439
.setId("id")
422440
.setEventId(0)
423-
.setBody(messageBody)
441+
.setBody(Any.pack(request))
424442
.build(),
425443
})));
426444
List<Command> commands = h.handleWorkflowTaskTakeCommands(stateMachines, 1);
427445
assertEquals(0, commands.size());
446+
List<Message> messages = stateMachines.takeMessages();
447+
assertEquals(1, messages.size());
448+
Rejection rejection = messages.get(0).getBody().unpack(Rejection.class);
449+
assertNotNull(rejection);
450+
assertEquals(request, rejection.getRejectedRequest());
428451
}
429452
}
430453

0 commit comments

Comments
 (0)