Skip to content

Commit d7ff874

Browse files
Add on closed invocation stream hook to make sure we correctly shutdown pending coroutine context in case of a failure. (#466)
1 parent 624fd2f commit d7ff874

File tree

5 files changed

+84
-41
lines changed

5 files changed

+84
-41
lines changed

sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/HandlerRunner.kt

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import dev.restate.serde.Serde
1515
import dev.restate.serde.SerdeFactory
1616
import io.opentelemetry.extension.kotlin.asContextElement
1717
import java.util.concurrent.CompletableFuture
18+
import java.util.concurrent.atomic.AtomicReference
1819
import kotlin.coroutines.CoroutineContext
1920
import kotlinx.coroutines.CoroutineScope
2021
import kotlinx.coroutines.Dispatchers
@@ -98,6 +99,7 @@ internal constructor(
9899
handlerContext: HandlerContext,
99100
requestSerde: Serde<REQ>,
100101
responseSerde: Serde<RES>,
102+
onClosedInvocationStreamHook: AtomicReference<Runnable>
101103
): CompletableFuture<Slice> {
102104
val ctx: Context = ContextImpl(handlerContext, contextSerdeFactory)
103105

@@ -109,42 +111,44 @@ internal constructor(
109111
handlerContext.request().otelContext()!!.asContextElement())
110112

111113
val completableFuture = CompletableFuture<Slice>()
112-
113-
scope.launch {
114-
val serializedResult: Slice
115-
116-
try {
117-
// Parse input
118-
val req: REQ
119-
try {
120-
req = requestSerde.deserialize(handlerContext.request().body)
121-
} catch (e: Throwable) {
122-
LOG.warn("Error deserializing request", e)
123-
completableFuture.completeExceptionally(
124-
throw TerminalException(
125-
TerminalException.BAD_REQUEST_CODE, "Cannot deserialize request: " + e.message))
126-
return@launch
127-
}
128-
129-
// Execute user code
130-
@Suppress("UNCHECKED_CAST") val res: RES = runner(ctx as CTX, req)
131-
132-
// Serialize output
133-
try {
134-
serializedResult = responseSerde.serialize(res)
135-
} catch (e: Throwable) {
136-
LOG.warn("Error when serializing response", e)
137-
completableFuture.completeExceptionally(e)
138-
return@launch
114+
val job =
115+
scope.launch {
116+
val serializedResult: Slice
117+
118+
try {
119+
// Parse input
120+
val req: REQ
121+
try {
122+
req = requestSerde.deserialize(handlerContext.request().body)
123+
} catch (e: Throwable) {
124+
LOG.warn("Error deserializing request", e)
125+
completableFuture.completeExceptionally(
126+
throw TerminalException(
127+
TerminalException.BAD_REQUEST_CODE,
128+
"Cannot deserialize request: " + e.message))
129+
return@launch
130+
}
131+
132+
// Execute user code
133+
@Suppress("UNCHECKED_CAST") val res: RES = runner(ctx as CTX, req)
134+
135+
// Serialize output
136+
try {
137+
serializedResult = responseSerde.serialize(res)
138+
} catch (e: Throwable) {
139+
LOG.warn("Error when serializing response", e)
140+
completableFuture.completeExceptionally(e)
141+
return@launch
142+
}
143+
} catch (e: Throwable) {
144+
completableFuture.completeExceptionally(e)
145+
return@launch
146+
}
147+
148+
// Complete callback
149+
completableFuture.complete(serializedResult)
139150
}
140-
} catch (e: Throwable) {
141-
completableFuture.completeExceptionally(e)
142-
return@launch
143-
}
144-
145-
// Complete callback
146-
completableFuture.complete(serializedResult)
147-
}
151+
onClosedInvocationStreamHook.set { job.cancel() }
148152

149153
return completableFuture
150154
}

sdk-api/src/main/java/dev/restate/sdk/HandlerRunner.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.util.concurrent.CompletableFuture;
2222
import java.util.concurrent.Executor;
2323
import java.util.concurrent.Executors;
24+
import java.util.concurrent.atomic.AtomicReference;
2425
import org.apache.logging.log4j.LogManager;
2526
import org.apache.logging.log4j.Logger;
2627
import org.jspecify.annotations.Nullable;
@@ -48,7 +49,10 @@ public class HandlerRunner<REQ, RES>
4849

4950
@Override
5051
public CompletableFuture<Slice> run(
51-
HandlerContext handlerContext, Serde<REQ> requestSerde, Serde<RES> responseSerde) {
52+
HandlerContext handlerContext,
53+
Serde<REQ> requestSerde,
54+
Serde<RES> responseSerde,
55+
AtomicReference<Runnable> onClosedInvocationStreamHook) {
5256
CompletableFuture<Slice> returnFuture = new CompletableFuture<>();
5357

5458
// Wrap the executor for setting/unsetting the thread local

sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerContext.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@ public interface HandlerContext {
4848

4949
CompletableFuture<Void> set(String name, Slice value);
5050

51-
// ----- Syscalls
52-
5351
CompletableFuture<AsyncResult<Void>> timer(Duration duration, @Nullable String name);
5452

5553
record CallResult(

sdk-common/src/main/java/dev/restate/sdk/endpoint/definition/HandlerRunner.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import dev.restate.common.Slice;
1212
import dev.restate.serde.Serde;
1313
import java.util.concurrent.CompletableFuture;
14+
import java.util.concurrent.atomic.AtomicReference;
1415

1516
public interface HandlerRunner<REQ, RES> {
1617
/**
@@ -27,5 +28,8 @@ public interface HandlerRunner<REQ, RES> {
2728
interface Options {}
2829

2930
CompletableFuture<Slice> run(
30-
HandlerContext handlerContext, Serde<REQ> requestSerde, Serde<RES> responseSerde);
31+
HandlerContext handlerContext,
32+
Serde<REQ> requestSerde,
33+
Serde<RES> responseSerde,
34+
AtomicReference<Runnable> onClosedInvocationStreamHook);
3135
}

sdk-core/src/main/java/dev/restate/sdk/core/RequestProcessorImpl.java

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import java.util.concurrent.CompletableFuture;
1717
import java.util.concurrent.Executor;
1818
import java.util.concurrent.Flow;
19+
import java.util.concurrent.atomic.AtomicReference;
1920
import org.apache.logging.log4j.LogManager;
2021
import org.apache.logging.log4j.Logger;
2122
import org.jspecify.annotations.Nullable;
@@ -30,6 +31,7 @@ final class RequestProcessorImpl implements RequestProcessor {
3031
private final Context otelContext;
3132
private final EndpointRequestHandler.LoggingContextSetter loggingContextSetter;
3233
private final Executor syscallsExecutor;
34+
private final AtomicReference<Runnable> onHandlerTaskCancellation;
3335

3436
@SuppressWarnings("unchecked")
3537
public RequestProcessorImpl(
@@ -45,14 +47,44 @@ public RequestProcessorImpl(
4547
this.loggingContextSetter = loggingContextSetter;
4648
this.handlerDefinition = (HandlerDefinition<Object, Object>) handlerDefinition;
4749
this.syscallsExecutor = syscallExecutor;
50+
this.onHandlerTaskCancellation = new AtomicReference<>();
4851
}
4952

5053
// Flow methods implementation
5154

5255
@Override
5356
public void subscribe(Flow.Subscriber<? super Slice> subscriber) {
5457
LOG.trace("Start processing invocation");
55-
this.stateMachine.subscribe(subscriber);
58+
this.stateMachine.subscribe(
59+
new Flow.Subscriber<>() {
60+
@Override
61+
public void onSubscribe(Flow.Subscription subscription) {
62+
subscriber.onSubscribe(subscription);
63+
}
64+
65+
@Override
66+
public void onNext(Slice slice) {
67+
subscriber.onNext(slice);
68+
}
69+
70+
@Override
71+
public void onError(Throwable throwable) {
72+
Runnable cancelTask = onHandlerTaskCancellation.get();
73+
if (cancelTask != null) {
74+
cancelTask.run();
75+
}
76+
subscriber.onError(throwable);
77+
}
78+
79+
@Override
80+
public void onComplete() {
81+
Runnable cancelTask = onHandlerTaskCancellation.get();
82+
if (cancelTask != null) {
83+
cancelTask.run();
84+
}
85+
subscriber.onComplete();
86+
}
87+
});
5688
stateMachine
5789
.waitForReady()
5890
.thenCompose(v -> this.onReady())
@@ -119,7 +151,8 @@ private CompletableFuture<Void> onReady() {
119151
.run(
120152
contextInternal,
121153
handlerDefinition.getRequestSerde(),
122-
handlerDefinition.getResponseSerde());
154+
handlerDefinition.getResponseSerde(),
155+
onHandlerTaskCancellation);
123156

124157
return userCodeFuture.handle(
125158
(slice, t) -> {

0 commit comments

Comments
 (0)