Skip to content

Commit 7aabbbb

Browse files
Update the entry ack for side effects to use the new mechanism
1 parent 2ebdbe0 commit 7aabbbb

File tree

5 files changed

+56
-28
lines changed

5 files changed

+56
-28
lines changed

sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -135,15 +135,10 @@ public void onNext(InvocationFlow.InvocationInput invocationInput) {
135135
// We check the instance rather than the state, because the user code might still be
136136
// replaying, but the network layer is already past it and is receiving completions from the
137137
// runtime.
138-
Protocol.CompletionMessage completionMessage = (Protocol.CompletionMessage) msg;
139-
140-
// If ack, give it to side effect publisher
141-
if (completionMessage.getResultCase()
142-
== Protocol.CompletionMessage.ResultCase.RESULT_NOT_SET) {
143-
this.sideEffectAckStateMachine.tryHandleSideEffectAck(completionMessage.getEntryIndex());
144-
} else {
145-
this.readyResultStateMachine.offerCompletion((Protocol.CompletionMessage) msg);
146-
}
138+
this.readyResultStateMachine.offerCompletion((Protocol.CompletionMessage) msg);
139+
} else if (msg instanceof Protocol.EntryAckMessage) {
140+
this.sideEffectAckStateMachine.tryHandleSideEffectAck(
141+
((Protocol.EntryAckMessage) msg).getEntryIndex());
147142
} else {
148143
this.incomingEntriesStateMachine.offer(msg);
149144
}

sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414

