Skip to content

Commit 8be2693

Browse files
committed
feat: support transport negotiation between client and server
1 parent 89a003a commit 8be2693

File tree

7 files changed

+97
-30
lines changed

7 files changed

+97
-30
lines changed

packages/actor-core/src/client/client.ts

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import { importEventSource } from "@/common/eventsource";
1919
*/
2020
export interface ClientOptions {
2121
encoding?: Encoding;
22-
transport?: Transport;
22+
supportedTransports?: Transport[];
2323
}
2424

2525
/**
@@ -135,7 +135,7 @@ export class Client {
135135
#managerEndpointPromise: Promise<string>;
136136
//#regionPromise: Promise<Region | undefined>;
137137
#encodingKind: Encoding;
138-
#transportKind: Transport;
138+
#supportedTransports: Transport[];
139139

140140
// External imports
141141
#dynamicImportsPromise: Promise<DynamicImports>;
@@ -164,7 +164,7 @@ export class Client {
164164
//this.#regionPromise = this.#fetchRegion();
165165

166166
this.#encodingKind = opts?.encoding ?? "cbor";
167-
this.#transportKind = opts?.transport ?? "websocket";
167+
this.#supportedTransports = opts?.supportedTransports ?? ["websocket", "sse"];
168168

169169
// Import dynamic dependencies
170170
this.#dynamicImportsPromise = (async () => {
@@ -198,10 +198,14 @@ export class Client {
198198
getForId: {
199199
actorId,
200200
},
201-
},
201+
}
202202
});
203203

204-
const handle = await this.#createHandle(resJson.endpoint, opts?.parameters);
204+
const handle = await this.#createHandle(
205+
resJson.endpoint,
206+
opts?.parameters,
207+
resJson.supportedTransports
208+
);
205209
return this.#createProxy(handle) as ActorHandle<A>;
206210
}
207211

@@ -259,10 +263,14 @@ export class Client {
259263
tags,
260264
create,
261265
},
262-
},
266+
}
263267
});
264268

265-
const handle = await this.#createHandle(resJson.endpoint, opts?.parameters);
269+
const handle = await this.#createHandle(
270+
resJson.endpoint,
271+
opts?.parameters,
272+
resJson.supportedTransports
273+
);
266274
return this.#createProxy(handle) as ActorHandle<A>;
267275
}
268276

@@ -307,24 +315,30 @@ export class Client {
307315
>("POST", "/manager/actors", {
308316
query: {
309317
create,
310-
},
318+
}
311319
});
312320

313-
const handle = await this.#createHandle(resJson.endpoint, opts?.parameters);
321+
const handle = await this.#createHandle(
322+
resJson.endpoint,
323+
opts?.parameters,
324+
resJson.supportedTransports
325+
);
314326
return this.#createProxy(handle) as ActorHandle<A>;
315327
}
316328

317329
async #createHandle(
318330
endpoint: string,
319331
parameters: unknown,
332+
serverTransports: Transport[],
320333
): Promise<ActorHandleRaw> {
321334
const imports = await this.#dynamicImportsPromise;
322335

323336
const handle = new ActorHandleRaw(
324337
endpoint,
325338
parameters,
326339
this.#encodingKind,
327-
this.#transportKind,
340+
this.#supportedTransports,
341+
serverTransports,
328342
imports,
329343
);
330344
handle.__connect();

packages/actor-core/src/client/errors.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ export class MalformedResponseMessage extends ActorClientError {
2424
}
2525
}
2626

27+
export class NoSupportedTransport extends ActorClientError {
28+
constructor() {
29+
super("No supported transport available between client and server");
30+
}
31+
}
32+
2733
export class RpcError extends ActorClientError {
2834
constructor(
2935
public readonly code: string,

packages/actor-core/src/client/handle.ts

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ export class ActorHandleRaw {
9393
private readonly endpoint: string,
9494
private readonly parameters: unknown,
9595
private readonly encodingKind: Encoding,
96-
private readonly transportKind: Transport,
96+
private readonly supportedTransports: Transport[],
97+
private readonly serverTransports: Transport[],
9798
private readonly dynamicImports: DynamicImports,
9899
) {
99100
this.#keepNodeAliveInterval = setInterval(() => 60_000);
@@ -205,12 +206,13 @@ enc
205206
this.#onOpenPromise = Promise.withResolvers();
206207

207208
// Connect transport
208-
if (this.transportKind === "websocket") {
209+
const transport = this.#pickTransport();
210+
if (transport === "websocket") {
209211
this.#connectWebSocket();
210-
} else if (this.transportKind === "sse") {
212+
} else if (transport === "sse") {
211213
this.#connectSse();
212214
} else {
213-
assertUnreachable(this.transportKind);
215+
assertUnreachable(transport);
214216
}
215217

216218
// Wait for result
@@ -220,10 +222,23 @@ enc
220222
}
221223
}
222224

225+
#pickTransport(): Transport {
226+
// Choose first supported transport from server's list that client also supports
227+
const transport = this.serverTransports.find(t =>
228+
this.supportedTransports.includes(t)
229+
);
230+
231+
if (!transport) {
232+
throw new errors.NoSupportedTransport();
233+
}
234+
235+
return transport;
236+
}
237+
223238
#connectWebSocket() {
224239
const { WebSocket } = this.dynamicImports;
225240

226-
const url = this.#buildConnectionUrl();
241+
const url = this.#buildConnectionUrl("websocket");
227242

228243
const ws = new WebSocket(url);
229244
if (this.encodingKind === "cbor") {
@@ -255,7 +270,7 @@ enc
255270
#connectSse() {
256271
const { EventSource } = this.dynamicImports;
257272

258-
const url = this.#buildConnectionUrl();
273+
const url = this.#buildConnectionUrl("sse");
259274

260275
const eventSource = new EventSource(url);
261276
this.#transport = { sse: eventSource };
@@ -391,8 +406,8 @@ enc
391406
logger().warn("socket error", { event });
392407
}
393408

394-
#buildConnectionUrl(): string {
395-
let url = `${this.endpoint}/connect/${this.transportKind}?encoding=${this.encodingKind}`;
409+
#buildConnectionUrl(transport: Transport): string {
410+
let url = `${this.endpoint}/connect/${transport}?encoding=${this.encodingKind}`;
396411

397412
if (this.parameters !== undefined) {
398413
const paramsStr = JSON.stringify(this.parameters);
@@ -593,7 +608,10 @@ enc
593608
}
594609
return JSON.parse(data);
595610
} else if (this.encodingKind === "cbor") {
596-
if (this.transportKind === "sse") {
611+
if (!this.#transport) {
612+
// Do thing
613+
throw new Error("Cannot parse message when no transport defined");
614+
} else if ("sse" in this.#transport) {
597615
// Decode base64 since SSE sends raw strings
598616
if (typeof data === "string") {
599617
const binaryString = atob(data);
@@ -605,10 +623,10 @@ enc
605623
`Expected data to be a string for SSE, got ${data}.`,
606624
);
607625
}
608-
} else if (this.transportKind === "websocket") {
626+
} else if ("websocket" in this.#transport) {
609627
// Do nothing
610628
} else {
611-
assertUnreachable(this.transportKind);
629+
assertUnreachable(this.#transport);
612630
}
613631

614632
// Decode data

packages/actor-core/src/manager/protocol/mod.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { z } from "zod";
22
import { ActorQuerySchema } from "./query";
3+
import { TransportSchema } from "@/actor/protocol/message/mod";
34
export * from "./query";
45

56
export const ActorsRequestSchema = z.object({
@@ -8,6 +9,7 @@ export const ActorsRequestSchema = z.object({
89

910
export const ActorsResponseSchema = z.object({
1011
endpoint: z.string(),
12+
supportedTransports: z.array(TransportSchema),
1113
});
1214

1315
//export const RivetConfigResponseSchema = z.object({

packages/platforms/cloudflare-workers/src/manager.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ export function buildManager(env: Env): ManagerDriver {
3636
//
3737
//return res.actor;
3838

39-
return { endpoint: buildActorEndpoint(origin, query.getForId.actorId) };
39+
return {
40+
endpoint: buildActorEndpoint(origin, query.getForId.actorId),
41+
supportedTransports: ["websocket", "sse"]
42+
};
4043
}
4144
if ("getOrCreateForTags" in query) {
4245
const tags = query.getOrCreateForTags.tags;
@@ -78,7 +81,10 @@ async function getWithTags(
7881
`actor:tags:${JSON.stringify(tags)}:id`,
7982
);
8083
if (actorId) {
81-
return { endpoint: buildActorEndpoint(origin, actorId) };
84+
return {
85+
endpoint: buildActorEndpoint(origin, actorId),
86+
supportedTransports: ["websocket", "sse"]
87+
};
8288
}
8389
return undefined;
8490

@@ -131,5 +137,8 @@ async function createActor(
131137
actorId.toString(),
132138
);
133139

134-
return { endpoint: buildActorEndpoint(origin, actorId.toString()) };
140+
return {
141+
endpoint: buildActorEndpoint(origin, actorId.toString()),
142+
supportedTransports: ["websocket", "sse"]
143+
};
135144
}

packages/platforms/redis/src/manager.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ export function buildManager(redis: Redis): ManagerDriver {
3737
//
3838
//return res.actor;
3939

40-
return { endpoint: buildActorEndpoint(origin, query.getForId.actorId) };
40+
return {
41+
endpoint: buildActorEndpoint(origin, query.getForId.actorId),
42+
supportedTransports: ["websocket", "sse"]
43+
};
4144
}
4245
if ("getOrCreateForTags" in query) {
4346
const tags = query.getOrCreateForTags.tags;
@@ -81,7 +84,10 @@ async function getWithTags(
8184

8285
const actorId = await redis.get(`actor_tags:${JSON.stringify(tags)}:id`);
8386
if (actorId) {
84-
return { endpoint: buildActorEndpoint(origin, actorId) };
87+
return {
88+
endpoint: buildActorEndpoint(origin, actorId),
89+
supportedTransports: ["websocket", "sse"]
90+
};
8591
}
8692
return undefined;
8793

@@ -126,5 +132,8 @@ async function createActor(
126132
[`actor_tags:${JSON.stringify(createRequest.tags)}:id`]: actorId,
127133
});
128134

129-
return { endpoint: buildActorEndpoint(origin, actorId.toString()) };
135+
return {
136+
endpoint: buildActorEndpoint(origin, actorId.toString()),
137+
supportedTransports: ["websocket", "sse"]
138+
};
130139
}

packages/platforms/rivet/src/manager.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@ export function buildManager(clientConfig: RivetClientConfig): ManagerDriver {
6868
);
6969
}
7070

71-
return { endpoint: buildActorEndpoint(res.actor) };
71+
return {
72+
endpoint: buildActorEndpoint(res.actor),
73+
supportedTransports: ["websocket", "sse"]
74+
};
7275
}
7376
if ("getOrCreateForTags" in query) {
7477
const tags = query.getOrCreateForTags.tags;
@@ -137,7 +140,10 @@ async function getWithTags(
137140
actors.sort((a: RivetActor, b: RivetActor) => a.id.localeCompare(b.id));
138141
}
139142

140-
return { endpoint: buildActorEndpoint(actors[0]) };
143+
return {
144+
endpoint: buildActorEndpoint(actors[0]),
145+
supportedTransports: ["websocket", "sse"]
146+
};
141147
}
142148

143149
async function createActor(
@@ -177,7 +183,10 @@ async function createActor(
177183
req,
178184
);
179185

180-
return { endpoint: buildActorEndpoint(actor) };
186+
return {
187+
endpoint: buildActorEndpoint(actor),
188+
supportedTransports: ["websocket", "sse"]
189+
};
181190
}
182191

183192
async function getBuildWithTags(

0 commit comments

Comments
 (0)