diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3b0790412..e1c96b2c4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -25,34 +25,34 @@ jobs: steps: - uses: actions/checkout@v4 - # Setup Node.js - - name: Set up Node.js - uses: actions/setup-node@v4 - with: - node-version: '22.14' - # Note: We're not using the built-in cache here because we need to use corepack - - - name: Setup Corepack - run: corepack enable - - - id: yarn-cache-dir-path - name: Get yarn cache directory path - run: echo "dir=$(yarn config get cacheFolder)" >> $GITHUB_OUTPUT - - - name: Cache dependencies - uses: actions/cache@v3 - id: cache - with: - path: | - ${{ steps.yarn-cache-dir-path.outputs.dir }} - .turbo - key: ${{ runner.os }}-deps-${{ hashFiles('**/yarn.lock') }}-${{ github.sha }} - restore-keys: | - ${{ runner.os }}-deps-${{ hashFiles('**/yarn.lock') }}- - ${{ runner.os }}-deps- - - - name: Install dependencies - run: yarn install + # # Setup Node.js + # - name: Set up Node.js + # uses: actions/setup-node@v4 + # with: + # node-version: '22.14' + # # Note: We're not using the built-in cache here because we need to use corepack + # + # - name: Setup Corepack + # run: corepack enable + # + # - id: yarn-cache-dir-path + # name: Get yarn cache directory path + # run: echo "dir=$(yarn config get cacheFolder)" >> $GITHUB_OUTPUT + # + # - name: Cache dependencies + # uses: actions/cache@v3 + # id: cache + # with: + # path: | + # ${{ steps.yarn-cache-dir-path.outputs.dir }} + # .turbo + # key: ${{ runner.os }}-deps-${{ hashFiles('**/yarn.lock') }}-${{ github.sha }} + # restore-keys: | + # ${{ runner.os }}-deps-${{ hashFiles('**/yarn.lock') }}- + # ${{ runner.os }}-deps- + # + # - name: Install dependencies + # run: yarn install # - name: Run actor-core tests # # TODO: Add back diff --git a/CLAUDE.md b/CLAUDE.md index 7f4ec0563..b06e464ab 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -85,6 +85,9 @@ This ensures imports resolve correctly across different build environments and p - Extend from `ActorError` base class - Use `UserError` for client-safe errors - Use `InternalError` for internal errors +- Don't try to fix type issues by casting to unknown or any. If you need to do this, then stop and ask me to manually intervene. +- Write log messages in lowercase +- Instead of returning raw HTTP responses with c.json, use or write an error in packages/actor-core/src/actor/errors.ts and throw that instead. The middleware will automatically serialize the response for you. ## Project Structure diff --git a/examples/chat-room/scripts/cli.ts b/examples/chat-room/scripts/cli.ts index 954dcd343..6e7465b51 100644 --- a/examples/chat-room/scripts/cli.ts +++ b/examples/chat-room/scripts/cli.ts @@ -10,7 +10,7 @@ async function main() { // connect to chat room - now accessed via property // can still pass parameters like room - const chatRoom = await client.chatRoom.connect(room, { + const chatRoom = client.chatRoom.connect(room, { params: { room }, }); diff --git a/examples/chat-room/scripts/connect.ts b/examples/chat-room/scripts/connect.ts index 59378aef2..b8e0ab3ea 100644 --- a/examples/chat-room/scripts/connect.ts +++ b/examples/chat-room/scripts/connect.ts @@ -7,7 +7,7 @@ async function main() { const client = createClient(process.env.ENDPOINT ?? "http://localhost:6420"); // connect to chat room - now accessed via property - const chatRoom = await client.chatRoom.connect(); + const chatRoom = client.chatRoom.connect(); // call action to get existing messages const messages = await chatRoom.getHistory(); diff --git a/examples/chat-room/tests/chat-room.test.ts b/examples/chat-room/tests/chat-room.test.ts index 527e9c223..88a801efd 100644 --- a/examples/chat-room/tests/chat-room.test.ts +++ b/examples/chat-room/tests/chat-room.test.ts @@ -6,7 +6,7 @@ test("chat room should handle messages", async (test) => { const { client } = await setupTest(test, app); // Connect to chat room - const chatRoom = await client.chatRoom.connect(); + const chatRoom = client.chatRoom.connect(); // Initial history should be empty const initialMessages = await chatRoom.getHistory(); diff --git a/examples/counter/scripts/connect.ts b/examples/counter/scripts/connect.ts index 0e76fb603..b4f83b252 100644 --- a/examples/counter/scripts/connect.ts +++ b/examples/counter/scripts/connect.ts @@ -5,7 +5,7 @@ import type { App } from "../actors/app"; async function main() { const client = createClient(process.env.ENDPOINT ?? "http://localhost:6420"); - const counter = await client.counter.connect() + const counter = client.counter.connect() counter.on("newCount", (count: number) => console.log("Event:", count)); diff --git a/examples/counter/tests/counter.test.ts b/examples/counter/tests/counter.test.ts index 26259b9aa..25861b474 100644 --- a/examples/counter/tests/counter.test.ts +++ b/examples/counter/tests/counter.test.ts @@ -4,7 +4,7 @@ import { app } from "../actors/app"; test("it should count", async (test) => { const { client } = await setupTest(test, app); - const counter = await client.counter.connect(); + const counter = client.counter.connect(); // Test initial count expect(await counter.getCount()).toBe(0); diff --git a/examples/linear-coding-agent/src/server/index.ts b/examples/linear-coding-agent/src/server/index.ts index 75fc02bb4..f395d38b4 100644 --- a/examples/linear-coding-agent/src/server/index.ts +++ b/examples/linear-coding-agent/src/server/index.ts @@ -75,7 +75,7 @@ server.post('/api/webhook/linear', async (c) => { // Create or get a coding agent instance with the issue ID as a key // This ensures each issue gets its own actor instance console.log(`[SERVER] Getting actor for issue: ${issueId}`); - const actorClient = await client.codingAgent.connect(issueId); + const actorClient = client.codingAgent.connect(issueId); // Initialize the agent if needed console.log(`[SERVER] Initializing actor for issue: ${issueId}`); diff --git a/examples/resend-streaks/tests/user.test.ts b/examples/resend-streaks/tests/user.test.ts index 12b945404..69fe0c60b 100644 --- a/examples/resend-streaks/tests/user.test.ts +++ b/examples/resend-streaks/tests/user.test.ts @@ -26,7 +26,7 @@ beforeEach(() => { test("streak tracking with time zone signups", async (t) => { const { client } = await setupTest(t, app); - const actor = await client.user.connect(); + const actor = client.user.connect(); // Sign up with specific time zone const signupResult = await actor.completeSignUp( diff --git a/packages/actor-core-cli/package.json b/packages/actor-core-cli/package.json index dc64232bd..00af27cb1 100644 --- a/packages/actor-core-cli/package.json +++ b/packages/actor-core-cli/package.json @@ -38,6 +38,7 @@ "bundle-require": "^5.1.0", "chokidar": "^4.0.3", "esbuild": "^0.25.1", + "invariant": "^2.2.4", "open": "^10.1.0", "yoga-wasm-web": "0.3.3" }, @@ -46,6 +47,7 @@ "@rivet-gg/api": "^24.6.2", "@sentry/esbuild-plugin": "^3.2.0", "@sentry/node": "^9.3.0", + "@types/invariant": "^2", "@types/micromatch": "^4", "@types/react": "^18.3", "@types/semver": "^7.5.8", diff --git a/packages/actor-core-cli/src/cli.ts b/packages/actor-core-cli/src/cli.ts index ff11779e3..9284e4294 100644 --- a/packages/actor-core-cli/src/cli.ts +++ b/packages/actor-core-cli/src/cli.ts @@ -1,5 +1,5 @@ import { PACKAGE_JSON } from "./macros" with { type: "macro" }; -import { create, deploy, dev, program } from "./mod"; +import { create, deploy, dev, endpoint, program } from "./mod"; export default program .name(PACKAGE_JSON.name) @@ -8,4 +8,5 @@ export default program .addCommand(deploy) .addCommand(create) .addCommand(dev) + .addCommand(endpoint) .parse(); diff --git a/packages/actor-core-cli/src/commands/deploy.tsx b/packages/actor-core-cli/src/commands/deploy.tsx index 156051ea2..cb7265bb8 100644 --- a/packages/actor-core-cli/src/commands/deploy.tsx +++ b/packages/actor-core-cli/src/commands/deploy.tsx @@ -57,10 +57,10 @@ export const deploy = new Command() await workflow( "Deploy actors to Rivet", - async function* (ctx) { + async function*(ctx) { const { config, cli } = yield* ctx.task( "Prepare", - async function* (ctx) { + async function*(ctx) { const config = yield* validateConfigTask(ctx, cwd, appPath); const cli = yield* ctx.task("Locale rivet-cli", async (ctx) => { @@ -96,7 +96,7 @@ export const deploy = new Command() ); const { accessToken, projectName, envName, endpoint } = - yield* ctx.task("Auth with Rivet", async function* (ctx) { + yield* ctx.task("Auth with Rivet", async function*(ctx) { const { stdout } = await exec`${cli} metadata auth-status`; const isLogged = stdout === "true"; @@ -164,7 +164,7 @@ export const deploy = new Command() if (!opts.skipManager) { manager = yield* ctx.task( "Deploy ActorCore", - async function* (ctx) { + async function*(ctx) { yield fs.mkdir(path.join(cwd, ".actorcore"), { recursive: true, }); @@ -185,7 +185,7 @@ export const deploy = new Command() ); const output = - await exec`${cli} publish manager --env ${envName} --tags access=private ${entrypoint}`; + await exec`${cli} publish manager --env ${envName} --tags role=manager,framework=actor-core,framework-version=${VERSION} --unstable-minify false ${entrypoint}`; if (output.exitCode !== 0) { throw ctx.error("Failed to deploy ActorCore.", { hint: "Check the logs above for more information.", @@ -257,11 +257,32 @@ export const deploy = new Command() environment: envName, body: { region: region.id, - tags: { name: "manager", owner: "rivet" }, - buildTags: { name: "manager", current: "true" }, + tags: { + name: "manager", + role: "manager", + framework: "actor-core", + }, + buildTags: { + name: "manager", + role: "manager", + framework: "actor-core", + current: "true", + }, runtime: { environment: { RIVET_SERVICE_TOKEN: serviceToken, + ...(process.env._RIVET_MANAGER_LOG_LEVEL + ? { + _LOG_LEVEL: + process.env._RIVET_MANAGER_LOG_LEVEL, + } + : {}), + ...(process.env._RIVET_ACTOR_LOG_LEVEL + ? { + _ACTOR_LOG_LEVEL: + process.env._RIVET_ACTOR_LOG_LEVEL, + } + : {}), }, }, network: { @@ -290,10 +311,9 @@ export const deploy = new Command() config.app.config.actors, ).entries()) { yield* ctx.task( - `Deploy & upload "${actorName}" build (${idx + 1}/${ - Object.keys(config.app.config.actors).length + `Deploy & upload "${actorName}" build (${idx + 1}/${Object.keys(config.app.config.actors).length })`, - async function* (ctx) { + async function*(ctx) { yield fs.mkdir(path.join(cwd, ".actorcore"), { recursive: true, }); @@ -317,18 +337,18 @@ export const deploy = new Command() `, ); - const actorTags = { - access: "public", + const buildTags = { + role: "actor", framework: "actor-core", "framework-version": VERSION, }; - const tagsArray = Object.entries(actorTags) + const tagsArray = Object.entries(buildTags) .map(([key, value]) => `${key}=${value}`) .join(","); const output = - await exec`${cli} publish --env=${envName} --tags=${tagsArray} ${actorName} ${entrypoint}`; + await exec`${cli} publish --env=${envName} --tags=${tagsArray} --unstable-minify false ${actorName} ${entrypoint}`; if (output.exitCode !== 0) { throw ctx.error("Failed to deploy & upload actors.", { diff --git a/packages/actor-core-cli/src/commands/endpoint.tsx b/packages/actor-core-cli/src/commands/endpoint.tsx new file mode 100644 index 000000000..970e77ca5 --- /dev/null +++ b/packages/actor-core-cli/src/commands/endpoint.tsx @@ -0,0 +1,177 @@ +import * as fs from "node:fs/promises"; +import path from "node:path"; +import { Argument, Command, Option } from "commander"; +import { $ } from "execa"; +import semver from "semver"; +import which from "which"; +import { MIN_RIVET_CLI_VERSION } from "../constants"; +import { workflow } from "../workflow"; +import { RivetClient } from "@rivet-gg/api"; +import { + createActorEndpoint, + createRivetApi, + getServiceToken, +} from "../utils/rivet-api"; +import { validateConfigTask } from "../workflows/validate-config"; +import invariant from "invariant"; + +export const endpoint = new Command() + .name("endpoint") + .description( + "Get the application endpoint URL for your deployed application in Rivet.", + ) + .addArgument( + new Argument("", "The platform to get the endpoint for").choices([ + "rivet", + ]), + ) + .addOption( + new Option( + "-e, --env [env]", + "Specify environment to get the endpoint for", + ), + ) + .addOption( + new Option("--plain", "Output only the URL without any additional text"), + ) + // No actor option needed - returns the first available endpoint + .action( + async ( + platform, + opts: { + env?: string; + plain?: boolean; + }, + ) => { + const cwd = process.cwd(); + + const exec = $({ + cwd, + env: { ...process.env, npm_config_yes: "true" }, + }); + + await workflow( + "Get actor endpoint", + async function*(ctx) { + const cli = yield* ctx.task("Locate rivet-cli", async (ctx) => { + let cliLocation = process.env.RIVET_CLI_PATH || null; + + if (!cliLocation) { + cliLocation = await which("rivet-cli", { nothrow: true }); + } + + if (!cliLocation) { + cliLocation = await which("rivet", { nothrow: true }); + } + + if (cliLocation) { + // check version + const { stdout } = await exec`${cliLocation} --version`; + const semVersion = semver.coerce( + stdout.split("\n")[2].split(" ")[1].trim(), + ); + + if (semVersion) { + if (semver.gte(semVersion, MIN_RIVET_CLI_VERSION)) { + return cliLocation; + } + } + } + + return ["npx", "@rivet-gg/cli@latest"]; + }); + + const { accessToken, projectName, envName, endpoint } = + yield* ctx.task("Auth with Rivet", async function*(ctx) { + const { stdout } = await exec`${cli} metadata auth-status`; + const isLogged = stdout === "true"; + + let endpoint: string | undefined; + if (!isLogged) { + const isUsingCloud = yield* ctx.prompt( + "Are you using Rivet Cloud?", + { + type: "confirm", + }, + ); + + endpoint = "https://api.rivet.gg"; + if (!isUsingCloud) { + endpoint = yield* ctx.prompt("What is the API endpoint?", { + type: "text", + defaultValue: "http://localhost:8080", + }); + } + + await exec`${cli} login --api-endpoint=${endpoint}`; + } else { + const { stdout } = await exec`${cli} metadata api-endpoint`; + endpoint = stdout; + } + + const { stdout: accessToken } = + await exec`${cli} metadata access-token`; + + const { stdout: rawEnvs } = await exec`${cli} env ls --json`; + const envs = JSON.parse(rawEnvs); + + const envName = + opts.env ?? + (yield* ctx.prompt("Select environment", { + type: "select", + choices: envs.map( + (env: { display_name: string; name_id: string }) => ({ + label: env.display_name, + value: env.name_id, + }), + ), + })); + + const { stdout: projectName } = + await exec`${cli} metadata project-name-id`; + + return { accessToken, projectName, envName, endpoint }; + }); + + const rivet = new RivetClient({ + token: accessToken, + environment: endpoint, + }); + + yield* ctx.task("Get actor endpoint", async function*(ctx) { + const { actors } = await rivet.actor.list({ + environment: envName, + project: projectName, + includeDestroyed: false, + tagsJson: JSON.stringify({ + name: "manager", + role: "manager", + framework: "actor-core", + }), + }); + + if (actors.length === 0) { + throw ctx.error("No managers found for this project.", { + hint: "Make sure you have deployed first.", + }); + } + + const managerActor = actors[0]; + const port = managerActor.network.ports.http; + invariant(port, "http port does not exist on manager"); + invariant(port.url, "port has no url"); + + if (opts.plain) { + console.log(port.url); + } else { + yield ctx.log(`Application endpoint: ${port.url}`); + } + }); + }, + { + showLabel: !opts.plain, + quiet: opts.plain, + }, + ).render(); + }, + ); diff --git a/packages/actor-core-cli/src/mod.ts b/packages/actor-core-cli/src/mod.ts index dc61d70cb..9e3b0f8ad 100644 --- a/packages/actor-core-cli/src/mod.ts +++ b/packages/actor-core-cli/src/mod.ts @@ -2,5 +2,6 @@ import "./instrument"; export { deploy } from "./commands/deploy"; export { create, action as createAction } from "./commands/create"; export { dev } from "./commands/dev"; +export { endpoint } from "./commands/endpoint"; export { program } from "commander"; export default {}; diff --git a/packages/actor-core-cli/src/workflow.tsx b/packages/actor-core-cli/src/workflow.tsx index f402bc08e..74be8f42d 100644 --- a/packages/actor-core-cli/src/workflow.tsx +++ b/packages/actor-core-cli/src/workflow.tsx @@ -216,6 +216,7 @@ export interface Context { interface TaskOptions { showLabel?: boolean; success?: ReactNode; + quiet?: boolean; } interface RunnerToolbox { diff --git a/packages/actor-core/src/actor/errors.ts b/packages/actor-core/src/actor/errors.ts index 1246703da..fa71a5969 100644 --- a/packages/actor-core/src/actor/errors.ts +++ b/packages/actor-core/src/actor/errors.ts @@ -15,6 +15,7 @@ interface ActorErrorOptions extends ErrorOptions { export class ActorError extends Error { public public: boolean; public metadata?: unknown; + public statusCode: number = 500; constructor( public readonly code: string, @@ -24,6 +25,22 @@ export class ActorError extends Error { super(message, { cause: opts?.cause }); this.public = opts?.public ?? false; this.metadata = opts?.metadata; + + // Set status code based on error type + if (opts?.public) { + this.statusCode = 400; // Bad request for public errors + } + } + + /** + * Serialize error for HTTP response + */ + serializeForHttp() { + return { + type: this.code, + message: this.message, + metadata: this.metadata, + }; } } @@ -194,3 +211,73 @@ export class UserError extends ActorError { }); } } + +// Proxy-related errors + +export class MissingRequiredParameters extends ActorError { + constructor(missingParams: string[]) { + super( + "missing_required_parameters", + `Missing required parameters: ${missingParams.join(", ")}`, + { public: true } + ); + } +} + +export class InvalidQueryJSON extends ActorError { + constructor(error?: unknown) { + super( + "invalid_query_json", + `Invalid query JSON: ${error}`, + { public: true, cause: error } + ); + } +} + +export class InvalidQueryFormat extends ActorError { + constructor(error?: unknown) { + super( + "invalid_query_format", + `Invalid query format: ${error}`, + { public: true, cause: error } + ); + } +} + +export class ActorNotFound extends ActorError { + constructor(identifier?: string) { + super( + "actor_not_found", + identifier ? `Actor not found: ${identifier}` : "Actor not found", + { public: true } + ); + } +} + +export class ProxyError extends ActorError { + constructor(operation: string, error?: unknown) { + super( + "proxy_error", + `Error proxying ${operation}: ${error}`, + { public: true, cause: error } + ); + } +} + +export class InvalidRpcRequest extends ActorError { + constructor(message: string) { + super("invalid_rpc_request", message, { public: true }); + } +} + +export class InvalidRequest extends ActorError { + constructor(message: string) { + super("invalid_request", message, { public: true }); + } +} + +export class InvalidParams extends ActorError { + constructor(message: string) { + super("invalid_params", message, { public: true }); + } +} diff --git a/packages/actor-core/src/actor/protocol/message/to-client.ts b/packages/actor-core/src/actor/protocol/message/to-client.ts index df74e8f5f..34ddd4884 100644 --- a/packages/actor-core/src/actor/protocol/message/to-client.ts +++ b/packages/actor-core/src/actor/protocol/message/to-client.ts @@ -8,6 +8,16 @@ export const InitSchema = z.object({ ct: z.string(), }); +// Used for connection errors (both during initialization and afterwards) +export const ConnectionErrorSchema = z.object({ + // Code + c: z.string(), + // Message + m: z.string(), + // Metadata + md: z.unknown().optional(), +}); + export const RpcResponseOkSchema = z.object({ // ID i: z.number().int(), @@ -46,6 +56,7 @@ export const ToClientSchema = z.object({ // Body b: z.union([ z.object({ i: InitSchema }), + z.object({ ce: ConnectionErrorSchema }), z.object({ ro: RpcResponseOkSchema }), z.object({ re: RpcResponseErrorSchema }), z.object({ ev: ToClientEventSchema }), @@ -54,6 +65,7 @@ export const ToClientSchema = z.object({ }); export type ToClient = z.infer; +export type ConnectionError = z.infer; export type RpcResponseOk = z.infer; export type RpcResponseError = z.infer; export type ToClientEvent = z.infer; diff --git a/packages/actor-core/src/actor/router.ts b/packages/actor-core/src/actor/router.ts index 62ff90a80..2ea7fc0d1 100644 --- a/packages/actor-core/src/actor/router.ts +++ b/packages/actor-core/src/actor/router.ts @@ -1,75 +1,50 @@ -import { Hono, type HonoRequest } from "hono"; -import type { UpgradeWebSocket, WSContext } from "hono/ws"; -import * as errors from "./errors"; +import { Hono, type Context as HonoContext } from "hono"; +import type { UpgradeWebSocket } from "hono/ws"; import { logger } from "./log"; -import { type Encoding, EncodingSchema } from "@/actor/protocol/serde"; -import { parseMessage } from "@/actor/protocol/message/mod"; -import * as protoHttpRpc from "@/actor/protocol/http/rpc"; -import type * as messageToServer from "@/actor/protocol/message/to-server"; -import type { InputData } from "@/actor/protocol/serde"; -import { type SSEStreamingApi, streamSSE } from "hono/streaming"; import { cors } from "hono/cors"; -import { assertUnreachable } from "./utils"; -import { handleRouteError, handleRouteNotFound } from "@/common/router"; -import { deconstructError, stringifyError } from "@/common/utils"; +import { + handleRouteError, + handleRouteNotFound, + loggerMiddleware, +} from "@/common/router"; import type { DriverConfig } from "@/driver-helpers/config"; import type { AppConfig } from "@/app/config"; import { type ActorInspectorConnHandler, createActorInspectorRouter, } from "@/inspector/actor"; - -export interface ConnectWebSocketOpts { - req: HonoRequest; - encoding: Encoding; - params: unknown; -} - -export interface ConnectWebSocketOutput { - onOpen: (ws: WSContext) => Promise; - onMessage: (message: messageToServer.ToServer) => Promise; - onClose: () => Promise; -} - -export interface ConnectSseOpts { - req: HonoRequest; - encoding: Encoding; - params: unknown; -} - -export interface ConnectSseOutput { - onOpen: (stream: SSEStreamingApi) => Promise; - onClose: () => Promise; -} - -export interface RpcOpts { - req: HonoRequest; - params: unknown; - rpcName: string; - rpcArgs: unknown[]; -} - -export interface RpcOutput { - output: unknown; -} - -export interface ConnsMessageOpts { - req: HonoRequest; - connId: string; - connToken: string; - message: messageToServer.ToServer; -} +import invariant from "invariant"; +import { + type ConnectWebSocketOpts, + type ConnectWebSocketOutput, + type ConnectSseOpts, + type ConnectSseOutput, + type RpcOpts, + type RpcOutput, + type ConnsMessageOpts, + type ConnectionHandlers, + handleWebSocketConnect, + handleSseConnect, + handleRpc, + handleConnectionMessage, +} from "./router_endpoints"; + +export type { + ConnectWebSocketOpts, + ConnectWebSocketOutput, + ConnectSseOpts, + ConnectSseOutput, + RpcOpts, + RpcOutput, + ConnsMessageOpts, +}; export interface ActorRouterHandler { - // Pass this value directly from Hono - upgradeWebSocket?: UpgradeWebSocket; + getActorId: () => Promise; + + // Connection handlers as a required subobject + connectionHandlers: ConnectionHandlers; - onConnectWebSocket?( - opts: ConnectWebSocketOpts, - ): Promise; - onConnectSse(opts: ConnectSseOpts): Promise; - onRpc(opts: RpcOpts): Promise; - onConnMessage(opts: ConnsMessageOpts): Promise; onConnectInspector?: ActorInspectorConnHandler; } @@ -85,7 +60,13 @@ export function createActorRouter( ): Hono { const app = new Hono(); + const upgradeWebSocket = driverConfig.getUpgradeWebSocket?.(app); + + app.use("*", loggerMiddleware(logger())); + // Apply CORS middleware if configured + // + //This is only relevant if the actor is exposed directly publicly if (appConfig.cors) { app.use("*", async (c, next) => { const path = c.req.path; @@ -109,105 +90,21 @@ export function createActorRouter( return c.text("ok"); }); - if (handler.upgradeWebSocket && handler.onConnectWebSocket) { + // Use the handlers from connectionHandlers + const handlers = handler.connectionHandlers; + + if (upgradeWebSocket && handlers.onConnectWebSocket) { app.get( "/connect/websocket", - handler.upgradeWebSocket(async (c) => { - try { - if (!handler.onConnectWebSocket) - throw new Error("onConnectWebSocket is not implemented"); - - const encoding = getRequestEncoding(c.req); - const parameters = getRequestConnParams( - c.req, - appConfig, - driverConfig, - ); - - const wsHandler = await handler.onConnectWebSocket({ - req: c.req, - encoding, - params: parameters, - }); - - const { promise: onOpenPromise, resolve: onOpenResolve } = - Promise.withResolvers(); - return { - onOpen: async (_evt, ws) => { - try { - logger().debug("websocket open"); - - // Call handler - await wsHandler.onOpen(ws); - - // Resolve promise - onOpenResolve(undefined); - } catch (error) { - const { code } = deconstructError(error, logger(), { - wsEvent: "open", - }); - ws.close(1011, code); - } - }, - onMessage: async (evt, ws) => { - try { - await onOpenPromise; - - logger().debug("received message"); - - const value = evt.data.valueOf() as InputData; - const message = await parseMessage(value, { - encoding: encoding, - maxIncomingMessageSize: appConfig.maxIncomingMessageSize, - }); - - await wsHandler.onMessage(message); - } catch (error) { - const { code } = deconstructError(error, logger(), { - wsEvent: "message", - }); - ws.close(1011, code); - } - }, - onClose: async (event) => { - try { - await onOpenPromise; - - if (event.wasClean) { - logger().info("websocket closed", { - code: event.code, - reason: event.reason, - wasClean: event.wasClean, - }); - } else { - logger().warn("websocket closed", { - code: event.code, - reason: event.reason, - wasClean: event.wasClean, - }); - } - - await wsHandler.onClose(); - } catch (error) { - deconstructError(error, logger(), { wsEvent: "close" }); - } - }, - onError: async (error) => { - try { - await onOpenPromise; - - // Actors don't need to know about this, since it's abstracted - // away - logger().warn("websocket error"); - } catch (error) { - deconstructError(error, logger(), { wsEvent: "error" }); - } - }, - }; - } catch (error) { - deconstructError(error, logger(), {}); - return {}; - } + upgradeWebSocket(async (c) => { + const actorId = await handler.getActorId(); + return handleWebSocketConnect( + c as HonoContext, + appConfig, + driverConfig, + handlers.onConnectWebSocket!, + actorId, + )(); }), ); } else { @@ -220,163 +117,60 @@ export function createActorRouter( } app.get("/connect/sse", async (c) => { - const encoding = getRequestEncoding(c.req); - const parameters = getRequestConnParams(c.req, appConfig, driverConfig); - - const sseHandler = await handler.onConnectSse({ - req: c.req, - encoding, - params: parameters, - }); - - return streamSSE( + if (!handlers.onConnectSse) { + throw new Error("onConnectSse handler is required"); + } + const actorId = await handler.getActorId(); + return handleSseConnect( c, - async (stream) => { - // Create connection with validated parameters - logger().debug("sse stream open"); - - await sseHandler.onOpen(stream); - - const { promise, resolve } = Promise.withResolvers(); - - stream.onAbort(() => { - sseHandler.onClose(); - - resolve(undefined); - }); - - await promise; - }, - async (error) => { - // Actors don't need to know about this, since it's abstracted - // away - logger().warn("sse error", { error: stringifyError(error) }); - }, + appConfig, + driverConfig, + handlers.onConnectSse, + actorId, ); }); app.post("/rpc/:rpc", async (c) => { - const rpcName = c.req.param("rpc"); - try { - // TODO: Support multiple encodings - const encoding: Encoding = "json"; - const parameters = getRequestConnParams(c.req, appConfig, driverConfig); - - // Parse request body if present - const contentLength = Number(c.req.header("content-length") || "0"); - if (contentLength > appConfig.maxIncomingMessageSize) { - throw new errors.MessageTooLong(); - } - - // Parse request body according to encoding - const body = await c.req.json(); - const { data: message, success } = - protoHttpRpc.RequestSchema.safeParse(body); - if (!success) { - throw new errors.MalformedMessage("Invalid request format"); - } - const rpcArgs = message.a; - - // Callback - const { output } = await handler.onRpc({ - req: c.req, - params: parameters, - rpcName, - rpcArgs, - }); - - // Format response according to encoding - return c.json({ - o: output, - } satisfies protoHttpRpc.ResponseOk); - } catch (error) { - // Build response error information similar to WebSocket handling - - const { statusCode, code, message, metadata } = deconstructError( - error, - logger(), - { rpc: rpcName }, - ); - - return c.json( - { - c: code, - m: message, - md: metadata, - } satisfies protoHttpRpc.ResponseErr, - { status: statusCode }, - ); + if (!handlers.onRpc) { + throw new Error("onRpc handler is required"); } + const rpcName = c.req.param("rpc"); + const actorId = await handler.getActorId(); + return handleRpc( + c, + appConfig, + driverConfig, + handlers.onRpc, + rpcName, + actorId, + ); }); app.post("/connections/:conn/message", async (c) => { - try { - const encoding = getRequestEncoding(c.req); - - const connId = c.req.param("conn"); - if (!connId) { - throw new errors.ConnNotFound(connId); - } - - const connToken = c.req.query("connectionToken"); - if (!connToken) throw new errors.IncorrectConnToken(); - - // Parse request body if present - const contentLength = Number(c.req.header("content-length") || "0"); - if (contentLength > appConfig.maxIncomingMessageSize) { - throw new errors.MessageTooLong(); - } - - // Read body - let value: InputData; - if (encoding === "json") { - // Handle decoding JSON in handleMessageEvent - value = await c.req.text(); - } else if (encoding === "cbor") { - value = await c.req.arrayBuffer(); - } else { - assertUnreachable(encoding); - } - - // Parse message - const message = await parseMessage(value, { - encoding, - maxIncomingMessageSize: appConfig.maxIncomingMessageSize, - }); - - await handler.onConnMessage({ - req: c.req, - connId, - connToken, - message, - }); - - // Not data to return - return c.json({}); - } catch (error) { - // Build response error information similar to WebSocket handling - const { statusCode, code, message, metadata } = deconstructError( - error, - logger(), - {}, - ); - - return c.json( - { - c: code, - m: message, - md: metadata, - } satisfies protoHttpRpc.ResponseErr, - { status: statusCode }, - ); + if (!handlers.onConnMessage) { + throw new Error("onConnMessage handler is required"); + } + const connId = c.req.param("conn"); + const connToken = c.req.query("connectionToken"); + const actorId = await handler.getActorId(); + if (!connId || !connToken) { + throw new Error("Missing required parameters"); } + return handleConnectionMessage( + c, + appConfig, + handlers.onConnMessage, + connId, + connToken, + actorId, + ); }); if (appConfig.inspector.enabled) { app.route( "/inspect", createActorInspectorRouter( - handler.upgradeWebSocket, + upgradeWebSocket, handler.onConnectInspector, appConfig.inspector, ), @@ -388,39 +182,3 @@ export function createActorRouter( return app; } - -function getRequestEncoding(req: HonoRequest): Encoding { - const encodingRaw = req.query("encoding"); - const { data: encoding, success } = EncodingSchema.safeParse(encodingRaw); - if (!success) { - logger().warn("invalid encoding", { - encoding: encodingRaw, - }); - throw new errors.InvalidEncoding(encodingRaw); - } - - return encoding; -} - -function getRequestConnParams( - req: HonoRequest, - appConfig: AppConfig, - driverConfig: DriverConfig, -): unknown { - // Validate params size - const paramsStr = req.query("params"); - if (paramsStr && paramsStr.length > appConfig.maxConnParamLength) { - logger().warn("connection parameters too long"); - throw new errors.ConnParamsTooLong(); - } - - // Parse and validate params - try { - return typeof paramsStr === "string" ? JSON.parse(paramsStr) : undefined; - } catch (error) { - logger().warn("malformed connection parameters", { - error: stringifyError(error), - }); - throw new errors.MalformedConnParams(error); - } -} diff --git a/packages/actor-core/src/actor/router_endpoints.ts b/packages/actor-core/src/actor/router_endpoints.ts new file mode 100644 index 000000000..b3e56b84d --- /dev/null +++ b/packages/actor-core/src/actor/router_endpoints.ts @@ -0,0 +1,424 @@ +import { type HonoRequest, type Context as HonoContext } from "hono"; +import { type SSEStreamingApi, streamSSE } from "hono/streaming"; +import { type WSContext } from "hono/ws"; +import * as errors from "./errors"; +import { logger } from "./log"; +import { + type Encoding, + EncodingSchema, + serialize, + deserialize, + CachedSerializer, +} from "@/actor/protocol/serde"; +import { parseMessage } from "@/actor/protocol/message/mod"; +import * as protoHttpRpc from "@/actor/protocol/http/rpc"; +import type * as messageToServer from "@/actor/protocol/message/to-server"; +import type { InputData, OutputData } from "@/actor/protocol/serde"; +import { assertUnreachable } from "./utils"; +import { deconstructError, stringifyError } from "@/common/utils"; +import type { AppConfig } from "@/app/config"; +import type { DriverConfig } from "@/driver-helpers/config"; +import { ToClient } from "./protocol/message/to-client"; +import invariant from "invariant"; + +export interface ConnectWebSocketOpts { + req: HonoRequest; + encoding: Encoding; + params: unknown; + actorId: string; +} + +export interface ConnectWebSocketOutput { + onOpen: (ws: WSContext) => Promise; + onMessage: (message: messageToServer.ToServer) => Promise; + onClose: () => Promise; +} + +export interface ConnectSseOpts { + req: HonoRequest; + encoding: Encoding; + params: unknown; + actorId: string; +} + +export interface ConnectSseOutput { + onOpen: (stream: SSEStreamingApi) => Promise; + onClose: () => Promise; +} + +export interface RpcOpts { + req: HonoRequest; + params: unknown; + rpcName: string; + rpcArgs: unknown[]; + actorId: string; +} + +export interface RpcOutput { + output: unknown; +} + +export interface ConnsMessageOpts { + req: HonoRequest; + connId: string; + connToken: string; + message: messageToServer.ToServer; + actorId: string; +} + +/** + * Shared interface for connection handlers used by both ActorRouterHandler and ManagerRouterHandler + */ +export interface ConnectionHandlers { + onConnectWebSocket?( + opts: ConnectWebSocketOpts, + ): Promise; + onConnectSse(opts: ConnectSseOpts): Promise; + onRpc(opts: RpcOpts): Promise; + onConnMessage(opts: ConnsMessageOpts): Promise; +} + +/** + * Creates a WebSocket connection handler + */ +export function handleWebSocketConnect( + context: HonoContext, + appConfig: AppConfig, + driverConfig: DriverConfig, + handler: (opts: ConnectWebSocketOpts) => Promise, + actorId: string, +) { + return async () => { + const encoding = getRequestEncoding(context.req); + + const parameters = getRequestConnParams( + context.req, + appConfig, + driverConfig, + ); + + // Continue with normal connection setup + const wsHandler = await handler({ + req: context.req, + encoding, + params: parameters, + actorId, + }); + + const { promise: onOpenPromise, resolve: onOpenResolve } = + Promise.withResolvers(); + + return { + onOpen: async (_evt: any, ws: WSContext) => { + try { + // TODO: maybe timeout this! + await wsHandler.onOpen(ws); + onOpenResolve(undefined); + } catch (error) { + deconstructError(error, logger(), { wsEvent: "open" }); + onOpenResolve(undefined); + ws.close(1011, "internal error"); + } + }, + onMessage: async (evt: { data: any }, ws: WSContext) => { + try { + invariant(encoding, "encoding should be defined"); + + await onOpenPromise; + + logger().debug("received message"); + + const value = evt.data.valueOf() as InputData; + const message = await parseMessage(value, { + encoding: encoding, + maxIncomingMessageSize: appConfig.maxIncomingMessageSize, + }); + + await wsHandler.onMessage(message); + } catch (error) { + const { code } = deconstructError(error, logger(), { + wsEvent: "message", + }); + ws.close(1011, code); + } + }, + onClose: async ( + event: { + wasClean: boolean; + code: number; + reason: string; + }, + ws: WSContext, + ) => { + try { + await onOpenPromise; + + // HACK: Close socket in order to fix bug with Cloudflare Durable Objects leaving WS in closing state + // https://github.com/cloudflare/workerd/issues/2569 + ws.close(1000, "hack_force_close"); + + if (event.wasClean) { + logger().info("websocket closed", { + code: event.code, + reason: event.reason, + wasClean: event.wasClean, + }); + } else { + logger().warn("websocket closed", { + code: event.code, + reason: event.reason, + wasClean: event.wasClean, + }); + } + + await wsHandler.onClose(); + } catch (error) { + deconstructError(error, logger(), { wsEvent: "close" }); + } + }, + onError: async (error: unknown) => { + try { + await onOpenPromise; + + // Actors don't need to know about this, since it's abstracted away + logger().warn("websocket error"); + } catch (error) { + deconstructError(error, logger(), { wsEvent: "error" }); + } + }, + }; + }; +} + +/** + * Creates an SSE connection handler + */ +export async function handleSseConnect( + c: HonoContext, + appConfig: AppConfig, + driverConfig: DriverConfig, + handler: (opts: ConnectSseOpts) => Promise, + actorId: string, +) { + const encoding = getRequestEncoding(c.req); + const parameters = getRequestConnParams(c.req, appConfig, driverConfig); + + const sseHandler = await handler({ + req: c.req, + encoding, + params: parameters, + actorId, + }); + + return streamSSE(c, async (stream) => { + try { + await sseHandler.onOpen(stream); + c.req.raw.signal.addEventListener("abort", async () => { + try { + await sseHandler.onClose(); + } catch (error) { + logger().error("error closing sse connection", { error }); + } + }); + } catch (error) { + logger().error("error opening sse connection", { error }); + throw error; + } + }); +} + +/** + * Creates an RPC handler + */ +export async function handleRpc( + c: HonoContext, + appConfig: AppConfig, + driverConfig: DriverConfig, + handler: (opts: RpcOpts) => Promise, + rpcName: string, + actorId: string, +) { + try { + const encoding = getRequestEncoding(c.req); + const parameters = getRequestConnParams(c.req, appConfig, driverConfig); + + // Validate incoming request + let rpcArgs: unknown[]; + if (encoding === "json") { + try { + rpcArgs = await c.req.json(); + } catch (err) { + throw new errors.InvalidRpcRequest("Invalid JSON"); + } + + if (!Array.isArray(rpcArgs)) { + throw new errors.InvalidRpcRequest("RPC arguments must be an array"); + } + } else if (encoding === "cbor") { + try { + const value = await c.req.arrayBuffer(); + const uint8Array = new Uint8Array(value); + const deserialized = await deserialize( + uint8Array as unknown as InputData, + encoding, + ); + + // Validate using the RPC schema + const result = protoHttpRpc.RequestSchema.safeParse(deserialized); + if (!result.success) { + throw new errors.InvalidRpcRequest("Invalid RPC request format"); + } + + rpcArgs = result.data.a; + } catch (err) { + throw new errors.InvalidRpcRequest( + `Invalid binary format: ${stringifyError(err)}`, + ); + } + } else { + return assertUnreachable(encoding); + } + + // Invoke the RPC + const result = await handler({ + req: c.req, + params: parameters, + rpcName, + rpcArgs, + actorId, + }); + + // Encode the response + if (encoding === "json") { + return c.json(result.output as Record); + } else if (encoding === "cbor") { + // Use serialize from serde.ts instead of custom encoder + const responseData = { + o: result.output, // Use the format expected by ResponseOkSchema + }; + const serialized = serialize(responseData, encoding); + + return c.body(serialized as Uint8Array, 200, { + "Content-Type": "application/octet-stream", + }); + } else { + return assertUnreachable(encoding); + } + } catch (err) { + if (err instanceof errors.ActorError) { + return c.json({ error: err.serializeForHttp() }, 400); + } else { + logger().error("error executing rpc", { err }); + return c.json( + { + error: { + type: "internal_error", + message: "An internal error occurred", + }, + }, + 500, + ); + } + } +} + +/** + * Create a connection message handler + */ +export async function handleConnectionMessage( + c: HonoContext, + appConfig: AppConfig, + handler: (opts: ConnsMessageOpts) => Promise, + connId: string, + connToken: string, + actorId: string, +) { + try { + const encoding = getRequestEncoding(c.req); + + // Validate incoming request + let message: messageToServer.ToServer; + if (encoding === "json") { + try { + message = await c.req.json(); + } catch (err) { + throw new errors.InvalidRequest("Invalid JSON"); + } + } else if (encoding === "cbor") { + try { + const value = await c.req.arrayBuffer(); + const uint8Array = new Uint8Array(value); + message = await parseMessage(uint8Array as unknown as InputData, { + encoding, + maxIncomingMessageSize: appConfig.maxIncomingMessageSize, + }); + } catch (err) { + throw new errors.InvalidRequest( + `Invalid binary format: ${stringifyError(err)}`, + ); + } + } else { + return assertUnreachable(encoding); + } + + await handler({ + req: c.req, + connId, + connToken, + message, + actorId, + }); + + return c.json({}); + } catch (err) { + if (err instanceof errors.ActorError) { + return c.json({ error: err.serializeForHttp() }, 400); + } else { + logger().error("error processing connection message", { err }); + return c.json( + { + error: { + type: "internal_error", + message: "An internal error occurred", + }, + }, + 500, + ); + } + } +} + +// Helper to get the connection encoding from a request +export function getRequestEncoding(req: HonoRequest): Encoding { + const encodingParam = req.query("encoding"); + if (!encodingParam) { + return "json"; + } + + const result = EncodingSchema.safeParse(encodingParam); + if (!result.success) { + throw new errors.InvalidEncoding(encodingParam as string); + } + + return result.data; +} + +// Helper to get connection parameters for the request +export function getRequestConnParams( + req: HonoRequest, + appConfig: AppConfig, + driverConfig: DriverConfig, +): unknown { + const paramsParam = req.query("params"); + if (!paramsParam) { + return null; + } + + try { + return JSON.parse(paramsParam); + } catch (err) { + throw new errors.InvalidParams( + `Invalid params JSON: ${stringifyError(err)}`, + ); + } +} diff --git a/packages/actor-core/src/client/actor_conn.ts b/packages/actor-core/src/client/actor_conn.ts index 4500aaa6e..9c94a9e98 100644 --- a/packages/actor-core/src/client/actor_conn.ts +++ b/packages/actor-core/src/client/actor_conn.ts @@ -8,9 +8,13 @@ import * as cbor from "cbor-x"; import * as errors from "./errors"; import { logger } from "./log"; import { type WebSocketMessage as ConnMessage, messageLength } from "./utils"; -import { ACTOR_CONNS_SYMBOL, ClientRaw, DynamicImports } from "./client"; -import { ActorDefinition, AnyActorDefinition } from "@/actor/definition"; -import pRetry, { AbortError } from "p-retry"; +import { ACTOR_CONNS_SYMBOL, type ClientRaw } from "./client"; +import type { ActorDefinition, AnyActorDefinition } from "@/actor/definition"; +import pRetry from "p-retry"; +import { importWebSocket } from "@/common/websocket"; +import { importEventSource } from "@/common/eventsource"; +import invariant from "invariant"; +import type { ActorQuery } from "@/manager/protocol/query"; interface RpcInFlight { name: string; @@ -30,6 +34,13 @@ interface EventSubscriptions> { */ export type EventUnsubscribe = () => void; +/** + * A function that handles connection errors. + * + * @typedef {Function} ConnectionErrorCallback + */ +export type ConnectionErrorCallback = (error: errors.ConnectionError) => void; + interface SendOpts { ephemeral: boolean; } @@ -38,6 +49,11 @@ export type ConnTransport = { websocket: WebSocket } | { sse: EventSource }; export const CONNECT_SYMBOL = Symbol("connect"); +interface DynamicImports { + WebSocket: typeof WebSocket; + EventSource: typeof EventSource; +} + /** * Provides underlying functions for {@link ActorConn}. See {@link ActorConn} for using type-safe remote procedure calls. * @@ -50,7 +66,7 @@ export class ActorConnRaw { #abortController = new AbortController(); /** If attempting to connect. Helpful for knowing if in a retry loop when reconnecting. */ - #connecting: boolean = false; + #connecting = false; // These will only be set on SSE driver #connectionId?: string; @@ -64,6 +80,8 @@ export class ActorConnRaw { // biome-ignore lint/suspicious/noExplicitAny: Unknown subscription type #eventSubscriptions = new Map>>(); + #errorHandlers = new Set(); + #rpcIdCounter = 0; /** @@ -73,11 +91,17 @@ export class ActorConnRaw { */ #keepNodeAliveInterval: NodeJS.Timeout; + /** Promise used to indicate the required properties for using this class have loaded. Currently just #dynamicImports */ + #onConstructedPromise: Promise; + /** Promise used to indicate the socket has connected successfully. This will be rejected if the connection fails. */ #onOpenPromise?: PromiseWithResolvers; // TODO: ws message queue + // External imports + #dynamicImports!: DynamicImports; + /** * Do not call this directly. * @@ -94,9 +118,21 @@ export class ActorConnRaw { private readonly encodingKind: Encoding, private readonly supportedTransports: Transport[], private readonly serverTransports: Transport[], - private readonly dynamicImports: DynamicImports, + private readonly actorQuery: ActorQuery, ) { this.#keepNodeAliveInterval = setInterval(() => 60_000); + + this.#onConstructedPromise = (async () => { + // Import dynamic dependencies + const [WebSocket, EventSource] = await Promise.all([ + importWebSocket(), + importEventSource(), + ]); + this.#dynamicImports = { + WebSocket, + EventSource, + }; + })(); } /** @@ -113,36 +149,131 @@ export class ActorConnRaw { name: string, ...args: Args ): Promise { - logger().debug("action", { name, args }); + await this.#onConstructedPromise; - // TODO: Add to queue if socket is not open + logger().debug("action", { name, args }); - const rpcId = this.#rpcIdCounter; - this.#rpcIdCounter += 1; + // Check if we have an active websocket connection + if (this.#transport) { + // If we have an active connection, use the websocket RPC + const rpcId = this.#rpcIdCounter; + this.#rpcIdCounter += 1; - const { promise, resolve, reject } = - Promise.withResolvers(); - this.#rpcInFlight.set(rpcId, { name, resolve, reject }); + const { promise, resolve, reject } = + Promise.withResolvers(); + this.#rpcInFlight.set(rpcId, { name, resolve, reject }); - this.#sendMessage({ - b: { - rr: { - i: rpcId, - n: name, - a: args, + this.#sendMessage({ + b: { + rr: { + i: rpcId, + n: name, + a: args, + }, }, - }, - } satisfies wsToServer.ToServer); + } satisfies wsToServer.ToServer); - // TODO: Throw error if disconnect is called + // TODO: Throw error if disconnect is called - const { i: responseId, o: output } = await promise; - if (responseId !== rpcId) - throw new Error( - `Request ID ${rpcId} does not match response ID ${responseId}`, - ); + const { i: responseId, o: output } = await promise; + if (responseId !== rpcId) + throw new Error( + `Request ID ${rpcId} does not match response ID ${responseId}`, + ); - return output as Response; + return output as Response; + } else { + // If no websocket connection, use HTTP RPC via manager + try { + // Get the manager endpoint from the endpoint provided + const actorQueryStr = encodeURIComponent( + JSON.stringify(this.actorQuery), + ); + + const url = `${this.endpoint}/actors/rpc/${name}?query=${actorQueryStr}`; + logger().debug("http rpc: request", { + url, + name, + }); + + try { + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + a: args, + }), + }); + + logger().debug("http rpc: response", { + status: response.status, + ok: response.ok, + }); + + if (!response.ok) { + try { + const errorData = await response.json(); + logger().error("http rpc error response", { errorData }); + throw new errors.ActionError( + errorData.c || "RPC_ERROR", + errorData.m || "RPC call failed", + errorData.md, + ); + } catch (parseError) { + // If response is not JSON, get it as text and throw generic error + const errorText = await response.text(); + logger().error("http rpc: error parsing response", { + errorText, + }); + throw new errors.ActionError( + "RPC_ERROR", + `RPC call failed: ${errorText}`, + {}, + ); + } + } + + // Clone response to avoid consuming it + const responseClone = response.clone(); + const responseText = await responseClone.text(); + + // Parse response body + try { + const responseData = JSON.parse(responseText); + return responseData.o as Response; + } catch (parseError) { + logger().error("http rpc: error parsing json", { + parseError, + }); + throw new errors.ActionError( + "RPC_ERROR", + `Failed to parse response: ${parseError}`, + { responseText }, + ); + } + } catch (fetchError) { + logger().error("http rpc: fetch error", { + error: fetchError, + }); + throw new errors.ActionError( + "RPC_ERROR", + `Fetch failed: ${fetchError}`, + { cause: fetchError }, + ); + } + } catch (error) { + if (error instanceof errors.ActionError) { + throw error; + } + throw new errors.ActionError( + "RPC_ERROR", + `Failed to execute RPC ${name}: ${error}`, + { cause: error }, + ); + } + } } //async #rpcHttp = unknown[], Response = unknown>(name: string, ...args: Args): Promise { @@ -210,6 +341,8 @@ enc async #connectAndWait() { try { + await this.#onConstructedPromise; + // Create promise for open if (this.#onOpenPromise) throw new Error("#onOpenPromise already defined"); @@ -246,7 +379,7 @@ enc } #connectWebSocket() { - const { WebSocket } = this.dynamicImports; + const { WebSocket } = this.#dynamicImports; const url = this.#buildConnUrl("websocket"); @@ -279,7 +412,7 @@ enc } #connectSse() { - const { EventSource } = this.dynamicImports; + const { EventSource } = this.#dynamicImports; const url = this.#buildConnUrl("sse"); @@ -334,37 +467,67 @@ enc /** Called by the onmessage event from drivers. */ async #handleOnMessage(event: MessageEvent) { - logger().trace("received message", { - dataType: typeof event.data, + logger().trace("received message", { + dataType: typeof event.data, isBlob: event.data instanceof Blob, - isArrayBuffer: event.data instanceof ArrayBuffer + isArrayBuffer: event.data instanceof ArrayBuffer, }); const response = (await this.#parse(event.data)) as wsToClient.ToClient; - logger().trace("parsed message", { - response: JSON.stringify(response).substring(0, 100) + "..." + logger().trace("parsed message", { + response: JSON.stringify(response).substring(0, 100) + "...", }); if ("i" in response.b) { // This is only called for SSE this.#connectionId = response.b.i.ci; this.#connectionToken = response.b.i.ct; - logger().trace("received init message", { - connectionId: this.#connectionId + logger().trace("received init message", { + connectionId: this.#connectionId, }); this.#handleOnOpen(); + } else if ("ce" in response.b) { + // Connection error + const { c: code, m: message, md: metadata } = response.b.ce; + + logger().warn("actor connection error", { + code, + message, + metadata, + }); + + // Create a connection error + const connectionError = new errors.ConnectionError( + code, + message, + metadata, + ); + + // If we have an onOpenPromise, reject it with the error + if (this.#onOpenPromise) { + this.#onOpenPromise.reject(connectionError); + } + + // Reject any in-flight requests + for (const [id, inFlight] of this.#rpcInFlight.entries()) { + inFlight.reject(connectionError); + this.#rpcInFlight.delete(id); + } + + // Dispatch to error handler if registered + this.#dispatchConnectionError(connectionError); } else if ("ro" in response.b) { // RPC response OK const { i: rpcId } = response.b.ro; - logger().trace("received RPC response", { - rpcId, - outputType: typeof response.b.ro.o + logger().trace("received RPC response", { + rpcId, + outputType: typeof response.b.ro.o, }); const inFlight = this.#takeRpcInFlight(rpcId); - logger().trace("resolving RPC promise", { - rpcId, - actionName: inFlight?.name + logger().trace("resolving RPC promise", { + rpcId, + actionName: inFlight?.name, }); inFlight.resolve(response.b.ro); } else if ("re" in response.b) { @@ -384,9 +547,9 @@ enc inFlight.reject(new errors.ActionError(code, message, metadata)); } else if ("ev" in response.b) { - logger().trace("received event", { - name: response.b.ev.n, - argsCount: response.b.ev.a?.length + logger().trace("received event", { + name: response.b.ev.n, + argsCount: response.b.ev.a?.length, }); this.#dispatchEvent(response.b.ev); } else if ("er" in response.b) { @@ -453,7 +616,14 @@ enc } #buildConnUrl(transport: Transport): string { - let url = `${this.endpoint}/connect/${transport}?encoding=${this.encodingKind}`; + // Get the manager endpoint from the endpoint provided + const actorQueryStr = encodeURIComponent(JSON.stringify(this.actorQuery)); + + logger().debug("building conn url", { + transport, + }); + + let url = `${this.endpoint}/actors/connect/${transport}?encoding=${this.encodingKind}&query=${actorQueryStr}`; if (this.params !== undefined) { const paramsStr = JSON.stringify(this.params); @@ -504,6 +674,19 @@ enc } } + #dispatchConnectionError(error: errors.ConnectionError) { + // Call all registered error handlers + for (const handler of [...this.#errorHandlers]) { + try { + handler(error); + } catch (err) { + logger().error("Error in connection error handler", { + error: stringifyError(err), + }); + } + } + } + #addEventSubscription>( eventName: string, callback: (...args: Args) => void, @@ -567,13 +750,28 @@ enc return this.#addEventSubscription(eventName, callback, true); } + /** + * Subscribes to connection errors. + * + * @param {ConnectionErrorCallback} callback - The callback function to execute when a connection error occurs. + * @returns {() => void} - A function to unsubscribe from the error handler. + */ + onError(callback: ConnectionErrorCallback): () => void { + this.#errorHandlers.add(callback); + + // Return unsubscribe function + return () => { + this.#errorHandlers.delete(callback); + }; + } + #sendMessage(message: wsToServer.ToServer, opts?: SendOpts) { - let queueMessage: boolean = false; + let queueMessage = false; if (!this.#transport) { // No transport connected yet queueMessage = true; } else if ("websocket" in this.#transport) { - const { WebSocket } = this.dynamicImports; + const { WebSocket } = this.#dynamicImports; if (this.#transport.websocket.readyState === WebSocket.OPEN) { try { const messageSerialized = this.#serialize(message); @@ -594,7 +792,7 @@ enc queueMessage = true; } } else if ("sse" in this.#transport) { - const { EventSource } = this.dynamicImports; + const { EventSource } = this.#dynamicImports; if (this.#transport.sse.readyState === EventSource.OPEN) { // Spawn in background since #sendMessage cannot be async @@ -617,7 +815,10 @@ enc if (!this.#connectionId || !this.#connectionToken) throw new errors.InternalError("Missing connection ID or token."); - let url = `${this.endpoint}/connections/${this.#connectionId}/message?encoding=${this.encodingKind}&connectionToken=${encodeURIComponent(this.#connectionToken)}`; + // Get the manager endpoint from the endpoint provided + const actorQueryStr = encodeURIComponent(JSON.stringify(this.actorQuery)); + + let url = `${this.endpoint}/actors/connections/${this.#connectionId}/message?encoding=${this.encodingKind}&connectionToken=${encodeURIComponent(this.#connectionToken)}&query=${actorQueryStr}`; // TODO: Implement ordered messages, this is not guaranteed order. Needs to use an index in order to ensure we can pipeline requests efficiently. // TODO: Validate that we're using HTTP/3 whenever possible for pipelining requests @@ -712,6 +913,8 @@ enc * @returns {Promise} A promise that resolves when the socket is gracefully closed. */ async dispose(): Promise { + await this.#onConstructedPromise; + // Internally, this "disposes" the connection if (this.#disposed) { @@ -780,7 +983,7 @@ type ActorDefinitionRpcs = { * * @example * ``` - * const room = await client.connect(...etc...); + * const room = client.connect(...etc...); * // This calls the rpc named `sendMessage` on the `ChatRoom` actor. * await room.sendMessage('Hello, world!'); * ``` diff --git a/packages/actor-core/src/client/client.ts b/packages/actor-core/src/client/client.ts index 32f01fbed..bf41eb0b4 100644 --- a/packages/actor-core/src/client/client.ts +++ b/packages/actor-core/src/client/client.ts @@ -1,12 +1,6 @@ import type { Transport } from "@/actor/protocol/message/mod"; import type { Encoding } from "@/actor/protocol/serde"; -import type { ActorKey } from "@/common//utils"; -import type { - ActorsRequest, - ActorsResponse, - //RivetConfigResponse, -} from "@/manager/protocol/mod"; -import type { CreateRequest } from "@/manager/protocol/query"; +import type { ActorQuery } from "@/manager/protocol/query"; import * as errors from "./errors"; import { ActorConn, @@ -15,9 +9,7 @@ import { CONNECT_SYMBOL, } from "./actor_conn"; import { logger } from "./log"; -import { importWebSocket } from "@/common/websocket"; -import { importEventSource } from "@/common/eventsource"; -import { ActorCoreApp } from "@/mod"; +import type { ActorCoreApp } from "@/mod"; import type { AnyActorDefinition } from "@/actor/definition"; /** Extract the actor registry from the app definition. */ @@ -39,9 +31,9 @@ export interface ActorAccessor { * @template A The actor class that this connection is for. * @param {string | string[]} [key=[]] - The key to identify the actor. Can be a single string or an array of strings. * @param {GetOptions} [opts] - Options for getting the actor. - * @returns {Promise>} - A promise resolving to the actor connection. + * @returns {ActorConn} - A promise resolving to the actor connection. */ - connect(key?: string | string[], opts?: GetOptions): Promise>; + connect(key?: string | string[], opts?: GetOptions): ActorConn; /** * Creates a new actor with the name automatically injected from the property accessor, @@ -50,9 +42,9 @@ export interface ActorAccessor { * @template A The actor class that this connection is for. * @param {string | string[]} key - The key to identify the actor. Can be a single string or an array of strings. * @param {CreateOptions} [opts] - Options for creating the actor (excluding name and key). - * @returns {Promise>} - A promise resolving to the actor connection. + * @returns {ActorConn} - A promise resolving to the actor connection. */ - createAndConnect(key: string | string[], opts?: CreateOptions): Promise>; + createAndConnect(key: string | string[], opts?: CreateOptions): ActorConn; /** * Connects to an actor by its ID. @@ -60,9 +52,9 @@ export interface ActorAccessor { * @template A The actor class that this connection is for. * @param {string} actorId - The ID of the actor. * @param {GetWithIdOptions} [opts] - Options for getting the actor. - * @returns {Promise>} - A promise resolving to the actor connection. + * @returns {ActorConn} - A promise resolving to the actor connection. */ - connectForId(actorId: string, opts?: GetWithIdOptions): Promise>; + connectForId(actorId: string, opts?: GetWithIdOptions): ActorConn; } /** @@ -94,21 +86,24 @@ export interface GetWithIdOptions extends QueryOptions {} * Options for getting an actor. * @typedef {QueryOptions} GetOptions * @property {boolean} [noCreate] - Prevents creating a new actor if one does not exist. - * @property {Partial} [create] - Config used to create the actor. + * @property {string} [createInRegion] - Region to create the actor in if it doesn't exist. */ export interface GetOptions extends QueryOptions { /** Prevents creating a new actor if one does not exist. */ noCreate?: boolean; - /** Config used to create the actor. */ - create?: Partial>; + /** Region to create the actor in if it doesn't exist. */ + createInRegion?: string; } /** * Options for creating an actor. * @typedef {QueryOptions} CreateOptions - * @property {Object} - Additional options for actor creation excluding name and key that come from the key parameter. + * @property {string} [region] - The region to create the actor in. */ -export interface CreateOptions extends QueryOptions, Omit {} +export interface CreateOptions extends QueryOptions { + /** The region to create the actor in. */ + region?: string; +} /** * Represents a region to connect to. @@ -130,11 +125,6 @@ export interface Region { name: string; } -export interface DynamicImports { - WebSocket: typeof WebSocket; - EventSource: typeof EventSource; -} - export const ACTOR_CONNS_SYMBOL = Symbol("actorConns"); /** @@ -148,49 +138,25 @@ export class ClientRaw { [ACTOR_CONNS_SYMBOL] = new Set(); - #managerEndpointPromise: Promise; - //#regionPromise: Promise; + #managerEndpoint: string; #encodingKind: Encoding; #supportedTransports: Transport[]; - // External imports - #dynamicImportsPromise: Promise; - /** * Creates an instance of Client. * - * @param {string | Promise} managerEndpointPromise - The manager endpoint or a promise resolving to it. See {@link https://rivet.gg/docs/setup|Initial Setup} for instructions on getting the manager endpoint. + * @param {string} managerEndpoint - The manager endpoint. See {@link https://rivet.gg/docs/setup|Initial Setup} for instructions on getting the manager endpoint. * @param {ClientOptions} [opts] - Options for configuring the client. * @see {@link https://rivet.gg/docs/setup|Initial Setup} */ - public constructor( - managerEndpointPromise: string | Promise, - opts?: ClientOptions, - ) { - if (managerEndpointPromise instanceof Promise) { - // Save promise - this.#managerEndpointPromise = managerEndpointPromise; - } else { - // Convert to promise - this.#managerEndpointPromise = new Promise((resolve) => - resolve(managerEndpointPromise), - ); - } - - //this.#regionPromise = this.#fetchRegion(); + public constructor(managerEndpoint: string, opts?: ClientOptions) { + this.#managerEndpoint = managerEndpoint; this.#encodingKind = opts?.encoding ?? "cbor"; this.#supportedTransports = opts?.supportedTransports ?? [ "websocket", "sse", ]; - - // Import dynamic dependencies - this.#dynamicImportsPromise = (async () => { - const WebSocket = await importWebSocket(); - const EventSource = await importEventSource(); - return { WebSocket, EventSource }; - })(); } /** @@ -201,32 +167,29 @@ export class ClientRaw { * @param {GetWithIdOptions} [opts] - Options for getting the actor. * @returns {Promise>} - A promise resolving to the actor connection. */ - async connectForId( + connectForId( name: string, actorId: string, opts?: GetWithIdOptions, - ): Promise> { + ): ActorConn { logger().debug("connect to actor with id ", { name, actorId, params: opts?.params, }); - const resJson = await this.#sendManagerRequest< - ActorsRequest, - ActorsResponse - >("POST", "/manager/actors", { - query: { - getForId: { - actorId, - }, + const actorQuery = { + getForId: { + actorId, }, - }); + }; - const conn = await this.#createConn( - resJson.endpoint, + const managerEndpoint = this.#managerEndpoint; + const conn = this.#createConn( + managerEndpoint, opts?.params, - resJson.supportedTransports, + ["websocket", "sse"], + actorQuery, ); return this.#createProxy(conn) as ActorConn; } @@ -236,14 +199,14 @@ export class ClientRaw { * * @example * ``` - * const room = await client.connect( + * const room = client.connect( * 'chat-room', * // Get or create the actor for the channel `random` * 'random', * ); * * // Or using an array of strings as key - * const room = await client.connect( + * const room = client.connect( * 'chat-room', * ['user123', 'room456'], * ); @@ -258,36 +221,26 @@ export class ClientRaw { * @returns {Promise>} - A promise resolving to the actor connection. * @see {@link https://rivet.gg/docs/manage#client.connect} */ - async connect( + connect( name: string, key?: string | string[], opts?: GetOptions, - ): Promise> { + ): ActorConn { // Convert string to array of strings - const keyArray: string[] = typeof key === 'string' ? [key] : (key || []); - - // Build create config - let create: CreateRequest | undefined = undefined; - if (!opts?.noCreate) { - create = { - name, - // Fall back to key defined when querying actor - key: opts?.create?.key ?? keyArray, - ...opts?.create, - }; - } + const keyArray: string[] = typeof key === "string" ? [key] : key || []; logger().debug("connect to actor", { name, key: keyArray, parameters: opts?.params, - create, + noCreate: opts?.noCreate, + createInRegion: opts?.createInRegion, }); - let requestQuery; + let actorQuery: ActorQuery; if (opts?.noCreate) { // Use getForKey endpoint if noCreate is specified - requestQuery = { + actorQuery = { getForKey: { name, key: keyArray, @@ -295,26 +248,21 @@ export class ClientRaw { }; } else { // Use getOrCreateForKey endpoint - requestQuery = { + actorQuery = { getOrCreateForKey: { name, key: keyArray, - region: create?.region, + region: opts?.createInRegion, }, }; } - const resJson = await this.#sendManagerRequest< - ActorsRequest, - ActorsResponse - >("POST", "/manager/actors", { - query: requestQuery, - }); - - const conn = await this.#createConn( - resJson.endpoint, + const managerEndpoint = this.#managerEndpoint; + const conn = this.#createConn( + managerEndpoint, opts?.params, - resJson.supportedTransports, + ["websocket", "sse"], + actorQuery, ); return this.#createProxy(conn) as ActorConn; } @@ -330,7 +278,7 @@ export class ClientRaw { * 'doc123', * { region: 'us-east-1' } * ); - * + * * // Or with an array of strings as key * const doc = await client.createAndConnect( * 'document', @@ -348,13 +296,13 @@ export class ClientRaw { * @returns {Promise>} - A promise resolving to the actor connection. * @see {@link https://rivet.gg/docs/manage#client.createAndConnect} */ - async createAndConnect( + createAndConnect( name: string, key: string | string[], opts: CreateOptions = {}, - ): Promise> { + ): ActorConn { // Convert string to array of strings - const keyArray: string[] = typeof key === 'string' ? [key] : key; + const keyArray: string[] = typeof key === "string" ? [key] : key; // Build create config const create = { @@ -364,9 +312,6 @@ export class ClientRaw { key: keyArray, }; - // Default to the chosen region - //if (!create.region) create.region = (await this.#regionPromise)?.id; - logger().debug("create actor and connect", { name, key: keyArray, @@ -374,30 +319,26 @@ export class ClientRaw { create, }); - const resJson = await this.#sendManagerRequest< - ActorsRequest, - ActorsResponse - >("POST", "/manager/actors", { - query: { - create, - }, - }); + const actorQuery = { + create, + }; - const conn = await this.#createConn( - resJson.endpoint, + const managerEndpoint = this.#managerEndpoint; + const conn = this.#createConn( + managerEndpoint, opts?.params, - resJson.supportedTransports, + ["websocket", "sse"], + actorQuery, ); return this.#createProxy(conn) as ActorConn; } - async #createConn( + #createConn( endpoint: string, params: unknown, serverTransports: Transport[], - ): Promise { - const imports = await this.#dynamicImportsPromise; - + actorQuery: ActorQuery, + ): ActorConnRaw { const conn = new ActorConnRaw( this, endpoint, @@ -405,7 +346,7 @@ export class ClientRaw { this.#encodingKind, this.#supportedTransports, serverTransports, - imports, + actorQuery, ); this[ACTOR_CONNS_SYMBOL].add(conn); conn[CONNECT_SYMBOL](); @@ -508,7 +449,7 @@ export class ClientRaw { body?: Request, ): Promise { try { - const managerEndpoint = await this.#managerEndpointPromise; + const managerEndpoint = this.#managerEndpoint; const res = await fetch(`${managerEndpoint}${path}`, { method, headers: { @@ -565,15 +506,15 @@ export type Client> = ClientRaw & { * Creates a client with the actor accessor proxy. * * @template A The actor application type. - * @param {string | Promise} managerEndpointPromise - The manager endpoint or a promise resolving to it. + * @param {string} managerEndpoint - The manager endpoint. * @param {ClientOptions} [opts] - Options for configuring the client. * @returns {Client} - A proxied client that supports the `client.myActor.connect()` syntax. */ export function createClient>( - managerEndpointPromise: string | Promise, + managerEndpoint: string, opts?: ClientOptions, ): Client { - const client = new ClientRaw(managerEndpointPromise, opts); + const client = new ClientRaw(managerEndpoint, opts); // Create proxy for accessing actors by name return new Proxy(client, { @@ -595,27 +536,25 @@ export function createClient>( connect: ( key?: string | string[], opts?: GetOptions, - ): Promise[typeof prop]>> => { + ): ActorConn[typeof prop]> => { return target.connect[typeof prop]>( prop, key, - opts + opts, ); }, createAndConnect: ( key: string | string[], opts: CreateOptions = {}, - ): Promise[typeof prop]>> => { - return target.createAndConnect[typeof prop]>( - prop, - key, - opts - ); + ): ActorConn[typeof prop]> => { + return target.createAndConnect< + ExtractActorsFromApp[typeof prop] + >(prop, key, opts); }, connectForId: ( actorId: string, opts?: GetWithIdOptions, - ): Promise[typeof prop]>> => { + ): ActorConn[typeof prop]> => { return target.connectForId[typeof prop]>( prop, actorId, diff --git a/packages/actor-core/src/client/errors.ts b/packages/actor-core/src/client/errors.ts index d10987745..a3e19b5dd 100644 --- a/packages/actor-core/src/client/errors.ts +++ b/packages/actor-core/src/client/errors.ts @@ -39,3 +39,16 @@ export class ActionError extends ActorClientError { super(message); } } + +/** + * Error thrown when a connection error occurs. + */ +export class ConnectionError extends ActorClientError { + constructor( + public readonly code: string, + message: string, + public readonly metadata?: unknown, + ) { + super(message); + } +} diff --git a/packages/actor-core/src/client/mod.ts b/packages/actor-core/src/client/mod.ts index 9c8635adf..119be2c76 100644 --- a/packages/actor-core/src/client/mod.ts +++ b/packages/actor-core/src/client/mod.ts @@ -26,6 +26,7 @@ export { MalformedResponseMessage, NoSupportedTransport, ActionError, + ConnectionError, } from "@/client/errors"; export { AnyActorDefinition, diff --git a/packages/actor-core/src/common/eventsource.ts b/packages/actor-core/src/common/eventsource.ts index 01ba18410..76c365dc4 100644 --- a/packages/actor-core/src/common/eventsource.ts +++ b/packages/actor-core/src/common/eventsource.ts @@ -1,30 +1,43 @@ import { logger } from "@/client/log"; +// Global singleton promise that will be reused for subsequent calls +let eventSourcePromise: Promise | null = null; + export async function importEventSource(): Promise { - let _EventSource: typeof EventSource; + // Return existing promise if we already started loading + if (eventSourcePromise !== null) { + return eventSourcePromise; + } + + // Create and store the promise + eventSourcePromise = (async () => { + let _EventSource: typeof EventSource; - if (typeof EventSource !== "undefined") { - // Browser environment - _EventSource = EventSource; - logger().debug("using native eventsource"); - } else { - // Node.js environment - try { - const es = await import("eventsource"); - _EventSource = es.EventSource; - logger().debug("using eventsource from npm"); - } catch (err) { - // EventSource not available - _EventSource = class MockEventSource { - constructor() { - throw new Error( - 'EventSource support requires installing the "eventsource" peer dependency.', - ); - } - } as unknown as typeof EventSource; - logger().debug("using mock eventsource"); + if (typeof EventSource !== "undefined") { + // Browser environment + _EventSource = EventSource; + logger().debug("using native eventsource"); + } else { + // Node.js environment + try { + const es = await import("eventsource"); + _EventSource = es.EventSource; + logger().debug("using eventsource from npm"); + } catch (err) { + // EventSource not available + _EventSource = class MockEventSource { + constructor() { + throw new Error( + 'EventSource support requires installing the "eventsource" peer dependency.', + ); + } + } as unknown as typeof EventSource; + logger().debug("using mock eventsource"); + } } - } - return _EventSource; + return _EventSource; + })(); + + return eventSourcePromise; } diff --git a/packages/actor-core/src/common/log.ts b/packages/actor-core/src/common/log.ts index c03c5eaf2..d5ccb97a9 100644 --- a/packages/actor-core/src/common/log.ts +++ b/packages/actor-core/src/common/log.ts @@ -75,7 +75,14 @@ export class Logger { const loggers: Record = {}; export function getLogger(name = "default"): Logger { - const defaultLogLevelEnv = typeof process !== "undefined" ? (process.env._LOG_LEVEL as LogLevel) : undefined; + let defaultLogLevelEnv: LogLevel | undefined = undefined; + if (typeof Deno !== "undefined") { + defaultLogLevelEnv = Deno.env.get("_LOG_LEVEL") as LogLevel; + } else if (typeof process !== "undefined") { + // Do this after Deno since `process` is sometimes polyfilled + defaultLogLevelEnv = process.env._LOG_LEVEL as LogLevel; + } + const defaultLogLevel: LogLevel = defaultLogLevelEnv ?? "INFO"; if (!loggers[name]) { loggers[name] = new Logger(name, defaultLogLevel); diff --git a/packages/actor-core/src/common/router.ts b/packages/actor-core/src/common/router.ts index 724a4cc33..3283af224 100644 --- a/packages/actor-core/src/common/router.ts +++ b/packages/actor-core/src/common/router.ts @@ -1,11 +1,32 @@ -import type { Context as HonoContext } from "hono"; -import { getLogger } from "./log"; +import type { Context as HonoContext, Next } from "hono"; +import { getLogger, Logger } from "./log"; import { deconstructError } from "./utils"; export function logger() { return getLogger("router"); } +export function loggerMiddleware(logger: Logger) { + return async (c: HonoContext, next: Next) => { + const method = c.req.method; + const path = c.req.path; + const startTime = Date.now(); + + await next(); + + const duration = Date.now() - startTime; + logger.debug("http request", { + method, + path, + status: c.res.status, + dt: `${duration}ms`, + reqSize: c.req.header("content-length"), + resSize: c.res.headers.get("content-length"), + userAgent: c.req.header("user-agent"), + }); + }; +} + export function handleRouteNotFound(c: HonoContext) { return c.text("Not Found (ActorCore)", 404); } diff --git a/packages/actor-core/src/common/websocket.ts b/packages/actor-core/src/common/websocket.ts index 29ac3398c..0b36cab4a 100644 --- a/packages/actor-core/src/common/websocket.ts +++ b/packages/actor-core/src/common/websocket.ts @@ -1,30 +1,43 @@ import { logger } from "@/client/log"; +// Global singleton promise that will be reused for subsequent calls +let webSocketPromise: Promise | null = null; + export async function importWebSocket(): Promise { - let _WebSocket: typeof WebSocket; + // Return existing promise if we already started loading + if (webSocketPromise !== null) { + return webSocketPromise; + } + + // Create and store the promise + webSocketPromise = (async () => { + let _WebSocket: typeof WebSocket; - if (typeof WebSocket !== "undefined") { - // Browser environment - _WebSocket = WebSocket; - logger().debug("using native websocket"); - } else { - // Node.js environment - try { - const ws = await import("ws"); - _WebSocket = ws.default as unknown as typeof WebSocket; - logger().debug("using websocket from npm"); - } catch { - // WS not available - _WebSocket = class MockWebSocket { - constructor() { - throw new Error( - 'WebSocket support requires installing the "ws" peer dependency.', - ); - } - } as unknown as typeof WebSocket; - logger().debug("using mock websocket"); + if (typeof WebSocket !== "undefined") { + // Browser environment + _WebSocket = WebSocket; + logger().debug("using native websocket"); + } else { + // Node.js environment + try { + const ws = await import("ws"); + _WebSocket = ws.default as unknown as typeof WebSocket; + logger().debug("using websocket from npm"); + } catch { + // WS not available + _WebSocket = class MockWebSocket { + constructor() { + throw new Error( + 'WebSocket support requires installing the "ws" peer dependency.', + ); + } + } as unknown as typeof WebSocket; + logger().debug("using mock websocket"); + } } - } - return _WebSocket; + return _WebSocket; + })(); + + return webSocketPromise; } diff --git a/packages/actor-core/src/driver-helpers/mod.ts b/packages/actor-core/src/driver-helpers/mod.ts index f42e1d263..12191a98c 100644 --- a/packages/actor-core/src/driver-helpers/mod.ts +++ b/packages/actor-core/src/driver-helpers/mod.ts @@ -1,3 +1,5 @@ +import { ToServer } from "@/actor/protocol/message/to-server"; + export { type DriverConfig, DriverConfigSchema } from "./config"; export type { ActorInstance, AnyActorInstance } from "@/actor/instance"; export { diff --git a/packages/actor-core/src/manager/driver.ts b/packages/actor-core/src/manager/driver.ts index 481f538a2..8d2c46f83 100644 --- a/packages/actor-core/src/manager/driver.ts +++ b/packages/actor-core/src/manager/driver.ts @@ -11,32 +11,32 @@ export interface ManagerDriver { } export interface GetForIdInput { c?: HonoContext; - baseUrl: string; actorId: string; } export interface GetWithKeyInput { c?: HonoContext; - baseUrl: string; name: string; key: ActorKey; } export interface GetActorOutput { c?: HonoContext; - endpoint: string; + actorId: string; name: string; key: ActorKey; + meta?: unknown; } export interface CreateActorInput { c?: HonoContext; - baseUrl: string; name: string; key: ActorKey; region?: string; } export interface CreateActorOutput { - endpoint: string; + actorId: string; + meta?: unknown; } + diff --git a/packages/actor-core/src/manager/protocol/mod.ts b/packages/actor-core/src/manager/protocol/mod.ts index 2c3bceefe..5ca94297f 100644 --- a/packages/actor-core/src/manager/protocol/mod.ts +++ b/packages/actor-core/src/manager/protocol/mod.ts @@ -8,7 +8,7 @@ export const ActorsRequestSchema = z.object({ }); export const ActorsResponseSchema = z.object({ - endpoint: z.string(), + actorId: z.string(), supportedTransports: z.array(TransportSchema), }); diff --git a/packages/actor-core/src/manager/protocol/query.ts b/packages/actor-core/src/manager/protocol/query.ts index 8e49fe7ce..6e3907a52 100644 --- a/packages/actor-core/src/manager/protocol/query.ts +++ b/packages/actor-core/src/manager/protocol/query.ts @@ -1,5 +1,6 @@ import { ActorKeySchema, type ActorKey } from "@/common//utils"; import { z } from "zod"; +import { EncodingSchema } from "@/actor/protocol/serde"; export const CreateRequestSchema = z.object({ name: z.string(), @@ -35,9 +36,16 @@ export const ActorQuerySchema = z.union([ }), ]); +export const ConnectQuerySchema = z.object({ + query: ActorQuerySchema, + encoding: EncodingSchema, + params: z.string().optional(), +}); + export type ActorQuery = z.infer; export type GetForKeyRequest = z.infer; export type GetOrCreateRequest = z.infer; +export type ConnectQuery = z.infer; /** * Interface representing a request to create an actor. */ diff --git a/packages/actor-core/src/manager/router.ts b/packages/actor-core/src/manager/router.ts index 4c0ae9d00..b63f49425 100644 --- a/packages/actor-core/src/manager/router.ts +++ b/packages/actor-core/src/manager/router.ts @@ -1,9 +1,11 @@ -import { ActorsRequestSchema } from "@/manager/protocol/mod"; -import { Hono, type Context as HonoContext } from "hono"; +import { Hono, Next, type Context as HonoContext } from "hono"; import { cors } from "hono/cors"; import { logger } from "./log"; -import { assertUnreachable } from "@/common/utils"; -import { handleRouteError, handleRouteNotFound } from "@/common/router"; +import { + handleRouteError, + handleRouteNotFound, + loggerMiddleware, +} from "@/common/router"; import type { DriverConfig } from "@/driver-helpers/config"; import type { AppConfig } from "@/app/config"; import { @@ -11,10 +13,59 @@ import { type ManagerInspectorConnHandler, } from "@/inspector/manager"; import type { UpgradeWebSocket } from "hono/ws"; +import { ConnectQuerySchema } from "./protocol/query"; +import * as errors from "@/actor/errors"; +import type { ActorQuery, ConnectQuery } from "./protocol/query"; +import { assertUnreachable } from "@/actor/utils"; +import invariant from "invariant"; +import { + type ConnectionHandlers, + handleSseConnect, + handleRpc, + handleConnectionMessage, + getRequestEncoding, + handleWebSocketConnect, +} from "@/actor/router_endpoints"; +import { ManagerDriver } from "./driver"; +import { setUncaughtExceptionCaptureCallback } from "process"; +import { Encoding, serialize } from "@/actor/protocol/serde"; +import { deconstructError } from "@/common/utils"; +import { WSContext } from "hono/ws"; +import { ToClient } from "@/actor/protocol/message/to-client"; +import { upgradeWebSocket } from "hono/deno"; + +type ProxyMode = + | { + inline: { + handlers: ConnectionHandlers; + }; + } + | { + custom: { + onProxyRequest: OnProxyRequest; + onProxyWebSocket: OnProxyWebSocket; + }; + }; + +export type BuildProxyEndpoint = (c: HonoContext, actorId: string) => string; + +export type OnProxyRequest = ( + c: HonoContext, + actorRequest: Request, + actorId: string, + meta?: unknown, +) => Promise; + +export type OnProxyWebSocket = ( + c: HonoContext, + path: string, + actorId: string, + meta?: unknown, +) => Promise; type ManagerRouterHandler = { onConnectInspector?: ManagerInspectorConnHandler; - upgradeWebSocket?: UpgradeWebSocket; + proxyMode: ProxyMode; }; export function createManagerRouter( @@ -29,9 +80,21 @@ export function createManagerRouter( const driver = driverConfig.drivers.manager; const app = new Hono(); - // Apply CORS middleware if configured + const upgradeWebSocket = driverConfig.getUpgradeWebSocket?.(app); + + app.use("*", loggerMiddleware(logger())); + if (appConfig.cors) { - app.use("*", cors(appConfig.cors)); + app.use("*", async (c, next) => { + const path = c.req.path; + + // Don't apply to WebSocket routes + if (path === "/actors/connect/websocket") { + return next(); + } + + return cors(appConfig.cors)(c, next); + }); } app.get("/", (c) => { @@ -44,92 +107,320 @@ export function createManagerRouter( return c.text("ok"); }); - app.post("/manager/actors", async (c: HonoContext) => { - const { query } = ActorsRequestSchema.parse(await c.req.json()); - logger().debug("query", { query }); - - const url = new URL(c.req.url); - - // Determine base URL to build endpoints from - // - // This is used to build actor endpoints - let baseUrl = url.origin; - if (appConfig.basePath) { - const basePath = appConfig.basePath; - if (!basePath.startsWith("/")) - throw new Error("config.basePath must start with /"); - if (basePath.endsWith("/")) - throw new Error("config.basePath must not end with /"); - baseUrl += basePath; - } + app.get("/actors/connect/websocket", async (c) => { + invariant(upgradeWebSocket, "WebSockets not supported"); - // Get the actor from the manager - let actorOutput: { endpoint: string }; - if ("getForId" in query) { - const output = await driver.getForId({ - c, - baseUrl: baseUrl, - actorId: query.getForId.actorId, + let encoding: Encoding | undefined; + try { + encoding = getRequestEncoding(c.req); + logger().debug("websocket connection request received", { encoding }); + + const params = ConnectQuerySchema.safeParse({ + query: parseQuery(c), + encoding: c.req.query("encoding"), + params: c.req.query("params"), }); - if (!output) - throw new Error( - `Actor does not exist for ID: ${query.getForId.actorId}`, + if (!params.success) { + logger().error("invalid connection parameters", { + error: params.error, + }); + throw new errors.InvalidQueryFormat(params.error); + } + + // Get the actor ID and meta + const { actorId, meta } = await queryActor(c, params.data.query, driver); + logger().debug("found actor for websocket connection", { actorId, meta }); + invariant(actorId, "missing actor id"); + + if ("inline" in handler.proxyMode) { + logger().debug("using inline proxy mode for websocket connection"); + invariant( + handler.proxyMode.inline.handlers.onConnectWebSocket, + "onConnectWebSocket not provided", ); - actorOutput = output; - } else if ("getForKey" in query) { - const existingActor = await driver.getWithKey({ - c, - baseUrl: baseUrl, - name: query.getForKey.name, - key: query.getForKey.key, - }); - if (!existingActor) { - throw new Error("Actor not found with key."); + + const onConnectWebSocket = + handler.proxyMode.inline.handlers.onConnectWebSocket; + return upgradeWebSocket((c) => { + return handleWebSocketConnect( + c, + appConfig, + driverConfig, + onConnectWebSocket, + actorId, + )(); + })(c, noopNext()); + } else if ("custom" in handler.proxyMode) { + logger().debug("using custom proxy mode for websocket connection"); + let pathname = `/connect/websocket?encoding=${params.data.encoding}`; + if (params.data.params) { + pathname += `¶ms=${params.data.params}`; + } + return await handler.proxyMode.custom.onProxyWebSocket( + c, + pathname, + actorId, + meta, + ); + } else { + assertUnreachable(handler.proxyMode); } - actorOutput = existingActor; - } else if ("getOrCreateForKey" in query) { - const existingActor = await driver.getWithKey({ - c, - baseUrl: baseUrl, - name: query.getOrCreateForKey.name, - key: query.getOrCreateForKey.key, + } catch (error) { + // If we receive an error during setup, we send the error and close the socket immediately + // + // We have to return the error over WS since WebSocket clients cannot read vanilla HTTP responses + + const { code, message, metadata } = deconstructError(error, logger(), { + wsEvent: "setup", + }); + + return await upgradeWebSocket(() => ({ + onOpen: async (_evt: unknown, ws: WSContext) => { + if (encoding) { + try { + // Serialize and send the connection error + const errorMsg: ToClient = { + b: { + ce: { + c: code, + m: message, + md: metadata, + }, + }, + }; + + // Send the error message to the client + invariant(encoding, "encoding should be defined"); + const serialized = serialize(errorMsg, encoding); + ws.send(serialized); + + // Close the connection with an error code + ws.close(1011, code); + } catch (serializeError) { + logger().error("failed to send error to websocket client", { + error: serializeError, + }); + ws.close(1011, "internal error during error handling"); + } + } else { + // We don't know the encoding so we send what we can + ws.close(1011, code); + } + }, + }))(c, noopNext()); + } + }); + + // Proxy SSE connection to actor + app.get("/actors/connect/sse", async (c) => { + logger().debug("sse connection request received"); + try { + const params = ConnectQuerySchema.safeParse({ + query: parseQuery(c), + encoding: c.req.query("encoding"), + params: c.req.query("params"), }); - if (existingActor) { - // Actor exists - actorOutput = existingActor; + + if (!params.success) { + logger().error("invalid connection parameters", { + error: params.error, + }); + throw new errors.InvalidQueryFormat(params.error); + } + + const query = params.data.query; + + // Get the actor ID and meta + const { actorId, meta } = await queryActor(c, query, driver); + invariant(actorId, "Missing actor ID"); + logger().debug("sse connection to actor", { actorId, meta }); + + // Handle based on mode + if ("inline" in handler.proxyMode) { + logger().debug("using inline proxy mode for sse connection"); + // Use the shared SSE handler + return handleSseConnect( + c, + appConfig, + driverConfig, + handler.proxyMode.inline.handlers.onConnectSse, + actorId, + ); + } else if ("custom" in handler.proxyMode) { + logger().debug("using custom proxy mode for sse connection"); + const url = new URL("http://actor/connect/sse"); + url.searchParams.set("encoding", params.data.encoding); + if (params.data.params) { + url.searchParams.set("params", params.data.params); + } + const proxyRequest = new Request(url, c.req.raw); + return await handler.proxyMode.custom.onProxyRequest( + c, + proxyRequest, + actorId, + meta, + ); + } else { + assertUnreachable(handler.proxyMode); + } + } catch (error) { + logger().error("error setting up sse proxy", { error }); + + // Use ProxyError if it's not already an ActorError + if (!(error instanceof errors.ActorError)) { + throw new errors.ProxyError("SSE connection", error); } else { - // Create if needed - actorOutput = await driver.createActor({ + throw error; + } + } + }); + + // Proxy RPC calls to actor + app.post("/actors/rpc/:rpc", async (c) => { + try { + const rpcName = c.req.param("rpc"); + logger().debug("rpc call received", { rpcName }); + + // Get query parameters for actor lookup + const queryParam = c.req.query("query"); + if (!queryParam) { + logger().error("missing query parameter for rpc"); + throw new errors.MissingRequiredParameters(["query"]); + } + + // Parse the query JSON and validate with schema + let parsedQuery: ActorQuery; + try { + parsedQuery = JSON.parse(queryParam as string); + } catch (error) { + logger().error("invalid query json for rpc", { error }); + throw new errors.InvalidQueryJSON(error); + } + + // Get the actor ID and meta + const { actorId, meta } = await queryActor(c, parsedQuery, driver); + logger().debug("found actor for rpc", { actorId, meta }); + invariant(actorId, "Missing actor ID"); + + // Handle based on mode + if ("inline" in handler.proxyMode) { + logger().debug("using inline proxy mode for rpc call"); + // Use shared RPC handler with direct parameter + return handleRpc( c, - baseUrl: baseUrl, - name: query.getOrCreateForKey.name, - key: query.getOrCreateForKey.key, - region: query.getOrCreateForKey.region, - }); + appConfig, + driverConfig, + handler.proxyMode.inline.handlers.onRpc, + rpcName, + actorId, + ); + } else if ("custom" in handler.proxyMode) { + logger().debug("using custom proxy mode for rpc call"); + const url = new URL(`http://actor/rpc/${encodeURIComponent(rpcName)}`); + const proxyRequest = new Request(url, c.req.raw); + return await handler.proxyMode.custom.onProxyRequest( + c, + proxyRequest, + actorId, + meta, + ); + } else { + assertUnreachable(handler.proxyMode); + } + } catch (error) { + logger().error("error in rpc handler", { error }); + + // Use ProxyError if it's not already an ActorError + if (!(error instanceof errors.ActorError)) { + throw new errors.ProxyError("RPC call", error); + } else { + throw error; } - } else if ("create" in query) { - actorOutput = await driver.createActor({ - c, - baseUrl: baseUrl, - name: query.create.name, - key: query.create.key, - region: query.create.region, - }); - } else { - assertUnreachable(query); } + }); - return c.json({ - endpoint: actorOutput.endpoint, - supportedTransports: ["websocket", "sse"], - }); + // Proxy connection messages to actor + app.post("/actors/connections/:conn/message", async (c) => { + logger().debug("connection message request received"); + try { + const connId = c.req.param("conn"); + const connToken = c.req.query("connectionToken"); + const encoding = c.req.query("encoding"); + + // Get query parameters for actor lookup + const queryParam = c.req.query("query"); + if (!queryParam) { + throw new errors.MissingRequiredParameters(["query"]); + } + + // Check other required parameters + const missingParams: string[] = []; + if (!connToken) missingParams.push("connectionToken"); + if (!encoding) missingParams.push("encoding"); + + if (missingParams.length > 0) { + throw new errors.MissingRequiredParameters(missingParams); + } + + // Parse the query JSON and validate with schema + let parsedQuery: ActorQuery; + try { + parsedQuery = JSON.parse(queryParam as string); + } catch (error) { + logger().error("invalid query json", { error }); + throw new errors.InvalidQueryJSON(error); + } + + // Get the actor ID and meta + const { actorId, meta } = await queryActor(c, parsedQuery, driver); + invariant(actorId, "Missing actor ID"); + logger().debug("connection message to actor", { connId, actorId, meta }); + + // Handle based on mode + if ("inline" in handler.proxyMode) { + logger().debug("using inline proxy mode for connection message"); + // Use shared connection message handler with direct parameters + return handleConnectionMessage( + c, + appConfig, + handler.proxyMode.inline.handlers.onConnMessage, + connId, + connToken as string, + actorId, + ); + } else if ("custom" in handler.proxyMode) { + logger().debug("using custom proxy mode for connection message"); + const url = new URL(`http://actor/connections/${connId}/message`); + url.searchParams.set("connectionToken", connToken!); + url.searchParams.set("encoding", encoding!); + const proxyRequest = new Request(url, c.req.raw); + return await handler.proxyMode.custom.onProxyRequest( + c, + proxyRequest, + actorId, + meta, + ); + } else { + assertUnreachable(handler.proxyMode); + } + } catch (error) { + logger().error("error proxying connection message", { error }); + + // Use ProxyError if it's not already an ActorError + if (!(error instanceof errors.ActorError)) { + throw new errors.ProxyError("connection message", error); + } else { + throw error; + } + } }); if (appConfig.inspector.enabled) { + logger().debug("setting up inspector routes"); app.route( - "/manager/inspect", + "/inspect", createManagerInspectorRouter( - handler.upgradeWebSocket, + upgradeWebSocket, handler.onConnectInspector, appConfig.inspector, ), @@ -140,4 +431,100 @@ export function createManagerRouter( app.onError(handleRouteError); return app; -} \ No newline at end of file +} + +/** + * Query the manager driver to get or create an actor based on the provided query + */ +export async function queryActor( + c: HonoContext, + query: ActorQuery, + driver: ManagerDriver, +): Promise<{ actorId: string; meta?: unknown }> { + logger().debug("querying actor", { query }); + let actorOutput: { actorId: string; meta?: unknown }; + if ("getForId" in query) { + const output = await driver.getForId({ + c, + actorId: query.getForId.actorId, + }); + if (!output) throw new errors.ActorNotFound(query.getForId.actorId); + actorOutput = output; + } else if ("getForKey" in query) { + const existingActor = await driver.getWithKey({ + c, + name: query.getForKey.name, + key: query.getForKey.key, + }); + if (!existingActor) { + throw new errors.ActorNotFound( + `${query.getForKey.name}:${JSON.stringify(query.getForKey.key)}`, + ); + } + actorOutput = existingActor; + } else if ("getOrCreateForKey" in query) { + const existingActor = await driver.getWithKey({ + c, + name: query.getOrCreateForKey.name, + key: query.getOrCreateForKey.key, + }); + if (existingActor) { + // Actor exists + actorOutput = existingActor; + } else { + // Create if needed + const createOutput = await driver.createActor({ + c, + name: query.getOrCreateForKey.name, + key: query.getOrCreateForKey.key, + region: query.getOrCreateForKey.region, + }); + actorOutput = { + actorId: createOutput.actorId, + meta: createOutput.meta, + }; + } + } else if ("create" in query) { + const createOutput = await driver.createActor({ + c, + name: query.create.name, + key: query.create.key, + region: query.create.region, + }); + actorOutput = { + actorId: createOutput.actorId, + meta: createOutput.meta, + }; + } else { + throw new errors.InvalidQueryFormat("Invalid query format"); + } + + logger().debug("actor query result", { + actorId: actorOutput.actorId, + meta: actorOutput.meta, + }); + return { actorId: actorOutput.actorId, meta: actorOutput.meta }; +} + +/** Generates a `Next` handler to pass to middleware in order to be able to call arbitrary middleware. */ +function noopNext(): Next { + return async () => {}; +} + +function parseQuery(c: HonoContext): unknown { + // Get query parameters for actor lookup + const queryParam = c.req.query("query"); + if (!queryParam) { + logger().error("missing query parameter for rpc"); + throw new errors.MissingRequiredParameters(["query"]); + } + + // Parse the query JSON and validate with schema + try { + const parsed = JSON.parse(queryParam as string); + return parsed; + } catch (error) { + logger().error("invalid query json for rpc", { error }); + throw new errors.InvalidQueryJSON(error); + } +} diff --git a/packages/actor-core/src/test/driver/manager.ts b/packages/actor-core/src/test/driver/manager.ts index dcbe115ca..eaeb77755 100644 --- a/packages/actor-core/src/test/driver/manager.ts +++ b/packages/actor-core/src/test/driver/manager.ts @@ -29,7 +29,6 @@ export class TestManagerDriver implements ManagerDriver { } async getForId({ - baseUrl, actorId, }: GetForIdInput): Promise { // Validate the actor exists @@ -39,41 +38,73 @@ export class TestManagerDriver implements ManagerDriver { } return { - endpoint: buildActorEndpoint(baseUrl, actorId), + actorId, name: actor.name, key: actor.key, }; } async getWithKey({ - baseUrl, name, key, }: GetWithKeyInput): Promise { // NOTE: This is a slow implementation that checks each actor individually. // This can be optimized with an index in the future. - // Search through all actors to find a match with the same key const actor = this.#state.findActor((actor) => { - if (actor.name !== name) return false; - - // Compare key arrays - if (!actor.key || actor.key.length !== key.length) { + if (actor.name !== name) { return false; } - // Check if all elements in key are in actor.key - for (let i = 0; i < key.length; i++) { - if (key[i] !== actor.key[i]) { + // handle empty key + if (key === null || key === undefined) { + return actor.key === null || actor.key === undefined; + } + + // handle array + if (Array.isArray(key)) { + if (!Array.isArray(actor.key)) { + return false; + } + if (key.length !== actor.key.length) { + return false; + } + // Check if all elements in key are in actor.key + for (let i = 0; i < key.length; i++) { + if (key[i] !== actor.key[i]) { + return false; + } + } + return true; + } + + // Handle object + if (typeof key === "object" && !Array.isArray(key)) { + if (typeof actor.key !== "object" || Array.isArray(actor.key)) { + return false; + } + if (actor.key === null) { return false; } + + // Check if all keys in key are in actor.key + const keyObj = key as Record; + const actorKeyObj = actor.key as unknown as Record; + for (const k in keyObj) { + if (!(k in actorKeyObj) || keyObj[k] !== actorKeyObj[k]) { + return false; + } + } + return true; } - return true; + + // handle scalar + return key === actor.key; }); if (actor) { return { - endpoint: buildActorEndpoint(baseUrl, actor.id), + actorId: actor.id, name, key: actor.key, }; @@ -83,7 +114,6 @@ export class TestManagerDriver implements ManagerDriver { } async createActor({ - baseUrl, name, key, }: CreateActorInput): Promise { @@ -93,11 +123,7 @@ export class TestManagerDriver implements ManagerDriver { this.inspector.onActorsChange(this.#state.getAllActors()); return { - endpoint: buildActorEndpoint(baseUrl, actorId), + actorId, }; } } - -function buildActorEndpoint(baseUrl: string, actorId: string) { - return `${baseUrl}/actors/${actorId}`; -} \ No newline at end of file diff --git a/packages/actor-core/src/topologies/coordinate/topology.ts b/packages/actor-core/src/topologies/coordinate/topology.ts index 5dab1fc81..3b1cd0ab1 100644 --- a/packages/actor-core/src/topologies/coordinate/topology.ts +++ b/packages/actor-core/src/topologies/coordinate/topology.ts @@ -12,6 +12,16 @@ import { handleRouteError, handleRouteNotFound } from "@/common/router"; import type { DriverConfig } from "@/driver-helpers/config"; import type { AppConfig } from "@/app/config"; import { createManagerRouter } from "@/manager/router"; +import type { + ConnectWebSocketOpts, + ConnectSseOpts, + RpcOpts, + ConnsMessageOpts, + ConnectWebSocketOutput, + ConnectSseOutput, + RpcOutput, + ConnectionHandlers, +} from "@/actor/router_endpoints"; export interface GlobalState { nodeId: string; @@ -53,77 +63,71 @@ export class CoordinateTopology { const upgradeWebSocket = driverConfig.getUpgradeWebSocket?.(app); - // Build manager router - const managerRouter = createManagerRouter(appConfig, driverConfig, { - upgradeWebSocket, - onConnectInspector: () => { - throw new errors.Unsupported("inspect"); - }, - }); - - // Forward requests to actor - const actorRouter = createActorRouter(appConfig, driverConfig, { - upgradeWebSocket, - onConnectWebSocket: async (opts) => { - const actorId = opts.req.param("actorId"); - if (!actorId) throw new errors.InternalError("Missing actor ID"); + // Share connection handlers for both routers + const connectionHandlers: ConnectionHandlers = { + onConnectWebSocket: async ( + opts: ConnectWebSocketOpts, + ): Promise => { return await serveWebSocket( appConfig, driverConfig, actorDriver, CoordinateDriver, globalState, - actorId, + opts.actorId, opts, ); }, - onConnectSse: async (opts) => { - const actorId = opts.req.param("actorId"); - if (!actorId) throw new errors.InternalError("Missing actor ID"); + onConnectSse: async (opts: ConnectSseOpts): Promise => { return await serveSse( appConfig, driverConfig, actorDriver, CoordinateDriver, globalState, - actorId, + opts.actorId, opts, ); }, - onRpc: async () => { + onRpc: async (opts: RpcOpts): Promise => { // TODO: throw new errors.InternalError("UNIMPLEMENTED"); }, - onConnMessage: async ({ req, connId, connToken, message }) => { - const actorId = req.param("actorId"); - if (!actorId) throw new errors.InternalError("Missing actor ID"); - + onConnMessage: async (opts: ConnsMessageOpts): Promise => { await publishMessageToLeader( appConfig, driverConfig, CoordinateDriver, globalState, - actorId, + opts.actorId, { b: { lm: { - ai: actorId, - ci: connId, - ct: connToken, - m: message, + ai: opts.actorId, + ci: opts.connId, + ct: opts.connToken, + m: opts.message, }, }, }, - req.raw.signal, + opts.req.raw.signal, ); }, - onConnectInspector: async () => { + }; + + // Build manager router + const managerRouter = createManagerRouter(appConfig, driverConfig, { + proxyMode: { + inline: { + handlers: connectionHandlers, + }, + }, + onConnectInspector: () => { throw new errors.Unsupported("inspect"); }, }); app.route("/", managerRouter); - app.route("/actors/:actorId", actorRouter); app.notFound(handleRouteNotFound); app.onError(handleRouteError); diff --git a/packages/actor-core/src/topologies/partition/toplogy.ts b/packages/actor-core/src/topologies/partition/toplogy.ts index 39ab5e9d7..6ac0366a2 100644 --- a/packages/actor-core/src/topologies/partition/toplogy.ts +++ b/packages/actor-core/src/topologies/partition/toplogy.ts @@ -24,43 +24,62 @@ import type { ActorKey } from "@/common/utils"; import type { DriverConfig } from "@/driver-helpers/config"; import type { AppConfig } from "@/app/config"; import type { ActorInspectorConnection } from "@/inspector/actor"; -import { createManagerRouter } from "@/manager/router"; +import { + createManagerRouter, + OnProxyWebSocket, + type OnProxyRequest, +} from "@/manager/router"; import type { ManagerInspectorConnection } from "@/inspector/manager"; +import type { + ConnectWebSocketOpts, + ConnectSseOpts, + RpcOpts, + ConnsMessageOpts, + ConnectWebSocketOutput, + ConnectSseOutput, + RpcOutput, +} from "@/actor/router_endpoints"; export class PartitionTopologyManager { - router = new Hono(); + router: Hono; - constructor(appConfig: AppConfig, driverConfig: DriverConfig) { - this.router.route( - "/", - createManagerRouter(appConfig, driverConfig, { - upgradeWebSocket: driverConfig.getUpgradeWebSocket?.(this.router), - onConnectInspector: async () => { - const inspector = driverConfig.drivers?.manager?.inspector; - if (!inspector) throw new errors.Unsupported("inspector"); - - let conn: ManagerInspectorConnection | undefined; - return { - onOpen: async (ws) => { - conn = inspector.createConnection(ws); - }, - onMessage: async (message) => { - if (!conn) { - logger().warn("`conn` does not exist"); - return; - } + constructor( + appConfig: AppConfig, + driverConfig: DriverConfig, + proxyCustomConfig: { + onProxyRequest: OnProxyRequest; + onProxyWebSocket: OnProxyWebSocket; + }, + ) { + this.router = createManagerRouter(appConfig, driverConfig, { + proxyMode: { + custom: proxyCustomConfig, + }, + onConnectInspector: async () => { + const inspector = driverConfig.drivers?.manager?.inspector; + if (!inspector) throw new errors.Unsupported("inspector"); + + let conn: ManagerInspectorConnection | undefined; + return { + onOpen: async (ws) => { + conn = inspector.createConnection(ws); + }, + onMessage: async (message) => { + if (!conn) { + logger().warn("`conn` does not exist"); + return; + } - inspector.processMessage(conn, message); - }, - onClose: async () => { - if (conn) { - inspector.removeConnection(conn); - } - }, - }; - }, - }), - ); + inspector.processMessage(conn, message); + }, + onClose: async () => { + if (conn) { + inspector.removeConnection(conn); + } + }, + }; + }, + }); } } @@ -90,15 +109,16 @@ export class PartitionTopologyActor { const genericConnGlobalState = new GenericConnGlobalState(); this.#connDrivers = createGenericConnDrivers(genericConnGlobalState); - // Build actor router - const actorRouter = new Hono(); - - // This route rhas to be mounted at the root since the root router must be passed to `upgradeWebSocket` - actorRouter.route( - "/", - createActorRouter(appConfig, driverConfig, { - upgradeWebSocket: driverConfig.getUpgradeWebSocket?.(actorRouter), - onConnectWebSocket: async ({ req, encoding, params: connParams }) => { + // TODO: Store this actor router globally so we're not re-initializing it for every DO + this.router = createActorRouter(appConfig, driverConfig, { + getActorId: async () => { + if (this.#actorStartedPromise) await this.#actorStartedPromise.promise; + return this.actor.id; + }, + connectionHandlers: { + onConnectWebSocket: async ( + opts: ConnectWebSocketOpts, + ): Promise => { if (this.#actorStartedPromise) await this.#actorStartedPromise.promise; @@ -107,7 +127,7 @@ export class PartitionTopologyActor { const connId = generateConnId(); const connToken = generateConnToken(); - const connState = await actor.prepareConn(connParams, req.raw); + const connState = await actor.prepareConn(opts.params, opts.req.raw); let conn: AnyConn | undefined; return { @@ -119,11 +139,12 @@ export class PartitionTopologyActor { conn = await actor.createConn( connId, connToken, - - connParams, + opts.params, connState, CONN_DRIVER_GENERIC_WEBSOCKET, - { encoding } satisfies GenericWebSocketDriverState, + { + encoding: opts.encoding, + } satisfies GenericWebSocketDriverState, ); }, onMessage: async (message) => { @@ -145,7 +166,9 @@ export class PartitionTopologyActor { }, }; }, - onConnectSse: async ({ req, encoding, params: connParams }) => { + onConnectSse: async ( + opts: ConnectSseOpts, + ): Promise => { if (this.#actorStartedPromise) await this.#actorStartedPromise.promise; @@ -154,7 +177,7 @@ export class PartitionTopologyActor { const connId = generateConnId(); const connToken = generateConnToken(); - const connState = await actor.prepareConn(connParams, req.raw); + const connState = await actor.prepareConn(opts.params, opts.req.raw); let conn: AnyConn | undefined; return { @@ -166,10 +189,10 @@ export class PartitionTopologyActor { conn = await actor.createConn( connId, connToken, - connParams, + opts.params, connState, CONN_DRIVER_GENERIC_SSE, - { encoding } satisfies GenericSseDriverState, + { encoding: opts.encoding } satisfies GenericSseDriverState, ); }, onClose: async () => { @@ -181,7 +204,7 @@ export class PartitionTopologyActor { }, }; }, - onRpc: async ({ req, params: connParams, rpcName, rpcArgs }) => { + onRpc: async (opts: RpcOpts): Promise => { let conn: AnyConn | undefined; try { // Wait for init to finish @@ -192,11 +215,14 @@ export class PartitionTopologyActor { if (!actor) throw new Error("Actor should be defined"); // Create conn - const connState = await actor.prepareConn(connParams, req.raw); + const connState = await actor.prepareConn( + opts.params, + opts.req.raw, + ); conn = await actor.createConn( generateConnId(), generateConnToken(), - connParams, + opts.params, connState, CONN_DRIVER_GENERIC_HTTP, {} satisfies GenericHttpDriverState, @@ -204,7 +230,11 @@ export class PartitionTopologyActor { // Call RPC const ctx = new ActionContext(actor.actorContext!, conn!); - const output = await actor.executeRpc(ctx, rpcName, rpcArgs); + const output = await actor.executeRpc( + ctx, + opts.rpcName, + opts.rpcArgs, + ); return { output }; } finally { @@ -213,7 +243,7 @@ export class PartitionTopologyActor { } } }, - onConnMessage: async ({ connId, connToken, message }) => { + onConnMessage: async (opts: ConnsMessageOpts): Promise => { // Wait for init to finish if (this.#actorStartedPromise) await this.#actorStartedPromise.promise; @@ -222,50 +252,47 @@ export class PartitionTopologyActor { if (!actor) throw new Error("Actor should be defined"); // Find connection - const conn = actor.conns.get(connId); + const conn = actor.conns.get(opts.connId); if (!conn) { - throw new errors.ConnNotFound(connId); + throw new errors.ConnNotFound(opts.connId); } // Authenticate connection - if (conn._token !== connToken) { + if (conn._token !== opts.connToken) { throw new errors.IncorrectConnToken(); } // Process message - await actor.processMessage(message, conn); + await actor.processMessage(opts.message, conn); }, - onConnectInspector: async () => { - if (this.#actorStartedPromise) - await this.#actorStartedPromise.promise; - - const actor = this.#actor; - if (!actor) throw new Error("Actor should be defined"); - - let conn: ActorInspectorConnection | undefined; - return { - onOpen: async (ws) => { - conn = actor.inspector.createConnection(ws); - }, - onMessage: async (message) => { - if (!conn) { - logger().warn("`conn` does not exist"); - return; - } - - actor.inspector.processMessage(conn, message); - }, - onClose: async () => { - if (conn) { - actor.inspector.removeConnection(conn); - } - }, - }; - }, - }), - ); + }, + onConnectInspector: async () => { + if (this.#actorStartedPromise) await this.#actorStartedPromise.promise; + + const actor = this.#actor; + if (!actor) throw new Error("Actor should be defined"); + + let conn: ActorInspectorConnection | undefined; + return { + onOpen: async (ws) => { + conn = actor.inspector.createConnection(ws); + }, + onMessage: async (message) => { + if (!conn) { + logger().warn("`conn` does not exist"); + return; + } - this.router = actorRouter; + actor.inspector.processMessage(conn, message); + }, + onClose: async () => { + if (conn) { + actor.inspector.removeConnection(conn); + } + }, + }; + }, + }); } async start(id: string, name: string, key: ActorKey, region: string) { @@ -294,4 +321,4 @@ export class PartitionTopologyActor { this.#actorStartedPromise?.resolve(); this.#actorStartedPromise = undefined; } -} \ No newline at end of file +} diff --git a/packages/actor-core/src/topologies/standalone/topology.ts b/packages/actor-core/src/topologies/standalone/topology.ts index dfd835f0d..0f5a095c7 100644 --- a/packages/actor-core/src/topologies/standalone/topology.ts +++ b/packages/actor-core/src/topologies/standalone/topology.ts @@ -5,7 +5,6 @@ import { generateConnId, generateConnToken, } from "@/actor/connection"; -import { createActorRouter } from "@/actor/router"; import { logger } from "./log"; import * as errors from "@/actor/errors"; import { @@ -22,8 +21,17 @@ import { ActionContext } from "@/actor/action"; import type { DriverConfig } from "@/driver-helpers/config"; import type { AppConfig } from "@/app/config"; import { createManagerRouter } from "@/manager/router"; -import type { ActorInspectorConnection } from "@/inspector/actor"; import type { ManagerInspectorConnection } from "@/inspector/manager"; +import type { + ConnectWebSocketOpts, + ConnectWebSocketOutput, + ConnectSseOpts, + ConnectSseOutput, + ConnsMessageOpts, + RpcOpts, + RpcOutput, + ConnectionHandlers, +} from "@/actor/router_endpoints"; class ActorHandler { /** Will be undefined if not yet loaded. */ @@ -75,8 +83,6 @@ export class StandaloneTopology { // Load actor meta const actorMetadata = await this.#driverConfig.drivers.manager.getForId({ - // HACK: The endpoint doesn't matter here, so we're passing a bogon IP - baseUrl: "http://192.0.2.0", actorId, }); if (!actorMetadata) throw new Error(`No actor found for ID ${actorId}`); @@ -124,47 +130,16 @@ export class StandaloneTopology { const upgradeWebSocket = driverConfig.getUpgradeWebSocket?.(app); - // Build manager router - const managerRouter = createManagerRouter(appConfig, driverConfig, { - upgradeWebSocket, - onConnectInspector: async () => { - const inspector = driverConfig.drivers?.manager?.inspector; - if (!inspector) throw new errors.Unsupported("inspector"); - - let conn: ManagerInspectorConnection | undefined; - return { - onOpen: async (ws) => { - conn = inspector.createConnection(ws); - }, - onMessage: async (message) => { - if (!conn) { - logger().warn("`conn` does not exist"); - return; - } - - inspector.processMessage(conn, message); - }, - onClose: async () => { - if (conn) { - inspector.removeConnection(conn); - } - }, - }; - }, - }); - - // Build actor router - const actorRouter = createActorRouter(appConfig, driverConfig, { - upgradeWebSocket, - onConnectWebSocket: async ({ req, encoding, params: connParams }) => { - const actorId = req.param("actorId"); - if (!actorId) throw new errors.InternalError("Missing actor ID"); - - const { handler, actor } = await this.#getActor(actorId); + // Create shared connection handlers that will be used by both manager and actor routers + const sharedConnectionHandlers: ConnectionHandlers = { + onConnectWebSocket: async ( + opts: ConnectWebSocketOpts, + ): Promise => { + const { handler, actor } = await this.#getActor(opts.actorId); const connId = generateConnId(); const connToken = generateConnToken(); - const connState = await actor.prepareConn(connParams, req.raw); + const connState = await actor.prepareConn(opts.params, opts.req.raw); let conn: AnyConn | undefined; return { @@ -176,11 +151,10 @@ export class StandaloneTopology { conn = await actor.createConn( connId, connToken, - - connParams, + opts.params, connState, CONN_DRIVER_GENERIC_WEBSOCKET, - { encoding } satisfies GenericWebSocketDriverState, + { encoding: opts.encoding } satisfies GenericWebSocketDriverState, ); }, onMessage: async (message) => { @@ -202,15 +176,12 @@ export class StandaloneTopology { }, }; }, - onConnectSse: async ({ req, encoding, params: connParams }) => { - const actorId = req.param("actorId"); - if (!actorId) throw new errors.InternalError("Missing actor ID"); - - const { handler, actor } = await this.#getActor(actorId); + onConnectSse: async (opts: ConnectSseOpts): Promise => { + const { handler, actor } = await this.#getActor(opts.actorId); const connId = generateConnId(); const connToken = generateConnToken(); - const connState = await actor.prepareConn(connParams, req.raw); + const connState = await actor.prepareConn(opts.params, opts.req.raw); let conn: AnyConn | undefined; return { @@ -222,10 +193,10 @@ export class StandaloneTopology { conn = await actor.createConn( connId, connToken, - connParams, + opts.params, connState, CONN_DRIVER_GENERIC_SSE, - { encoding } satisfies GenericSseDriverState, + { encoding: opts.encoding } satisfies GenericSseDriverState, ); }, onClose: async () => { @@ -237,20 +208,17 @@ export class StandaloneTopology { }, }; }, - onRpc: async ({ req, params: connParams, rpcName, rpcArgs }) => { - const actorId = req.param("actorId"); - if (!actorId) throw new errors.InternalError("Missing actor ID"); - + onRpc: async (opts: RpcOpts): Promise => { let conn: AnyConn | undefined; try { - const { actor } = await this.#getActor(actorId); + const { actor } = await this.#getActor(opts.actorId); // Create conn - const connState = await actor.prepareConn(connParams, req.raw); + const connState = await actor.prepareConn(opts.params, opts.req.raw); conn = await actor.createConn( generateConnId(), generateConnToken(), - connParams, + opts.params, connState, CONN_DRIVER_GENERIC_HTTP, {} satisfies GenericHttpDriverState, @@ -258,46 +226,54 @@ export class StandaloneTopology { // Call RPC const ctx = new ActionContext(actor.actorContext!, conn); - const output = await actor.executeRpc(ctx, rpcName, rpcArgs); + const output = await actor.executeRpc( + ctx, + opts.rpcName, + opts.rpcArgs, + ); return { output }; } finally { if (conn) { - const { actor } = await this.#getActor(actorId); + const { actor } = await this.#getActor(opts.actorId); actor.__removeConn(conn); } } }, - onConnMessage: async ({ req, connId, connToken, message }) => { - const actorId = req.param("actorId"); - if (!actorId) throw new errors.InternalError("Missing actor ID"); - - const { actor } = await this.#getActor(actorId); + onConnMessage: async (opts: ConnsMessageOpts): Promise => { + const { actor } = await this.#getActor(opts.actorId); // Find connection - const conn = actor.conns.get(connId); + const conn = actor.conns.get(opts.connId); if (!conn) { - throw new errors.ConnNotFound(connId); + throw new errors.ConnNotFound(opts.connId); } // Authenticate connection - if (conn._token !== connToken) { + if (conn._token !== opts.connToken) { throw new errors.IncorrectConnToken(); } // Process message - await actor.processMessage(message, conn); + await actor.processMessage(opts.message, conn); }, - onConnectInspector: async ({ req }) => { - const actorId = req.param("actorId"); - if (!actorId) throw new errors.InternalError("Missing actor ID"); + }; - const { actor } = await this.#getActor(actorId); + // Build manager router + const managerRouter = createManagerRouter(appConfig, driverConfig, { + proxyMode: { + inline: { + handlers: sharedConnectionHandlers, + }, + }, + onConnectInspector: async () => { + const inspector = driverConfig.drivers?.manager?.inspector; + if (!inspector) throw new errors.Unsupported("inspector"); - let conn: ActorInspectorConnection | undefined; + let conn: ManagerInspectorConnection | undefined; return { onOpen: async (ws) => { - conn = actor.inspector.createConnection(ws); + conn = inspector.createConnection(ws); }, onMessage: async (message) => { if (!conn) { @@ -305,11 +281,11 @@ export class StandaloneTopology { return; } - actor.inspector.processMessage(conn, message); + inspector.processMessage(conn, message); }, onClose: async () => { if (conn) { - actor.inspector.removeConnection(conn); + inspector.removeConnection(conn); } }, }; @@ -317,9 +293,7 @@ export class StandaloneTopology { }); app.route("/", managerRouter); - // Mount the actor router - app.route("/actors/:actorId", actorRouter); this.router = app; } -} \ No newline at end of file +} diff --git a/packages/actor-core/tests/action-timeout.test.ts b/packages/actor-core/tests/action-timeout.test.ts index ac88d3647..a45f3dea3 100644 --- a/packages/actor-core/tests/action-timeout.test.ts +++ b/packages/actor-core/tests/action-timeout.test.ts @@ -40,7 +40,7 @@ describe("Action Timeout", () => { }); const { client } = await setupTest(c, app); - const instance = await client.timeoutActor.connect(); + const instance = client.timeoutActor.connect(); // The quick action should complete successfully const quickResult = await instance.quickAction(); @@ -72,7 +72,7 @@ describe("Action Timeout", () => { }); const { client } = await setupTest(c, app); - const instance = await client.defaultTimeoutActor.connect(); + const instance = client.defaultTimeoutActor.connect(); // This action should complete successfully const result = await instance.normalAction(); @@ -101,7 +101,7 @@ describe("Action Timeout", () => { }); const { client } = await setupTest(c, app); - const instance = await client.syncActor.connect(); + const instance = client.syncActor.connect(); // Synchronous action should not be affected by timeout const result = await instance.syncAction(); @@ -169,13 +169,13 @@ describe("Action Timeout", () => { const { client } = await setupTest(c, app); // The short timeout actor should fail - const shortInstance = await client.shortTimeoutActor.connect(); + const shortInstance = client.shortTimeoutActor.connect(); await expect(shortInstance.delayedAction()).rejects.toThrow( "Action timed out.", ); // The longer timeout actor should succeed - const longerInstance = await client.longerTimeoutActor.connect(); + const longerInstance = client.longerTimeoutActor.connect(); const result = await longerInstance.delayedAction(); expect(result).toBe("delayed response"); }); diff --git a/packages/actor-core/tests/action-types.test.ts b/packages/actor-core/tests/action-types.test.ts index 9f68625eb..686bc5fb7 100644 --- a/packages/actor-core/tests/action-types.test.ts +++ b/packages/actor-core/tests/action-types.test.ts @@ -31,7 +31,7 @@ describe("Action Types", () => { }); const { client } = await setupTest(c, app); - const instance = await client.syncActor.connect(); + const instance = client.syncActor.connect(); // Test increment action let result = await instance.increment(5); @@ -101,7 +101,7 @@ describe("Action Types", () => { }); const { client } = await setupTest(c, app); - const instance = await client.asyncActor.connect(); + const instance = client.asyncActor.connect(); // Test delayed increment const result = await instance.delayedIncrement(5); @@ -155,7 +155,7 @@ describe("Action Types", () => { }); const { client } = await setupTest(c, app); - const instance = await client.promiseActor.connect(); + const instance = client.promiseActor.connect(); // Test resolved promise const resolvedValue = await instance.resolvedPromise(); diff --git a/packages/actor-core/tests/basic.test.ts b/packages/actor-core/tests/basic.test.ts index cae0fdbdf..dd6544e03 100644 --- a/packages/actor-core/tests/basic.test.ts +++ b/packages/actor-core/tests/basic.test.ts @@ -20,6 +20,6 @@ test("basic actor setup", async (c) => { const { client } = await setupTest(c, app); - const counterInstance = await client.counter.connect(); + const counterInstance = client.counter.connect(); await counterInstance.increment(1); }); diff --git a/packages/actor-core/tests/vars.test.ts b/packages/actor-core/tests/vars.test.ts index 6bfe89ea0..d94c26530 100644 --- a/packages/actor-core/tests/vars.test.ts +++ b/packages/actor-core/tests/vars.test.ts @@ -25,7 +25,7 @@ describe("Actor Vars", () => { }); const { client } = await setupTest(c, app); - const instance = await client.varActor.connect(); + const instance = client.varActor.connect(); // Test accessing vars const result = await instance.getVars(); @@ -72,10 +72,10 @@ describe("Actor Vars", () => { const { client } = await setupTest(c, app); // Create two separate instances - const instance1 = await client.nestedVarActor.connect( + const instance1 = client.nestedVarActor.connect( ["instance1"] ); - const instance2 = await client.nestedVarActor.connect( + const instance2 = client.nestedVarActor.connect( ["instance2"] ); @@ -119,7 +119,7 @@ describe("Actor Vars", () => { const { client } = await setupTest(c, app); // Create an instance - const instance = await client.dynamicVarActor.connect(); + const instance = client.dynamicVarActor.connect(); // Test accessing dynamically created vars const vars = await instance.getVars(); @@ -154,10 +154,10 @@ describe("Actor Vars", () => { const { client } = await setupTest(c, app); // Create two separate instances - const instance1 = await client.uniqueVarActor.connect( + const instance1 = client.uniqueVarActor.connect( ["test1"] ); - const instance2 = await client.uniqueVarActor.connect( + const instance2 = client.uniqueVarActor.connect( ["test2"] ); @@ -204,7 +204,7 @@ describe("Actor Vars", () => { const { client } = await setupTest(c, app); // Create an instance - const instance = await client.driverCtxActor.connect(); + const instance = client.driverCtxActor.connect(); // Test accessing driver context through vars const vars = await instance.getVars(); diff --git a/packages/actor-core/tsconfig.json b/packages/actor-core/tsconfig.json index b30a98e6e..e3ebd02f1 100644 --- a/packages/actor-core/tsconfig.json +++ b/packages/actor-core/tsconfig.json @@ -1,6 +1,7 @@ { "extends": "../../tsconfig.base.json", "compilerOptions": { + "types": ["deno", "node"], "paths": { "@/*": ["./src/*"] } diff --git a/packages/actor-core/tsup.config.bundled_xvi1jgwbzx.mjs b/packages/actor-core/tsup.config.bundled_xvi1jgwbzx.mjs new file mode 100644 index 000000000..e01a2a057 --- /dev/null +++ b/packages/actor-core/tsup.config.bundled_xvi1jgwbzx.mjs @@ -0,0 +1,22 @@ +// ../../tsup.base.ts +var tsup_base_default = { + target: "node16", + platform: "node", + format: ["cjs", "esm"], + sourcemap: true, + clean: true, + dts: true, + minify: false, + // IMPORTANT: Splitting is required to fix a bug with ESM (https://github.com/egoist/tsup/issues/992#issuecomment-1763540165) + splitting: true, + skipNodeModulesBundle: true, + publicDir: true +}; + +// tsup.config.ts +import { defineConfig } from "tsup"; +var tsup_config_default = defineConfig(tsup_base_default); +export { + tsup_config_default as default +}; +//# sourceMappingURL=data:application/json;base64,ewogICJ2ZXJzaW9uIjogMywKICAic291cmNlcyI6IFsiLi4vLi4vdHN1cC5iYXNlLnRzIiwgInRzdXAuY29uZmlnLnRzIl0sCiAgInNvdXJjZXNDb250ZW50IjogWyJjb25zdCBfX2luamVjdGVkX2ZpbGVuYW1lX18gPSBcIi9Vc2Vycy9uYXRoYW4vcml2ZXQvYWN0b3ItY29yZS90c3VwLmJhc2UudHNcIjtjb25zdCBfX2luamVjdGVkX2Rpcm5hbWVfXyA9IFwiL1VzZXJzL25hdGhhbi9yaXZldC9hY3Rvci1jb3JlXCI7Y29uc3QgX19pbmplY3RlZF9pbXBvcnRfbWV0YV91cmxfXyA9IFwiZmlsZTovLy9Vc2Vycy9uYXRoYW4vcml2ZXQvYWN0b3ItY29yZS90c3VwLmJhc2UudHNcIjtpbXBvcnQgdHlwZSB7IE9wdGlvbnMgfSBmcm9tIFwidHN1cFwiO1xuXG5leHBvcnQgZGVmYXVsdCB7XG5cdHRhcmdldDogXCJub2RlMTZcIixcblx0cGxhdGZvcm06IFwibm9kZVwiLFxuXHRmb3JtYXQ6IFtcImNqc1wiLCBcImVzbVwiXSxcblx0c291cmNlbWFwOiB0cnVlLFxuXHRjbGVhbjogdHJ1ZSxcblx0ZHRzOiB0cnVlLFxuXHRtaW5pZnk6IGZhbHNlLFxuXHQvLyBJTVBPUlRBTlQ6IFNwbGl0dGluZyBpcyByZXF1aXJlZCB0byBmaXggYSBidWcgd2l0aCBFU00gKGh0dHBzOi8vZ2l0aHViLmNvbS9lZ29pc3QvdHN1cC9pc3N1ZXMvOTkyI2lzc3VlY29tbWVudC0xNzYzNTQwMTY1KVxuXHRzcGxpdHRpbmc6IHRydWUsXG5cdHNraXBOb2RlTW9kdWxlc0J1bmRsZTogdHJ1ZSxcblx0cHVibGljRGlyOiB0cnVlLFxufSBzYXRpc2ZpZXMgT3B0aW9ucztcbiIsICJjb25zdCBfX2luamVjdGVkX2ZpbGVuYW1lX18gPSBcIi9Vc2Vycy9uYXRoYW4vcml2ZXQvYWN0b3ItY29yZS9wYWNrYWdlcy9hY3Rvci1jb3JlL3RzdXAuY29uZmlnLnRzXCI7Y29uc3QgX19pbmplY3RlZF9kaXJuYW1lX18gPSBcIi9Vc2Vycy9uYXRoYW4vcml2ZXQvYWN0b3ItY29yZS9wYWNrYWdlcy9hY3Rvci1jb3JlXCI7Y29uc3QgX19pbmplY3RlZF9pbXBvcnRfbWV0YV91cmxfXyA9IFwiZmlsZTovLy9Vc2Vycy9uYXRoYW4vcml2ZXQvYWN0b3ItY29yZS9wYWNrYWdlcy9hY3Rvci1jb3JlL3RzdXAuY29uZmlnLnRzXCI7aW1wb3J0IGRlZmF1bHRDb25maWcgZnJvbSBcIi4uLy4uL3RzdXAuYmFzZS50c1wiO1xuaW1wb3J0IHsgZGVmaW5lQ29uZmlnIH0gZnJvbSBcInRzdXBcIjtcblxuZXhwb3J0IGRlZmF1bHQgZGVmaW5lQ29uZmlnKGRlZmF1bHRDb25maWcpO1xuIl0sCiAgIm1hcHBpbmdzIjogIjtBQUVBLElBQU8sb0JBQVE7QUFBQSxFQUNkLFFBQVE7QUFBQSxFQUNSLFVBQVU7QUFBQSxFQUNWLFFBQVEsQ0FBQyxPQUFPLEtBQUs7QUFBQSxFQUNyQixXQUFXO0FBQUEsRUFDWCxPQUFPO0FBQUEsRUFDUCxLQUFLO0FBQUEsRUFDTCxRQUFRO0FBQUE7QUFBQSxFQUVSLFdBQVc7QUFBQSxFQUNYLHVCQUF1QjtBQUFBLEVBQ3ZCLFdBQVc7QUFDWjs7O0FDYkEsU0FBUyxvQkFBb0I7QUFFN0IsSUFBTyxzQkFBUSxhQUFhLGlCQUFhOyIsCiAgIm5hbWVzIjogW10KfQo= diff --git a/packages/drivers/file-system/src/manager.ts b/packages/drivers/file-system/src/manager.ts index e9fe12f31..1fad92cc0 100644 --- a/packages/drivers/file-system/src/manager.ts +++ b/packages/drivers/file-system/src/manager.ts @@ -31,7 +31,6 @@ export class FileSystemManagerDriver implements ManagerDriver { } async getForId({ - baseUrl, actorId, }: GetForIdInput): Promise { // Validate the actor exists @@ -44,9 +43,10 @@ export class FileSystemManagerDriver implements ManagerDriver { const state = this.#state.loadActorState(actorId); return { - endpoint: buildActorEndpoint(baseUrl, actorId), + actorId, name: state.name, key: state.key, + meta: undefined, }; } catch (error) { logger().error("failed to read actor state", { actorId, error }); @@ -55,7 +55,6 @@ export class FileSystemManagerDriver implements ManagerDriver { } async getWithKey({ - baseUrl, name, key, }: GetWithKeyInput): Promise { @@ -65,12 +64,12 @@ export class FileSystemManagerDriver implements ManagerDriver { // Search through all actors to find a match const actor = this.#state.findActor((actor) => { if (actor.name !== name) return false; - + // If actor doesn't have a key, it's not a match if (!actor.key || actor.key.length !== key.length) { return false; } - + // Check if all elements in key are in actor.key for (let i = 0; i < key.length; i++) { if (key[i] !== actor.key[i]) { @@ -82,9 +81,10 @@ export class FileSystemManagerDriver implements ManagerDriver { if (actor) { return { - endpoint: buildActorEndpoint(baseUrl, actor.id), + actorId: actor.id, name, key: actor.key, + meta: undefined, }; } @@ -92,22 +92,18 @@ export class FileSystemManagerDriver implements ManagerDriver { } async createActor({ - baseUrl, name, key, }: CreateActorInput): Promise { const actorId = crypto.randomUUID(); await this.#state.createActor(actorId, name, key); - + // Notify inspector about actor changes this.inspector.onActorsChange(this.#state.getAllActors()); - + return { - endpoint: buildActorEndpoint(baseUrl, actorId), + actorId, + meta: undefined, }; } } - -function buildActorEndpoint(baseUrl: string, actorId: string) { - return `${baseUrl}/actors/${actorId}`; -} diff --git a/packages/drivers/memory/src/manager.ts b/packages/drivers/memory/src/manager.ts index 8750617cc..66ba1e7d2 100644 --- a/packages/drivers/memory/src/manager.ts +++ b/packages/drivers/memory/src/manager.ts @@ -29,7 +29,6 @@ export class MemoryManagerDriver implements ManagerDriver { } async getForId({ - baseUrl, actorId, }: GetForIdInput): Promise { // Validate the actor exists @@ -39,14 +38,14 @@ export class MemoryManagerDriver implements ManagerDriver { } return { - endpoint: buildActorEndpoint(baseUrl, actorId), + actorId: actor.id, name: actor.name, key: actor.key, + meta: undefined, }; } async getWithKey({ - baseUrl, name, key, }: GetWithKeyInput): Promise { @@ -56,12 +55,12 @@ export class MemoryManagerDriver implements ManagerDriver { // Search through all actors to find a match const actor = this.#state.findActor((actor) => { if (actor.name !== name) return false; - + // If actor doesn't have a key, it's not a match if (!actor.key || actor.key.length !== key.length) { return false; } - + // Check if all elements in key are in actor.key for (let i = 0; i < key.length; i++) { if (key[i] !== actor.key[i]) { @@ -73,9 +72,10 @@ export class MemoryManagerDriver implements ManagerDriver { if (actor) { return { - endpoint: buildActorEndpoint(baseUrl, actor.id), + actorId: actor.id, name, key: actor.key, + meta: undefined, }; } @@ -83,7 +83,6 @@ export class MemoryManagerDriver implements ManagerDriver { } async createActor({ - baseUrl, name, key, }: CreateActorInput): Promise { @@ -92,12 +91,6 @@ export class MemoryManagerDriver implements ManagerDriver { this.inspector.onActorsChange(this.#state.getAllActors()); - return { - endpoint: buildActorEndpoint(baseUrl, actorId), - }; + return { actorId, meta: undefined }; } } - -function buildActorEndpoint(baseUrl: string, actorId: string) { - return `${baseUrl}/actors/${actorId}`; -} diff --git a/packages/drivers/redis/src/manager.ts b/packages/drivers/redis/src/manager.ts index 927328f8c..7ba8ed060 100644 --- a/packages/drivers/redis/src/manager.ts +++ b/packages/drivers/redis/src/manager.ts @@ -53,7 +53,6 @@ export class RedisManagerDriver implements ManagerDriver { } async getForId({ - baseUrl, actorId, }: GetForIdInput): Promise { // Get metadata from Redis @@ -68,14 +67,14 @@ export class RedisManagerDriver implements ManagerDriver { const { name, key } = metadata; return { - endpoint: buildActorEndpoint(baseUrl, actorId), + actorId, name, key, + meta: undefined, }; } async getWithKey({ - baseUrl, name, key, }: GetWithKeyInput): Promise { @@ -87,11 +86,10 @@ export class RedisManagerDriver implements ManagerDriver { return undefined; } - return this.getForId({ baseUrl, actorId }); + return this.getForId({ actorId }); } async createActor({ - baseUrl, name, key, }: CreateActorInput): Promise { @@ -121,7 +119,8 @@ export class RedisManagerDriver implements ManagerDriver { ]); return { - endpoint: buildActorEndpoint(baseUrl, actorId.toString()), + actorId, + meta: undefined, }; } @@ -170,9 +169,4 @@ export class RedisManagerDriver implements ManagerDriver { .replace(/\\/g, "\\\\") // Escape backslashes first .replace(/:/g, "\\:"); // Escape colons (our delimiter) } -} - -function buildActorEndpoint(baseUrl: string, actorId: string) { - return `${baseUrl}/actors/${actorId}`; -} - +} \ No newline at end of file diff --git a/packages/drivers/redis/tests/driver-tests.test.ts b/packages/drivers/redis/tests/driver-tests.test.ts index 71f018908..94ee79a74 100644 --- a/packages/drivers/redis/tests/driver-tests.test.ts +++ b/packages/drivers/redis/tests/driver-tests.test.ts @@ -120,7 +120,6 @@ test("Valkey container starts and stops properly", async () => { host: "localhost", connectTimeout: 1000, }); - await newRedis.connect(); await newRedis.quit(); throw new Error("Valkey connection should have failed"); } catch (error) { diff --git a/packages/misc/driver-test-suite/src/tests/actor-driver.ts b/packages/misc/driver-test-suite/src/tests/actor-driver.ts index 6aec43f08..66d702b03 100644 --- a/packages/misc/driver-test-suite/src/tests/actor-driver.ts +++ b/packages/misc/driver-test-suite/src/tests/actor-driver.ts @@ -31,12 +31,12 @@ export function runActorDriverTests(driverTestConfig: DriverTestConfig) { ); // Create instance and increment - const counterInstance = await client.counter.connect(); + const counterInstance = client.counter.connect(); const initialCount = await counterInstance.increment(5); expect(initialCount).toBe(5); // Get a fresh reference to the same actor and verify state persisted - const sameInstance = await client.counter.connect(); + const sameInstance = client.counter.connect(); const persistedCount = await sameInstance.increment(3); expect(persistedCount).toBe(8); }); @@ -49,14 +49,14 @@ export function runActorDriverTests(driverTestConfig: DriverTestConfig) { ); // Create actor and set initial state - const counterInstance = await client.counter.connect(); + const counterInstance = client.counter.connect(); await counterInstance.increment(5); // Disconnect the actor await counterInstance.dispose(); // Reconnect to the same actor - const reconnectedInstance = await client.counter.connect(); + const reconnectedInstance = client.counter.connect(); const persistedCount = await reconnectedInstance.increment(0); expect(persistedCount).toBe(5); }); @@ -69,11 +69,11 @@ export function runActorDriverTests(driverTestConfig: DriverTestConfig) { ); // Create first counter with specific key - const counterA = await client.counter.connect(["counter-a"]); + const counterA = client.counter.connect(["counter-a"]); await counterA.increment(5); // Create second counter with different key - const counterB = await client.counter.connect(["counter-b"]); + const counterB = client.counter.connect(["counter-b"]); await counterB.increment(10); // Verify state is separate @@ -93,7 +93,7 @@ export function runActorDriverTests(driverTestConfig: DriverTestConfig) { ); // Create instance - const alarmInstance = await client.scheduled.connect(); + const alarmInstance = client.scheduled.connect(); // Schedule a task to run in 100ms await alarmInstance.scheduleTask(100); diff --git a/packages/misc/driver-test-suite/src/tests/manager-driver.ts b/packages/misc/driver-test-suite/src/tests/manager-driver.ts index dec88f427..1e66ce973 100644 --- a/packages/misc/driver-test-suite/src/tests/manager-driver.ts +++ b/packages/misc/driver-test-suite/src/tests/manager-driver.ts @@ -1,8 +1,9 @@ -import { describe, test, expect } from "vitest"; -import type { DriverTestConfig } from "@/mod"; +import { describe, test, expect, vi } from "vitest"; +import { waitFor, type DriverTestConfig } from "@/mod"; import { setupDriverTest } from "@/utils"; import { resolve } from "node:path"; import type { App as CounterApp } from "../../fixtures/apps/counter"; +import { ConnectionError } from "actor-core/client"; export function runManagerDriverTests(driverTestConfig: DriverTestConfig) { describe("Manager Driver Tests", () => { @@ -15,66 +16,65 @@ export function runManagerDriverTests(driverTestConfig: DriverTestConfig) { ); // Basic connect() with no parameters creates a default actor - const counterA = await client.counter.connect(); + const counterA = client.counter.connect(); await counterA.increment(5); // Get the same actor again to verify state persisted - const counterAAgain = await client.counter.connect(); + const counterAAgain = client.counter.connect(); const count = await counterAAgain.increment(0); expect(count).toBe(5); // Connect with key creates a new actor with specific parameters - const counterB = await client.counter.connect(["counter-b", "testing"]); + const counterB = client.counter.connect(["counter-b", "testing"]); await counterB.increment(10); const countB = await counterB.increment(0); expect(countB).toBe(10); }); - test("create() - always creates a new actor", async (c) => { - const { client } = await setupDriverTest( - c, - driverTestConfig, - resolve(__dirname, "../fixtures/apps/counter.ts"), - ); - - // Create with basic options - const counterA = await client.counter.createAndConnect([ - "explicit-create", - ]); - await counterA.increment(7); - - // Create with the same ID should overwrite or return a conflict - try { - // Should either create a new actor with the same ID (overwriting) - // or throw an error (if the driver prevents ID conflicts) - const counterADuplicate = await client.counter.connect(undefined, { - create: { - key: ["explicit-create"], - }, - }); - await counterADuplicate.increment(1); - - // If we get here, the driver allows ID overwrites - // Verify that state was reset or overwritten - const newCount = await counterADuplicate.increment(0); - expect(newCount).toBe(1); // Not 8 (7+1) if it's a new instance - } catch (error) { - // This is also valid behavior if the driver prevents ID conflicts - // No assertion needed - } - - // Create with full options - const counterB = await client.counter.createAndConnect([ - "full-options", - "testing", - "counter", - ]); - - await counterB.increment(3); - const countB = await counterB.increment(0); - expect(countB).toBe(3); - }); + // TODO: Add back, createAndConnect is not valid logic + //test("create() - always creates a new actor", async (c) => { + // const { client } = await setupDriverTest( + // c, + // driverTestConfig, + // resolve(__dirname, "../fixtures/apps/counter.ts"), + // ); + // + // // Create with basic options + // const counterA = await client.counter.createAndConnect([ + // "explicit-create", + // ]); + // await counterA.increment(7); + // + // // Create with the same ID should overwrite or return a conflict + // try { + // // Should either create a new actor with the same ID (overwriting) + // // or throw an error (if the driver prevents ID conflicts) + // const counterADuplicate = client.counter.createAndConnect([ + // "explicit-create", + // ]); + // await counterADuplicate.increment(1); + // + // // If we get here, the driver allows ID overwrites + // // Verify that state was reset or overwritten + // const newCount = await counterADuplicate.increment(0); + // expect(newCount).toBe(1); // Not 8 (7+1) if it's a new instance + // } catch (error) { + // // This is also valid behavior if the driver prevents ID conflicts + // // No assertion needed + // } + // + // // Create with full options + // const counterB = await client.counter.createAndConnect([ + // "full-options", + // "testing", + // "counter", + // ]); + // + // await counterB.increment(3); + // const countB = await counterB.increment(0); + // expect(countB).toBe(3); + //}); }); describe("Connection Options", () => { @@ -86,31 +86,29 @@ export function runManagerDriverTests(driverTestConfig: DriverTestConfig) { ); // Try to get a nonexistent actor with noCreate - const nonexistentId = `nonexistent-${Date.now()}`; + const nonexistentId = `nonexistent-${crypto.randomUUID()}`; // Should fail when actor doesn't exist - let error: unknown; - try { - await client.counter.connect([nonexistentId], { - noCreate: true, - }); - } catch (err) { - error = err; - } - - // Verify we got an error - expect(error).toBeTruthy(); + let counter1Error: ConnectionError; + const counter1 = client.counter.connect([nonexistentId], { + noCreate: true, + }); + counter1.onError((e) => { + counter1Error = e; + }); + await vi.waitFor( + () => expect(counter1Error).toBeInstanceOf(ConnectionError), + 500, + ); + await counter1.dispose(); // Create the actor - const counter = await client.counter.connect(undefined, { - create: { - key: [nonexistentId], - }, - }); - await counter.increment(3); + const createdCounter = client.counter.connect(nonexistentId); + await createdCounter.increment(3); + await createdCounter.dispose(); // Now noCreate should work since the actor exists - const retrievedCounter = await client.counter.connect([nonexistentId], { + const retrievedCounter = client.counter.connect(nonexistentId, { noCreate: true, }); @@ -129,7 +127,7 @@ export function runManagerDriverTests(driverTestConfig: DriverTestConfig) { // Note: In a real test we'd verify these are received by the actor, // but our simple counter actor doesn't use connection params. // This test just ensures the params are accepted by the driver. - const counter = await client.counter.connect(undefined, { + const counter = client.counter.connect(undefined, { params: { userId: "user-123", authToken: "token-abc", @@ -152,14 +150,14 @@ export function runManagerDriverTests(driverTestConfig: DriverTestConfig) { ); // Create a unique ID for this test - const uniqueId = `test-counter-${Date.now()}`; + const uniqueId = `test-counter-${crypto.randomUUID()}`; // Create actor with specific ID - const counter = await client.counter.connect([uniqueId]); + const counter = client.counter.connect([uniqueId]); await counter.increment(10); // Retrieve the same actor by ID and verify state - const retrievedCounter = await client.counter.connect([uniqueId]); + const retrievedCounter = client.counter.connect([uniqueId]); const count = await retrievedCounter.increment(0); // Get current value expect(count).toBe(10); }); @@ -172,7 +170,7 @@ export function runManagerDriverTests(driverTestConfig: DriverTestConfig) { // ); // // // Create actor with a specific region - // const counter = await client.counter.connect({ + // const counter = client.counter.connect({ // create: { // key: ["metadata-test", "testing"], // region: "test-region", @@ -183,7 +181,7 @@ export function runManagerDriverTests(driverTestConfig: DriverTestConfig) { // await counter.increment(42); // // // Retrieve by ID (since metadata is not used for retrieval) - // const retrievedCounter = await client.counter.connect(["metadata-test"]); + // const retrievedCounter = client.counter.connect(["metadata-test"]); // // // Verify it's the same instance // const count = await retrievedCounter.increment(0); @@ -192,7 +190,7 @@ export function runManagerDriverTests(driverTestConfig: DriverTestConfig) { }); describe("Key Matching", () => { - test("finds actors with equal or superset of specified keys", async (c) => { + test("matches actors only with exactly the same keys", async (c) => { const { client } = await setupDriverTest( c, driverTestConfig, @@ -200,7 +198,7 @@ export function runManagerDriverTests(driverTestConfig: DriverTestConfig) { ); // Create actor with multiple keys - const originalCounter = await client.counter.connect([ + const originalCounter = client.counter.connect([ "counter-match", "test", "us-east", @@ -208,7 +206,7 @@ export function runManagerDriverTests(driverTestConfig: DriverTestConfig) { await originalCounter.increment(10); // Should match with exact same keys - const exactMatchCounter = await client.counter.connect([ + const exactMatchCounter = client.counter.connect([ "counter-match", "test", "us-east", @@ -216,109 +214,96 @@ export function runManagerDriverTests(driverTestConfig: DriverTestConfig) { const exactMatchCount = await exactMatchCounter.increment(0); expect(exactMatchCount).toBe(10); - // Should match with subset of keys - const subsetMatchCounter = await client.counter.connect([ + // Should NOT match with subset of keys - should create new actor + const subsetMatchCounter = client.counter.connect([ "counter-match", "test", ]); const subsetMatchCount = await subsetMatchCounter.increment(0); - expect(subsetMatchCount).toBe(10); + expect(subsetMatchCount).toBe(0); // Should be a new counter with 0 - // Should match with just one key - const singleKeyCounter = await client.counter.connect([ - "counter-match", - ]); + // Should NOT match with just one key - should create new actor + const singleKeyCounter = client.counter.connect(["counter-match"]); const singleKeyCount = await singleKeyCounter.increment(0); - expect(singleKeyCount).toBe(10); + expect(singleKeyCount).toBe(0); // Should be a new counter with 0 }); - test("no keys match actors with keys", async (c) => { + test("string key matches array with single string key", async (c) => { const { client } = await setupDriverTest( c, driverTestConfig, resolve(__dirname, "../fixtures/apps/counter.ts"), ); - // Create counter with keys - const keyedCounter = await client.counter.connect([ - "counter-with-keys", - "special", - ]); - await keyedCounter.increment(15); - - // Should match when searching with no keys - const noKeysCounter = await client.counter.connect(); - const count = await noKeysCounter.increment(0); + // Create actor with string key + const stringKeyCounter = client.counter.connect("string-key-test"); + await stringKeyCounter.increment(7); - // Should have matched existing actor - expect(count).toBe(15); + // Should match with equivalent array key + const arrayKeyCounter = client.counter.connect(["string-key-test"]); + const count = await arrayKeyCounter.increment(0); + expect(count).toBe(7); }); - test("actors with keys match actors with no keys", async (c) => { + test("undefined key matches empty array key and no key", async (c) => { const { client } = await setupDriverTest( c, driverTestConfig, resolve(__dirname, "../fixtures/apps/counter.ts"), ); - // Create a counter with no keys - const noKeysCounter = await client.counter.connect(); - await noKeysCounter.increment(25); + // Create actor with undefined key + const undefinedKeyCounter = client.counter.connect(undefined); + await undefinedKeyCounter.increment(12); - // Get counter with keys - should create a new one - const keyedCounter = await client.counter.connect([ - "new-counter", - "prod", - ]); - const keyedCount = await keyedCounter.increment(0); + // Should match with empty array key + const emptyArrayKeyCounter = client.counter.connect([]); + const emptyArrayCount = await emptyArrayKeyCounter.increment(0); + expect(emptyArrayCount).toBe(12); - // Should be a new counter, not the one created above - expect(keyedCount).toBe(0); + // Should match with no key + const noKeyCounter = client.counter.connect(); + const noKeyCount = await noKeyCounter.increment(0); + expect(noKeyCount).toBe(12); }); - test("specifying different keys for connect and create results in the expected keys", async (c) => { + test("no keys does not match actors with keys", async (c) => { const { client } = await setupDriverTest( c, driverTestConfig, resolve(__dirname, "../fixtures/apps/counter.ts"), ); - // Create a counter with specific create keys - const counter = await client.counter.connect(["key-test", "test"], { - create: { - key: ["key-test", "test", "1.0"], - }, - }); - await counter.increment(5); - - // Should match when searching with original search keys - const foundWithSearchKeys = await client.counter.connect([ - "key-test", - "test", + // Create counter with keys + const keyedCounter = client.counter.connect([ + "counter-with-keys", + "special", ]); - const countWithSearchKeys = await foundWithSearchKeys.increment(0); - expect(countWithSearchKeys).toBe(5); + await keyedCounter.increment(15); - // Should also match when searching with any subset of the create keys - const foundWithExtraKeys = await client.counter.connect([ - "key-test", - "1.0", - ]); - const countWithExtraKeys = await foundWithExtraKeys.increment(0); - expect(countWithExtraKeys).toBe(5); + // Should not match when searching with no keys + const noKeysCounter = client.counter.connect(); + const count = await noKeysCounter.increment(10); + expect(count).toBe(10); + }); - // Create a new counter with just search keys but different create keys - const newCounter = await client.counter.connect(["secondary"], { - create: { - key: ["secondary", "low", "true"], - }, - }); - await newCounter.increment(10); + test("actors with keys match actors with no keys", async (c) => { + const { client } = await setupDriverTest( + c, + driverTestConfig, + resolve(__dirname, "../fixtures/apps/counter.ts"), + ); - // Should not find when searching with keys not in create keys - const notFound = await client.counter.connect(["secondary", "active"]); - const notFoundCount = await notFound.increment(0); - expect(notFoundCount).toBe(0); // New counter + // Create a counter with no keys + const noKeysCounter = client.counter.connect(); + await noKeysCounter.increment(25); + + // Get counter with keys - should create a new one + const keyedCounter = client.counter.connect(["new-counter", "prod"]); + const keyedCount = await keyedCounter.increment(0); + + // Should be a new counter, not the one created above + expect(keyedCount).toBe(0); }); }); @@ -331,9 +316,9 @@ export function runManagerDriverTests(driverTestConfig: DriverTestConfig) { // ); // // // Create multiple instances with different IDs - // const instance1 = await client.counter.connect(["multi-1"]); - // const instance2 = await client.counter.connect(["multi-2"]); - // const instance3 = await client.counter.connect(["multi-3"]); + // const instance1 = client.counter.connect(["multi-1"]); + // const instance2 = client.counter.connect(["multi-2"]); + // const instance3 = client.counter.connect(["multi-3"]); // // // Set different states // await instance1.increment(1); @@ -341,9 +326,9 @@ export function runManagerDriverTests(driverTestConfig: DriverTestConfig) { // await instance3.increment(3); // // // Retrieve all instances again - // const retrieved1 = await client.counter.connect(["multi-1"]); - // const retrieved2 = await client.counter.connect(["multi-2"]); - // const retrieved3 = await client.counter.connect(["multi-3"]); + // const retrieved1 = client.counter.connect(["multi-1"]); + // const retrieved2 = client.counter.connect(["multi-2"]); + // const retrieved3 = client.counter.connect(["multi-3"]); // // // Verify separate state // expect(await retrieved1.increment(0)).toBe(1); @@ -359,13 +344,13 @@ export function runManagerDriverTests(driverTestConfig: DriverTestConfig) { ); // Get default instance (no ID specified) - const defaultCounter = await client.counter.connect(); + const defaultCounter = client.counter.connect(); // Set state await defaultCounter.increment(5); // Get default instance again - const sameDefaultCounter = await client.counter.connect(); + const sameDefaultCounter = client.counter.connect(); // Verify state is maintained const count = await sameDefaultCounter.increment(0); diff --git a/packages/misc/driver-test-suite/vitest.config.ts b/packages/misc/driver-test-suite/vitest.config.ts index 3e3d05481..87909ac20 100644 --- a/packages/misc/driver-test-suite/vitest.config.ts +++ b/packages/misc/driver-test-suite/vitest.config.ts @@ -4,10 +4,13 @@ import { resolve } from "path"; export default defineConfig({ ...defaultConfig, + test: { + ...defaultConfig.test, + maxConcurrency: 1, + }, resolve: { alias: { "@": resolve(__dirname, "./src"), }, }, }); - diff --git a/packages/platforms/cloudflare-workers/src/actor_handler_do.ts b/packages/platforms/cloudflare-workers/src/actor_handler_do.ts index e4bf1dac7..7f465b259 100644 --- a/packages/platforms/cloudflare-workers/src/actor_handler_do.ts +++ b/packages/platforms/cloudflare-workers/src/actor_handler_do.ts @@ -101,8 +101,6 @@ export function createActorDurableObject( if (!config.drivers.actor) { config.drivers.actor = new CloudflareWorkersActorDriver(globalState); } - if (!config.getUpgradeWebSocket) - config.getUpgradeWebSocket = () => upgradeWebSocket; const actorTopology = new PartitionTopologyActor(app.config, config); // Register DO with global state diff --git a/packages/platforms/cloudflare-workers/src/handler.ts b/packages/platforms/cloudflare-workers/src/handler.ts index 417ef4470..a7b0f62f8 100644 --- a/packages/platforms/cloudflare-workers/src/handler.ts +++ b/packages/platforms/cloudflare-workers/src/handler.ts @@ -10,6 +10,7 @@ import { PartitionTopologyManager } from "actor-core/topologies/partition"; import { logger } from "./log"; import { CloudflareWorkersManagerDriver } from "./manager_driver"; import { ActorCoreApp } from "actor-core"; +import { upgradeWebSocket } from "./websocket"; /** Cloudflare Workers env */ export interface Bindings { @@ -17,7 +18,10 @@ export interface Bindings { ACTOR_DO: DurableObjectNamespace; } -export function createHandler(app: ActorCoreApp, inputConfig?: InputConfig): { +export function createHandler( + app: ActorCoreApp, + inputConfig?: InputConfig, +): { handler: ExportedHandler; ActorHandler: DurableObjectConstructor; } { @@ -48,34 +52,61 @@ export function createRouter( if (!driverConfig.drivers.manager) driverConfig.drivers.manager = new CloudflareWorkersManagerDriver(); + // Setup WebSockets + if (!driverConfig.getUpgradeWebSocket) + driverConfig.getUpgradeWebSocket = () => upgradeWebSocket; + // Create Durable Object const ActorHandler = createActorDurableObject(app, driverConfig); driverConfig.topology = driverConfig.topology ?? "partition"; if (driverConfig.topology === "partition") { - const managerTopology = new PartitionTopologyManager(app.config, driverConfig); + const managerTopology = new PartitionTopologyManager( + app.config, + driverConfig, + { + onProxyRequest: async (c, actorRequest, actorId): Promise => { + logger().debug("forwarding request to durable object", { + actorId, + method: actorRequest.method, + url: actorRequest.url, + }); - // Force the router to have access to the Cloudflare bindings - const router = managerTopology.router as unknown as Hono<{ - Bindings: Bindings; - }>; + const id = c.env.ACTOR_DO.idFromString(actorId); + const stub = c.env.ACTOR_DO.get(id); + + return await stub.fetch(actorRequest); + }, + onProxyWebSocket: async (c, path, actorId) => { + logger().debug("forwarding websocket to durable object", { + actorId, + path, + }); - // Forward requests to actor - router.all("/actors/:actorId/:path{.+}", (c) => { - const actorId = c.req.param("actorId"); - const subpath = `/${c.req.param("path")}`; - logger().debug("forwarding request", { actorId, subpath }); + // Validate upgrade + const upgradeHeader = c.req.header("Upgrade"); + if (!upgradeHeader || upgradeHeader !== "websocket") { + return new Response("Expected Upgrade: websocket", { + status: 426, + }); + } - const id = c.env.ACTOR_DO.idFromString(actorId); - const stub = c.env.ACTOR_DO.get(id); + // Update path on URL + const newUrl = new URL(`http://actor${path}`); + const actorRequest = new Request(newUrl, c.req.raw); - // Modify the path to remove the prefix - const url = new URL(c.req.url); - url.pathname = subpath; - const actorRequest = new Request(url.toString(), c.req.raw); + const id = c.env.ACTOR_DO.idFromString(actorId); + const stub = c.env.ACTOR_DO.get(id); - return stub.fetch(actorRequest); - }); + return await stub.fetch(actorRequest); + }, + }, + ); + + // Force the router to have access to the Cloudflare bindings + const router = managerTopology.router as unknown as Hono<{ + Bindings: Bindings; + }>; return { router, ActorHandler }; } else if ( diff --git a/packages/platforms/cloudflare-workers/src/manager_driver.ts b/packages/platforms/cloudflare-workers/src/manager_driver.ts index 0de27610c..8f172c93f 100644 --- a/packages/platforms/cloudflare-workers/src/manager_driver.ts +++ b/packages/platforms/cloudflare-workers/src/manager_driver.ts @@ -37,7 +37,6 @@ const KEYS = { export class CloudflareWorkersManagerDriver implements ManagerDriver { async getForId({ c, - baseUrl, actorId, }: GetForIdInput<{ Bindings: Bindings }>): Promise< GetActorOutput | undefined @@ -54,16 +53,19 @@ export class CloudflareWorkersManagerDriver implements ManagerDriver { return undefined; } + // Generate durable ID from actorId for meta + const durableId = c.env.ACTOR_DO.idFromString(actorId); + return { - endpoint: buildActorEndpoint(baseUrl, actorId), + actorId, name: actorData.name, key: actorData.key, + meta: durableId, }; } async getWithKey({ c, - baseUrl, name, key, }: GetWithKeyInput<{ Bindings: Bindings }>): Promise< @@ -99,15 +101,13 @@ export class CloudflareWorkersManagerDriver implements ManagerDriver { name, key, }); - return this.#buildActorOutput(c, baseUrl, actorId); + return this.#buildActorOutput(c, actorId); } async createActor({ c, - baseUrl, name, key, - region, }: CreateActorInput<{ Bindings: Bindings }>): Promise { if (!c) throw new Error("Missing Hono context"); const log = logger(); @@ -136,16 +136,16 @@ export class CloudflareWorkersManagerDriver implements ManagerDriver { await c.env.ACTOR_KV.put(KEYS.ACTOR.keyIndex(name, key), actorId); return { - endpoint: buildActorEndpoint(baseUrl, actorId), + actorId, name, key, + meta: durableId, }; } // Helper method to build actor output from an ID async #buildActorOutput( c: any, - baseUrl: string, actorId: string, ): Promise { const actorData = (await c.env.ACTOR_KV.get(KEYS.ACTOR.metadata(actorId), { @@ -156,14 +156,14 @@ export class CloudflareWorkersManagerDriver implements ManagerDriver { return undefined; } + // Generate durable ID for meta + const durableId = c.env.ACTOR_DO.idFromString(actorId); + return { - endpoint: buildActorEndpoint(baseUrl, actorId), + actorId, name: actorData.name, key: actorData.key, + meta: durableId, }; } -} - -function buildActorEndpoint(baseUrl: string, actorId: string) { - return `${baseUrl}/actors/${actorId}`; -} +} \ No newline at end of file diff --git a/packages/platforms/cloudflare-workers/tests/driver-tests.test.ts b/packages/platforms/cloudflare-workers/tests/driver-tests.test.ts index 93b718933..33b1e452e 100644 --- a/packages/platforms/cloudflare-workers/tests/driver-tests.test.ts +++ b/packages/platforms/cloudflare-workers/tests/driver-tests.test.ts @@ -129,6 +129,7 @@ export { handler as default, ActorHandler }; wranglerProcess.stdout?.on("data", (data) => { const output = data.toString(); + console.log(`wrangler: ${output}`); if (output.includes(`Ready on http://localhost:${port}`)) { if (!isResolved) { isResolved = true; @@ -139,7 +140,7 @@ export { handler as default, ActorHandler }; }); wranglerProcess.stderr?.on("data", (data) => { - console.error(`wrangler error: ${data}`); + console.error(`wrangler: ${data}`); }); wranglerProcess.on("error", (error) => { diff --git a/packages/platforms/rivet/.actorcore/entrypoint-counter.js b/packages/platforms/rivet/.actorcore/entrypoint-counter.js new file mode 100644 index 000000000..f6ace2d0c --- /dev/null +++ b/packages/platforms/rivet/.actorcore/entrypoint-counter.js @@ -0,0 +1,3 @@ +import { createActorHandler } from "@actor-core/rivet"; +import { app } from "../tmp/actor-core-test-fa4426da-4d36-40b3-aa42-7e8f444d32f6/app.ts"; +export default createActorHandler({ app }); \ No newline at end of file diff --git a/packages/platforms/rivet/package.json b/packages/platforms/rivet/package.json index d59a54eb8..a34d561ff 100644 --- a/packages/platforms/rivet/package.json +++ b/packages/platforms/rivet/package.json @@ -33,6 +33,7 @@ "@actor-core/driver-test-suite": "workspace:*", "@rivet-gg/actor-core": "^25.1.0", "@types/deno": "^2.0.0", + "@types/invariant": "^2", "@types/node": "^22.13.1", "actor-core": "workspace:*", "tsup": "^8.4.0", @@ -41,6 +42,7 @@ }, "dependencies": { "hono": "^4.7.0", + "invariant": "^2.2.4", "zod": "^3.24.2" } } diff --git a/packages/platforms/rivet/src/actor_driver.ts b/packages/platforms/rivet/src/actor_driver.ts index 19d02ae55..71409b6c0 100644 --- a/packages/platforms/rivet/src/actor_driver.ts +++ b/packages/platforms/rivet/src/actor_driver.ts @@ -17,8 +17,12 @@ export class RivetActorDriver implements ActorDriver { } async readPersistedData(_actorId: string): Promise { - // Use "state" as the key for persisted data - return await this.#ctx.kv.get(["actor-core", "data"]); + let data = await this.#ctx.kv.get(["actor-core", "data"]); + + // HACK: Modify to be undefined if null. This will be fixed in Actors v2. + if (data === null) data = undefined; + + return data; } async writePersistedData(_actorId: string, data: unknown): Promise { diff --git a/packages/platforms/rivet/src/actor_handler.ts b/packages/platforms/rivet/src/actor_handler.ts index 5ebba7fec..157cace2c 100644 --- a/packages/platforms/rivet/src/actor_handler.ts +++ b/packages/platforms/rivet/src/actor_handler.ts @@ -1,13 +1,14 @@ import { setupLogging } from "actor-core/log"; import type { ActorContext } from "@rivet-gg/actor-core"; -import type { ActorKey } from "actor-core"; import { upgradeWebSocket } from "hono/deno"; import { logger } from "./log"; import type { RivetHandler } from "./util"; +import { deserializeKeyFromTag } from "./util"; import { PartitionTopologyActor } from "actor-core/topologies/partition"; import { ConfigSchema, type InputConfig } from "./config"; import { RivetActorDriver } from "./actor_driver"; import { rivetRequest } from "./rivet_client"; +import invariant from "invariant"; export function createActorHandler(inputConfig: InputConfig): RivetHandler { const driverConfig = ConfigSchema.parse(inputConfig); @@ -135,17 +136,6 @@ export function createActorHandler(inputConfig: InputConfig): RivetHandler { // Helper function to extract key array from Rivet's tag format function extractKeyFromRivetTags(tags: Record): string[] { - const key: string[] = []; - - // Extract key values from tags using the numerical suffix pattern - for (let i = 0; ; i++) { - const tagKey = `key${i}`; - if (tagKey in tags) { - key.push(tags[tagKey]); - } else { - break; - } - } - - return key; -} \ No newline at end of file + invariant(typeof tags.key === "string", "key tag does not exist"); + return deserializeKeyFromTag(tags.key); +} diff --git a/packages/platforms/rivet/src/manager_driver.ts b/packages/platforms/rivet/src/manager_driver.ts index 76efabb96..8bd52aba4 100644 --- a/packages/platforms/rivet/src/manager_driver.ts +++ b/packages/platforms/rivet/src/manager_driver.ts @@ -1,9 +1,24 @@ +import { assertUnreachable } from "actor-core/utils"; +import type { + ManagerDriver, + GetForIdInput, + GetWithKeyInput, + CreateActorInput, + GetActorOutput, +} from "actor-core/driver-helpers"; +import { logger } from "./log"; +import { type RivetClientConfig, rivetRequest } from "./rivet_client"; +import { serializeKeyForTag, deserializeKeyFromTag } from "./util"; export interface ActorState { key: string[]; destroyedAt?: number; } +export interface GetActorMeta { + endpoint: string; +} + export class RivetManagerDriver implements ManagerDriver { #clientConfig: RivetClientConfig; @@ -22,8 +37,8 @@ export class RivetManagerDriver implements ManagerDriver { `/actors/${encodeURIComponent(actorId)}`, ); - // Check if actor exists, is public, and not destroyed - if ((res.actor.tags as Record).access !== "public" || res.actor.destroyedAt) { + // Check if actor exists and not destroyed + if (res.actor.destroyedAt) { return undefined; } @@ -31,11 +46,20 @@ export class RivetManagerDriver implements ManagerDriver { if (!("name" in res.actor.tags)) { throw new Error(`Actor ${res.actor.id} missing 'name' in tags.`); } + if (res.actor.tags.role !== "actor") { + throw new Error(`Actor ${res.actor.id} does not have an actor role.`); + } + if (res.actor.tags.framework !== "actor-core") { + throw new Error(`Actor ${res.actor.id} is not an ActorCore actor.`); + } return { - endpoint: buildActorEndpoint(res.actor), + actorId: res.actor.id, name: res.actor.tags.name, key: this.#extractKeyFromRivetTags(res.actor.tags), + meta: { + endpoint: buildActorEndpoint(res.actor), + } satisfies GetActorMeta, }; } catch (error) { // Handle not found or other errors @@ -49,7 +73,7 @@ export class RivetManagerDriver implements ManagerDriver { }: GetWithKeyInput): Promise { // Convert key array to Rivet's tag format const rivetTags = this.#convertKeyToRivetTags(name, key); - + // Query actors with matching tags const { actors } = await rivetRequest( this.#clientConfig, @@ -59,11 +83,6 @@ export class RivetManagerDriver implements ManagerDriver { // Filter actors to ensure they're valid const validActors = actors.filter((a: RivetActor) => { - // Verify actor is public - if ((a.tags as Record).access !== "public") { - return false; - } - // Verify all ports have hostname and port for (const portName in a.network.ports) { const port = a.network.ports[portName]; @@ -77,9 +96,10 @@ export class RivetManagerDriver implements ManagerDriver { } // For consistent results, sort by ID if multiple actors match - const actor = validActors.length > 1 - ? validActors.sort((a, b) => a.id.localeCompare(b.id))[0] - : validActors[0]; + const actor = + validActors.length > 1 + ? validActors.sort((a, b) => a.id.localeCompare(b.id))[0] + : validActors[0]; // Ensure actor has required tags if (!("name" in actor.tags)) { @@ -87,9 +107,12 @@ export class RivetManagerDriver implements ManagerDriver { } return { - endpoint: buildActorEndpoint(actor), + actorId: actor.id, name: actor.tags.name, key: this.#extractKeyFromRivetTags(actor.tags), + meta: { + endpoint: buildActorEndpoint(actor), + } satisfies GetActorMeta, }; } @@ -98,21 +121,23 @@ export class RivetManagerDriver implements ManagerDriver { key, region, }: CreateActorInput): Promise { - // Find a matching build that's public and current - const build = await this.#getBuildWithTags({ - name, - current: "true", - access: "public", - }); - - if (!build) { - throw new Error("Build not found with tags or is private"); + // Create the actor request + let actorLogLevel: string | undefined = undefined; + if (typeof Deno !== "undefined") { + actorLogLevel = Deno.env.get("_ACTOR_LOG_LEVEL"); + } else if (typeof process !== "undefined") { + // Do this after Deno since `process` is sometimes polyfilled + actorLogLevel = process.env._ACTOR_LOG_LEVEL; } - // Create the actor request const createRequest = { tags: this.#convertKeyToRivetTags(name, key), - build: build.id, + build_tags: { + name, + role: "actor", + framework: "actor-core", + current: "true", + }, region, network: { ports: { @@ -122,22 +147,33 @@ export class RivetManagerDriver implements ManagerDriver { }, }, }, + runtime: { + environment: actorLogLevel + ? { + _LOG_LEVEL: actorLogLevel, + } + : {}, + }, + lifecycle: { + durable: true, + }, }; logger().info("creating actor", { ...createRequest }); - + // Create the actor - const { actor } = await rivetRequest( - this.#clientConfig, - "POST", - "/actors", - createRequest, - ); + const { actor } = await rivetRequest< + typeof createRequest, + { actor: RivetActor } + >(this.#clientConfig, "POST", "/actors", createRequest); return { - endpoint: buildActorEndpoint(actor), + actorId: actor.id, name, key: this.#extractKeyFromRivetTags(actor.tags), + meta: { + endpoint: buildActorEndpoint(actor), + } satisfies GetActorMeta, }; } @@ -145,11 +181,12 @@ export class RivetManagerDriver implements ManagerDriver { #convertKeyToRivetTags(name: string, key: string[]): Record { return { name, - access: "public", key: serializeKeyForTag(key), + role: "actor", + framework: "actor-core", }; } - + // Helper method to extract key array from Rivet's tag-based format #extractKeyFromRivetTags(tags: Record): string[] { return deserializeKeyFromTag(tags.key); @@ -165,17 +202,14 @@ export class RivetManagerDriver implements ManagerDriver { `/builds?tags_json=${encodeURIComponent(JSON.stringify(buildTags))}`, ); - // Filter to public builds - const publicBuilds = builds.filter(b => b.tags.access === "public"); - - if (publicBuilds.length === 0) { + if (builds.length === 0) { return undefined; } - + // For consistent results, sort by ID if multiple builds match - return publicBuilds.length > 1 - ? publicBuilds.sort((a, b) => a.id.localeCompare(b.id))[0] - : publicBuilds[0]; + return builds.length > 1 + ? builds.sort((a, b) => a.id.localeCompare(b.id))[0] + : builds[0]; } } @@ -183,7 +217,7 @@ function buildActorEndpoint(actor: RivetActor): string { // Fetch port const httpPort = actor.network.ports.http; if (!httpPort) throw new Error("missing http port"); - const hostname = httpPort.hostname; + let hostname = httpPort.hostname; if (!hostname) throw new Error("missing hostname"); const port = httpPort.port; if (!port) throw new Error("missing port"); @@ -206,23 +240,13 @@ function buildActorEndpoint(actor: RivetActor): string { const path = httpPort.path ?? ""; + // HACK: Fix hostname inside of Docker Compose + if (hostname === "127.0.0.1") hostname = "rivet-guard"; + return `${isTls ? "https" : "http"}://${hostname}:${port}${path}`; } -import { assertUnreachable } from "actor-core/utils"; -import type { ActorKey } from "actor-core"; -import { - ManagerDriver, - GetForIdInput, - GetWithKeyInput, - CreateActorInput, - GetActorOutput, -} from "actor-core/driver-helpers"; -import { logger } from "./log"; -import { type RivetClientConfig, rivetRequest } from "./rivet_client"; -import { serializeKeyForTag, deserializeKeyFromTag } from "./util"; - // biome-ignore lint/suspicious/noExplicitAny: will add api types later type RivetActor = any; // biome-ignore lint/suspicious/noExplicitAny: will add api types later -type RivetBuild = any; \ No newline at end of file +type RivetBuild = any; diff --git a/packages/platforms/rivet/src/manager_handler.ts b/packages/platforms/rivet/src/manager_handler.ts index 984c3912e..0243c33b9 100644 --- a/packages/platforms/rivet/src/manager_handler.ts +++ b/packages/platforms/rivet/src/manager_handler.ts @@ -1,11 +1,15 @@ import { setupLogging } from "actor-core/log"; import type { ActorContext } from "@rivet-gg/actor-core"; import { logger } from "./log"; -import { RivetManagerDriver } from "./manager_driver"; +import { GetActorMeta, RivetManagerDriver } from "./manager_driver"; import type { RivetClientConfig } from "./rivet_client"; import type { RivetHandler } from "./util"; +import { createWebSocketProxy } from "./ws_proxy"; import { PartitionTopologyManager } from "actor-core/topologies/partition"; import { type InputConfig, ConfigSchema } from "./config"; +import { proxy } from "hono/proxy"; +import invariant from "invariant"; +import { upgradeWebSocket } from "hono/deno"; export function createManagerHandler(inputConfig: InputConfig): RivetHandler { const driverConfig = ConfigSchema.parse(inputConfig); @@ -69,11 +73,46 @@ export function createManagerHandler(inputConfig: InputConfig): RivetHandler { driverConfig.drivers.manager = new RivetManagerDriver(clientConfig); } + // Setup WebSocket upgrader + if (!driverConfig.getUpgradeWebSocket) { + driverConfig.getUpgradeWebSocket = () => upgradeWebSocket; + } + // Create manager topology driverConfig.topology = driverConfig.topology ?? "partition"; const managerTopology = new PartitionTopologyManager( driverConfig.app.config, driverConfig, + { + onProxyRequest: async (c, actorRequest, _actorId, metaRaw) => { + invariant(metaRaw, "meta not provided"); + const meta = metaRaw as GetActorMeta; + + const parsedRequestUrl = new URL(actorRequest.url); + const actorUrl = `${meta.endpoint}${parsedRequestUrl.pathname}${parsedRequestUrl.search}`; + + logger().debug("proxying request to rivet actor", { + method: actorRequest.method, + url: actorUrl, + }); + + const proxyRequest = new Request(actorUrl, actorRequest); + return await proxy(proxyRequest); + }, + onProxyWebSocket: async (c, path, actorId, metaRaw) => { + invariant(metaRaw, "meta not provided"); + const meta = metaRaw as GetActorMeta; + + const actorUrl = `${meta.endpoint}${path}`; + + logger().debug("proxying websocket to rivet actor", { + url: actorUrl, + }); + + // TODO: fix as any + return createWebSocketProxy(c, actorUrl) as any; + }, + }, ); const app = managerTopology.router; diff --git a/packages/platforms/rivet/src/util.ts b/packages/platforms/rivet/src/util.ts index 6f3cd08da..5dda225c6 100644 --- a/packages/platforms/rivet/src/util.ts +++ b/packages/platforms/rivet/src/util.ts @@ -1,4 +1,5 @@ import type { ActorContext } from "@rivet-gg/actor-core"; +import invariant from "invariant"; export interface RivetHandler { start(ctx: ActorContext): Promise; @@ -10,7 +11,7 @@ export const KEY_SEPARATOR = ","; /** * Serializes an array of key strings into a single string for storage in a Rivet tag - * + * * @param key Array of key strings to serialize * @returns A single string containing the serialized key */ @@ -19,69 +20,64 @@ export function serializeKeyForTag(key: string[]): string { if (key.length === 0) { return EMPTY_KEY; } - + // Escape each key part to handle the separator and the empty key marker - const escapedParts = key.map(part => { + const escapedParts = key.map((part) => { // First check if it matches our empty key marker if (part === EMPTY_KEY) { return `\\${EMPTY_KEY}`; } - + // Escape backslashes first, then commas let escaped = part.replace(/\\/g, "\\\\"); escaped = escaped.replace(/,/g, "\\,"); return escaped; }); - + return escapedParts.join(KEY_SEPARATOR); } /** * Deserializes a key string from a Rivet tag back into an array of key strings - * + * * @param keyString The serialized key string from a tag * @returns Array of key strings */ export function deserializeKeyFromTag(keyString: string): string[] { - // Handle empty values - if (!keyString) { - return []; - } - // Check for special empty key marker if (keyString === EMPTY_KEY) { return []; } - + // Split by unescaped commas and unescape the escaped characters const parts: string[] = []; - let currentPart = ''; + let currentPart = ""; let escaping = false; - + for (let i = 0; i < keyString.length; i++) { const char = keyString[i]; - + if (escaping) { // This is an escaped character, add it directly currentPart += char; escaping = false; - } else if (char === '\\') { + } else if (char === "\\") { // Start of an escape sequence escaping = true; } else if (char === KEY_SEPARATOR) { // This is a separator parts.push(currentPart); - currentPart = ''; + currentPart = ""; } else { // Regular character currentPart += char; } } - + // Add the last part if it exists if (currentPart || parts.length > 0) { parts.push(currentPart); } - + return parts; } diff --git a/packages/platforms/rivet/src/ws_proxy.ts b/packages/platforms/rivet/src/ws_proxy.ts new file mode 100644 index 000000000..8bdbeb157 --- /dev/null +++ b/packages/platforms/rivet/src/ws_proxy.ts @@ -0,0 +1,119 @@ +import { upgradeWebSocket } from "hono/deno"; +import { WSContext } from "hono/ws"; +import { Context } from "hono"; +import { logger } from "./log"; +import invariant from "invariant"; + +/** + * Creates a WebSocket proxy to forward connections to a target endpoint + * + * @param c Hono context + * @param targetUrl Target WebSocket URL to proxy to + * @returns Response with upgraded WebSocket + */ +export function createWebSocketProxy(c: Context, targetUrl: string) { + return upgradeWebSocket((c) => { + let targetWs: WebSocket | undefined = undefined; + const messageQueue: any[] = []; + + return { + onOpen: (_evt: any, wsContext: WSContext) => { + // Create target WebSocket connection + targetWs = new WebSocket(targetUrl); + + // Set up target websocket handlers + targetWs.onopen = () => { + invariant(targetWs, "targetWs does not exist"); + + // Process any queued messages once connected + if (messageQueue.length > 0) { + for (const data of messageQueue) { + targetWs.send(data); + } + // Clear the queue after sending + messageQueue.length = 0; + } + }; + + targetWs.onmessage = (event) => { + wsContext.send(event.data); + }; + + targetWs.onclose = (event) => { + logger().debug("target websocket closed", { + code: event.code, + reason: event.reason, + }); + + if (wsContext.readyState === WebSocket.OPEN) { + // Forward the close code and reason from target to client + wsContext.close(event.code, event.reason); + } + }; + + targetWs.onerror = (event) => { + logger().warn("target websocket error"); + + if (wsContext.readyState === WebSocket.OPEN) { + // Use standard WebSocket error code: 1006 - Abnormal Closure + // The connection was closed abnormally, e.g., without sending or receiving a Close control frame + wsContext.close(1006, "Error in target connection"); + } + }; + }, + + // Handle messages from client to target + onMessage: (evt: { data: any }, wsContext: WSContext) => { + invariant(targetWs, "targetWs not defined"); + + // If the WebSocket is OPEN, send immediately + if (targetWs.readyState === WebSocket.OPEN) { + targetWs.send(evt.data); + } + // If the WebSocket is CONNECTING, queue the message + else if (targetWs.readyState === WebSocket.CONNECTING) { + messageQueue.push(evt.data); + } + // Otherwise (CLOSING or CLOSED), ignore the message + }, + + // Handle client WebSocket close + onClose: (evt: CloseEvent, wsContext: WSContext) => { + invariant(targetWs, "targetWs not defined"); + + logger().debug("client websocket closed", { + code: evt.code, + reason: evt.reason, + }); + + // Close target if it's either CONNECTING or OPEN + // + // We're only allowed to send code 1000 from the client + if ( + targetWs.readyState === WebSocket.CONNECTING || + targetWs.readyState === WebSocket.OPEN + ) { + // We can only send code 1000 from the client + targetWs.close(1000, evt.reason || "Client closed connection"); + } + }, + + // Handle client WebSocket errors + onError: (_evt: Event, wsContext: WSContext) => { + invariant(targetWs, "targetWs not defined"); + + logger().warn("websocket proxy received error from client"); + + // Close target with specific error code for proxy errors + // + // We're only allowed to send code 1000 from the client + if ( + targetWs.readyState === WebSocket.CONNECTING || + targetWs.readyState === WebSocket.OPEN + ) { + targetWs.close(1000, "Error in client connection"); + } + }, + }; + })(c, async () => {}); +} diff --git a/packages/platforms/rivet/tests/deployment.test.ts b/packages/platforms/rivet/tests/deployment.test.ts new file mode 100644 index 000000000..a49c84b76 --- /dev/null +++ b/packages/platforms/rivet/tests/deployment.test.ts @@ -0,0 +1,72 @@ +import { describe, test, expect, beforeAll, afterAll } from "vitest"; +import fs from "node:fs/promises"; +import path from "node:path"; +import { fileURLToPath } from "node:url"; +import { deployToRivet } from "./rivet-deploy"; + +const __dirname = path.dirname(fileURLToPath(import.meta.url)); + +// Simple counter actor definition to deploy +const COUNTER_ACTOR = ` +import { actor, setup } from "actor-core"; + +const counter = actor({ + state: { count: 0 }, + actions: { + increment: (c, amount) => { + c.state.count += amount; + c.broadcast("newCount", c.state.count); + return c.state.count; + }, + getCount: (c) => { + return c.state.count; + }, + }, +}); + +export const app = setup({ + actors: { counter }, +}); + +export type App = typeof app; +`; + +describe.skip("Rivet deployment tests", () => { + let tmpDir: string; + let cleanup: () => Promise; + + // Set up test environment before all tests + beforeAll(async () => { + // Create a temporary path for the counter actor + const tempFilePath = path.join( + __dirname, + "../../../..", + "target", + "temp-counter-app.ts", + ); + + // Ensure target directory exists + await fs.mkdir(path.dirname(tempFilePath), { recursive: true }); + + // Write the counter actor file + await fs.writeFile(tempFilePath, COUNTER_ACTOR); + + // Run the deployment + const result = await deployToRivet(tempFilePath); + tmpDir = result.tmpDir; + cleanup = result.cleanup; + }); + + // Clean up after all tests + afterAll(async () => { + if (cleanup) { + await cleanup(); + } + }); + + test("deploys counter actor to Rivet and retrieves endpoint", async () => { + // This test just verifies that the deployment was successful + // The actual deployment work is done in the beforeAll hook + expect(tmpDir).toBeTruthy(); + }, 180000); // Increased timeout to 3 minutes for the full deployment +}); diff --git a/packages/platforms/rivet/tests/driver-tests.test.ts b/packages/platforms/rivet/tests/driver-tests.test.ts new file mode 100644 index 000000000..f35aef393 --- /dev/null +++ b/packages/platforms/rivet/tests/driver-tests.test.ts @@ -0,0 +1,73 @@ +import { runDriverTests } from "@actor-core/driver-test-suite"; +import { deployToRivet, RIVET_CLIENT_CONFIG } from "./rivet-deploy"; +import { type RivetClientConfig, rivetRequest } from "../src/rivet_client"; +import invariant from "invariant"; + +let alreadyDeployedManager = false; +const alreadyDeployedApps = new Set(); +let managerEndpoint: string | undefined = undefined; + +const driverTestConfig = { + useRealTimers: true, + HACK_skipCleanupNet: true, + async start(appPath: string) { + console.log("Starting test", { + alreadyDeployedManager, + alreadyDeployedApps, + managerEndpoint, + }); + + // Cleanup actors from previous tests + await deleteAllActors(RIVET_CLIENT_CONFIG, !alreadyDeployedManager); + + if (!alreadyDeployedApps.has(appPath)) { + console.log(`Starting Rivet driver tests with app: ${appPath}`); + + // Deploy to Rivet + const result = await deployToRivet(appPath, !alreadyDeployedManager); + console.log( + `Deployed to Rivet at ${result.endpoint} (manager: ${!alreadyDeployedManager})`, + ); + + // Save as deployed + managerEndpoint = result.endpoint; + alreadyDeployedApps.add(appPath); + alreadyDeployedManager = true; + } else { + console.log(`Already deployed: ${appPath}`); + } + + invariant(managerEndpoint, "missing manager endpoint"); + return { + endpoint: managerEndpoint, + async cleanup() { + await deleteAllActors(RIVET_CLIENT_CONFIG, false); + }, + }; + }, +}; + +async function deleteAllActors( + clientConfig: RivetClientConfig, + deleteManager: boolean, +) { + console.log("Listing actors to delete"); + const { actors } = await rivetRequest< + void, + { actors: { id: string; tags: Record }[] } + >(clientConfig, "GET", "/actors"); + + for (const actor of actors) { + if (!deleteManager && actor.tags.name === "manager") continue; + + console.log(`Deleting actor ${actor.id} (${JSON.stringify(actor.tags)})`); + await rivetRequest( + clientConfig, + "DELETE", + `/actors/${actor.id}`, + ); + } +} + +// Run the driver tests with our config +runDriverTests(driverTestConfig); diff --git a/packages/platforms/rivet/tests/rivet-deploy.ts b/packages/platforms/rivet/tests/rivet-deploy.ts new file mode 100644 index 000000000..77f7e62b4 --- /dev/null +++ b/packages/platforms/rivet/tests/rivet-deploy.ts @@ -0,0 +1,201 @@ +import fs from "node:fs/promises"; +import path from "node:path"; +import os from "node:os"; +import { spawn, exec } from "node:child_process"; +import crypto from "node:crypto"; +import { promisify } from "node:util"; +import invariant from "invariant"; +import type { RivetClientConfig } from "../src/rivet_client"; + +const execPromise = promisify(exec); +//const RIVET_API_ENDPOINT = "https://api.rivet.gg"; +const RIVET_API_ENDPOINT = "http://localhost:8080"; +const ENV = "default"; + +const rivetCloudToken = process.env.RIVET_CLOUD_TOKEN; +invariant(rivetCloudToken, "missing RIVET_CLOUD_TOKEN"); +export const RIVET_CLIENT_CONFIG: RivetClientConfig = { + endpoint: RIVET_API_ENDPOINT, + token: rivetCloudToken, +}; + +/** + * Deploy an app to Rivet and return the endpoint + */ +export async function deployToRivet(appPath: string, deployManager: boolean) { + console.log("=== START deployToRivet ==="); + console.log(`Deploying app from path: ${appPath}`); + + // Create a temporary directory for the test + const uuid = crypto.randomUUID(); + const appName = `actor-core-test-${uuid}`; + const tmpDir = path.join(os.tmpdir(), appName); + console.log(`Creating temp directory: ${tmpDir}`); + await fs.mkdir(tmpDir, { recursive: true }); + + // Create package.json with workspace dependencies + const packageJson = { + name: "actor-core-test", + private: true, + version: "1.0.0", + type: "module", + scripts: { + deploy: "actor-core deploy rivet app.ts --env prod", + }, + dependencies: { + "@actor-core/rivet": "workspace:*", + "@actor-core/cli": "workspace:*", + "actor-core": "workspace:*", + }, + packageManager: + "yarn@4.7.0+sha512.5a0afa1d4c1d844b3447ee3319633797bcd6385d9a44be07993ae52ff4facabccafb4af5dcd1c2f9a94ac113e5e9ff56f6130431905884414229e284e37bb7c9", + }; + console.log("Writing package.json"); + await fs.writeFile( + path.join(tmpDir, "package.json"), + JSON.stringify(packageJson, null, 2), + ); + + // Disable PnP + const yarnPnp = "nodeLinker: node-modules"; + console.log("Configuring Yarn nodeLinker"); + await fs.writeFile(path.join(tmpDir, ".yarnrc.yml"), yarnPnp); + + // Get the current workspace root path and link the workspace + const workspaceRoot = path.resolve(__dirname, "../../../.."); + console.log(`Linking workspace from: ${workspaceRoot}`); + + try { + console.log("Running yarn link command..."); + const linkOutput = await execPromise(`yarn link -A ${workspaceRoot}`, { + cwd: tmpDir, + }); + console.log("Yarn link output:", linkOutput.stdout); + } catch (error) { + console.error("Error linking workspace:", error); + throw error; + } + + // Install deps + console.log("Installing dependencies..."); + try { + const installOutput = await execPromise("yarn install", { cwd: tmpDir }); + console.log("Install output:", installOutput.stdout); + } catch (error) { + console.error("Error installing dependencies:", error); + throw error; + } + + // Create app.ts file based on the app path + const appTsContent = `export { app } from "${appPath.replace(/\.ts$/, "")}"`; + console.log(`Creating app.ts with content: ${appTsContent}`); + await fs.writeFile(path.join(tmpDir, "app.ts"), appTsContent); + + // Build and deploy to Rivet using actor-core CLI + console.log("Building and deploying to Rivet..."); + + if (!process.env._RIVET_SKIP_DEPLOY) { + // Deploy using the actor-core CLI + console.log("Spawning @actor-core/cli deploy command..."); + const deployProcess = spawn( + "npx", + [ + "@actor-core/cli", + "deploy", + "rivet", + "app.ts", + "--env", + ENV, + ...(deployManager ? [] : ["--skip-manager"]), + ], + { + cwd: tmpDir, + env: { + ...process.env, + RIVET_ENDPOINT: RIVET_API_ENDPOINT, + RIVET_CLOUD_TOKEN: rivetCloudToken, + _RIVET_MANAGER_LOG_LEVEL: "DEBUG", + _RIVET_ACTOR_LOG_LEVEL: "DEBUG", + //CI: "1", + }, + stdio: "inherit", // Stream output directly to console + }, + ); + + console.log("Waiting for deploy process to complete..."); + await new Promise((resolve, reject) => { + deployProcess.on("exit", (code) => { + console.log(`Deploy process exited with code: ${code}`); + if (code === 0) { + resolve(undefined); + } else { + reject(new Error(`Deploy process exited with code ${code}`)); + } + }); + deployProcess.on("error", (err) => { + console.error("Deploy process error:", err); + reject(err); + }); + }); + console.log("Deploy process completed successfully"); + } + + // Get the endpoint URL + console.log("Getting Rivet endpoint..."); + + // Get the endpoint using the CLI endpoint command + console.log("Spawning @actor-core/cli endpoint command..."); + const endpointProcess = spawn( + "npx", + ["@actor-core/cli", "endpoint", "rivet", "--env", ENV, "--plain"], + { + cwd: tmpDir, + env: { + ...process.env, + RIVET_ENDPOINT: RIVET_API_ENDPOINT, + RIVET_CLOUD_TOKEN: rivetCloudToken, + CI: "1", + }, + stdio: ["inherit", "pipe", "inherit"], // Capture stdout + }, + ); + + // Capture the endpoint + let endpointOutput = ""; + endpointProcess.stdout.on("data", (data) => { + const output = data.toString(); + console.log(`Endpoint output: ${output}`); + endpointOutput += output; + }); + + // Wait for endpoint command to complete + console.log("Waiting for endpoint process to complete..."); + await new Promise((resolve, reject) => { + endpointProcess.on("exit", (code) => { + console.log(`Endpoint process exited with code: ${code}`); + if (code === 0) { + resolve(undefined); + } else { + reject(new Error(`Endpoint command failed with code ${code}`)); + } + }); + endpointProcess.on("error", (err) => { + console.error("Endpoint process error:", err); + reject(err); + }); + }); + + invariant(endpointOutput, "endpoint command returned empty output"); + console.log(`Raw endpoint output: ${endpointOutput}`); + + // Look for something that looks like a URL in the string + const lines = endpointOutput.trim().split("\n"); + const endpoint = lines[lines.length - 1]; + invariant(endpoint, "endpoint not found"); + + console.log("=== END deployToRivet ==="); + + return { + endpoint, + }; +} diff --git a/packages/platforms/rivet/turbo.json b/packages/platforms/rivet/turbo.json index 95960709b..da2495302 100644 --- a/packages/platforms/rivet/turbo.json +++ b/packages/platforms/rivet/turbo.json @@ -1,4 +1,10 @@ { "$schema": "https://turbo.build/schema.json", - "extends": ["//"] + "extends": ["//"], + "tasks": { + "test": { + "dependsOn": ["^build", "check-types", "build", "@actor-core/cli#build"], + "env": ["RIVET_API_ENDPOINT", "RIVET_CLOUD_TOKEN", "_RIVET_SKIP_DEPLOY"] + } + } } diff --git a/packages/platforms/rivet/vitest.config.ts b/packages/platforms/rivet/vitest.config.ts index c7da6b38e..9dbd4e8c9 100644 --- a/packages/platforms/rivet/vitest.config.ts +++ b/packages/platforms/rivet/vitest.config.ts @@ -1,8 +1,9 @@ -import { defineConfig } from 'vitest/config'; +import { defineConfig } from "vitest/config"; export default defineConfig({ - test: { - globals: true, - environment: 'node', - }, -}); \ No newline at end of file + test: { + globals: true, + environment: "node", + testTimeout: 60_000, + }, +}); diff --git a/turbo.json b/turbo.json index 9e9446a6c..7d236df3c 100644 --- a/turbo.json +++ b/turbo.json @@ -1,5 +1,4 @@ -{ - "$schema": "https://turbo.build/schema.json", +{ "$schema": "https://turbo.build/schema.json", "tasks": { "//#fmt": { "cache": false diff --git a/vitest.base.ts b/vitest.base.ts index 5c50b5c0a..2fa6f5370 100644 --- a/vitest.base.ts +++ b/vitest.base.ts @@ -6,10 +6,10 @@ export default { sequence: { concurrent: true, }, - // Increase timeout - testTimeout: 5_000, + // Increase timeout for proxy tests + testTimeout: 15_000, env: { - // Enable loggin + // Enable logging _LOG_LEVEL: "DEBUG" } }, diff --git a/yarn.lock b/yarn.lock index 13971b43d..8cef3daf2 100644 --- a/yarn.lock +++ b/yarn.lock @@ -35,6 +35,7 @@ __metadata: "@sentry/esbuild-plugin": "npm:^3.2.0" "@sentry/node": "npm:^9.3.0" "@sentry/profiling-node": "npm:^9.3.0" + "@types/invariant": "npm:^2" "@types/micromatch": "npm:^4" "@types/react": "npm:^18.3" "@types/semver": "npm:^7.5.8" @@ -49,6 +50,7 @@ __metadata: ink-gradient: "npm:^3.0.0" ink-link: "npm:^4.1.0" ink-spinner: "npm:^5.0.0" + invariant: "npm:^2.2.4" micromatch: "npm:^4.0.8" open: "npm:^10.1.0" pkg-types: "npm:^2.0.0" @@ -216,10 +218,13 @@ __metadata: dependencies: "@actor-core/driver-test-suite": "workspace:*" "@rivet-gg/actor-core": "npm:^25.1.0" + "@rivet-gg/api": "npm:^25.4.2" "@types/deno": "npm:^2.0.0" + "@types/invariant": "npm:^2" "@types/node": "npm:^22.13.1" actor-core: "workspace:*" hono: "npm:^4.7.0" + invariant: "npm:^2.2.4" tsup: "npm:^8.4.0" typescript: "npm:^5.5.2" vitest: "npm:^3.1.1" @@ -2374,6 +2379,20 @@ __metadata: languageName: node linkType: hard +"@rivet-gg/api@npm:^25.4.2": + version: 25.4.2 + resolution: "@rivet-gg/api@npm:25.4.2" + dependencies: + form-data: "npm:^4.0.0" + js-base64: "npm:^3.7.5" + node-fetch: "npm:2" + qs: "npm:^6.11.2" + readable-stream: "npm:^4.5.2" + url-join: "npm:^5.0.0" + checksum: 10c0/eb6a25b1468b9cd8f9b548fa7cdec948d8bcc21bc1274b06507b1b519cbba739cc828974a0917ebee9ab18c92ba7fe228d8ac596b3e71c5efaf4f4f8ed12c8f1 + languageName: node + linkType: hard + "@rollup/rollup-android-arm-eabi@npm:4.39.0": version: 4.39.0 resolution: "@rollup/rollup-android-arm-eabi@npm:4.39.0"