Skip to content

Commit 7b4dd3f

Browse files
Implement request identity (#297)
* Add RequestIdentityVerifier interface, to be used to implement request identity. * Add RestateRequestIdentityVerifier implementation. * Fix case with unsigned signature scheme. * Fix description of module * Better name for the factory method * Add test and remove prefixing with ASN1, this seems not needed.
1 parent 9718851 commit 7b4dd3f

File tree

16 files changed

+315
-11
lines changed

16 files changed

+315
-11
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH
2+
//
3+
// This file is part of the Restate Java SDK,
4+
// which is released under the MIT license.
5+
//
6+
// You can find a copy of the license in file LICENSE in the root
7+
// directory of this repository or package, or at
8+
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
9+
package dev.restate.sdk.auth;
10+
11+
import org.jspecify.annotations.Nullable;
12+
13+
/** Interface to verify requests. */
14+
public interface RequestIdentityVerifier {
15+
16+
/** Abstraction for headers map. */
17+
@FunctionalInterface
18+
interface Headers {
19+
@Nullable String get(String key);
20+
}
21+
22+
/**
23+
* @throws Exception if the request cannot be verified
24+
*/
25+
void verifyRequest(Headers headers) throws Exception;
26+
}

sdk-core/src/main/java/dev/restate/sdk/core/Entries.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ public Result<Collection<String>> parseCompletionResult(CompletionMessage actual
217217
} catch (InvalidProtocolBufferException e) {
218218
throw new ProtocolException(
219219
"Cannot parse get state keys completion",
220-
e,
221-
ProtocolException.PROTOCOL_VIOLATION_CODE);
220+
ProtocolException.PROTOCOL_VIOLATION_CODE,
221+
e);
222222
}
223223
return Result.success(
224224
stateKeys.getKeysList().stream()

sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,8 @@ void onStartMessage(MessageLite msg) {
209209
this.fail(
210210
new ProtocolException(
211211
"Expected at least one entry with Input, got " + this.entriesToReplay + " entries",
212-
null,
213-
TerminalException.INTERNAL_SERVER_ERROR_CODE));
212+
TerminalException.INTERNAL_SERVER_ERROR_CODE,
213+
null));
214214
return;
215215
}
216216

sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
public class ProtocolException extends RuntimeException {
1616

17+
static final int UNAUTHORIZED_CODE = 401;
1718
static final int NOT_FOUND_CODE = 404;
1819
static final int JOURNAL_MISMATCH_CODE = 570;
1920
static final int PROTOCOL_VIOLATION_CODE = 571;
@@ -28,10 +29,10 @@ private ProtocolException(String message) {
2829
}
2930

3031
private ProtocolException(String message, int code) {
31-
this(message, null, code);
32+
this(message, code, null);
3233
}
3334

34-
public ProtocolException(String message, Throwable cause, int code) {
35+
public ProtocolException(String message, int code, Throwable cause) {
3536
super(message, cause);
3637
this.code = code;
3738
}
@@ -77,7 +78,11 @@ static ProtocolException methodNotFound(String serviceName, String handlerName)
7778
static ProtocolException invalidSideEffectCall() {
7879
return new ProtocolException(
7980
"A syscall was invoked from within a side effect closure.",
80-
null,
81-
TerminalException.INTERNAL_SERVER_ERROR_CODE);
81+
TerminalException.INTERNAL_SERVER_ERROR_CODE,
82+
null);
83+
}
84+
85+
static ProtocolException unauthorized(Throwable e) {
86+
return new ProtocolException("Unauthorized", UNAUTHORIZED_CODE, e);
8287
}
8388
}

sdk-core/src/main/java/dev/restate/sdk/core/RestateEndpoint.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
99
package dev.restate.sdk.core;
1010

