diff --git a/packages/actor-core/fixtures/driver-test-suite/conn-params.ts b/packages/actor-core/fixtures/driver-test-suite/conn-params.ts new file mode 100644 index 000000000..06de727ff --- /dev/null +++ b/packages/actor-core/fixtures/driver-test-suite/conn-params.ts @@ -0,0 +1,33 @@ +import { actor, setup } from "actor-core"; + +const counterWithParams = actor({ + state: { count: 0, initializers: [] as string[] }, + createConnState: (c, { params }: { params: { name?: string } }) => { + return { + name: params?.name || "anonymous", + }; + }, + onConnect: (c, conn) => { + // Record connection name + c.state.initializers.push(conn.state.name); + }, + actions: { + increment: (c, x: number) => { + c.state.count += x; + c.broadcast("newCount", { + count: c.state.count, + by: c.conn.state.name, + }); + return c.state.count; + }, + getInitializers: (c) => { + return c.state.initializers; + }, + }, +}); + +export const app = setup({ + actors: { counter: counterWithParams }, +}); + +export type App = typeof app; diff --git a/packages/actor-core/fixtures/driver-test-suite/lifecycle.ts b/packages/actor-core/fixtures/driver-test-suite/lifecycle.ts new file mode 100644 index 000000000..3d0f22c42 --- /dev/null +++ b/packages/actor-core/fixtures/driver-test-suite/lifecycle.ts @@ -0,0 +1,38 @@ +import { actor, setup } from "actor-core"; + +const lifecycleActor = actor({ + state: { + count: 0, + events: [] as string[], + }, + createConnState: () => ({ joinTime: Date.now() }), + onStart: (c) => { + c.state.events.push("onStart"); + }, + onBeforeConnect: (c, { params }: { params: any }) => { + c.state.events.push("onBeforeConnect"); + // Could throw here to reject connection + }, + onConnect: (c) => { + c.state.events.push("onConnect"); + }, + onDisconnect: (c) => { + c.state.events.push("onDisconnect"); + }, + actions: { + getEvents: (c) => { + return c.state.events; + }, + increment: (c, x: number) => { + c.state.count += x; + return c.state.count; + }, + }, +}); + +export const app = setup({ + actors: { counter: lifecycleActor }, +}); + +export type App = typeof app; + diff --git a/packages/actor-core/src/actor/errors.ts b/packages/actor-core/src/actor/errors.ts index 943e96bd9..214e29c0e 100644 --- a/packages/actor-core/src/actor/errors.ts +++ b/packages/actor-core/src/actor/errors.ts @@ -212,18 +212,6 @@ export class UserError extends ActorError { } } -// Proxy-related errors - -export class MissingRequiredParameters extends ActorError { - constructor(missingParams: string[]) { - super( - "missing_required_parameters", - `Missing required parameters: ${missingParams.join(", ")}`, - { public: true } - ); - } -} - export class InvalidQueryJSON extends ActorError { constructor(error?: unknown) { super( @@ -234,11 +222,11 @@ export class InvalidQueryJSON extends ActorError { } } -export class InvalidQueryFormat extends ActorError { +export class InvalidRequest extends ActorError { constructor(error?: unknown) { super( - "invalid_query_format", - `Invalid query format: ${error}`, + "invalid_request", + `Invalid request: ${error}`, { public: true, cause: error } ); } @@ -280,12 +268,6 @@ export class InvalidRpcRequest extends ActorError { } } -export class InvalidRequest extends ActorError { - constructor(message: string) { - super("invalid_request", message, { public: true }); - } -} - export class InvalidParams extends ActorError { constructor(message: string) { super("invalid_params", message, { public: true }); diff --git a/packages/actor-core/src/actor/instance.ts b/packages/actor-core/src/actor/instance.ts index bdfece0d9..1e23a7c81 100644 --- a/packages/actor-core/src/actor/instance.ts +++ b/packages/actor-core/src/actor/instance.ts @@ -720,7 +720,8 @@ export class ActorInstance { new CachedSerializer({ b: { i: { - ci: `${conn.id}`, + ai: this.id, + ci: conn.id, ct: conn._token, }, }, diff --git a/packages/actor-core/src/actor/protocol/message/mod.ts b/packages/actor-core/src/actor/protocol/message/mod.ts index bcaba3c49..07ef1303f 100644 --- a/packages/actor-core/src/actor/protocol/message/mod.ts +++ b/packages/actor-core/src/actor/protocol/message/mod.ts @@ -15,6 +15,7 @@ import { } from "@/actor/protocol/serde"; import { deconstructError } from "@/common/utils"; import { Actions } from "@/actor/config"; +import invariant from "invariant"; export const TransportSchema = z.enum(["websocket", "sse"]); @@ -91,7 +92,9 @@ export async function processMessage( let rpcName: string | undefined; try { - if ("rr" in message.b) { + if ("i" in message.b) { + invariant(false, "should not be notified of init event"); + } else if ("rr" in message.b) { // RPC request if (handler.onExecuteRpc === undefined) { diff --git a/packages/actor-core/src/actor/protocol/message/to-client.ts b/packages/actor-core/src/actor/protocol/message/to-client.ts index 5547d100b..cf7da6151 100644 --- a/packages/actor-core/src/actor/protocol/message/to-client.ts +++ b/packages/actor-core/src/actor/protocol/message/to-client.ts @@ -2,6 +2,8 @@ import { z } from "zod"; // Only called for SSE because we don't need this for WebSockets export const InitSchema = z.object({ + // Actor ID + ai: z.string(), // Connection ID ci: z.string(), // Connection token diff --git a/packages/actor-core/src/actor/protocol/message/to-server.ts b/packages/actor-core/src/actor/protocol/message/to-server.ts index 6ee5b093b..56372b7ac 100644 --- a/packages/actor-core/src/actor/protocol/message/to-server.ts +++ b/packages/actor-core/src/actor/protocol/message/to-server.ts @@ -1,5 +1,10 @@ import { z } from "zod"; +const InitSchema = z.object({ + // Conn Params + p: z.unknown({}).optional(), +}); + const RpcRequestSchema = z.object({ // ID i: z.number().int(), @@ -19,6 +24,7 @@ const SubscriptionRequestSchema = z.object({ export const ToServerSchema = z.object({ // Body b: z.union([ + z.object({ i: InitSchema }), z.object({ rr: RpcRequestSchema }), z.object({ sr: SubscriptionRequestSchema }), ]), diff --git a/packages/actor-core/src/actor/router-endpoints.ts b/packages/actor-core/src/actor/router-endpoints.ts index f54320f21..dfd53eeae 100644 --- a/packages/actor-core/src/actor/router-endpoints.ts +++ b/packages/actor-core/src/actor/router-endpoints.ts @@ -18,7 +18,6 @@ import { assertUnreachable } from "./utils"; import { deconstructError, stringifyError } from "@/common/utils"; import type { AppConfig } from "@/app/config"; import type { DriverConfig } from "@/driver-helpers/config"; -import { ToClient } from "./protocol/message/to-client"; import invariant from "invariant"; export interface ConnectWebSocketOpts { @@ -89,52 +88,88 @@ export function handleWebSocketConnect( actorId: string, ) { return async () => { - const encoding = getRequestEncoding(context.req); + const encoding = getRequestEncoding(context.req, true); - const parameters = getRequestConnParams( - context.req, - appConfig, - driverConfig, - ); + let sharedWs: WSContext | undefined = undefined; - // Continue with normal connection setup - const wsHandler = await handler({ - req: context.req, - encoding, - params: parameters, - actorId, - }); + // Setup promise for the init message since all other behavior depends on this + const { + promise: onInitPromise, + resolve: onInitResolve, + reject: onInitReject, + } = Promise.withResolvers(); + + let didTimeOut = false; + let didInit = false; + + // Add timeout waiting for init + const initTimeout = setTimeout(() => { + logger().warn("timed out waiting for init"); - const { promise: onOpenPromise, resolve: onOpenResolve } = - Promise.withResolvers(); + sharedWs?.close(1001, "timed out waiting for init message"); + didTimeOut = true; + onInitReject("init timed out"); + }, appConfig.webSocketInitTimeout); return { onOpen: async (_evt: any, ws: WSContext) => { - try { - // TODO: maybe timeout this! - await wsHandler.onOpen(ws); - onOpenResolve(undefined); - } catch (error) { - deconstructError(error, logger(), { wsEvent: "open" }); - onOpenResolve(undefined); - ws.close(1011, "internal error"); - } + sharedWs = ws; + + logger().debug("websocket open"); + + // Close WS immediately if init timed out. This indicates a long delay at the protocol level in sending the init message. + if (didTimeOut) ws.close(1001, "timed out waiting for init message"); }, onMessage: async (evt: { data: any }, ws: WSContext) => { try { - invariant(encoding, "encoding should be defined"); - - await onOpenPromise; - - logger().debug("received message"); - const value = evt.data.valueOf() as InputData; const message = await parseMessage(value, { encoding: encoding, maxIncomingMessageSize: appConfig.maxIncomingMessageSize, }); - await wsHandler.onMessage(message); + if ("i" in message.b) { + // Handle init message + // + // Parameters must go over the init message instead of a query parameter so it receives full E2EE + + logger().debug("received init ws message"); + + invariant( + !didInit, + "should not have already received init message", + ); + didInit = true; + clearTimeout(initTimeout); + + try { + // Create connection handler + const wsHandler = await handler({ + req: context.req, + encoding, + params: message.b.i.p, + actorId, + }); + + // Notify socket open + // TODO: Add timeout to this + await wsHandler.onOpen(ws); + + // Allow all other events to proceed + onInitResolve(wsHandler); + } catch (error) { + deconstructError(error, logger(), { wsEvent: "open" }); + onInitReject(error); + ws.close(1011, "internal error"); + } + } else { + // Handle all other messages + + logger().debug("received regular ws message"); + + const wsHandler = await onInitPromise; + await wsHandler.onMessage(message); + } } catch (error) { const { code } = deconstructError(error, logger(), { wsEvent: "message", @@ -150,36 +185,33 @@ export function handleWebSocketConnect( }, ws: WSContext, ) => { - try { - await onOpenPromise; - - // HACK: Close socket in order to fix bug with Cloudflare Durable Objects leaving WS in closing state - // https://github.com/cloudflare/workerd/issues/2569 - ws.close(1000, "hack_force_close"); - - if (event.wasClean) { - logger().info("websocket closed", { - code: event.code, - reason: event.reason, - wasClean: event.wasClean, - }); - } else { - logger().warn("websocket closed", { - code: event.code, - reason: event.reason, - wasClean: event.wasClean, - }); - } + if (event.wasClean) { + logger().info("websocket closed", { + code: event.code, + reason: event.reason, + wasClean: event.wasClean, + }); + } else { + logger().warn("websocket closed", { + code: event.code, + reason: event.reason, + wasClean: event.wasClean, + }); + } + // HACK: Close socket in order to fix bug with Cloudflare Durable Objects leaving WS in closing state + // https://github.com/cloudflare/workerd/issues/2569 + ws.close(1000, "hack_force_close"); + + try { + const wsHandler = await onInitPromise; await wsHandler.onClose(); } catch (error) { deconstructError(error, logger(), { wsEvent: "close" }); } }, - onError: async (error: unknown) => { + onError: async (_error: unknown) => { try { - await onOpenPromise; - // Actors don't need to know about this, since it's abstracted away logger().warn("websocket error"); } catch (error) { @@ -200,7 +232,7 @@ export async function handleSseConnect( handler: (opts: ConnectSseOpts) => Promise, actorId: string, ) { - const encoding = getRequestEncoding(c.req); + const encoding = getRequestEncoding(c.req, false); const parameters = getRequestConnParams(c.req, appConfig, driverConfig); const sseHandler = await handler({ @@ -246,7 +278,7 @@ export async function handleRpc( actorId: string, ) { try { - const encoding = getRequestEncoding(c.req); + const encoding = getRequestEncoding(c.req, false); const parameters = getRequestConnParams(c.req, appConfig, driverConfig); logger().debug("handling rpc", { rpcName, encoding }); @@ -343,7 +375,7 @@ export async function handleConnectionMessage( actorId: string, ) { try { - const encoding = getRequestEncoding(c.req); + const encoding = getRequestEncoding(c.req, false); // Validate incoming request let message: messageToServer.ToServer; @@ -398,8 +430,13 @@ export async function handleConnectionMessage( } // Helper to get the connection encoding from a request -export function getRequestEncoding(req: HonoRequest): Encoding { - const encodingParam = req.query("encoding"); +export function getRequestEncoding( + req: HonoRequest, + useQuery: boolean, +): Encoding { + const encodingParam = useQuery + ? req.query("encoding") + : req.header(HEADER_ENCODING); if (!encodingParam) { return "json"; } @@ -412,13 +449,55 @@ export function getRequestEncoding(req: HonoRequest): Encoding { return result.data; } +export function getRequestQuery(c: HonoContext, useQuery: boolean): unknown { + // Get query parameters for actor lookup + const queryParam = useQuery + ? c.req.query("query") + : c.req.header(HEADER_ACTOR_QUERY); + if (!queryParam) { + logger().error("missing query parameter"); + throw new errors.InvalidRequest("missing query"); + } + + // Parse the query JSON and validate with schema + try { + const parsed = JSON.parse(queryParam); + return parsed; + } catch (error) { + logger().error("invalid query json", { error }); + throw new errors.InvalidQueryJSON(error); + } +} + +export const HEADER_ACTOR_QUERY = "X-AC-Query"; + +export const HEADER_ENCODING = "X-AC-Encoding"; + +// IMPORTANT: Params must be in headers or in an E2EE part of the request (i.e. NOT the URL or query string) in order to ensure that tokens can be securely passed in params. +export const HEADER_CONN_PARAMS = "X-AC-Conn-Params"; + +export const HEADER_ACTOR_ID = "X-AC-Actor"; + +export const HEADER_CONN_ID = "X-AC-Conn"; + +export const HEADER_CONN_TOKEN = "X-AC-Conn-Token"; + +export const ALL_HEADERS = [ + HEADER_ACTOR_QUERY, + HEADER_ENCODING, + HEADER_CONN_PARAMS, + HEADER_ACTOR_ID, + HEADER_CONN_ID, + HEADER_CONN_TOKEN, +]; + // Helper to get connection parameters for the request export function getRequestConnParams( req: HonoRequest, appConfig: AppConfig, driverConfig: DriverConfig, ): unknown { - const paramsParam = req.query("params"); + const paramsParam = req.header(HEADER_CONN_PARAMS); if (!paramsParam) { return null; } diff --git a/packages/actor-core/src/actor/router.ts b/packages/actor-core/src/actor/router.ts index 9ab9521be..4df8f4a33 100644 --- a/packages/actor-core/src/actor/router.ts +++ b/packages/actor-core/src/actor/router.ts @@ -27,6 +27,9 @@ import { handleSseConnect, handleRpc, handleConnectionMessage, + HEADER_CONN_TOKEN, + HEADER_CONN_ID, + ALL_HEADERS, } from "./router-endpoints"; export type { @@ -68,6 +71,8 @@ export function createActorRouter( // //This is only relevant if the actor is exposed directly publicly if (appConfig.cors) { + const corsConfig = appConfig.cors; + app.use("*", async (c, next) => { const path = c.req.path; @@ -76,7 +81,10 @@ export function createActorRouter( return next(); } - return cors(appConfig.cors)(c, next); + return cors({ + ...corsConfig, + allowHeaders: [...(appConfig.cors?.allowHeaders ?? []), ...ALL_HEADERS], + })(c, next); }); } @@ -146,12 +154,12 @@ export function createActorRouter( ); }); - app.post("/connections/:conn/message", async (c) => { + app.post("/connections/message", async (c) => { if (!handlers.onConnMessage) { throw new Error("onConnMessage handler is required"); } - const connId = c.req.param("conn"); - const connToken = c.req.query("connectionToken"); + const connId = c.req.header(HEADER_CONN_ID); + const connToken = c.req.header(HEADER_CONN_TOKEN); const actorId = await handler.getActorId(); if (!connId || !connToken) { throw new Error("Missing required parameters"); diff --git a/packages/actor-core/src/app/config.ts b/packages/actor-core/src/app/config.ts index 701f0d958..88d392bb4 100644 --- a/packages/actor-core/src/app/config.ts +++ b/packages/actor-core/src/app/config.ts @@ -63,6 +63,9 @@ export const AppConfigSchema = z.object({ maxIncomingMessageSize: z.number().optional().default(65_536), + /** How long to wait for the WebSocket to send an init message before closing it. */ + webSocketInitTimeout: z.number().optional().default(5_000), + /** Peer configuration for coordinated topology. */ actorPeer: ActorPeerConfigSchema.optional().default({}), diff --git a/packages/actor-core/src/client/actor-common.ts b/packages/actor-core/src/client/actor-common.ts index 613cab85a..b397c4fbb 100644 --- a/packages/actor-core/src/client/actor-common.ts +++ b/packages/actor-core/src/client/actor-common.ts @@ -5,6 +5,7 @@ import type { ActorQuery } from "@/manager/protocol/query"; import { logger } from "./log"; import * as errors from "./errors"; import { sendHttpRequest } from "./utils"; +import { HEADER_ACTOR_QUERY, HEADER_ENCODING } from "@/actor/router-endpoints"; /** * RPC function returned by Actor connections and handles. @@ -49,18 +50,17 @@ export async function resolveActorId( ): Promise { logger().debug("resolving actor ID", { query: actorQuery }); - // Construct the URL using the current actor query - const queryParam = encodeURIComponent(JSON.stringify(actorQuery)); - const url = `${endpoint}/actors/resolve?encoding=${encodingKind}&query=${queryParam}`; - - // Use the shared HTTP request utility with integrated serialization try { const result = await sendHttpRequest< Record, protoHttpResolve.ResolveResponse >({ - url, + url: `${endpoint}/actors/resolve`, method: "POST", + headers: { + [HEADER_ENCODING]: encodingKind, + [HEADER_ACTOR_QUERY]: JSON.stringify(actorQuery), + }, body: {}, encoding: encodingKind, }); diff --git a/packages/actor-core/src/client/actor-conn.ts b/packages/actor-core/src/client/actor-conn.ts index f089dede8..0bb3be1ab 100644 --- a/packages/actor-core/src/client/actor-conn.ts +++ b/packages/actor-core/src/client/actor-conn.ts @@ -16,6 +16,15 @@ import { ACTOR_CONNS_SYMBOL, type ClientRaw, TRANSPORT_SYMBOL } from "./client"; import * as errors from "./errors"; import { logger } from "./log"; import { type WebSocketMessage as ConnMessage, messageLength } from "./utils"; +import { + HEADER_ACTOR_ID, + HEADER_ACTOR_QUERY, + HEADER_CONN_ID, + HEADER_CONN_TOKEN, + HEADER_ENCODING, + HEADER_CONN_PARAMS, +} from "@/actor/router-endpoints"; +import type { EventSource } from "eventsource"; // Re-export the type with the original name to maintain compatibility type ActorDefinitionRpcs = @@ -74,6 +83,7 @@ export class ActorConnRaw { #connecting = false; // These will only be set on SSE driver + #actorId?: string; #connectionId?: string; #connectionToken?: string; @@ -258,7 +268,11 @@ enc #connectWebSocket() { const { WebSocket } = this.#dynamicImports; - const url = this.#buildConnUrl("websocket"); + const actorQueryStr = encodeURIComponent(JSON.stringify(this.actorQuery)); + const endpoint = this.endpoint + .replace(/^http:/, "ws:") + .replace(/^https:/, "wss:"); + const url = `${endpoint}/actors/connect/websocket?encoding=${this.encodingKind}&query=${actorQueryStr}`; logger().debug("connecting to websocket", { url }); const ws = new WebSocket(url); @@ -275,7 +289,16 @@ enc this.#transport = { websocket: ws }; ws.onopen = () => { logger().debug("websocket open"); - // #handleOnOpen is called on "i" event + + // Set init message + this.#sendMessage( + { + b: { i: { p: this.params } }, + }, + { ephemeral: true }, + ); + + // #handleOnOpen is called on "i" event from the server }; ws.onmessage = async (ev) => { this.#handleOnMessage(ev); @@ -291,10 +314,25 @@ enc #connectSse() { const { EventSource } = this.#dynamicImports; - const url = this.#buildConnUrl("sse"); + const url = `${this.endpoint}/actors/connect/sse`; logger().debug("connecting to sse", { url }); - const eventSource = new EventSource(url); + const eventSource = new EventSource(url, { + fetch: (input, init) => { + return fetch(input, { + ...init, + headers: { + ...init?.headers, + "User-Agent": httpUserAgent(), + [HEADER_ENCODING]: this.encodingKind, + [HEADER_ACTOR_QUERY]: JSON.stringify(this.actorQuery), + ...(this.params !== undefined + ? { [HEADER_CONN_PARAMS]: JSON.stringify(this.params) } + : {}), + }, + }); + }, + }); this.#transport = { sse: eventSource }; eventSource.onopen = () => { logger().debug("eventsource open"); @@ -357,9 +395,11 @@ enc if ("i" in response.b) { // This is only called for SSE + this.#actorId = response.b.i.ai; this.#connectionId = response.b.i.ci; this.#connectionToken = response.b.i.ct; logger().trace("received init message", { + actorId: this.#actorId, connectionId: this.#connectionId, }); this.#handleOnOpen(); @@ -477,34 +517,6 @@ enc logger().warn("socket error", { event }); } - #buildConnUrl(transport: Transport): string { - // Get the manager endpoint from the endpoint provided - const actorQueryStr = encodeURIComponent(JSON.stringify(this.actorQuery)); - - logger().debug("building conn url", { - transport, - }); - - let url = `${this.endpoint}/actors/connect/${transport}?encoding=${this.encodingKind}&query=${actorQueryStr}`; - - if (this.params !== undefined) { - const paramsStr = JSON.stringify(this.params); - - // TODO: This is an imprecise count since it doesn't count the full URL length & URI encoding expansion in the URL size - if (paramsStr.length > MAX_CONN_PARAMS_SIZE) { - throw new errors.ConnParamsTooLong(); - } - - url += `¶ms=${encodeURIComponent(paramsStr)}`; - } - - if (transport === "websocket") { - url = url.replace(/^http:/, "ws:").replace(/^https:/, "wss:"); - } - - return url; - } - #takeRpcInFlight(id: number): RpcInFlight { const inFlight = this.#rpcInFlight.get(id); if (!inFlight) { @@ -674,21 +686,20 @@ enc async #sendHttpMessage(message: wsToServer.ToServer, opts?: SendOpts) { try { - if (!this.#connectionId || !this.#connectionToken) + if (!this.#actorId || !this.#connectionId || !this.#connectionToken) throw new errors.InternalError("Missing connection ID or token."); - // Get the manager endpoint from the endpoint provided - const actorQueryStr = encodeURIComponent(JSON.stringify(this.actorQuery)); - - const url = `${this.endpoint}/actors/connections/${this.#connectionId}/message?encoding=${this.encodingKind}&connectionToken=${encodeURIComponent(this.#connectionToken)}&query=${actorQueryStr}`; - // TODO: Implement ordered messages, this is not guaranteed order. Needs to use an index in order to ensure we can pipeline requests efficiently. // TODO: Validate that we're using HTTP/3 whenever possible for pipelining requests const messageSerialized = this.#serialize(message); - const res = await fetch(url, { + const res = await fetch(`${this.endpoint}/actors/message`, { method: "POST", headers: { "User-Agent": httpUserAgent(), + [HEADER_ENCODING]: this.encodingKind, + [HEADER_ACTOR_ID]: this.#actorId, + [HEADER_CONN_ID]: this.#connectionId, + [HEADER_CONN_TOKEN]: this.#connectionToken, }, body: messageSerialized, }); diff --git a/packages/actor-core/src/client/actor-handle.ts b/packages/actor-core/src/client/actor-handle.ts index 96f9c75f0..70b7aec4f 100644 --- a/packages/actor-core/src/client/actor-handle.ts +++ b/packages/actor-core/src/client/actor-handle.ts @@ -9,6 +9,11 @@ import { logger } from "./log"; import { sendHttpRequest } from "./utils"; import invariant from "invariant"; import { assertUnreachable } from "@/actor/utils"; +import { + HEADER_ACTOR_QUERY, + HEADER_CONN_PARAMS, + HEADER_ENCODING, +} from "@/actor/router-endpoints"; /** * Provides underlying functions for stateless {@link ActorHandle} for RPC calls. @@ -73,16 +78,16 @@ export class ActorHandleRaw { query: this.#actorQuery, }); - // Build query parameters - let baseUrl = `${this.#endpoint}/actors/rpc/${encodeURIComponent(name)}?encoding=${this.#encodingKind}&query=${encodeURIComponent(JSON.stringify(this.#actorQuery))}`; - if (this.params !== undefined) { - baseUrl += `¶ms=${encodeURIComponent(JSON.stringify(this.params))}`; - } - - // Use the shared HTTP request utility with integrated serialization const responseData = await sendHttpRequest({ - url: baseUrl, + url: `${this.#endpoint}/actors/rpc/${encodeURIComponent(name)}`, method: "POST", + headers: { + [HEADER_ENCODING]: this.#encodingKind, + [HEADER_ACTOR_QUERY]: JSON.stringify(this.#actorQuery), + ...(this.params !== undefined + ? { [HEADER_CONN_PARAMS]: JSON.stringify(this.params) } + : {}), + }, body: { a: args } satisfies RpcRequest, encoding: this.#encodingKind, }); diff --git a/packages/actor-core/src/client/utils.ts b/packages/actor-core/src/client/utils.ts index ca3b660a3..70a40e34b 100644 --- a/packages/actor-core/src/client/utils.ts +++ b/packages/actor-core/src/client/utils.ts @@ -28,6 +28,7 @@ export function messageLength(message: WebSocketMessage): number { export interface HttpRequestOpts { method: string; url: string; + headers: Record; body?: Body; encoding: Encoding; skipParseResponse?: boolean; @@ -64,12 +65,13 @@ export async function sendHttpRequest< response = await fetch(opts.url, { method: opts.method, headers: { - "User-Agent": httpUserAgent(), + ...opts.headers, ...(contentType ? { "Content-Type": contentType, } : {}), + "User-Agent": httpUserAgent(), }, body: bodyData, }); diff --git a/packages/actor-core/src/common/eventsource.ts b/packages/actor-core/src/common/eventsource.ts index 76c365dc4..8329c46f3 100644 --- a/packages/actor-core/src/common/eventsource.ts +++ b/packages/actor-core/src/common/eventsource.ts @@ -1,8 +1,12 @@ import { logger } from "@/client/log"; +import type { EventSource } from "eventsource"; // Global singleton promise that will be reused for subsequent calls let eventSourcePromise: Promise | null = null; +/** + * Import `eventsource` from the custom `eventsource` library. We need a custom implemnetation since we need to attach our own custom headers to the request. + **/ export async function importEventSource(): Promise { // Return existing promise if we already started loading if (eventSourcePromise !== null) { @@ -13,27 +17,21 @@ export async function importEventSource(): Promise { eventSourcePromise = (async () => { let _EventSource: typeof EventSource; - if (typeof EventSource !== "undefined") { - // Browser environment - _EventSource = EventSource; - logger().debug("using native eventsource"); - } else { - // Node.js environment - try { - const es = await import("eventsource"); - _EventSource = es.EventSource; - logger().debug("using eventsource from npm"); - } catch (err) { - // EventSource not available - _EventSource = class MockEventSource { - constructor() { - throw new Error( - 'EventSource support requires installing the "eventsource" peer dependency.', - ); - } - } as unknown as typeof EventSource; - logger().debug("using mock eventsource"); - } + // Node.js environment + try { + const es = await import("eventsource"); + _EventSource = es.EventSource; + logger().debug("using eventsource from npm"); + } catch (err) { + // EventSource not available + _EventSource = class MockEventSource { + constructor() { + throw new Error( + 'EventSource support requires installing the "eventsource" peer dependency.', + ); + } + } as unknown as typeof EventSource; + logger().debug("using mock eventsource"); } return _EventSource; @@ -41,3 +39,42 @@ export async function importEventSource(): Promise { return eventSourcePromise; } + +//export async function importEventSource(): Promise { +// // Return existing promise if we already started loading +// if (eventSourcePromise !== null) { +// return eventSourcePromise; +// } +// +// // Create and store the promise +// eventSourcePromise = (async () => { +// let _EventSource: typeof EventSource; +// +// if (typeof EventSource !== "undefined") { +// // Browser environment +// _EventSource = EventSource; +// logger().debug("using native eventsource"); +// } else { +// // Node.js environment +// try { +// const es = await import("eventsource"); +// _EventSource = es.EventSource; +// logger().debug("using eventsource from npm"); +// } catch (err) { +// // EventSource not available +// _EventSource = class MockEventSource { +// constructor() { +// throw new Error( +// 'EventSource support requires installing the "eventsource" peer dependency.', +// ); +// } +// } as unknown as typeof EventSource; +// logger().debug("using mock eventsource"); +// } +// } +// +// return _EventSource; +// })(); +// +// return eventSourcePromise; +//} diff --git a/packages/actor-core/src/common/router.ts b/packages/actor-core/src/common/router.ts index 66f66a749..7de56b18f 100644 --- a/packages/actor-core/src/common/router.ts +++ b/packages/actor-core/src/common/router.ts @@ -44,7 +44,7 @@ export function handleRouteError(error: unknown, c: HonoContext) { }, ); - const encoding = getRequestEncoding(c.req); + const encoding = getRequestEncoding(c.req, false); const output = serialize( { c: code, diff --git a/packages/actor-core/src/driver-test-suite/mod.ts b/packages/actor-core/src/driver-test-suite/mod.ts index 4c5469475..593355c86 100644 --- a/packages/actor-core/src/driver-test-suite/mod.ts +++ b/packages/actor-core/src/driver-test-suite/mod.ts @@ -5,7 +5,7 @@ import { DriverConfig, ManagerDriver, } from "@/driver-helpers/mod"; -import { runActorDriverTests, waitFor } from "./tests/actor-driver"; +import { runActorDriverTests } from "./tests/actor-driver"; import { runManagerDriverTests } from "./tests/manager-driver"; import { describe } from "vitest"; import { @@ -18,6 +18,7 @@ import invariant from "invariant"; import { bundleRequire } from "bundle-require"; import { getPort } from "@/test/mod"; import { Transport } from "@/client/mod"; +import { runActorConnTests } from "./tests/actor-conn"; export interface DriverTestConfig { /** Deploys an app and returns the connection endpoint. */ @@ -31,10 +32,8 @@ export interface DriverTestConfig { /** Cloudflare Workers has some bugs with cleanup. */ HACK_skipCleanupNet?: boolean; -} -export interface DriverTestConfigWithTransport extends DriverTestConfig { - transport: Transport; + transport?: Transport; } export interface DriverDeployOutput { @@ -46,13 +45,12 @@ export interface DriverDeployOutput { /** Runs all Vitest tests against the provided drivers. */ export function runDriverTests(driverTestConfig: DriverTestConfig) { + runActorDriverTests(driverTestConfig); + runManagerDriverTests(driverTestConfig); + for (const transport of ["websocket", "sse"] as Transport[]) { - describe(`driver tests (${transport})`, () => { - runActorDriverTests({ - ...driverTestConfig, - transport, - }); - runManagerDriverTests({ + describe(`actor connection (${transport})`, () => { + runActorConnTests({ ...driverTestConfig, transport, }); @@ -60,13 +58,6 @@ export function runDriverTests(driverTestConfig: DriverTestConfig) { } } -/** - * Re-export the waitFor helper for use in other tests. - * This function handles waiting in tests, using either real timers or mocked timers - * based on the driverTestConfig.useRealTimers setting. - */ -export { waitFor }; - /** * Helper function to adapt the drivers to the Node.js runtime for tests. * diff --git a/packages/actor-core/src/driver-test-suite/test-apps.ts b/packages/actor-core/src/driver-test-suite/test-apps.ts index 59308db71..69309f0a1 100644 --- a/packages/actor-core/src/driver-test-suite/test-apps.ts +++ b/packages/actor-core/src/driver-test-suite/test-apps.ts @@ -2,6 +2,8 @@ import { resolve } from "node:path"; export type { App as CounterApp } from "../../fixtures/driver-test-suite/counter"; export type { App as ScheduledApp } from "../../fixtures/driver-test-suite/scheduled"; +export type { App as ConnParamsApp } from "../../fixtures/driver-test-suite/conn-params"; +export type { App as LifecycleApp } from "../../fixtures/driver-test-suite/lifecycle"; export const COUNTER_APP_PATH = resolve( __dirname, @@ -11,3 +13,11 @@ export const SCHEDULED_APP_PATH = resolve( __dirname, "../../fixtures/driver-test-suite/scheduled.ts", ); +export const CONN_PARAMS_APP_PATH = resolve( + __dirname, + "../../fixtures/driver-test-suite/conn-params.ts", +); +export const LIFECYCLE_APP_PATH = resolve( + __dirname, + "../../fixtures/driver-test-suite/lifecycle.ts", +); \ No newline at end of file diff --git a/packages/actor-core/src/driver-test-suite/tests/actor-conn.ts b/packages/actor-core/src/driver-test-suite/tests/actor-conn.ts new file mode 100644 index 000000000..89c529a10 --- /dev/null +++ b/packages/actor-core/src/driver-test-suite/tests/actor-conn.ts @@ -0,0 +1,261 @@ +import { describe, test, expect } from "vitest"; +import type { DriverTestConfig } from "../mod"; +import { setupDriverTest } from "../utils"; +import { + COUNTER_APP_PATH, + CONN_PARAMS_APP_PATH, + LIFECYCLE_APP_PATH, + type CounterApp, + type ConnParamsApp, + type LifecycleApp, +} from "../test-apps"; + +export function runActorConnTests(driverTestConfig: DriverTestConfig) { + describe("Actor Connection Tests", () => { + describe("Connection Methods", () => { + test("should connect using .get().connect()", async (c) => { + const { client } = await setupDriverTest( + c, + driverTestConfig, + COUNTER_APP_PATH, + ); + + // Create actor + await client.counter.create(["test-get"]); + + // Get a handle and connect + const handle = client.counter.get(["test-get"]); + const connection = handle.connect(); + + // Verify connection by performing an action + const count = await connection.increment(5); + expect(count).toBe(5); + + // Clean up + await connection.dispose(); + }); + + test("should connect using .getForId().connect()", async (c) => { + const { client } = await setupDriverTest( + c, + driverTestConfig, + COUNTER_APP_PATH, + ); + + // Create an actor first to get its ID + const handle = client.counter.getOrCreate(["test-get-for-id"]); + await handle.increment(3); + const actorId = await handle.resolve(); + + // Get a new handle using the actor ID and connect + const idHandle = client.counter.getForId(actorId); + const connection = idHandle.connect(); + + // Verify connection works and state is preserved + const count = await connection.getCount(); + expect(count).toBe(3); + + // Clean up + await connection.dispose(); + }); + + test("should connect using .getOrCreate().connect()", async (c) => { + const { client } = await setupDriverTest( + c, + driverTestConfig, + COUNTER_APP_PATH, + ); + + // Get or create actor and connect + const handle = client.counter.getOrCreate(["test-get-or-create"]); + const connection = handle.connect(); + + // Verify connection works + const count = await connection.increment(7); + expect(count).toBe(7); + + // Clean up + await connection.dispose(); + }); + + test("should connect using (await create()).connect()", async (c) => { + const { client } = await setupDriverTest( + c, + driverTestConfig, + COUNTER_APP_PATH, + ); + + // Create actor and connect + const handle = await client.counter.create(["test-create"]); + const connection = handle.connect(); + + // Verify connection works + const count = await connection.increment(9); + expect(count).toBe(9); + + // Clean up + await connection.dispose(); + }); + }); + + describe("Event Communication", () => { + test("should receive events via broadcast", async (c) => { + const { client } = await setupDriverTest( + c, + driverTestConfig, + COUNTER_APP_PATH, + ); + + // Create actor and connect + const handle = client.counter.getOrCreate(["test-broadcast"]); + const connection = handle.connect(); + + // Set up event listener + const receivedEvents: number[] = []; + connection.on("newCount", (count: number) => { + receivedEvents.push(count); + }); + + // Trigger broadcast events + await connection.increment(5); + await connection.increment(3); + + // Verify events were received + expect(receivedEvents).toContain(5); + expect(receivedEvents).toContain(8); + + // Clean up + await connection.dispose(); + }); + + test("should handle one-time events with once()", async (c) => { + const { client } = await setupDriverTest( + c, + driverTestConfig, + COUNTER_APP_PATH, + ); + + // Create actor and connect + const handle = client.counter.getOrCreate(["test-once"]); + const connection = handle.connect(); + + // Set up one-time event listener + const receivedEvents: number[] = []; + connection.once("newCount", (count: number) => { + receivedEvents.push(count); + }); + + // Trigger multiple events, but should only receive the first one + await connection.increment(5); + await connection.increment(3); + + // Verify only the first event was received + expect(receivedEvents).toEqual([5]); + expect(receivedEvents).not.toContain(8); + + // Clean up + await connection.dispose(); + }); + + test("should unsubscribe from events", async (c) => { + const { client } = await setupDriverTest( + c, + driverTestConfig, + COUNTER_APP_PATH, + ); + + // Create actor and connect + const handle = client.counter.getOrCreate(["test-unsubscribe"]); + const connection = handle.connect(); + + // Set up event listener with unsubscribe + const receivedEvents: number[] = []; + const unsubscribe = connection.on("newCount", (count: number) => { + receivedEvents.push(count); + }); + + // Trigger first event + await connection.increment(5); + + // Unsubscribe + unsubscribe(); + + // Trigger second event, should not be received + await connection.increment(3); + + // Verify only the first event was received + expect(receivedEvents).toEqual([5]); + expect(receivedEvents).not.toContain(8); + + // Clean up + await connection.dispose(); + }); + }); + + describe("Connection Parameters", () => { + test("should pass connection parameters", async (c) => { + const { client } = await setupDriverTest( + c, + driverTestConfig, + CONN_PARAMS_APP_PATH, + ); + + // Create two connections with different params + const handle1 = client.counter.getOrCreate(["test-params"], { + params: { name: "user1" }, + }); + const handle2 = client.counter.getOrCreate(["test-params"], { + params: { name: "user2" }, + }); + + const conn1 = handle1.connect(); + const conn2 = handle2.connect(); + + // Get initializers to verify connection params were used + const initializers = await conn1.getInitializers(); + + // Verify both connection names were recorded + expect(initializers).toContain("user1"); + expect(initializers).toContain("user2"); + + // Clean up + await conn1.dispose(); + await conn2.dispose(); + }); + }); + + describe("Lifecycle Hooks", () => { + test("should trigger lifecycle hooks", async (c) => { + const { client } = await setupDriverTest( + c, + driverTestConfig, + LIFECYCLE_APP_PATH, + ); + + // Create and connect + const handle = client.counter.getOrCreate(["test-lifecycle"]); + const connection = handle.connect(); + + // Verify lifecycle events were triggered + const events = await connection.getEvents(); + + // Check lifecycle hooks were called in the correct order + expect(events).toContain("onStart"); + expect(events).toContain("onBeforeConnect"); + expect(events).toContain("onConnect"); + + // Disconnect should trigger onDisconnect + await connection.dispose(); + + // Reconnect to check if onDisconnect was called + const newConnection = handle.connect(); + + const finalEvents = await newConnection.getEvents(); + expect(finalEvents).toContain("onDisconnect"); + + // Clean up + await newConnection.dispose(); + }); + }); + }); +} diff --git a/packages/actor-core/src/driver-test-suite/tests/actor-driver.ts b/packages/actor-core/src/driver-test-suite/tests/actor-driver.ts index b3f982cf8..e1f5c6560 100644 --- a/packages/actor-core/src/driver-test-suite/tests/actor-driver.ts +++ b/packages/actor-core/src/driver-test-suite/tests/actor-driver.ts @@ -1,6 +1,6 @@ import { describe, test, expect, vi } from "vitest"; -import type { DriverTestConfig, DriverTestConfigWithTransport } from "../mod"; -import { setupDriverTest } from "../utils"; +import type { DriverTestConfig} from "../mod"; +import { setupDriverTest, waitFor } from "../utils"; import { COUNTER_APP_PATH, SCHEDULED_APP_PATH, @@ -8,23 +8,8 @@ import { type ScheduledApp, } from "../test-apps"; -/** - * Waits for the specified time, using either real setTimeout or vi.advanceTimersByTime - * based on the driverTestConfig. - */ -export async function waitFor( - driverTestConfig: DriverTestConfig, - ms: number, -): Promise { - if (driverTestConfig.useRealTimers) { - return new Promise((resolve) => setTimeout(resolve, ms)); - } else { - vi.advanceTimersByTime(ms); - return Promise.resolve(); - } -} export function runActorDriverTests( - driverTestConfig: DriverTestConfigWithTransport, + driverTestConfig: DriverTestConfig ) { describe("Actor Driver Tests", () => { describe("State Persistence", () => { diff --git a/packages/actor-core/src/driver-test-suite/tests/manager-driver.ts b/packages/actor-core/src/driver-test-suite/tests/manager-driver.ts index c559d75bb..832bef2a2 100644 --- a/packages/actor-core/src/driver-test-suite/tests/manager-driver.ts +++ b/packages/actor-core/src/driver-test-suite/tests/manager-driver.ts @@ -1,12 +1,10 @@ import { describe, test, expect, vi } from "vitest"; -import type { DriverTestConfigWithTransport } from "../mod"; import { setupDriverTest } from "../utils"; import { ActorError } from "@/client/mod"; import { COUNTER_APP_PATH, type CounterApp } from "../test-apps"; +import { DriverTestConfig } from "../mod"; -export function runManagerDriverTests( - driverTestConfig: DriverTestConfigWithTransport, -) { +export function runManagerDriverTests(driverTestConfig: DriverTestConfig) { describe("Manager Driver Tests", () => { describe("Client Connection Methods", () => { test("connect() - finds or creates an actor", async (c) => { @@ -72,16 +70,13 @@ export function runManagerDriverTests( const nonexistentId = `nonexistent-${crypto.randomUUID()}`; // Should fail when actor doesn't exist - let counter1Error: ActorError; - const counter1 = client.counter.get([nonexistentId]).connect(); - counter1.onError((e) => { - counter1Error = e; - }); - await vi.waitFor( - () => expect(counter1Error).toBeInstanceOf(ActorError), - 500, - ); - await counter1.dispose(); + try { + await client.counter.get([nonexistentId]).resolve(); + expect.fail("did not error for get"); + } catch (err) { + expect(err).toBeInstanceOf(ActorError); + expect((err as ActorError).code).toBe("actor_not_found"); + } // Create the actor const createdCounter = client.counter.getOrCreate(nonexistentId); diff --git a/packages/actor-core/src/driver-test-suite/utils.ts b/packages/actor-core/src/driver-test-suite/utils.ts index 7f1062268..f0c455665 100644 --- a/packages/actor-core/src/driver-test-suite/utils.ts +++ b/packages/actor-core/src/driver-test-suite/utils.ts @@ -1,12 +1,12 @@ import type { ActorCoreApp } from "@/mod"; import { type TestContext, vi } from "vitest"; import { createClient, type Client } from "@/client/mod"; -import type { DriverTestConfigWithTransport } from "./mod"; +import type { DriverTestConfig, } from "./mod"; // Must use `TestContext` since global hooks do not work when running concurrently export async function setupDriverTest>( c: TestContext, - driverTestConfig: DriverTestConfigWithTransport, + driverTestConfig: DriverTestConfig, appPath: string, ): Promise<{ client: Client; @@ -31,3 +31,15 @@ export async function setupDriverTest>( client, }; } + +export async function waitFor( + driverTestConfig: DriverTestConfig, + ms: number, +): Promise { + if (driverTestConfig.useRealTimers) { + return new Promise((resolve) => setTimeout(resolve, ms)); + } else { + vi.advanceTimersByTime(ms); + return Promise.resolve(); + } +} diff --git a/packages/actor-core/src/manager/protocol/query.ts b/packages/actor-core/src/manager/protocol/query.ts index 6e3907a52..df67d9629 100644 --- a/packages/actor-core/src/manager/protocol/query.ts +++ b/packages/actor-core/src/manager/protocol/query.ts @@ -1,6 +1,14 @@ -import { ActorKeySchema, type ActorKey } from "@/common//utils"; +import { ActorKeySchema } from "@/common//utils"; import { z } from "zod"; import { EncodingSchema } from "@/actor/protocol/serde"; +import { + HEADER_ACTOR_ID, + HEADER_CONN_ID, + HEADER_CONN_PARAMS, + HEADER_CONN_TOKEN, + HEADER_ENCODING, + HEADER_ACTOR_QUERY, +} from "@/actor/router-endpoints"; export const CreateRequestSchema = z.object({ name: z.string(), @@ -36,16 +44,32 @@ export const ActorQuerySchema = z.union([ }), ]); -export const ConnectQuerySchema = z.object({ - query: ActorQuerySchema, - encoding: EncodingSchema, - params: z.string().optional(), +export const ConnectRequestSchema = z.object({ + query: ActorQuerySchema.describe(HEADER_ACTOR_QUERY), + encoding: EncodingSchema.describe(HEADER_ENCODING), + connParams: z.string().optional().describe(HEADER_CONN_PARAMS), +}); + +export const ConnectWebSocketRequestSchema = z.object({ + query: ActorQuerySchema.describe("query"), + encoding: EncodingSchema.describe("encoding"), +}); + +export const ConnMessageRequestSchema = z.object({ + actorId: z.string().describe(HEADER_ACTOR_ID), + connId: z.string().describe(HEADER_CONN_ID), + encoding: EncodingSchema.describe(HEADER_ENCODING), + connToken: z.string().describe(HEADER_CONN_TOKEN), +}); + +export const ResolveRequestSchema = z.object({ + query: ActorQuerySchema.describe(HEADER_ACTOR_QUERY), }); export type ActorQuery = z.infer; export type GetForKeyRequest = z.infer; export type GetOrCreateRequest = z.infer; -export type ConnectQuery = z.infer; +export type ConnectQuery = z.infer; /** * Interface representing a request to create an actor. */ diff --git a/packages/actor-core/src/manager/router.ts b/packages/actor-core/src/manager/router.ts index 727e511bf..8ba275b3d 100644 --- a/packages/actor-core/src/manager/router.ts +++ b/packages/actor-core/src/manager/router.ts @@ -9,6 +9,14 @@ import { handleRpc, handleSseConnect, handleWebSocketConnect, + HEADER_ACTOR_ID, + HEADER_CONN_ID, + HEADER_CONN_PARAMS, + HEADER_CONN_TOKEN, + HEADER_ENCODING, + HEADER_ACTOR_QUERY, + ALL_HEADERS, + getRequestQuery, } from "@/actor/router-endpoints"; import { assertUnreachable } from "@/actor/utils"; import type { AppConfig } from "@/app/config"; @@ -30,7 +38,12 @@ import type { WSContext } from "hono/ws"; import invariant from "invariant"; import type { ManagerDriver } from "./driver"; import { logger } from "./log"; -import { ConnectQuerySchema } from "./protocol/query"; +import { + ConnectRequestSchema, + ConnectWebSocketRequestSchema, + ConnMessageRequestSchema, + ResolveRequestSchema, +} from "./protocol/query"; import type { ActorQuery } from "./protocol/query"; type ProxyMode = @@ -84,15 +97,20 @@ export function createManagerRouter( app.use("*", loggerMiddleware(logger())); if (appConfig.cors) { + const corsConfig = appConfig.cors; + app.use("*", async (c, next) => { const path = c.req.path; // Don't apply to WebSocket routes - if (path === "/actors/connect/websocket") { + if (path === "/actors/connect/websocket" || path === "/inspect") { return next(); } - return cors(appConfig.cors)(c, next); + return cors({ + ...corsConfig, + allowHeaders: [...(appConfig.cors?.allowHeaders ?? []), ...ALL_HEADERS], + })(c, next); }); } @@ -108,27 +126,21 @@ export function createManagerRouter( // Resolve actor ID from query app.post("/actors/resolve", async (c) => { - const encoding = getRequestEncoding(c.req); + const encoding = getRequestEncoding(c.req, false); logger().debug("resolve request encoding", { encoding }); - // Get query parameters for actor lookup - const queryParam = c.req.query("query"); - if (!queryParam) { - logger().error("missing query parameter for resolve"); - throw new errors.MissingRequiredParameters(["query"]); - } - - // Parse the query JSON and validate with schema - let parsedQuery: ActorQuery; - try { - parsedQuery = JSON.parse(queryParam as string); - } catch (error) { - logger().error("invalid query json for resolve", { error }); - throw new errors.InvalidQueryJSON(error); + const params = ResolveRequestSchema.safeParse({ + query: getRequestQuery(c, false), + }); + if (!params.success) { + logger().error("invalid connection parameters", { + error: params.error, + }); + throw new errors.InvalidRequest(params.error); } // Get the actor ID and meta - const { actorId, meta } = await queryActor(c, parsedQuery, driver); + const { actorId, meta } = await queryActor(c, params.data.query, driver); logger().debug("resolved actor", { actorId, meta }); invariant(actorId, "Missing actor ID"); @@ -145,19 +157,20 @@ export function createManagerRouter( let encoding: Encoding | undefined; try { - encoding = getRequestEncoding(c.req); - logger().debug("websocket connection request received", { encoding }); + logger().debug("websocket connection request received"); - const params = ConnectQuerySchema.safeParse({ - query: parseQuery(c), + // We can't use the standard headers with WebSockets + // + // All other information will be sent over the socket itself, since that data needs to be E2EE + const params = ConnectWebSocketRequestSchema.safeParse({ + query: getRequestQuery(c, true), encoding: c.req.query("encoding"), - params: c.req.query("params"), }); if (!params.success) { logger().error("invalid connection parameters", { error: params.error, }); - throw new errors.InvalidQueryFormat(params.error); + throw new errors.InvalidRequest(params.error); } // Get the actor ID and meta @@ -185,13 +198,9 @@ export function createManagerRouter( })(c, noopNext()); } else if ("custom" in handler.proxyMode) { logger().debug("using custom proxy mode for websocket connection"); - let pathname = `/connect/websocket?encoding=${params.data.encoding}`; - if (params.data.params) { - pathname += `¶ms=${params.data.params}`; - } return await handler.proxyMode.custom.onProxyWebSocket( c, - pathname, + `/connect/websocket?encoding=${params.data.encoding}`, actorId, meta, ); @@ -247,20 +256,20 @@ export function createManagerRouter( app.get("/actors/connect/sse", async (c) => { let encoding: Encoding | undefined; try { - encoding = getRequestEncoding(c.req); + encoding = getRequestEncoding(c.req, false); logger().debug("sse connection request received", { encoding }); - const params = ConnectQuerySchema.safeParse({ - query: parseQuery(c), - encoding: c.req.query("encoding"), - params: c.req.query("params"), + const params = ConnectRequestSchema.safeParse({ + query: getRequestQuery(c, false), + encoding: c.req.header(HEADER_ENCODING), + params: c.req.header(HEADER_CONN_PARAMS), }); if (!params.success) { logger().error("invalid connection parameters", { error: params.error, }); - throw new errors.InvalidQueryFormat(params.error); + throw new errors.InvalidRequest(params.error); } const query = params.data.query; @@ -284,11 +293,11 @@ export function createManagerRouter( } else if ("custom" in handler.proxyMode) { logger().debug("using custom proxy mode for sse connection"); const url = new URL("http://actor/connect/sse"); - url.searchParams.set("encoding", params.data.encoding); - if (params.data.params) { - url.searchParams.set("params", params.data.params); - } const proxyRequest = new Request(url, c.req.raw); + proxyRequest.headers.set(HEADER_ENCODING, params.data.encoding); + if (params.data.connParams) { + proxyRequest.headers.set(HEADER_CONN_PARAMS, params.data.connParams); + } return await handler.proxyMode.custom.onProxyRequest( c, proxyRequest, @@ -357,24 +366,21 @@ export function createManagerRouter( const rpcName = c.req.param("rpc"); logger().debug("rpc call received", { rpcName }); - // Get query parameters for actor lookup - const queryParam = c.req.query("query"); - if (!queryParam) { - logger().error("missing query parameter for rpc"); - throw new errors.MissingRequiredParameters(["query"]); - } + const params = ConnectRequestSchema.safeParse({ + query: getRequestQuery(c, false), + encoding: c.req.header(HEADER_ENCODING), + params: c.req.header(HEADER_CONN_PARAMS), + }); - // Parse the query JSON and validate with schema - let parsedQuery: ActorQuery; - try { - parsedQuery = JSON.parse(queryParam as string); - } catch (error) { - logger().error("invalid query json for rpc", { error }); - throw new errors.InvalidQueryJSON(error); + if (!params.success) { + logger().error("invalid connection parameters", { + error: params.error, + }); + throw new errors.InvalidRequest(params.error); } // Get the actor ID and meta - const { actorId, meta } = await queryActor(c, parsedQuery, driver); + const { actorId, meta } = await queryActor(c, params.data.query, driver); logger().debug("found actor for rpc", { actorId, meta }); invariant(actorId, "Missing actor ID"); @@ -416,41 +422,22 @@ export function createManagerRouter( }); // Proxy connection messages to actor - app.post("/actors/connections/:conn/message", async (c) => { + app.post("/actors/message", async (c) => { logger().debug("connection message request received"); try { - const connId = c.req.param("conn"); - const connToken = c.req.query("connectionToken"); - const encoding = c.req.query("encoding"); - - // Get query parameters for actor lookup - const queryParam = c.req.query("query"); - if (!queryParam) { - throw new errors.MissingRequiredParameters(["query"]); - } - - // Check other required parameters - const missingParams: string[] = []; - if (!connToken) missingParams.push("connectionToken"); - if (!encoding) missingParams.push("encoding"); - - if (missingParams.length > 0) { - throw new errors.MissingRequiredParameters(missingParams); - } - - // Parse the query JSON and validate with schema - let parsedQuery: ActorQuery; - try { - parsedQuery = JSON.parse(queryParam as string); - } catch (error) { - logger().error("invalid query json", { error }); - throw new errors.InvalidQueryJSON(error); + const params = ConnMessageRequestSchema.safeParse({ + actorId: c.req.header(HEADER_ACTOR_ID), + connId: c.req.header(HEADER_CONN_ID), + encoding: c.req.header(HEADER_ENCODING), + connToken: c.req.header(HEADER_CONN_TOKEN), + }); + if (!params.success) { + logger().error("invalid connection parameters", { + error: params.error, + }); + throw new errors.InvalidRequest(params.error); } - - // Get the actor ID and meta - const { actorId, meta } = await queryActor(c, parsedQuery, driver); - invariant(actorId, "Missing actor ID"); - logger().debug("connection message to actor", { connId, actorId, meta }); + const { actorId, connId, encoding, connToken } = params.data; // Handle based on mode if ("inline" in handler.proxyMode) { @@ -467,14 +454,16 @@ export function createManagerRouter( } else if ("custom" in handler.proxyMode) { logger().debug("using custom proxy mode for connection message"); const url = new URL(`http://actor/connections/${connId}/message`); - url.searchParams.set("connectionToken", connToken!); - url.searchParams.set("encoding", encoding!); + const proxyRequest = new Request(url, c.req.raw); + proxyRequest.headers.set(HEADER_ENCODING, encoding); + proxyRequest.headers.set(HEADER_CONN_ID, connId); + proxyRequest.headers.set(HEADER_CONN_TOKEN, connToken); + return await handler.proxyMode.custom.onProxyRequest( c, proxyRequest, actorId, - meta, ); } else { assertUnreachable(handler.proxyMode); @@ -571,7 +560,7 @@ export async function queryActor( meta: createOutput.meta, }; } else { - throw new errors.InvalidQueryFormat("Invalid query format"); + throw new errors.InvalidRequest("Invalid query format"); } logger().debug("actor query result", { @@ -585,21 +574,3 @@ export async function queryActor( function noopNext(): Next { return async () => {}; } - -function parseQuery(c: HonoContext): unknown { - // Get query parameters for actor lookup - const queryParam = c.req.query("query"); - if (!queryParam) { - logger().error("missing query parameter for rpc"); - throw new errors.MissingRequiredParameters(["query"]); - } - - // Parse the query JSON and validate with schema - try { - const parsed = JSON.parse(queryParam as string); - return parsed; - } catch (error) { - logger().error("invalid query json for rpc", { error }); - throw new errors.InvalidQueryJSON(error); - } -} diff --git a/packages/actor-core/tests/basic.test.ts b/packages/actor-core/tests/basic.test.ts deleted file mode 100644 index 56c81de36..000000000 --- a/packages/actor-core/tests/basic.test.ts +++ /dev/null @@ -1,80 +0,0 @@ -import { actor, setup } from "@/mod"; -import { test, expect } from "vitest"; -import { setupTest } from "@/test/mod"; - -test("basic actor setup", async (c) => { - const counter = actor({ - state: { count: 0 }, - actions: { - increment: (c, x: number) => { - c.state.count += x; - c.broadcast("newCount", c.state.count); - return c.state.count; - }, - }, - }); - - const app = setup({ - actors: { counter }, - }); - - const { client } = await setupTest(c, app); - - const counterInstance = client.counter.getOrCreate(); - await counterInstance.increment(1); -}); - -test("actorhandle.resolve resolves actor ID", async (c) => { - const testActor = actor({ - state: { value: "" }, - actions: { - getValue: (c) => c.state.value, - }, - }); - - const app = setup({ - actors: { testActor }, - }); - - const { client } = await setupTest(c, app); - - // Get a handle to the actor using a key - const handle = client.testActor.getOrCreate("test-key"); - - // Resolve should work without errors and return void - await handle.resolve(); - - // After resolving, we should be able to call an action - const value = await handle.getValue(); - expect(value).toBeDefined(); -}); - -test("client.create creates a new actor", async (c) => { - const testActor = actor({ - state: { createdVia: "" }, - actions: { - setCreationMethod: (c, method: string) => { - c.state.createdVia = method; - return c.state.createdVia; - }, - getCreationMethod: (c) => c.state.createdVia, - }, - }); - - const app = setup({ - actors: { testActor }, - }); - - const { client } = await setupTest(c, app); - - // Create a new actor using client.create - const handle = await client.testActor.create("created-actor"); - - // Set some state to confirm it works - const result = await handle.setCreationMethod("client.create"); - expect(result).toBe("client.create"); - - // Verify we can retrieve the state - const method = await handle.getCreationMethod(); - expect(method).toBe("client.create"); -}); diff --git a/packages/actor-core/tsconfig.json b/packages/actor-core/tsconfig.json index 42a144203..2c7bec6f5 100644 --- a/packages/actor-core/tsconfig.json +++ b/packages/actor-core/tsconfig.json @@ -8,5 +8,5 @@ "actor-core": ["./src/mod.ts"] } }, - "include": ["src/**/*", "tests/**/*"] + "include": ["src/**/*", "tests/**/*", "fixtures/driver-test-suite/**/*"] } diff --git a/vitest.base.ts b/vitest.base.ts index c419adc59..f69bd00ea 100644 --- a/vitest.base.ts +++ b/vitest.base.ts @@ -6,8 +6,6 @@ export default { sequence: { concurrent: true, }, - // Increase timeout for proxy tests - testTimeout: 15_000, env: { // Enable logging _LOG_LEVEL: "DEBUG",