1515
public class MessageHeader {
1616

17-
private static final short DONE_FLAG = 0x0001;
18-
private static final short REQUIRES_ACK_FLAG = 0x0001;
17+
static final short DONE_FLAG = 0x0001;
18+
static final int REQUIRES_ACK_FLAG = 0x8000;
1919

2020
private final MessageType type;
21-
private final short flags;
21+
private final int flags;
2222
private final int length;
2323

24-
public MessageHeader(MessageType type, short flags, int length) {
24+
public MessageHeader(MessageType type, int flags, int length) {
2525
this.type = type;
2626
this.flags = flags;
2727
this.length = length;
@@ -57,15 +57,15 @@ public static MessageHeader parse(long encoded) throws ProtocolException {
5757

5858
public static MessageHeader fromMessage(MessageLite msg) {
5959
if (msg instanceof Protocol.SuspensionMessage) {
60-
return new MessageHeader(MessageType.SuspensionMessage, (short) 0, msg.getSerializedSize());
60+
return new MessageHeader(MessageType.SuspensionMessage, 0, msg.getSerializedSize());
6161
} else if (msg instanceof Protocol.ErrorMessage) {
62-
return new MessageHeader(MessageType.ErrorMessage, (short) 0, msg.getSerializedSize());
62+
return new MessageHeader(MessageType.ErrorMessage, 0, msg.getSerializedSize());
63+
} else if (msg instanceof Protocol.EntryAckMessage) {
64+
return new MessageHeader(MessageType.EntryAckMessage, 0, msg.getSerializedSize());
6365
} else if (msg instanceof Protocol.PollInputStreamEntryMessage) {
64-
return new MessageHeader(
65-
MessageType.PollInputStreamEntryMessage, (short) 0, msg.getSerializedSize());
66+
return new MessageHeader(MessageType.PollInputStreamEntryMessage, 0, msg.getSerializedSize());
6667
} else if (msg instanceof Protocol.OutputStreamEntryMessage) {
67-
return new MessageHeader(
68-
MessageType.OutputStreamEntryMessage, (short) 0, msg.getSerializedSize());
68+
return new MessageHeader(MessageType.OutputStreamEntryMessage, 0, msg.getSerializedSize());
6969
} else if (msg instanceof Protocol.GetStateEntryMessage) {
7070
return new MessageHeader(
7171
MessageType.GetStateEntryMessage,
@@ -75,11 +75,9 @@ public static MessageHeader fromMessage(MessageLite msg) {
7575
: 0,
7676
msg.getSerializedSize());
7777
} else if (msg instanceof Protocol.SetStateEntryMessage) {
78-
return new MessageHeader(
79-
MessageType.SetStateEntryMessage, (short) 0, msg.getSerializedSize());
78+
return new MessageHeader(MessageType.SetStateEntryMessage, 0, msg.getSerializedSize());
8079
} else if (msg instanceof Protocol.ClearStateEntryMessage) {
81-
return new MessageHeader(
82-
MessageType.ClearStateEntryMessage, (short) 0, msg.getSerializedSize());
80+
return new MessageHeader(MessageType.ClearStateEntryMessage, 0, msg.getSerializedSize());
8381
} else if (msg instanceof Protocol.SleepEntryMessage) {
8482
return new MessageHeader(
8583
MessageType.SleepEntryMessage,
@@ -95,7 +93,7 @@ public static MessageHeader fromMessage(MessageLite msg) {
9593
msg.getSerializedSize());
9694
} else if (msg instanceof Protocol.BackgroundInvokeEntryMessage) {
9795
return new MessageHeader(
98-
MessageType.BackgroundInvokeEntryMessage, (short) 0, msg.getSerializedSize());
96+
MessageType.BackgroundInvokeEntryMessage, 0, msg.getSerializedSize());
9997
} else if (msg instanceof Protocol.AwakeableEntryMessage) {
10098
return new MessageHeader(
10199
MessageType.AwakeableEntryMessage,
@@ -106,10 +104,10 @@ public static MessageHeader fromMessage(MessageLite msg) {
106104
msg.getSerializedSize());
107105
} else if (msg instanceof Protocol.CompleteAwakeableEntryMessage) {
108106
return new MessageHeader(
109-
MessageType.CompleteAwakeableEntryMessage, (short) 0, msg.getSerializedSize());
107+
MessageType.CompleteAwakeableEntryMessage, 0, msg.getSerializedSize());
110108
} else if (msg instanceof Java.CombinatorAwaitableEntryMessage) {
111109
return new MessageHeader(
112-
MessageType.CombinatorAwaitableEntryMessage, (short) 0, msg.getSerializedSize());
110+
MessageType.CombinatorAwaitableEntryMessage, 0, msg.getSerializedSize());
113111
} else if (msg instanceof Java.SideEffectEntryMessage) {
114112
return new MessageHeader(
115113
MessageType.SideEffectEntryMessage, REQUIRES_ACK_FLAG, msg.getSerializedSize());

sdk-core/src/main/java/dev/restate/sdk/core/MessageType.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ public enum MessageType {
1818
CompletionMessage,
1919
SuspensionMessage,
2020
ErrorMessage,
21+
EntryAckMessage,
2122

2223
// IO
2324
PollInputStreamEntryMessage,
@@ -43,6 +44,7 @@ public enum MessageType {
4344
public static final short COMPLETION_MESSAGE_TYPE = 0x0001;
4445
public static final short SUSPENSION_MESSAGE_TYPE = 0x0002;
4546
public static final short ERROR_MESSAGE_TYPE = 0x0003;
47+
public static final short ENTRY_ACK_MESSAGE_TYPE = 0x0004;
4648
public static final short POLL_INPUT_STREAM_ENTRY_MESSAGE_TYPE = 0x0400;
4749
public static final short OUTPUT_STREAM_ENTRY_MESSAGE_TYPE = 0x0401;
4850
public static final short GET_STATE_ENTRY_MESSAGE_TYPE = 0x0800;
@@ -66,6 +68,8 @@ public Parser<? extends MessageLite> messageParser() {
6668
return Protocol.SuspensionMessage.parser();
6769
case ErrorMessage:
6870
return Protocol.ErrorMessage.parser();
71+
case EntryAckMessage:
72+
return Protocol.EntryAckMessage.parser();
6973
case PollInputStreamEntryMessage:
7074
return Protocol.PollInputStreamEntryMessage.parser();
7175
case OutputStreamEntryMessage:
@@ -104,6 +108,8 @@ public short encode() {
104108
return SUSPENSION_MESSAGE_TYPE;
105109
case ErrorMessage:
106110
return ERROR_MESSAGE_TYPE;
111+
case EntryAckMessage:
112+
return ENTRY_ACK_MESSAGE_TYPE;
107113
case PollInputStreamEntryMessage:
108114
return POLL_INPUT_STREAM_ENTRY_MESSAGE_TYPE;
109115
case OutputStreamEntryMessage:
@@ -142,6 +148,8 @@ public static MessageType decode(short value) throws ProtocolException {
142148
return SuspensionMessage;
143149
case ERROR_MESSAGE_TYPE:
144150
return ErrorMessage;
151+
case ENTRY_ACK_MESSAGE_TYPE:
152+
return EntryAckMessage;
145153
case POLL_INPUT_STREAM_ENTRY_MESSAGE_TYPE:
146154
return PollInputStreamEntryMessage;
147155
case OUTPUT_STREAM_ENTRY_MESSAGE_TYPE:
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH
2+
//
3+
// This file is part of the Restate Java SDK,
4+
// which is released under the MIT license.
5+
//
6+
// You can find a copy of the license in file LICENSE in the root
7+
// directory of this repository or package, or at
8+
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
9+
package dev.restate.sdk.core;
10+
11+
import static org.assertj.core.api.Assertions.assertThat;
12+
13+
import org.junit.jupiter.api.Test;
14+
15+
public class MessageHeaderTest {
16+
17+
@Test
18+
void requiresAckFlag() {
19+
assertThat(
20+
new MessageHeader(
21+
MessageType.InvokeEntryMessage,
22+
MessageHeader.DONE_FLAG | MessageHeader.REQUIRES_ACK_FLAG,
23+
2)
24+
.encode())
25+
.isEqualTo(0x0C01_8001_0000_0002L);
26+
}
27+
}

sdk-core/src/test/java/dev/restate/sdk/core/ProtoUtils.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ public static Protocol.CompletionMessage completionMessage(int index, Throwable
8484
.build();
8585
}
8686

87-
public static Protocol.CompletionMessage ackMessage(int index) {
88-
return Protocol.CompletionMessage.newBuilder().setEntryIndex(index).build();
87+
public static Protocol.EntryAckMessage ackMessage(int index) {
88+
return Protocol.EntryAckMessage.newBuilder().setEntryIndex(index).build();
8989
}
9090

9191
public static Protocol.SuspensionMessage suspensionMessage(Integer... indexes) {

0 commit comments

Comments
 (0)