11+
import dev.restate.sdk.auth.RequestIdentityVerifier;
1112
import dev.restate.sdk.common.BindableServiceFactory;
1213
import dev.restate.sdk.common.syscalls.HandlerDefinition;
1314
import dev.restate.sdk.common.syscalls.ServiceDefinition;
@@ -32,14 +33,17 @@ public class RestateEndpoint {
3233

3334
private final Map<String, ServiceAndOptions<?>> services;
3435
private final Tracer tracer;
36+
private final RequestIdentityVerifier requestIdentityVerifier;
3537
private final DeploymentManifest deploymentManifest;
3638

3739
private RestateEndpoint(
3840
DeploymentManifestSchema.ProtocolMode protocolMode,
3941
Map<String, ServiceAndOptions<?>> services,
40-
Tracer tracer) {
42+
Tracer tracer,
43+
RequestIdentityVerifier requestIdentityVerifier) {
4144
this.services = services;
4245
this.tracer = tracer;
46+
this.requestIdentityVerifier = requestIdentityVerifier;
4347
this.deploymentManifest =
4448
new DeploymentManifest(protocolMode, services.values().stream().map(c -> c.service));
4549

@@ -49,6 +53,7 @@ private RestateEndpoint(
4953
public ResolvedEndpointHandler resolve(
5054
String componentName,
5155
String handlerName,
56+
RequestIdentityVerifier.Headers headers,
5257
io.opentelemetry.context.Context otelContext,
5358
LoggingContextSetter loggingContextSetter,
5459
@Nullable Executor syscallExecutor)
@@ -65,6 +70,15 @@ public ResolvedEndpointHandler resolve(
6570
throw ProtocolException.methodNotFound(componentName, handlerName);
6671
}
6772

73+
// Verify request
74+
if (requestIdentityVerifier != null) {
75+
try {
76+
requestIdentityVerifier.verifyRequest(headers);
77+
} catch (Exception e) {
78+
throw ProtocolException.unauthorized(e);
79+
}
80+
}
81+
6882
// Generate the span
6983
Span span =
7084
tracer
@@ -108,6 +122,7 @@ public static class Builder {
108122

109123
private final List<ServiceAndOptions<?>> services = new ArrayList<>();
110124
private final DeploymentManifestSchema.ProtocolMode protocolMode;
125+
private RequestIdentityVerifier requestIdentityVerifier;
111126
private Tracer tracer = OpenTelemetry.noop().getTracer("NOOP");
112127

113128
public Builder(DeploymentManifestSchema.ProtocolMode protocolMode) {
@@ -124,12 +139,18 @@ public Builder withTracer(Tracer tracer) {
124139
return this;
125140
}
126141

142+
public Builder withRequestIdentityVerifier(RequestIdentityVerifier requestIdentityVerifier) {
143+
this.requestIdentityVerifier = requestIdentityVerifier;
144+
return this;
145+
}
146+
127147
public RestateEndpoint build() {
128148
return new RestateEndpoint(
129149
this.protocolMode,
130150
this.services.stream()
131151
.collect(Collectors.toMap(c -> c.service.getServiceName(), Function.identity())),
132-
tracer);
152+
tracer,
153+
requestIdentityVerifier);
133154
}
134155
}
135156

sdk-core/src/test/java/dev/restate/sdk/core/MockMultiThreaded.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ public void executeTest(TestDefinitions.TestDefinition definition) {
5656
server.resolve(
5757
serviceDefinition.get(0).getServiceName(),
5858
definition.getMethod(),
59+
k -> null,
5960
io.opentelemetry.context.Context.current(),
6061
RestateEndpoint.LoggingContextSetter.THREAD_LOCAL_INSTANCE,
6162
syscallsExecutor);

sdk-core/src/test/java/dev/restate/sdk/core/MockSingleThread.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ public void executeTest(TestDefinition definition) {
5454
server.resolve(
5555
serviceDefinition.get(0).getServiceName(),
5656
definition.getMethod(),
57+
k -> null,
5758
io.opentelemetry.context.Context.current(),
5859
RestateEndpoint.LoggingContextSetter.THREAD_LOCAL_INSTANCE,
5960
null);

sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/RequestHttpServerHandler.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ public void handle(HttpServerRequest request) {
116116
restateEndpoint.resolve(
117117
serviceName,
118118
handlerName,
119+
request::getHeader,
119120
otelContext,
120121
ContextualData::put,
121122
currentContextExecutor(vertxCurrentContext));

sdk-http-vertx/src/main/java/dev/restate/sdk/http/vertx/RestateHttpEndpointBuilder.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
99
package dev.restate.sdk.http.vertx;
1010

11+
import dev.restate.sdk.auth.RequestIdentityVerifier;
1112
import dev.restate.sdk.common.BindableService;
1213
import dev.restate.sdk.common.syscalls.ServiceDefinition;
1314
import dev.restate.sdk.core.RestateEndpoint;
@@ -103,7 +104,7 @@ public <O> RestateHttpEndpointBuilder bind(BindableService<O> service, O options
103104
}
104105

105106
/**
106-
* Add a {@link OpenTelemetry} implementation for tracing and metrics.
107+
* Set the {@link OpenTelemetry} implementation for tracing and metrics.
107108
*
108109
* @see OpenTelemetry
109110
*/
@@ -112,6 +113,18 @@ public RestateHttpEndpointBuilder withOpenTelemetry(OpenTelemetry openTelemetry)
112113
return this;
113114
}
114115

116+
/**
117+
* Set the request identity verifier for this endpoint.
118+
*
119+
* <p>For the Restate implementation to use with Restate Cloud, check the module {@code
120+
* sdk-request-identity}.
121+
*/
122+
public RestateHttpEndpointBuilder withRequestIdentityVerifier(
123+
RequestIdentityVerifier requestIdentityVerifier) {
124+
this.endpointBuilder.withRequestIdentityVerifier(requestIdentityVerifier);
125+
return this;
126+
}
127+
115128
/** Build and listen on the specified port. */
116129
public void buildAndListen(int port) {
117130
build().listen(port).onComplete(RestateHttpEndpointBuilder::handleStart);

sdk-lambda/src/main/java/dev/restate/sdk/lambda/RestateLambdaEndpoint.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ private APIGatewayProxyResponseEvent handleInvoke(APIGatewayProxyRequestEvent in
122122
this.restateEndpoint.resolve(
123123
serviceName,
124124
handlerName,
125+
input.getHeaders()::get,
125126
otelContext,
126127
RestateEndpoint.LoggingContextSetter.THREAD_LOCAL_INSTANCE,
127128
null);

sdk-lambda/src/main/java/dev/restate/sdk/lambda/RestateLambdaEndpointBuilder.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
99
package dev.restate.sdk.lambda;
1010

11+
import dev.restate.sdk.auth.RequestIdentityVerifier;
1112
import dev.restate.sdk.common.BindableService;
1213
import dev.restate.sdk.common.syscalls.ServiceDefinition;
1314
import dev.restate.sdk.core.RestateEndpoint;
@@ -49,6 +50,18 @@ public RestateLambdaEndpointBuilder withOpenTelemetry(OpenTelemetry openTelemetr
4950
return this;
5051
}
5152

53+
/**
54+
* Set the request identity verifier for this endpoint.
55+
*
56+
* <p>For the Restate implementation to use with Restate Cloud, check the module {@code
57+
* sdk-request-identity}.
58+
*/
59+
public RestateLambdaEndpointBuilder withRequestIdentityVerifier(
60+
RequestIdentityVerifier requestIdentityVerifier) {
61+
this.restateEndpoint.withRequestIdentityVerifier(requestIdentityVerifier);
62+
return this;
63+
}
64+
5265
/** Build the {@link RestateLambdaEndpoint} serving the Restate service endpoint. */
5366
public RestateLambdaEndpoint build() {
5467
return new RestateLambdaEndpoint(this.restateEndpoint.build(), this.openTelemetry);

sdk-request-identity/build.gradle.kts

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
plugins {
2+
`java-library`
3+
`library-publishing-conventions`
4+
}
5+
6+
description = "Restate SDK request identity implementation"
7+
8+
dependencies {
9+
compileOnly(coreLibs.jspecify)
10+
11+
implementation(project(":sdk-common"))
12+
13+
// Dependencies for signing request tokens
14+
implementation(coreLibs.jwt)
15+
implementation(coreLibs.tink)
16+
17+
testImplementation(testingLibs.junit.jupiter)
18+
testImplementation(testingLibs.assertj)
19+
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH
2+
//
3+
// This file is part of the Restate Java SDK,
4+
// which is released under the MIT license.
5+
//
6+
// You can find a copy of the license in file LICENSE in the root
7+
// directory of this repository or package, or at
8+
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
9+
package dev.restate.sdk.auth.signing;
10+
11+
import java.util.Arrays;
12+
13+
// Copied and adapted from
14+
// https://github.com/bitcoinj/bitcoinj/blob/7df957e4c6817036c096283c5f0dcb7e4d60c982/core/src/main/java/org/bitcoinj/base/Base58.java#L50
15+
// License Apache 2.0
16+
// Copyright 2011 Google Inc.
17+
// Copyright 2018 Andreas Schildbach
18+
19+
class Base58 {
20+
public static final char[] ALPHABET =
21+
"123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz".toCharArray();
22+
private static final int[] INDEXES = new int[128];
23+
24+
static {
25+
Arrays.fill(INDEXES, -1);
26+
for (int i = 0; i < ALPHABET.length; i++) {
27+
INDEXES[ALPHABET[i]] = i;
28+
}
29+
}
30+
31+
/**
32+
* Decodes the given base58 string into the original data bytes.
33+
*
34+
* @param input the base58-encoded string to decode
35+
* @return the decoded data bytes
36+
*/
37+
public static byte[] decode(String input) {
38+
if (input.isEmpty()) {
39+
return new byte[0];
40+
}
41+
// Convert the base58-encoded ASCII chars to a base58 byte sequence (base58 digits).
42+
byte[] input58 = new byte[input.length()];
43+
for (int i = 0; i < input.length(); ++i) {
44+
char c = input.charAt(i);
45+
int digit = c < 128 ? INDEXES[c] : -1;
46+
if (digit < 0) {
47+
throw new IllegalArgumentException(
48+
String.format("Invalid character in Base58: 0x%04x", (int) c));
49+
}
50+
input58[i] = (byte) digit;
51+
}
52+
// Count leading zeros.
53+
int zeros = 0;
54+
while (zeros < input58.length && input58[zeros] == 0) {
55+
++zeros;
56+
}
57+
// Convert base-58 digits to base-256 digits.
58+
byte[] decoded = new byte[input.length()];
59+
int outputStart = decoded.length;
60+
for (int inputStart = zeros; inputStart < input58.length; ) {
61+
decoded[--outputStart] = divmod(input58, inputStart, 58, 256);
62+
if (input58[inputStart] == 0) {
63+
++inputStart; // optimization - skip leading zeros
64+
}
65+
}
66+
// Ignore extra leading zeroes that were added during the calculation.
67+
while (outputStart < decoded.length && decoded[outputStart] == 0) {
68+
++outputStart;
69+
}
70+
// Return decoded data (including original number of leading zeros).
71+
return Arrays.copyOfRange(decoded, outputStart - zeros, decoded.length);
72+
}
73+
74+
/**
75+
* Divides a number, represented as an array of bytes each containing a single digit in the
76+
* specified base, by the given divisor. The given number is modified in-place to contain the
77+
* quotient, and the return value is the remainder.
78+
*
79+
* @param number the number to divide
80+
* @param firstDigit the index within the array of the first non-zero digit (this is used for
81+
* optimization by skipping the leading zeros)
82+
* @param base the base in which the number's digits are represented (up to 256)
83+
* @param divisor the number to divide by (up to 256)
84+
* @return the remainder of the division operation
85+
*/
86+
private static byte divmod(byte[] number, int firstDigit, int base, int divisor) {
87+
// this is just long division which accounts for the base of the input digits
88+
int remainder = 0;
89+
for (int i = firstDigit; i < number.length; i++) {
90+
int digit = (int) number[i] & 0xFF;
91+
int temp = remainder * base + digit;
92+
number[i] = (byte) (temp / divisor);
93+
remainder = temp % divisor;
94+
}
95+
return (byte) remainder;
96+
}
97+
}

0 commit comments

Comments
 (0)