diff --git a/.changeset/curvy-dogs-share.md b/.changeset/curvy-dogs-share.md new file mode 100644 index 0000000000..a0071042aa --- /dev/null +++ b/.changeset/curvy-dogs-share.md @@ -0,0 +1,5 @@ +--- +"@trigger.dev/sdk": patch +--- + +When you create a Waitpoint token using `wait.createToken()` you get a URL back that can be used to complete it by making an HTTP POST request. diff --git a/apps/webapp/app/components/runs/v3/RunIcon.tsx b/apps/webapp/app/components/runs/v3/RunIcon.tsx index 8a1924b3ef..fd277997af 100644 --- a/apps/webapp/app/components/runs/v3/RunIcon.tsx +++ b/apps/webapp/app/components/runs/v3/RunIcon.tsx @@ -19,6 +19,7 @@ import { FunctionIcon } from "~/assets/icons/FunctionIcon"; import { TriggerIcon } from "~/assets/icons/TriggerIcon"; import { PythonLogoIcon } from "~/assets/icons/PythonLogoIcon"; import { TraceIcon } from "~/assets/icons/TraceIcon"; +import { WaitpointTokenIcon } from "~/assets/icons/WaitpointTokenIcon"; type TaskIconProps = { name: string | undefined; @@ -75,6 +76,10 @@ export function RunIcon({ name, className, spanName }: TaskIconProps) { return ; case "python": return ; + case "wait-token": + return ; + case "function": + return ; //log levels case "debug": case "log": diff --git a/apps/webapp/app/components/runs/v3/WaitpointDetails.tsx b/apps/webapp/app/components/runs/v3/WaitpointDetails.tsx index 95331c6e37..b5842b5ec0 100644 --- a/apps/webapp/app/components/runs/v3/WaitpointDetails.tsx +++ b/apps/webapp/app/components/runs/v3/WaitpointDetails.tsx @@ -11,6 +11,7 @@ import { v3WaitpointTokenPath, v3WaitpointTokensPath } from "~/utils/pathBuilder import { PacketDisplay } from "./PacketDisplay"; import { WaitpointStatusCombo } from "./WaitpointStatus"; import { RunTag } from "./RunTag"; +import { ClipboardField } from "~/components/primitives/ClipboardField"; export function WaitpointDetailTable({ waitpoint, @@ -50,6 +51,14 @@ export function WaitpointDetailTable({ )} + {waitpoint.type === "MANUAL" && ( + + Callback URL + + + + + )} Idempotency key diff --git a/apps/webapp/app/presenters/v3/ApiWaitpointTokenListPresenter.server.ts b/apps/webapp/app/presenters/v3/ApiWaitpointListPresenter.server.ts similarity index 82% rename from apps/webapp/app/presenters/v3/ApiWaitpointTokenListPresenter.server.ts rename to apps/webapp/app/presenters/v3/ApiWaitpointListPresenter.server.ts index b4a181dfb7..6390f637f3 100644 --- a/apps/webapp/app/presenters/v3/ApiWaitpointTokenListPresenter.server.ts +++ b/apps/webapp/app/presenters/v3/ApiWaitpointListPresenter.server.ts @@ -1,16 +1,12 @@ -import { RuntimeEnvironmentType, WaitpointTokenStatus } from "@trigger.dev/core/v3"; +import { type RuntimeEnvironmentType, WaitpointTokenStatus } from "@trigger.dev/core/v3"; +import { type RunEngineVersion, type WaitpointResolver } from "@trigger.dev/database"; import { z } from "zod"; -import { BasePresenter } from "./basePresenter.server"; import { CoercedDate } from "~/utils/zod"; -import { AuthenticatedEnvironment } from "@internal/run-engine"; -import { - WaitpointTokenListOptions, - WaitpointTokenListPresenter, -} from "./WaitpointTokenListPresenter.server"; import { ServiceValidationError } from "~/v3/services/baseService.server"; -import { RunEngineVersion } from "@trigger.dev/database"; +import { BasePresenter } from "./basePresenter.server"; +import { type WaitpointListOptions, WaitpointListPresenter } from "./WaitpointListPresenter.server"; -export const ApiWaitpointTokenListSearchParams = z.object({ +export const ApiWaitpointListSearchParams = z.object({ "page[size]": z.coerce.number().int().positive().min(1).max(100).optional(), "page[after]": z.string().optional(), "page[before]": z.string().optional(), @@ -61,9 +57,9 @@ export const ApiWaitpointTokenListSearchParams = z.object({ "filter[createdAt][to]": CoercedDate, }); -type ApiWaitpointTokenListSearchParams = z.infer; +type ApiWaitpointListSearchParams = z.infer; -export class ApiWaitpointTokenListPresenter extends BasePresenter { +export class ApiWaitpointListPresenter extends BasePresenter { public async call( environment: { id: string; @@ -72,11 +68,12 @@ export class ApiWaitpointTokenListPresenter extends BasePresenter { id: string; engine: RunEngineVersion; }; + apiKey: string; }, - searchParams: ApiWaitpointTokenListSearchParams + searchParams: ApiWaitpointListSearchParams ) { return this.trace("call", async (span) => { - const options: WaitpointTokenListOptions = { + const options: WaitpointListOptions = { environment, }; @@ -118,7 +115,7 @@ export class ApiWaitpointTokenListPresenter extends BasePresenter { options.to = searchParams["filter[createdAt][to]"].getTime(); } - const presenter = new WaitpointTokenListPresenter(); + const presenter = new WaitpointListPresenter(); const result = await presenter.call(options); if (!result.success) { diff --git a/apps/webapp/app/presenters/v3/ApiWaitpointPresenter.server.ts b/apps/webapp/app/presenters/v3/ApiWaitpointPresenter.server.ts index b443568c14..1cec530cf0 100644 --- a/apps/webapp/app/presenters/v3/ApiWaitpointPresenter.server.ts +++ b/apps/webapp/app/presenters/v3/ApiWaitpointPresenter.server.ts @@ -2,8 +2,8 @@ import { logger, type RuntimeEnvironmentType } from "@trigger.dev/core/v3"; import { type RunEngineVersion } from "@trigger.dev/database"; import { ServiceValidationError } from "~/v3/services/baseService.server"; import { BasePresenter } from "./basePresenter.server"; -import { WaitpointPresenter } from "./WaitpointPresenter.server"; -import { waitpointStatusToApiStatus } from "./WaitpointTokenListPresenter.server"; +import { waitpointStatusToApiStatus } from "./WaitpointListPresenter.server"; +import { generateHttpCallbackUrl } from "~/services/httpCallback.server"; export class ApiWaitpointPresenter extends BasePresenter { public async call( @@ -14,6 +14,7 @@ export class ApiWaitpointPresenter extends BasePresenter { id: string; engine: RunEngineVersion; }; + apiKey: string; }, waitpointId: string ) { @@ -24,6 +25,7 @@ export class ApiWaitpointPresenter extends BasePresenter { environmentId: environment.id, }, select: { + id: true, friendlyId: true, type: true, status: true, @@ -62,6 +64,7 @@ export class ApiWaitpointPresenter extends BasePresenter { return { id: waitpoint.friendlyId, type: waitpoint.type, + url: generateHttpCallbackUrl(waitpoint.id, environment.apiKey), status: waitpointStatusToApiStatus(waitpoint.status, waitpoint.outputIsError), idempotencyKey: waitpoint.idempotencyKey, userProvidedIdempotencyKey: waitpoint.userProvidedIdempotencyKey, diff --git a/apps/webapp/app/presenters/v3/WaitpointTokenListPresenter.server.ts b/apps/webapp/app/presenters/v3/WaitpointListPresenter.server.ts similarity index 95% rename from apps/webapp/app/presenters/v3/WaitpointTokenListPresenter.server.ts rename to apps/webapp/app/presenters/v3/WaitpointListPresenter.server.ts index ff2578e07d..018c83f6ca 100644 --- a/apps/webapp/app/presenters/v3/WaitpointTokenListPresenter.server.ts +++ b/apps/webapp/app/presenters/v3/WaitpointListPresenter.server.ts @@ -1,6 +1,7 @@ import parse from "parse-duration"; import { Prisma, + type WaitpointResolver, type RunEngineVersion, type RuntimeEnvironmentType, type WaitpointStatus, @@ -11,10 +12,11 @@ import { BasePresenter } from "./basePresenter.server"; import { type WaitpointSearchParams } from "~/components/runs/v3/WaitpointTokenFilters"; import { determineEngineVersion } from "~/v3/engineVersion.server"; import { type WaitpointTokenStatus, type WaitpointTokenItem } from "@trigger.dev/core/v3"; +import { generateHttpCallbackUrl } from "~/services/httpCallback.server"; const DEFAULT_PAGE_SIZE = 25; -export type WaitpointTokenListOptions = { +export type WaitpointListOptions = { environment: { id: string; type: RuntimeEnvironmentType; @@ -22,6 +24,7 @@ export type WaitpointTokenListOptions = { id: string; engine: RunEngineVersion; }; + apiKey: string; }; // filters id?: string; @@ -63,7 +66,7 @@ type Result = filters: undefined; }; -export class WaitpointTokenListPresenter extends BasePresenter { +export class WaitpointListPresenter extends BasePresenter { public async call({ environment, id, @@ -76,7 +79,7 @@ export class WaitpointTokenListPresenter extends BasePresenter { direction = "forward", cursor, pageSize = DEFAULT_PAGE_SIZE, - }: WaitpointTokenListOptions): Promise { + }: WaitpointListOptions): Promise { const engineVersion = await determineEngineVersion({ environment }); if (engineVersion === "V1") { return { @@ -165,8 +168,8 @@ export class WaitpointTokenListPresenter extends BasePresenter { ${sqlDatabaseSchema}."Waitpoint" w WHERE w."environmentId" = ${environment.id} - AND w.type = 'MANUAL' - -- cursor + AND w.type = 'MANUAL' + -- cursor ${ cursor ? direction === "forward" @@ -263,6 +266,7 @@ export class WaitpointTokenListPresenter extends BasePresenter { success: true, tokens: tokensToReturn.map((token) => ({ id: token.friendlyId, + url: generateHttpCallbackUrl(token.id, environment.apiKey), status: waitpointStatusToApiStatus(token.status, token.outputIsError), completedAt: token.completedAt ?? undefined, timeoutAt: token.completedAfter ?? undefined, diff --git a/apps/webapp/app/presenters/v3/WaitpointPresenter.server.ts b/apps/webapp/app/presenters/v3/WaitpointPresenter.server.ts index f005f5a2dc..0262dcea84 100644 --- a/apps/webapp/app/presenters/v3/WaitpointPresenter.server.ts +++ b/apps/webapp/app/presenters/v3/WaitpointPresenter.server.ts @@ -1,8 +1,9 @@ import { isWaitpointOutputTimeout, prettyPrintPacket } from "@trigger.dev/core/v3"; +import { generateHttpCallbackUrl } from "~/services/httpCallback.server"; import { logger } from "~/services/logger.server"; import { BasePresenter } from "./basePresenter.server"; import { type RunListItem, RunListPresenter } from "./RunListPresenter.server"; -import { waitpointStatusToApiStatus } from "./WaitpointTokenListPresenter.server"; +import { waitpointStatusToApiStatus } from "./WaitpointListPresenter.server"; export type WaitpointDetail = NonNullable>>; @@ -22,6 +23,7 @@ export class WaitpointPresenter extends BasePresenter { environmentId, }, select: { + id: true, friendlyId: true, type: true, status: true, @@ -42,6 +44,11 @@ export class WaitpointPresenter extends BasePresenter { take: 5, }, tags: true, + environment: { + select: { + apiKey: true, + }, + }, }, }); @@ -83,6 +90,7 @@ export class WaitpointPresenter extends BasePresenter { return { id: waitpoint.friendlyId, type: waitpoint.type, + url: generateHttpCallbackUrl(waitpoint.id, waitpoint.environment.apiKey), status: waitpointStatusToApiStatus(waitpoint.status, waitpoint.outputIsError), idempotencyKey: waitpoint.idempotencyKey, userProvidedIdempotencyKey: waitpoint.userProvidedIdempotencyKey, diff --git a/apps/webapp/app/routes/_app.orgs.$organizationSlug.projects.$projectParam.env.$envParam.waitpoints.tokens/route.tsx b/apps/webapp/app/routes/_app.orgs.$organizationSlug.projects.$projectParam.env.$envParam.waitpoints.tokens/route.tsx index cea5332692..7e32f08244 100644 --- a/apps/webapp/app/routes/_app.orgs.$organizationSlug.projects.$projectParam.env.$envParam.waitpoints.tokens/route.tsx +++ b/apps/webapp/app/routes/_app.orgs.$organizationSlug.projects.$projectParam.env.$envParam.waitpoints.tokens/route.tsx @@ -7,6 +7,7 @@ import { NoWaitpointTokens } from "~/components/BlankStatePanels"; import { MainCenteredContainer, PageBody, PageContainer } from "~/components/layout/AppLayout"; import { ListPagination } from "~/components/ListPagination"; import { LinkButton } from "~/components/primitives/Buttons"; +import { ClipboardField } from "~/components/primitives/ClipboardField"; import { CopyableText } from "~/components/primitives/CopyableText"; import { DateTime } from "~/components/primitives/DateTime"; import { NavBar, PageAccessories, PageTitle } from "~/components/primitives/PageHeader"; @@ -36,7 +37,7 @@ import { useOrganization } from "~/hooks/useOrganizations"; import { useProject } from "~/hooks/useProject"; import { findProjectBySlug } from "~/models/project.server"; import { findEnvironmentBySlug } from "~/models/runtimeEnvironment.server"; -import { WaitpointTokenListPresenter } from "~/presenters/v3/WaitpointTokenListPresenter.server"; +import { WaitpointListPresenter } from "~/presenters/v3/WaitpointListPresenter.server"; import { requireUserId } from "~/services/session.server"; import { docsPath, EnvironmentParamSchema, v3WaitpointTokenPath } from "~/utils/pathBuilder"; @@ -84,7 +85,7 @@ export const loader = async ({ request, params }: LoaderFunctionArgs) => { } try { - const presenter = new WaitpointTokenListPresenter(); + const presenter = new WaitpointListPresenter(); const result = await presenter.call({ environment, ...searchParams, @@ -143,6 +144,7 @@ export default function Page() { Created ID + Callback URL Status Completed Idempotency Key @@ -178,6 +180,9 @@ export default function Page() { + + + diff --git a/apps/webapp/app/routes/api.v1.waitpoints.tokens.$waitpointFriendlyId.callback.$hash.ts b/apps/webapp/app/routes/api.v1.waitpoints.tokens.$waitpointFriendlyId.callback.$hash.ts new file mode 100644 index 0000000000..0dc5c26111 --- /dev/null +++ b/apps/webapp/app/routes/api.v1.waitpoints.tokens.$waitpointFriendlyId.callback.$hash.ts @@ -0,0 +1,91 @@ +import { type ActionFunctionArgs, json } from "@remix-run/server-runtime"; +import { + type CompleteWaitpointTokenResponseBody, + conditionallyExportPacket, + stringifyIO, +} from "@trigger.dev/core/v3"; +import { WaitpointId } from "@trigger.dev/core/v3/isomorphic"; +import { z } from "zod"; +import { $replica } from "~/db.server"; +import { env } from "~/env.server"; +import { verifyHttpCallbackHash } from "~/services/httpCallback.server"; +import { logger } from "~/services/logger.server"; +import { engine } from "~/v3/runEngine.server"; + +const paramsSchema = z.object({ + waitpointFriendlyId: z.string(), + hash: z.string(), +}); + +export async function action({ request, params }: ActionFunctionArgs) { + if (request.method.toUpperCase() !== "POST") { + return json({ error: "Method not allowed" }, { status: 405, headers: { Allow: "POST" } }); + } + + const contentLength = request.headers.get("content-length"); + if (!contentLength) { + return json({ error: "Content-Length header is required" }, { status: 411 }); + } + + if (parseInt(contentLength) > env.TASK_PAYLOAD_MAXIMUM_SIZE) { + return json({ error: "Request body too large" }, { status: 413 }); + } + + const { waitpointFriendlyId, hash } = paramsSchema.parse(params); + const waitpointId = WaitpointId.toId(waitpointFriendlyId); + + try { + const waitpoint = await $replica.waitpoint.findFirst({ + where: { + id: waitpointId, + }, + include: { + environment: { + select: { + apiKey: true, + }, + }, + }, + }); + + if (!waitpoint) { + return json({ error: "Waitpoint not found" }, { status: 404 }); + } + + if (!verifyHttpCallbackHash(waitpoint.id, hash, waitpoint.environment.apiKey)) { + return json({ error: "Invalid URL, hash doesn't match" }, { status: 401 }); + } + + if (waitpoint.status === "COMPLETED") { + return json({ + success: true, + }); + } + + // If the request body is not valid JSON, return an empty object + const body = await request.json().catch(() => ({})); + + const stringifiedData = await stringifyIO(body); + const finalData = await conditionallyExportPacket( + stringifiedData, + `${waitpointId}/waitpoint/http-callback` + ); + + const result = await engine.completeWaitpoint({ + id: waitpointId, + output: finalData.data + ? { type: finalData.dataType, value: finalData.data, isError: false } + : undefined, + }); + + return json( + { + success: true, + }, + { status: 200 } + ); + } catch (error) { + logger.error("Failed to complete HTTP callback", { error }); + throw json({ error: "Failed to complete HTTP callback" }, { status: 500 }); + } +} diff --git a/apps/webapp/app/routes/api.v1.waitpoints.tokens.ts b/apps/webapp/app/routes/api.v1.waitpoints.tokens.ts index 7e49de8f40..4542236d48 100644 --- a/apps/webapp/app/routes/api.v1.waitpoints.tokens.ts +++ b/apps/webapp/app/routes/api.v1.waitpoints.tokens.ts @@ -6,14 +6,15 @@ import { import { WaitpointId } from "@trigger.dev/core/v3/isomorphic"; import { createWaitpointTag, MAX_TAGS_PER_WAITPOINT } from "~/models/waitpointTag.server"; import { - ApiWaitpointTokenListPresenter, - ApiWaitpointTokenListSearchParams, -} from "~/presenters/v3/ApiWaitpointTokenListPresenter.server"; + ApiWaitpointListPresenter, + ApiWaitpointListSearchParams, +} from "~/presenters/v3/ApiWaitpointListPresenter.server"; +import { type AuthenticatedEnvironment } from "~/services/apiAuth.server"; +import { generateHttpCallbackUrl } from "~/services/httpCallback.server"; import { createActionApiRoute, createLoaderApiRoute, } from "~/services/routeBuilders/apiBuilder.server"; -import { AuthenticatedEnvironment } from "~/services/apiAuth.server"; import { parseDelay } from "~/utils/delays"; import { resolveIdempotencyKeyTTL } from "~/utils/idempotencyKeys.server"; import { engine } from "~/v3/runEngine.server"; @@ -21,11 +22,11 @@ import { ServiceValidationError } from "~/v3/services/baseService.server"; export const loader = createLoaderApiRoute( { - searchParams: ApiWaitpointTokenListSearchParams, + searchParams: ApiWaitpointListSearchParams, findResource: async () => 1, // This is a dummy function, we don't need to find a resource }, async ({ searchParams, authentication }) => { - const presenter = new ApiWaitpointTokenListPresenter(); + const presenter = new ApiWaitpointListPresenter(); const result = await presenter.call(authentication.environment, searchParams); return json(result); @@ -84,6 +85,7 @@ const { action } = createActionApiRoute( { id: WaitpointId.toFriendlyId(result.waitpoint.id), isCached: result.isCached, + url: generateHttpCallbackUrl(result.waitpoint.id, authentication.environment.apiKey), }, { status: 200, headers: $responseHeaders } ); diff --git a/apps/webapp/app/services/apiRateLimit.server.ts b/apps/webapp/app/services/apiRateLimit.server.ts index 466aaa98b8..611a19fb3e 100644 --- a/apps/webapp/app/services/apiRateLimit.server.ts +++ b/apps/webapp/app/services/apiRateLimit.server.ts @@ -59,6 +59,7 @@ export const apiRateLimiter = authorizationRateLimitMiddleware({ "/api/v1/usage/ingest", "/api/v1/auth/jwt/claims", /^\/api\/v1\/runs\/[^\/]+\/attempts$/, // /api/v1/runs/$runFriendlyId/attempts + /^\/api\/v1\/waitpoints\/tokens\/[^\/]+\/callback\/[^\/]+$/, // /api/v1/waitpoints/tokens/$waitpointFriendlyId/callback/$hash ], log: { rejections: env.API_RATE_LIMIT_REJECTION_LOGS_ENABLED === "1", diff --git a/apps/webapp/app/services/httpCallback.server.ts b/apps/webapp/app/services/httpCallback.server.ts new file mode 100644 index 0000000000..b346da0c14 --- /dev/null +++ b/apps/webapp/app/services/httpCallback.server.ts @@ -0,0 +1,30 @@ +import { WaitpointId } from "@trigger.dev/core/v3/isomorphic"; +import nodeCrypto from "node:crypto"; +import { env } from "~/env.server"; + +export function generateHttpCallbackUrl(waitpointId: string, apiKey: string) { + const hash = generateHttpCallbackHash(waitpointId, apiKey); + + return `${env.API_ORIGIN ?? env.APP_ORIGIN}/api/v1/waitpoints/tokens/${WaitpointId.toFriendlyId( + waitpointId + )}/callback/${hash}`; +} + +function generateHttpCallbackHash(waitpointId: string, apiKey: string) { + const hmac = nodeCrypto.createHmac("sha256", apiKey); + hmac.update(waitpointId); + return hmac.digest("hex"); +} + +export function verifyHttpCallbackHash(waitpointId: string, hash: string, apiKey: string) { + const expectedHash = generateHttpCallbackHash(waitpointId, apiKey); + + if ( + hash.length === expectedHash.length && + nodeCrypto.timingSafeEqual(Buffer.from(hash, "hex"), Buffer.from(expectedHash, "hex")) + ) { + return true; + } + + return false; +} diff --git a/docs/wait-for-token.mdx b/docs/wait-for-token.mdx index 8b7a44438e..9bda1edfea 100644 --- a/docs/wait-for-token.mdx +++ b/docs/wait-for-token.mdx @@ -7,6 +7,8 @@ import UpgradeToV4Note from "/snippets/upgrade-to-v4-note.mdx"; Waitpoint tokens pause task runs until you complete the token. They're commonly used for approval workflows and other scenarios where you need to wait for external confirmation, such as human-in-the-loop processes. +You can complete a token using the SDK or by making a POST request to the token's URL. + ## Usage @@ -52,6 +54,29 @@ await wait.completeToken(tokenId, { }); ``` +Or you can make an HTTP POST request to the `url` it returns: + +```ts +import { wait } from "@trigger.dev/sdk"; + +const token = await wait.createToken({ + timeout: "10m", +}); + +const call = await replicate.predictions.create({ + version: "27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478", + input: { + prompt: "A painting of a cat by Andy Warhol", + }, + // pass the provided URL to Replicate's webhook, so they can "callback" + webhook: token.url, + webhook_events_filter: ["completed"], +}); + +const prediction = await wait.forToken(token).unwrap(); +// unwrap() throws a timeout error or returns the result 👆 +``` + ## wait.createToken Create a waitpoint token. @@ -85,6 +110,13 @@ The `createToken` function returns a token object with the following properties: The ID of the token. Starts with `waitpoint_`. + + The URL of the token. This is the URL you can make a POST request to in order to complete the token. + +The JSON body of the POST request will be used as the output of the token. If there's no body the output will be an empty object `{}`. + + + Whether the token is cached. Will return true if the token was created with an idempotency key and the same idempotency key was used again. @@ -270,6 +302,18 @@ The `forToken` function returns a result object with the following properties: timeout error. +### unwrap() + +We provide a handy `.unwrap()` method that will throw an error if the result is not ok. This means your happy path is a lot cleaner. + +```ts +const approval = await wait.forToken(tokenId).unwrap(); +// unwrap means an error will throw if the waitpoint times out 👆 + +// This is the actual data you sent to the token now, not a result object +console.log("Approval", approval); +``` + ### Example ```ts @@ -326,6 +370,13 @@ Each token is an object with the following properties: The ID of the token. + + The URL of the token. This is the URL you can make a POST request to in order to complete the token. + +The JSON body of the POST request will be used as the output of the token. If there's no body the output will be an empty object `{}`. + + + The status of the token. @@ -392,6 +443,13 @@ The `retrieveToken` function returns a token object with the following propertie The ID of the token. + + The URL of the token. This is the URL you can make a POST request to in order to complete the token. + +The JSON body of the POST request will be used as the output of the token. If there's no body the output will be an empty object `{}`. + + + The status of the token. diff --git a/docs/wait.mdx b/docs/wait.mdx index af80c8cf02..cfe5b2385b 100644 --- a/docs/wait.mdx +++ b/docs/wait.mdx @@ -4,14 +4,14 @@ sidebarTitle: "Overview" description: "During your run you can wait for a period of time or for something to happen." --- -import PausedExecutionFree from "/snippets/paused-execution-free.mdx" +import PausedExecutionFree from "/snippets/paused-execution-free.mdx"; -Waiting allows you to write complex tasks as a set of async code, without having to scheduled another task or poll for changes. +Waiting allows you to write complex tasks as a set of async code, without having to schedule another task or poll for changes. -| Function | What it does | -| :--------------------------------------| :---------------------------------------------------------------------------------------- | -| [wait.for()](/wait-for) | Waits for a specific period of time, e.g. 1 day. | -| [wait.until()](/wait-until) | Waits until the provided `Date`. | -| [wait.forToken()](/wait-for-token) | Pauses task runs until a token is completed. | +| Function | What it does | +| :--------------------------------- | :----------------------------------------------- | +| [wait.for()](/wait-for) | Waits for a specific period of time, e.g. 1 day. | +| [wait.until()](/wait-until) | Waits until the provided `Date`. | +| [wait.forToken()](/wait-for-token) | Pauses runs until a token is completed. | diff --git a/internal-packages/database/prisma/schema.prisma b/internal-packages/database/prisma/schema.prisma index 5f83179943..c67edbd173 100644 --- a/internal-packages/database/prisma/schema.prisma +++ b/internal-packages/database/prisma/schema.prisma @@ -2151,6 +2151,7 @@ model Waitpoint { updatedAt DateTime @updatedAt /// Denormized column that holds the raw tags + /// Denormalized column that holds the raw tags tags String[] /// Quickly find an idempotent waitpoint diff --git a/packages/core/src/v3/apiClient/index.ts b/packages/core/src/v3/apiClient/index.ts index f86fef448f..0bac97c8e5 100644 --- a/packages/core/src/v3/apiClient/index.ts +++ b/packages/core/src/v3/apiClient/index.ts @@ -16,9 +16,7 @@ import { CreateWaitpointTokenResponseBody, DeletedScheduleObject, EnvironmentVariableResponseBody, - EnvironmentVariableValue, EnvironmentVariableWithSecret, - EnvironmentVariables, ListQueueOptions, ListRunResponseItem, ListScheduleOptions, @@ -42,8 +40,10 @@ import { WaitpointRetrieveTokenResponse, WaitpointTokenItem, } from "../schemas/index.js"; +import { AsyncIterableStream } from "../streams/asyncIterableStream.js"; import { taskContext } from "../task-context-api.js"; import { AnyRunTypes, TriggerJwtOptions } from "../types/tasks.js"; +import { Prettify } from "../types/utils.js"; import { AnyZodFetchOptions, ApiPromise, @@ -63,9 +63,9 @@ import { RunShape, RunStreamCallback, RunSubscription, + SSEStreamSubscriptionFactory, TaskRunShape, runShapeStream, - SSEStreamSubscriptionFactory, } from "./runStream.js"; import { CreateEnvironmentVariableParams, @@ -76,8 +76,6 @@ import { SubscribeToRunsQueryParams, UpdateEnvironmentVariableParams, } from "./types.js"; -import { AsyncIterableStream } from "../streams/asyncIterableStream.js"; -import { Prettify } from "../types/utils.js"; export type CreateWaitpointTokenResponse = Prettify< CreateWaitpointTokenResponseBody & { diff --git a/packages/core/src/v3/schemas/api.ts b/packages/core/src/v3/schemas/api.ts index ef8ff25baf..bc73ce534b 100644 --- a/packages/core/src/v3/schemas/api.ts +++ b/packages/core/src/v3/schemas/api.ts @@ -961,6 +961,7 @@ export type CreateWaitpointTokenRequestBody = z.infer; @@ -970,6 +971,8 @@ export type WaitpointTokenStatus = z.infer; export const WaitpointTokenItem = z.object({ id: z.string(), + /** If you make a POST request to this URL, it will complete the waitpoint. */ + url: z.string(), status: WaitpointTokenStatus, completedAt: z.coerce.date().optional(), completedAfter: z.coerce.date().optional(), diff --git a/packages/trigger-sdk/src/v3/wait.ts b/packages/trigger-sdk/src/v3/wait.ts index 4768a9e7ee..1c606ea79c 100644 --- a/packages/trigger-sdk/src/v3/wait.ts +++ b/packages/trigger-sdk/src/v3/wait.ts @@ -1,29 +1,27 @@ +import { SpanStatusCode } from "@opentelemetry/api"; import { - SemanticInternalAttributes, accessoryAttributes, - runtime, apiClientManager, ApiPromise, ApiRequestOptions, + CompleteWaitpointTokenResponseBody, CreateWaitpointTokenRequestBody, + CreateWaitpointTokenResponse, CreateWaitpointTokenResponseBody, - mergeRequestOptions, - CompleteWaitpointTokenResponseBody, - WaitpointTokenTypedResult, - Prettify, - taskContext, - ListWaitpointTokensQueryParams, CursorPagePromise, - WaitpointTokenItem, flattenAttributes, + ListWaitpointTokensQueryParams, + mergeRequestOptions, + runtime, + SemanticInternalAttributes, + taskContext, WaitpointListTokenItem, - WaitpointTokenStatus, WaitpointRetrieveTokenResponse, - CreateWaitpointTokenResponse, + WaitpointTokenStatus, + WaitpointTokenTypedResult, } from "@trigger.dev/core/v3"; -import { tracer } from "./tracer.js"; import { conditionallyImportAndParsePacket } from "@trigger.dev/core/v3/utils/ioSerialization"; -import { SpanStatusCode } from "@opentelemetry/api"; +import { tracer } from "./tracer.js"; /** * This creates a waitpoint token. @@ -31,6 +29,8 @@ import { SpanStatusCode } from "@opentelemetry/api"; * * @example * + * **Manually completing a token** + * * ```ts * const token = await wait.createToken({ * idempotencyKey: `approve-document-${documentId}`, @@ -45,6 +45,30 @@ import { SpanStatusCode } from "@opentelemetry/api"; * }); * ``` * + * @example + * + * **Completing a token with a webhook** + * + * ```ts + * const token = await wait.createToken({ + * timeout: "10m", + * tags: ["replicate"], + * }); + * + * // Later, in a different part of your codebase, you can complete the waitpoint + * await replicate.predictions.create({ + * version: "27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478", + * input: { + * prompt: "A painting of a cat by Andy Warhol", + * }, + * // pass the provided URL to Replicate's webhook, so they can "callback" + * webhook: token.url, + * webhook_events_filter: ["completed"], + * }); + * + * const prediction = await wait.forToken(token).unwrap(); + * ``` + * * @param options - The options for the waitpoint token. * @param requestOptions - The request options for the waitpoint token. * @returns The waitpoint token. @@ -73,6 +97,7 @@ function createToken( onResponseBody: (body: CreateWaitpointTokenResponseBody, span) => { span.setAttribute("id", body.id); span.setAttribute("isCached", body.isCached); + span.setAttribute("url", body.url); }, }, requestOptions @@ -151,6 +176,8 @@ function listTokens( */ export type WaitpointRetrievedToken = { id: string; + /** A URL that you can make a POST request to in order to complete the waitpoint. */ + url: string; status: WaitpointTokenStatus; completedAt?: Date; timeoutAt?: Date; @@ -204,6 +231,7 @@ async function retrieveToken( }, onResponseBody: (body: WaitpointRetrieveTokenResponse, span) => { span.setAttribute("id", body.id); + span.setAttribute("url", body.url); span.setAttribute("status", body.status); if (body.completedAt) { span.setAttribute("completedAt", body.completedAt.toISOString()); @@ -244,6 +272,7 @@ async function retrieveToken( return { id: result.id, + url: result.url, status: result.status, completedAt: result.completedAt, timeoutAt: result.timeoutAt, @@ -377,6 +406,29 @@ function printWaitBelowThreshold() { ); } +class ManualWaitpointPromise extends Promise> { + constructor( + executor: ( + resolve: ( + value: WaitpointTokenTypedResult | PromiseLike> + ) => void, + reject: (reason?: any) => void + ) => void + ) { + super(executor); + } + + unwrap(): Promise { + return this.then((result) => { + if (result.ok) { + return result.output; + } else { + throw new WaitpointTimeoutError(result.error.message); + } + }); + } +} + export const wait = { for: async (options: WaitForOptions) => { const ctx = taskContext.ctx; @@ -554,9 +606,9 @@ export const wait = { * * @param token - The token to wait for. * @param options - The options for the waitpoint token. - * @returns The waitpoint token. + * @returns A promise that resolves to the result of the waitpoint. You can use `.unwrap()` to get the result and an error will throw. */ - forToken: async ( + forToken: ( /** * The token to wait for. * This can be a string token ID or an object with an `id` property. @@ -575,76 +627,84 @@ export const wait = { */ releaseConcurrency?: boolean; } - ): Promise>> => { - const ctx = taskContext.ctx; + ): ManualWaitpointPromise => { + return new ManualWaitpointPromise(async (resolve, reject) => { + try { + const ctx = taskContext.ctx; - if (!ctx) { - throw new Error("wait.forToken can only be used from inside a task.run()"); - } + if (!ctx) { + throw new Error("wait.forToken can only be used from inside a task.run()"); + } - const apiClient = apiClientManager.clientOrThrow(); + const apiClient = apiClientManager.clientOrThrow(); - const tokenId = typeof token === "string" ? token : token.id; + const tokenId = typeof token === "string" ? token : token.id; - return tracer.startActiveSpan( - `wait.forToken()`, - async (span) => { - const response = await apiClient.waitForWaitpointToken({ - runFriendlyId: ctx.run.id, - waitpointFriendlyId: tokenId, - releaseConcurrency: options?.releaseConcurrency, - }); - - if (!response.success) { - throw new Error(`Failed to wait for wait token ${tokenId}`); - } + const result = await tracer.startActiveSpan( + `wait.forToken()`, + async (span) => { + const response = await apiClient.waitForWaitpointToken({ + runFriendlyId: ctx.run.id, + waitpointFriendlyId: tokenId, + releaseConcurrency: options?.releaseConcurrency, + }); - const result = await runtime.waitUntil(tokenId); - - const data = result.output - ? await conditionallyImportAndParsePacket( - { data: result.output, dataType: result.outputType ?? "application/json" }, - apiClient - ) - : undefined; - - if (result.ok) { - return { - ok: result.ok, - output: data, - } as WaitpointTokenTypedResult; - } else { - const error = new WaitpointTimeoutError(data.message); - - span.recordException(error); - span.setStatus({ - code: SpanStatusCode.ERROR, - }); - - return { - ok: result.ok, - error, - } as WaitpointTokenTypedResult; - } - }, - { - attributes: { - [SemanticInternalAttributes.STYLE_ICON]: "wait", - [SemanticInternalAttributes.ENTITY_TYPE]: "waitpoint", - [SemanticInternalAttributes.ENTITY_ID]: tokenId, - id: tokenId, - ...accessoryAttributes({ - items: [ - { - text: tokenId, - variant: "normal", - }, - ], - style: "codepath", - }), - }, + if (!response.success) { + throw new Error(`Failed to wait for wait token ${tokenId}`); + } + + const result = await runtime.waitUntil(tokenId); + + const data = result.output + ? await conditionallyImportAndParsePacket( + { data: result.output, dataType: result.outputType ?? "application/json" }, + apiClient + ) + : undefined; + + if (result.ok) { + return { + ok: result.ok, + output: data, + } as WaitpointTokenTypedResult; + } else { + const error = new WaitpointTimeoutError(data.message); + + span.recordException(error); + span.setStatus({ + code: SpanStatusCode.ERROR, + }); + + return { + ok: result.ok, + error, + } as WaitpointTokenTypedResult; + } + }, + { + attributes: { + [SemanticInternalAttributes.STYLE_ICON]: "wait", + [SemanticInternalAttributes.ENTITY_TYPE]: "waitpoint", + [SemanticInternalAttributes.ENTITY_ID]: tokenId, + id: tokenId, + ...accessoryAttributes({ + items: [ + { + text: tokenId, + variant: "normal", + }, + ], + style: "codepath", + }), + }, + } + ); + + resolve(result); + } catch (error) { + reject(error); } - ); + }); }, }; @@ -711,8 +771,3 @@ function calculateDurationInMs(options: WaitForOptions): number { throw new Error("Invalid options"); } - -type RequestOptions = { - to: (url: string) => Promise; - timeout: WaitForOptions; -}; diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 1a42f0a2b9..cda9a2d7bc 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1930,6 +1930,15 @@ importers: '@trigger.dev/sdk': specifier: workspace:* version: link:../../packages/trigger-sdk + openai: + specifier: ^4.97.0 + version: 4.97.0(zod@3.23.8) + replicate: + specifier: ^1.0.1 + version: 1.0.1 + zod: + specifier: 3.23.8 + version: 3.23.8 devDependencies: trigger.dev: specifier: workspace:* @@ -28886,6 +28895,30 @@ packages: - encoding dev: false + /openai@4.97.0(zod@3.23.8): + resolution: {integrity: sha512-LRoiy0zvEf819ZUEJhgfV8PfsE8G5WpQi4AwA1uCV8SKvvtXQkoWUFkepD6plqyJQRghy2+AEPQ07FrJFKHZ9Q==} + hasBin: true + peerDependencies: + ws: ^8.18.0 + zod: ^3.23.8 + peerDependenciesMeta: + ws: + optional: true + zod: + optional: true + dependencies: + '@types/node': 18.19.20 + '@types/node-fetch': 2.6.12 + abort-controller: 3.0.0 + agentkeepalive: 4.5.0 + form-data-encoder: 1.7.2 + formdata-node: 4.4.1 + node-fetch: 2.6.12 + zod: 3.23.8 + transitivePeerDependencies: + - encoding + dev: false + /openapi-fetch@0.9.8: resolution: {integrity: sha512-zM6elH0EZStD/gSiNlcPrzXcVQ/pZo3BDvC6CDwRDUt1dDzxlshpmQnpD6cZaJ39THaSmwVCxxRrPKNM1hHrDg==} dependencies: @@ -30178,7 +30211,7 @@ packages: /process@0.11.10: resolution: {integrity: sha512-cdGef/drWFoydD1JsMzuFf8100nZl+GT+yacc2bEced5f9Rjk4z+WtFUTBu9PhOi9j/jfmBPu0mMEY4wIdAF8A==} engines: {node: '>= 0.6.0'} - dev: true + requiresBuild: true /progress@2.0.3: resolution: {integrity: sha512-7PiHtLll5LdnKIMw100I+8xJXR5gW2QwWYkT6iJva0bXitZKa/XMrSbdmg3r2Xnaidz9Qumd0VPaMrZlF9V9sA==} @@ -31149,7 +31182,6 @@ packages: events: 3.3.0 process: 0.11.10 string_decoder: 1.3.0 - dev: true /readdir-glob@1.1.3: resolution: {integrity: sha512-v05I2k7xN8zXvPD9N+z/uhXPaj0sUFCe2rcWZIpBsqxfP7xXFQ0tipAd/wjj1YxWyWtUS5IDJpOG82JKt2EAVA==} @@ -31526,6 +31558,13 @@ packages: engines: {node: '>=8'} dev: true + /replicate@1.0.1: + resolution: {integrity: sha512-EY+rK1YR5bKHcM9pd6WyaIbv6m2aRIvHfHDh51j/LahlHTLKemTYXF6ptif2sLa+YospupAsIoxw8Ndt5nI3vg==} + engines: {git: '>=2.11.0', node: '>=18.0.0', npm: '>=7.19.0', yarn: '>=1.7.0'} + optionalDependencies: + readable-stream: 4.5.2 + dev: false + /request@2.88.2: resolution: {integrity: sha512-MsvtOrfG9ZcrOwAW+Qi+F6HbD0CWXEh9ou77uOb7FM2WPhwT7smM833PzanhJLsgXjN89Ir6V2PczXNnMpwKhw==} engines: {node: '>= 6'} diff --git a/references/hello-world/package.json b/references/hello-world/package.json index 9512898e21..67d93c4fdc 100644 --- a/references/hello-world/package.json +++ b/references/hello-world/package.json @@ -6,7 +6,10 @@ "trigger.dev": "workspace:*" }, "dependencies": { - "@trigger.dev/sdk": "workspace:*" + "@trigger.dev/sdk": "workspace:*", + "openai": "^4.97.0", + "replicate": "^1.0.1", + "zod": "3.23.8" }, "scripts": { "dev": "trigger dev" diff --git a/references/hello-world/src/trigger/waits.ts b/references/hello-world/src/trigger/waits.ts index 675b852aa6..9a0afa642f 100644 --- a/references/hello-world/src/trigger/waits.ts +++ b/references/hello-world/src/trigger/waits.ts @@ -1,5 +1,5 @@ -import { logger, wait, task, retry, idempotencyKeys, auth } from "@trigger.dev/sdk/v3"; - +import { auth, idempotencyKeys, logger, retry, task, wait } from "@trigger.dev/sdk/v3"; +import Replicate, { Prediction } from "replicate"; type Token = { status: "approved" | "pending" | "rejected"; }; @@ -140,3 +140,49 @@ export const waitForDuration = task({ ); }, }); + +export const waitHttpCallback = task({ + id: "wait-http-callback", + retry: { + maxAttempts: 1, + }, + run: async () => { + if (process.env.REPLICATE_API_KEY) { + const replicate = new Replicate({ + auth: process.env.REPLICATE_API_KEY, + }); + + const token = await wait.createToken({ + timeout: "10m", + tags: ["replicate"], + }); + logger.log("Create result", { token }); + + const call = await replicate.predictions.create({ + version: "27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478", + input: { + prompt: "A painting of a cat by Any Warhol", + }, + // pass the provided URL to Replicate's webhook, so they can "callback" + webhook: token.url, + webhook_events_filter: ["completed"], + }); + + const prediction = await wait.forToken(token); + + if (!prediction.ok) { + throw new Error("Failed to create prediction"); + } + + logger.log("Prediction", prediction); + + const imageUrl = prediction.output.output; + logger.log("Image URL", imageUrl); + + //same again but with unwrapping + const result2 = await wait.forToken(token).unwrap(); + + logger.log("Result2", { result2 }); + } + }, +});