Skip to content

Commit 8a74cae

Browse files
Add ctx.stateKeys() (#219)
1 parent dd79547 commit 8a74cae

File tree

16 files changed

+232
-5
lines changed

16 files changed

+232
-5
lines changed

sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,25 @@ internal class ContextImpl internal constructor(private val syscalls: Syscalls)
4747
return key.serde().deserializeWrappingException(syscalls, readyResult.value!!)!!
4848
}
4949

50+
override suspend fun stateKeys(): Collection<String> {
51+
val deferred: Deferred<Collection<String>> =
52+
suspendCancellableCoroutine { cont: CancellableContinuation<Deferred<Collection<String>>> ->
53+
syscalls.getKeys(completingContinuation(cont))
54+
}
55+
56+
if (!deferred.isCompleted) {
57+
suspendCancellableCoroutine { cont: CancellableContinuation<Unit> ->
58+
syscalls.resolveDeferred(deferred, completingUnitContinuation(cont))
59+
}
60+
}
61+
62+
val readyResult = deferred.toResult()!!
63+
if (!readyResult.isSuccess) {
64+
throw readyResult.failure!!
65+
}
66+
return readyResult.value!!
67+
}
68+
5069
override suspend fun <T : Any> set(key: StateKey<T>, value: T) {
5170
val serializedValue = key.serde().serializeWrappingException(syscalls, value)!!
5271
return suspendCancellableCoroutine { cont: CancellableContinuation<Unit> ->

sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,13 @@ sealed interface KeyedContext : UnkeyedContext {
222222
*/
223223
suspend fun <T : Any> get(key: StateKey<T>): T?
224224

225+
/**
226+
* Gets all the known state keys for this service instance.
227+
*
228+
* @return the immutable collection of known state keys.
229+
*/
230+
suspend fun stateKeys(): Collection<String>
231+
225232
/**
226233
* Sets the given value under the given key, serializing the value using the registered
227234
* [dev.restate.sdk.core.serde.Serde] in the interceptor.

sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/EagerStateTest.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,4 +90,14 @@ class EagerStateTest : EagerStateTestSuite() {
9090
override fun getClearAllAndGet(): BindableService {
9191
return GetClearAllAndGet()
9292
}
93+
94+
private class ListKeys : GreeterRestateKt.GreeterRestateKtImplBase() {
95+
override suspend fun greet(context: KeyedContext, request: GreetingRequest): GreetingResponse {
96+
return greetingResponse { message = context.stateKeys().joinToString(separator = ",") }
97+
}
98+
}
99+
100+
override fun listKeys(): BindableService {
101+
return ListKeys()
102+
}
93103
}

sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import dev.restate.sdk.common.syscalls.Syscalls;
1818
import io.grpc.MethodDescriptor;
1919
import java.time.Duration;
20+
import java.util.Collection;
2021
import java.util.Map;
2122
import java.util.Optional;
2223
import java.util.concurrent.CompletableFuture;
@@ -43,6 +44,17 @@ public <T> Optional<T> get(StateKey<T> key) {
4344
.map(bs -> Util.deserializeWrappingException(syscalls, key.serde(), bs));
4445
}
4546

47+
@Override
48+
public Collection<String> stateKeys() {
49+
Deferred<Collection<String>> deferred = Util.blockOnSyscall(syscalls::getKeys);
50+
51+
if (!deferred.isCompleted()) {
52+
Util.<Void>blockOnSyscall(cb -> syscalls.resolveDeferred(deferred, cb));
53+
}
54+
55+
return Util.unwrapResult(deferred.toResult());
56+
}
57+
4658
@Override
4759
public void clear(StateKey<?> key) {
4860
Util.<Void>blockOnSyscall(cb -> syscalls.clear(key.name(), cb));

sdk-api/src/main/java/dev/restate/sdk/KeyedContext.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import dev.restate.sdk.common.*;
1212
import dev.restate.sdk.common.syscalls.Syscalls;
13+
import java.util.Collection;
1314
import java.util.Optional;
1415
import javax.annotation.Nonnull;
1516
import javax.annotation.concurrent.NotThreadSafe;
@@ -34,6 +35,13 @@ public interface KeyedContext extends UnkeyedContext {
3435
*/
3536
<T> Optional<T> get(StateKey<T> key);
3637

38+
/**
39+
* Gets all the known state keys for this service instance.
40+
*
41+
* @return the immutable collection of known state keys.
42+
*/
43+
Collection<String> stateKeys();
44+
3745
/**
3846
* Clears the state stored under key.
3947
*

sdk-api/src/test/java/dev/restate/sdk/EagerStateTest.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212

1313
import dev.restate.sdk.common.CoreSerdes;
1414
import dev.restate.sdk.common.StateKey;
15+
import dev.restate.sdk.common.TerminalException;
1516
import dev.restate.sdk.core.EagerStateTestSuite;
1617
import dev.restate.sdk.core.testservices.GreeterGrpc;
18+
import dev.restate.sdk.core.testservices.GreeterRestate;
1719
import dev.restate.sdk.core.testservices.GreetingRequest;
1820
import dev.restate.sdk.core.testservices.GreetingResponse;
1921
import io.grpc.BindableService;
@@ -119,4 +121,19 @@ public void greet(GreetingRequest request, StreamObserver<GreetingResponse> resp
119121
protected BindableService getClearAllAndGet() {
120122
return new GetClearAllAndGet();
121123
}
124+
125+
private static class ListKeys extends GreeterRestate.GreeterRestateImplBase {
126+
@Override
127+
public GreetingResponse greet(KeyedContext context, GreetingRequest request)
128+
throws TerminalException {
129+
return GreetingResponse.newBuilder()
130+
.setMessage(String.join(",", context.stateKeys()))
131+
.build();
132+
}
133+
}
134+
135+
@Override
136+
protected BindableService listKeys() {
137+
return new ListKeys();
138+
}
122139
}

sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Syscalls.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import io.grpc.Context;
1515
import io.grpc.MethodDescriptor;
1616
import java.time.Duration;
17+
import java.util.*;
1718
import java.util.List;
1819
import java.util.Map;
1920
import java.util.Objects;
@@ -61,6 +62,8 @@ static Syscalls current() {
6162

6263
void get(String name, SyscallCallback<Deferred<ByteString>> callback);
6364

65+
void getKeys(SyscallCallback<Deferred<Collection<String>>> callback);
66+
6467
void clear(String name, SyscallCallback<Void> callback);
6568

6669
void clearAll(SyscallCallback<Void> callback);

sdk-core/src/main/java/dev/restate/sdk/core/Entries.java

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,16 @@
1010

1111
import com.google.protobuf.ByteString;
1212
import com.google.protobuf.Empty;
13+
import com.google.protobuf.InvalidProtocolBufferException;
1314
import com.google.protobuf.MessageLite;
1415
import dev.restate.generated.service.protocol.Protocol;
1516
import dev.restate.generated.service.protocol.Protocol.*;
1617
import dev.restate.sdk.common.syscalls.Result;
1718
import io.opentelemetry.api.common.Attributes;
1819
import io.opentelemetry.api.trace.Span;
20+
import java.util.Collection;
1921
import java.util.function.Function;
22+
import java.util.stream.Collectors;
2023

2124
final class Entries {
2225
static final String AWAKEABLE_IDENTIFIER_PREFIX = "prom_1";
@@ -183,6 +186,78 @@ void updateUserStateStorageWithCompletion(
183186
}
184187
}
185188

189+
static final class GetStateKeysEntry
190+
extends CompletableJournalEntry<GetStateKeysEntryMessage, Collection<String>> {
191+
192+
static final GetStateKeysEntry INSTANCE = new GetStateKeysEntry();
193+
194+
private GetStateKeysEntry() {}
195+
196+
@Override
197+
void trace(GetStateKeysEntryMessage expected, Span span) {
198+
span.addEvent("GetStateKeys");
199+
}
200+
201+
@Override
202+
public boolean hasResult(GetStateKeysEntryMessage actual) {
203+
return actual.getResultCase() != GetStateKeysEntryMessage.ResultCase.RESULT_NOT_SET;
204+
}
205+
206+
@Override
207+
void checkEntryHeader(GetStateKeysEntryMessage expected, MessageLite actual)
208+
throws ProtocolException {
209+
if (!(actual instanceof GetStateKeysEntryMessage)) {
210+
throw ProtocolException.entryDoesNotMatch(expected, actual);
211+
}
212+
}
213+
214+
@Override
215+
public Result<Collection<String>> parseEntryResult(GetStateKeysEntryMessage actual) {
216+
if (actual.getResultCase() == GetStateKeysEntryMessage.ResultCase.VALUE) {
217+
return Result.success(
218+
actual.getValue().getKeysList().stream()
219+
.map(ByteString::toStringUtf8)
220+
.collect(Collectors.toUnmodifiableList()));
221+
} else if (actual.getResultCase() == GetStateKeysEntryMessage.ResultCase.FAILURE) {
222+
return Result.failure(Util.toRestateException(actual.getFailure()));
223+
} else {
224+
throw new IllegalStateException("GetStateKeysEntryMessage has not been completed.");
225+
}
226+
}
227+
228+
@Override
229+
public Result<Collection<String>> parseCompletionResult(CompletionMessage actual) {
230+
if (actual.getResultCase() == CompletionMessage.ResultCase.VALUE) {
231+
GetStateKeysEntryMessage.StateKeys stateKeys;
232+
try {
233+
stateKeys = GetStateKeysEntryMessage.StateKeys.parseFrom(actual.getValue());
234+
} catch (InvalidProtocolBufferException e) {
235+
throw new ProtocolException(
236+
"Cannot parse get state keys completion", e, ProtocolException.PROTOCOL_VIOLATION);
237+
}
238+
return Result.success(
239+
stateKeys.getKeysList().stream()
240+
.map(ByteString::toStringUtf8)
241+
.collect(Collectors.toUnmodifiableList()));
242+
} else if (actual.getResultCase() == CompletionMessage.ResultCase.FAILURE) {
243+
return Result.failure(Util.toRestateException(actual.getFailure()));
244+
}
245+
return super.parseCompletionResult(actual);
246+
}
247+
248+
@Override
249+
GetStateKeysEntryMessage tryCompleteWithUserStateStorage(
250+
GetStateKeysEntryMessage expected, UserStateStore userStateStore) {
251+
if (userStateStore.isComplete()) {
252+
return expected.toBuilder()
253+
.setValue(
254+
GetStateKeysEntryMessage.StateKeys.newBuilder().addAllKeys(userStateStore.keys()))
255+
.build();
256+
}
257+
return expected;
258+
}
259+
}
260+
186261
static final class ClearStateEntry extends JournalEntry<ClearStateEntryMessage> {
187262

188263
static final ClearStateEntry INSTANCE = new ClearStateEntry();

sdk-core/src/main/java/dev/restate/sdk/core/ExecutorSwitchingSyscalls.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import dev.restate.sdk.common.syscalls.SyscallCallback;
1818
import io.grpc.MethodDescriptor;
1919
import java.time.Duration;
20+
import java.util.Collection;
2021
import java.util.Map;
2122
import java.util.concurrent.Executor;
2223

@@ -50,6 +51,11 @@ public void get(String name, SyscallCallback<Deferred<ByteString>> callback) {
5051
syscallsExecutor.execute(() -> syscalls.get(name, callback));
5152
}
5253

54+
@Override
55+
public void getKeys(SyscallCallback<Deferred<Collection<String>>> callback) {
56+
syscallsExecutor.execute(() -> syscalls.getKeys(callback));
57+
}
58+
5359
@Override
5460
public void clear(String name, SyscallCallback<Void> callback) {
5561
syscallsExecutor.execute(() -> syscalls.clear(name, callback));

sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,14 @@ public static MessageHeader fromMessage(MessageLite msg) {
8787
return new MessageHeader(MessageType.ClearStateEntryMessage, 0, msg.getSerializedSize());
8888
} else if (msg instanceof Protocol.ClearAllStateEntryMessage) {
8989
return new MessageHeader(MessageType.ClearAllStateEntryMessage, 0, msg.getSerializedSize());
90+
} else if (msg instanceof Protocol.GetStateKeysEntryMessage) {
91+
return new MessageHeader(
92+
MessageType.GetStateKeysEntryMessage,
93+
((Protocol.GetStateKeysEntryMessage) msg).getResultCase()
94+
!= Protocol.GetStateKeysEntryMessage.ResultCase.RESULT_NOT_SET
95+
? DONE_FLAG
96+
: 0,
97+
msg.getSerializedSize());
9098
} else if (msg instanceof Protocol.SleepEntryMessage) {
9199
return new MessageHeader(
92100
MessageType.SleepEntryMessage,

0 commit comments

Comments
 (0)