diff --git a/examples/external_clients/aisdk.ts b/examples/external_clients/aisdk.ts index 1d72d984f..491cb277d 100644 --- a/examples/external_clients/aisdk.ts +++ b/examples/external_clients/aisdk.ts @@ -8,10 +8,18 @@ import { generateText, ImagePart, LanguageModel, + streamText, TextPart, } from "ai"; import { CreateChatCompletionOptions, LLMClient, AvailableModel } from "@/dist"; import { ChatCompletion } from "openai/resources"; +import { + GenerateObjectOptions, + GenerateTextOptions, + ObjectResponse, + StreamingTextResponse, + TextResponse, +} from "@/lib"; export class AISdkClient extends LLMClient { public type = "aisdk" as const; @@ -119,4 +127,61 @@ export class AISdkClient extends LLMClient { }, } as T; } + + async streamText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + const tools: Record = {}; + if (options.tools) { + for (const rawTool of options.tools) { + tools[rawTool.name] = { + description: rawTool.description, + parameters: rawTool.parameters, + }; + } + } + + const response = await streamText({ + model: this.model, + prompt: prompt, + tools, + }); + return response as T; + } + + async generateText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + const tools: Record = {}; + if (options.tools) { + for (const rawTool of options.tools) { + tools[rawTool.name] = { + description: rawTool.description, + parameters: rawTool.parameters, + }; + } + } + + const response = await generateText({ + model: this.model, + prompt: prompt, + tools, + }); + return response as T; + } + async generateObject({ + prompt, + schema, + options = {}, + }: GenerateObjectOptions): Promise { + const response = await generateObject({ + model: this.model, + prompt: prompt, + schema: schema, + ...options, + }); + return response as T; + } } diff --git a/examples/external_clients/customOpenAI.ts b/examples/external_clients/customOpenAI.ts index 6a6d70b3f..37d913ee5 100644 --- a/examples/external_clients/customOpenAI.ts +++ b/examples/external_clients/customOpenAI.ts @@ -14,12 +14,23 @@ import type { ChatCompletionContentPartImage, ChatCompletionContentPartText, ChatCompletionCreateParamsNonStreaming, + ChatCompletionCreateParamsStreaming, ChatCompletionMessageParam, ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam, } from "openai/resources/chat/completions"; import { z } from "zod"; import { CreateChatCompletionResponseError } from "@/types/stagehandErrors"; +import { + GenerateObjectOptions, + GenerateTextOptions, + LLMObjectResponse, + LLMResponse, + ObjectResponse, + StreamingChatResponse, + StreamingTextResponse, + TextResponse, +} from "@/lib"; function validateZodSchema(schema: z.ZodTypeAny, data: unknown) { try { @@ -220,21 +231,395 @@ export class CustomOpenAIClient extends LLMClient { return { data: parsedData, - usage: { - prompt_tokens: response.usage?.prompt_tokens ?? 0, - completion_tokens: response.usage?.completion_tokens ?? 0, - total_tokens: response.usage?.total_tokens ?? 0, - }, + response: response, } as T; } - return { - data: response.choices[0].message.content, - usage: { - prompt_tokens: response.usage?.prompt_tokens ?? 0, - completion_tokens: response.usage?.completion_tokens ?? 0, - total_tokens: response.usage?.total_tokens ?? 0, + return response as T; + } + + async createChatCompletionStream({ + options, + logger, + retries = 3, + }: CreateChatCompletionOptions): Promise { + const { image, requestId, ...optionsWithoutImageAndRequestId } = options; + + // TODO: Implement vision support + if (image) { + console.warn( + "Image provided. Vision is not currently supported for openai", + ); + } + + logger({ + category: "openai", + message: "creating chat completion stream", + level: 1, + auxiliary: { + options: { + value: JSON.stringify({ + ...optionsWithoutImageAndRequestId, + requestId, + }), + type: "object", + }, + modelName: { + value: this.modelName, + type: "string", + }, }, - } as T; + }); + + if (options.image) { + console.warn( + "Image provided. Vision is not currently supported for openai", + ); + } + + let responseFormat = undefined; + if (options.response_model) { + responseFormat = zodResponseFormat( + options.response_model.schema, + options.response_model.name, + ); + } + + /* eslint-disable */ + // Remove unsupported options + const { response_model, ...openaiOptions } = { + ...optionsWithoutImageAndRequestId, + model: this.modelName, + }; + + const formattedMessages: ChatCompletionMessageParam[] = + options.messages.map((message) => { + if (Array.isArray(message.content)) { + const contentParts = message.content.map((content) => { + if ("image_url" in content) { + const imageContent: ChatCompletionContentPartImage = { + image_url: { + url: content.image_url.url, + }, + type: "image_url", + }; + return imageContent; + } else { + const textContent: ChatCompletionContentPartText = { + text: content.text, + type: "text", + }; + return textContent; + } + }); + + if (message.role === "system") { + const formattedMessage: ChatCompletionSystemMessageParam = { + ...message, + role: "system", + content: contentParts.filter( + (content): content is ChatCompletionContentPartText => + content.type === "text", + ), + }; + return formattedMessage; + } else if (message.role === "user") { + const formattedMessage: ChatCompletionUserMessageParam = { + ...message, + role: "user", + content: contentParts, + }; + return formattedMessage; + } else { + const formattedMessage: ChatCompletionAssistantMessageParam = { + ...message, + role: "assistant", + content: contentParts.filter( + (content): content is ChatCompletionContentPartText => + content.type === "text", + ), + }; + return formattedMessage; + } + } + + const formattedMessage: ChatCompletionUserMessageParam = { + role: "user", + content: message.content, + }; + + return formattedMessage; + }); + + const body: ChatCompletionCreateParamsStreaming = { + ...openaiOptions, + model: this.modelName, + messages: formattedMessages, + response_format: responseFormat, + stream: true, + tools: options.tools?.map((tool) => ({ + function: { + name: tool.name, + description: tool.description, + parameters: tool.parameters, + }, + type: "function", + })), + }; + + const response = await this.client.chat.completions.create(body); + return response as T; + } + + async streamText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { logger = () => {}, retries = 3, ...chatOptions } = options; + + // Create a unique request ID if not provided + const requestId = options.requestId || Date.now().toString(); + + logger({ + category: "openai", + message: "Initiating text streaming", + level: 2, + auxiliary: { + options: { + value: JSON.stringify({ + prompt, + requestId, + }), + type: "object", + }, + modelName: { + value: this.modelName, + type: "string", + }, + }, + }); + + try { + // Create a chat completion with the prompt as a user message + const response = (await this.createChatCompletionStream({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId, + }, + logger, + retries, + })) as StreamingChatResponse; + + // Restructure the response to return a stream of text + const textStream = new ReadableStream({ + async start(controller) { + try { + for await (const chunk of response) { + const content = chunk.choices[0]?.delta?.content; + if (content !== undefined) { + controller.enqueue(content); + } + } + controller.close(); + } catch (error) { + controller.error(error); + } + }, + }); + + logger({ + category: "openai", + message: "text streaming response", + level: 2, + auxiliary: { + response: { + value: JSON.stringify(textStream), + type: "object", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return { textStream: textStream } as T; + } catch (error) { + logger({ + category: "openai", + message: "Text streaming failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } + + async generateText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { logger = () => {}, retries = 3, ...chatOptions } = options; + + // Create chat completion with single user message + const res = await (this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + // Generate unique request ID if not provided + requestId: options.requestId || Date.now().toString(), + }, + logger, + retries, + }) as Promise); + // Validate response and extract generated text + if (res.choices && res.choices.length > 0) { + return { + text: res.choices[0].message.content, + finishReason: res.choices[0].finish_reason, + usage: res.usage, + response: res, + } as T; + } else { + throw new CreateChatCompletionResponseError("No choices in response"); + } + } + async generateObject({ + prompt, + schema, + options = {}, + }: GenerateObjectOptions): Promise { + // Destructure options with defaults + const { + logger = () => {}, + retries = 3, + requestId = Date.now().toString(), + ...chatOptions + } = options; + + try { + // Log the generation attempt + logger({ + category: "openai", + message: "Initiating object generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Create chat completion with the provided prompt + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + response_model: { + name: "object", + schema: schema, + }, + ...chatOptions, + requestId, + }, + logger, + retries, + })) as LLMObjectResponse; + + // Validate response structure + if (!response.data || response.data.length === 0) { + throw new CreateChatCompletionResponseError( + "API response contains no valid choices", + ); + } + + // Extract and validate the generated text + const generatedObject = response.data; + if (generatedObject === null || generatedObject === undefined) { + throw new CreateChatCompletionResponseError( + "Generated text content is empty", + ); + } + + // Construct the final response + const objResponse = { + object: generatedObject, + finishReason: response.response.choices[0].finish_reason, + usage: response.response.usage, + ...response, + } as T; + + // Log successful generation + logger({ + category: "openai", + message: "Text generation successful", + level: 2, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return objResponse; + } catch (error) { + // Log the error + logger({ + category: "openai", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } } } diff --git a/examples/external_clients/langchain.ts b/examples/external_clients/langchain.ts index 1d071a63b..5faefc34e 100644 --- a/examples/external_clients/langchain.ts +++ b/examples/external_clients/langchain.ts @@ -8,6 +8,17 @@ import { SystemMessage, } from "@langchain/core/messages"; import { ChatCompletion } from "openai/resources"; +import { + CreateChatCompletionResponseError, + GenerateObjectOptions, + GenerateTextOptions, + LLMObjectResponse, + LLMResponse, + ObjectResponse, + StreamingChatResponse, + StreamingTextResponse, + TextResponse, +} from "@/lib"; export class LangchainClient extends LLMClient { public type = "langchainClient" as const; @@ -84,4 +95,241 @@ export class LangchainClient extends LLMClient { }, } as T; } + + async createChatCompletionStream({ + options, + logger, + retries = 3, + }: CreateChatCompletionOptions): Promise { + console.log(logger, retries); + const formattedMessages: BaseMessageLike[] = options.messages.map( + (message) => { + if (Array.isArray(message.content)) { + if (message.role === "system") { + return new SystemMessage( + message.content + .map((c) => ("text" in c ? c.text : "")) + .join("\n"), + ); + } + + const content = message.content.map((content) => + "image_url" in content + ? { type: "image", image: content.image_url.url } + : { type: "text", text: content.text }, + ); + + if (message.role === "user") return new HumanMessage({ content }); + + const textOnlyParts = content.map((part) => ({ + type: "text" as const, + text: part.type === "image" ? "[Image]" : part.text, + })); + + return new AIMessage({ content: textOnlyParts }); + } + + return { + role: message.role, + content: message.content, + }; + }, + ); + const modelWithTools = this.model.bindTools(options.tools); + const response = await modelWithTools._streamIterator(formattedMessages); + return response as T; + } + + async streamText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { logger = () => {}, retries = 3, ...chatOptions } = options; + + // Create a unique request ID if not provided + const requestId = options.requestId || Date.now().toString(); + + // Create a chat completion with the prompt as a user message + const response = (await this.createChatCompletionStream({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId, + }, + logger, + retries, + })) as StreamingChatResponse; + + // Restructure the response to return a stream of text + const textStream = new ReadableStream({ + async start(controller) { + try { + for await (const chunk of response) { + const content = chunk.choices[0]?.delta?.content; + if (content !== undefined) { + controller.enqueue(content); + } + } + controller.close(); + } catch (error) { + controller.error(error); + } + }, + }); + + return { textStream: textStream } as T; + } + + async generateText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { logger = () => {}, retries = 3, ...chatOptions } = options; + + // Create chat completion with single user message + const res = await (this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + // Generate unique request ID if not provided + requestId: options.requestId || Date.now().toString(), + }, + logger, + retries, + }) as Promise); + + // Validate and extract response + if (res.choices && res.choices.length > 0) { + return { + ...res, + text: res.choices[0].message.content, + } as T; + } else { + throw new CreateChatCompletionResponseError("No choices in response"); + } + } + async generateObject({ + prompt, + schema, + options = {}, + }: GenerateObjectOptions): Promise { + // Destructure options with defaults + const { + logger = () => {}, + retries = 3, + requestId = Date.now().toString(), + ...chatOptions + } = options; + + try { + // Log the generation attempt + logger({ + category: "anthropic", + message: "Initiating object generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Create chat completion with the provided prompt + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + response_model: { + name: "object", + schema: schema, + }, + ...chatOptions, + requestId, + }, + logger, + retries, + })) as LLMObjectResponse; + + // Validate response structure + if (!response.data || response.data.length === 0) { + throw new CreateChatCompletionResponseError( + "API response contains no valid choices", + ); + } + + // Extract and validate the generated text + const generatedObject = response.data; + if (generatedObject === null || generatedObject === undefined) { + throw new CreateChatCompletionResponseError( + "Generated text content is empty", + ); + } + + // Construct the final response + const objResponse = { + ...response, + object: generatedObject, + } as T; + + // Log successful generation + logger({ + category: "anthropic", + message: "Text generation successful", + level: 2, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return objResponse; + } catch (error) { + // Log the error + logger({ + category: "anthropic", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } } diff --git a/examples/llm_usage_wordle.ts b/examples/llm_usage_wordle.ts new file mode 100644 index 000000000..73681cfb6 --- /dev/null +++ b/examples/llm_usage_wordle.ts @@ -0,0 +1,43 @@ +import { Stagehand } from "@/dist"; +import StagehandConfig from "@/stagehand.config"; +import { z } from "zod"; + +async function example() { + const stagehand = new Stagehand({ + ...StagehandConfig, + }); + const prompt = + "you are playing wordle. Return the 5-letter word that would be the best guess"; + await stagehand.init(); + console.log("---Generating Text---"); + const responseText = await stagehand.llmClient.generateText({ + prompt: prompt, + }); + console.log(responseText); + + console.log("---Generating Object---"); + const responseObj = await stagehand.llmClient.generateObject({ + prompt: prompt, + schema: z.object({ + guess: z + .string() + .describe("The 5-letter word that would be the best guess"), + }), + }); + console.log(responseObj); + + console.log("---Streaming Text---"); + const { textStream } = await stagehand.llmClient.streamText({ + prompt: prompt, + }); + + for await (const textPart of textStream) { + process.stdout.write(textPart); + } + + await stagehand.close(); +} + +(async () => { + await example(); +})(); diff --git a/lib/llm/AnthropicClient.ts b/lib/llm/AnthropicClient.ts index 27d8f4c7c..749ba2151 100644 --- a/lib/llm/AnthropicClient.ts +++ b/lib/llm/AnthropicClient.ts @@ -11,8 +11,15 @@ import { AnthropicJsonSchemaObject, AvailableModel } from "../../types/model"; import { LLMCache } from "../cache/LLMCache"; import { CreateChatCompletionOptions, + GenerateObjectOptions, + GenerateTextOptions, LLMClient, + LLMObjectResponse, LLMResponse, + ObjectResponse, + StreamingChatResponse, + StreamingTextResponse, + TextResponse, } from "./LLMClient"; import { CreateChatCompletionResponseError } from "@/types/stagehandErrors"; @@ -289,7 +296,6 @@ export class AnthropicClient extends LLMClient { ], usage: usageData, }; - logger({ category: "anthropic", message: "transformed response", @@ -314,12 +320,12 @@ export class AnthropicClient extends LLMClient { const finalParsedResponse = { data: result, usage: usageData, + ...response, } as unknown as T; if (this.enableCaching) { this.cache.set(cacheOptions, finalParsedResponse, options.requestId); } - return finalParsedResponse; } else { if (!retries || retries < 5) { @@ -373,6 +379,577 @@ export class AnthropicClient extends LLMClient { // so we can safely cast here to T, which defaults to AnthropicTransformedResponse return transformedResponse as T; } + + async createChatCompletionStream({ + options, + retries, + logger, + }: CreateChatCompletionOptions): Promise { + const optionsWithoutImage = { ...options }; + delete optionsWithoutImage.image; + + logger({ + category: "anthropic", + message: "creating chat completion stream", + level: 2, + auxiliary: { + options: { + value: JSON.stringify(optionsWithoutImage), + type: "object", + }, + }, + }); + + // Try to get cached response + const cacheOptions = { + model: this.modelName, + messages: options.messages, + temperature: options.temperature, + image: options.image, + response_model: options.response_model, + tools: options.tools, + retries: retries, + }; + + if (this.enableCaching) { + const cachedResponse = await this.cache.get( + cacheOptions, + options.requestId, + ); + if (cachedResponse) { + logger({ + category: "llm_cache", + message: "LLM cache hit - returning cached response", + level: 1, + auxiliary: { + cachedResponse: { + value: JSON.stringify(cachedResponse), + type: "object", + }, + requestId: { + value: options.requestId, + type: "string", + }, + cacheOptions: { + value: JSON.stringify(cacheOptions), + type: "object", + }, + }, + }); + return cachedResponse as T; + } else { + logger({ + category: "llm_cache", + message: "LLM cache miss - no cached response found", + level: 1, + auxiliary: { + cacheOptions: { + value: JSON.stringify(cacheOptions), + type: "object", + }, + requestId: { + value: options.requestId, + type: "string", + }, + }, + }); + } + } + + const systemMessage = options.messages.find((msg) => { + if (msg.role === "system") { + if (typeof msg.content === "string") { + return true; + } else if (Array.isArray(msg.content)) { + return msg.content.every((content) => content.type !== "image_url"); + } + } + return false; + }); + + const userMessages = options.messages.filter( + (msg) => msg.role !== "system", + ); + + const formattedMessages: MessageParam[] = userMessages.map((msg) => { + if (typeof msg.content === "string") { + return { + role: msg.role as "user" | "assistant", // ensure its not checking for system types + content: msg.content, + }; + } else { + return { + role: msg.role as "user" | "assistant", + content: msg.content.map((content) => { + if ("image_url" in content) { + const formattedContent: ImageBlockParam = { + type: "image", + source: { + type: "base64", + media_type: "image/jpeg", + data: content.image_url.url, + }, + }; + + return formattedContent; + } else { + return { type: "text", text: content.text }; + } + }), + }; + } + }); + + if (options.image) { + const screenshotMessage: MessageParam = { + role: "user", + content: [ + { + type: "image", + source: { + type: "base64", + media_type: "image/jpeg", + data: options.image.buffer.toString("base64"), + }, + }, + ], + }; + if ( + options.image.description && + Array.isArray(screenshotMessage.content) + ) { + screenshotMessage.content.push({ + type: "text", + text: options.image.description, + }); + } + + formattedMessages.push(screenshotMessage); + } + + let anthropicTools: Tool[] = options.tools?.map((tool) => { + return { + name: tool.name, + description: tool.description, + input_schema: { + type: "object", + properties: tool.parameters.properties, + required: tool.parameters.required, + }, + }; + }); + + let toolDefinition: Tool | undefined; + + // Check if a response model is provided + if (options.response_model) { + const jsonSchema = zodToJsonSchema(options.response_model.schema); + const { properties: schemaProperties, required: schemaRequired } = + extractSchemaProperties(jsonSchema); + + toolDefinition = { + name: "print_extracted_data", + description: "Prints the extracted data based on the provided schema.", + input_schema: { + type: "object", + properties: schemaProperties, + required: schemaRequired, + }, + }; + } + + // Add the tool definition to the tools array if it exists + if (toolDefinition) { + anthropicTools = anthropicTools ?? []; + anthropicTools.push(toolDefinition); + } + + // Create the chat completion stream with the provided messages + const response = await this.client.messages.create({ + model: this.modelName, + max_tokens: options.maxTokens || 8192, + messages: formattedMessages, + tools: anthropicTools, + system: systemMessage + ? (systemMessage.content as string | TextBlockParam[]) + : undefined, + temperature: options.temperature, + stream: true, + }); + + // TODO: Transform response stream to preferred format + // TODO: Response model validation + // TODO: Enable caching + + // Temporarily restructure the response to match the expected format + return new ReadableStream({ + async start(controller) { + try { + for await (const chunk of response) { + if ( + chunk.type === "content_block_delta" && + chunk.delta.type === "text_delta" + ) { + controller.enqueue(chunk.delta.text); + } + } + controller.close(); + } catch (error) { + controller.error(error); + } + }, + }) as T; + } + + async streamText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { logger = () => {}, retries = 3, ...chatOptions } = options; + + // Create a unique request ID if not provided + const requestId = options.requestId || Date.now().toString(); + + // Log the generation attempt + logger({ + category: "anthropic", + message: "Initiating text streaming", + level: 2, + auxiliary: { + options: { + value: JSON.stringify({ + prompt, + requestId, + }), + type: "object", + }, + modelName: { + value: this.modelName, + type: "string", + }, + }, + }); + + try { + // Create a chat completion stream with the prompt as a user message + const response = (await this.createChatCompletionStream({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId, + }, + logger, + retries, + })) as StreamingChatResponse; + + logger({ + category: "anthropic", + message: "Text streaming response", + level: 2, + auxiliary: { + response: { + value: JSON.stringify(response), + type: "object", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return { + textStream: response, + } as T; + } catch (error) { + logger({ + category: "anthropic", + message: "Text streaming failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } + + async generateText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { + logger = () => {}, + retries = 3, + requestId = Date.now().toString(), + ...chatOptions + } = options; + + // Log the generation attempt + logger({ + category: "anthropic", + message: "Initiating text generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + try { + // Create chat completion with the provided prompt + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId, + }, + logger, + retries, + })) as LLMResponse; + + // Validate response structure + if ( + !response.choices || + response.choices.length === 0 || + response.choices[0].message.content == undefined || + response.choices[0].message.content == null + ) { + logger({ + category: "anthropic", + message: "Text generation failed", + level: 0, + auxiliary: { + error: { + value: "API response contains no valid choices", + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + throw new CreateChatCompletionResponseError( + "API response contains no valid choices", + ); + } + + // Extract and validate the generated text + const generatedText = response.choices[0].message.content; + + // Construct the final response + const textResponse = { + text: generatedText, + finishReason: response.choices[0].finish_reason, + usage: response.usage, + response: response, + // reasoning: response.reasoning, + // sources: response.sources + } as T; + + // Log successful generation + logger({ + category: "anthropic", + message: "Text generation successful", + level: 2, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + responseLength: { + value: generatedText.length.toString(), + type: "string", + }, + }, + }); + + return textResponse; + } catch (error) { + // Log the error + logger({ + category: "anthropic", + message: "Text generation failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } + + async generateObject({ + prompt, + schema, + options = {}, + }: GenerateObjectOptions): Promise { + // Destructure options with defaults + const { + logger = () => {}, + retries = 3, + requestId = Date.now().toString(), + ...chatOptions + } = options; + + // Log the generation attempt + logger({ + category: "anthropic", + message: "Initiating object generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + try { + // Create chat completion with the provided prompt + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + response_model: { + name: "object", + schema: schema, + }, + ...chatOptions, + requestId, + }, + logger, + retries, + })) as LLMObjectResponse; + + // Validate response structure + if ( + !response.data || + response.data.length === 0 || + response.data === undefined + ) { + logger({ + category: "anthropic", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: "API response contains no valid choices", + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + throw new CreateChatCompletionResponseError( + "API response contains no valid choices", + ); + } + + // Extract and validate the generated text + const generatedObject = response.data; + + // Construct the final response + const objResponse = { + object: generatedObject, + // finishReason: response.stop_reason, + // usage: response.response.usage, + response: response, + } as T; + + // Log successful generation + logger({ + category: "anthropic", + message: "Object generation successful", + level: 2, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return objResponse; + } catch (error) { + // Log the error + logger({ + category: "anthropic", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } } const extractSchemaProperties = (jsonSchema: AnthropicJsonSchemaObject) => { diff --git a/lib/llm/CerebrasClient.ts b/lib/llm/CerebrasClient.ts index 4d5a0daca..419787d7f 100644 --- a/lib/llm/CerebrasClient.ts +++ b/lib/llm/CerebrasClient.ts @@ -7,8 +7,15 @@ import { LLMCache } from "../cache/LLMCache"; import { ChatMessage, CreateChatCompletionOptions, + GenerateObjectOptions, + GenerateTextOptions, LLMClient, + LLMObjectResponse, LLMResponse, + ObjectResponse, + StreamingChatResponse, + StreamingTextResponse, + TextResponse, } from "./LLMClient"; import { CreateChatCompletionResponseError } from "@/types/stagehandErrors"; @@ -243,7 +250,7 @@ export class CerebrasClient extends LLMClient { if (this.enableCaching) { this.cache.set(cacheOptions, result, options.requestId); } - return result as T; + return { data: result, response: response } as T; } catch (e) { // If JSON parse fails, the model might be returning a different format logger({ @@ -271,7 +278,7 @@ export class CerebrasClient extends LLMClient { if (this.enableCaching) { this.cache.set(cacheOptions, result, options.requestId); } - return result as T; + return { data: result, response: response } as T; } } catch (e) { logger({ @@ -324,4 +331,554 @@ export class CerebrasClient extends LLMClient { throw error; } } + + async createChatCompletionStream({ + options, + logger, + retries = 3, + }: CreateChatCompletionOptions): Promise { + const optionsWithoutImage = { ...options }; + delete optionsWithoutImage.image; + + logger({ + category: "cerebras", + message: "creating chat completion stream", + level: 2, + auxiliary: { + options: { + value: JSON.stringify(optionsWithoutImage), + type: "object", + }, + }, + }); + + // Try to get cached response + const cacheOptions = { + model: this.modelName.split("cerebras-")[1], + messages: options.messages, + temperature: options.temperature, + response_model: options.response_model, + tools: options.tools, + retries: retries, + }; + + if (this.enableCaching) { + const cachedResponse = await this.cache.get( + cacheOptions, + options.requestId, + ); + if (cachedResponse) { + logger({ + category: "llm_cache", + message: "LLM cache hit - returning cached response", + level: 1, + auxiliary: { + cachedResponse: { + value: JSON.stringify(cachedResponse), + type: "object", + }, + requestId: { + value: options.requestId, + type: "string", + }, + cacheOptions: { + value: JSON.stringify(cacheOptions), + type: "object", + }, + }, + }); + return cachedResponse as T; + } + } + + // Format messages for Cerebras API (using OpenAI format) + const formattedMessages = options.messages.map((msg: ChatMessage) => { + const baseMessage = { + content: + typeof msg.content === "string" + ? msg.content + : Array.isArray(msg.content) && + msg.content.length > 0 && + "text" in msg.content[0] + ? msg.content[0].text + : "", + }; + + // Cerebras only supports system, user, and assistant roles + if (msg.role === "system") { + return { ...baseMessage, role: "system" as const }; + } else if (msg.role === "assistant") { + return { ...baseMessage, role: "assistant" as const }; + } else { + // Default to user for any other role + return { ...baseMessage, role: "user" as const }; + } + }); + + // Format tools if provided + let tools = options.tools?.map((tool) => ({ + type: "function" as const, + function: { + name: tool.name, + description: tool.description, + parameters: { + type: "object", + properties: tool.parameters.properties, + required: tool.parameters.required, + }, + }, + })); + + // Add response model as a tool if provided + if (options.response_model) { + const jsonSchema = zodToJsonSchema(options.response_model.schema) as { + properties?: Record; + required?: string[]; + }; + const schemaProperties = jsonSchema.properties || {}; + const schemaRequired = jsonSchema.required || []; + + const responseTool = { + type: "function" as const, + function: { + name: "print_extracted_data", + description: + "Prints the extracted data based on the provided schema.", + parameters: { + type: "object", + properties: schemaProperties, + required: schemaRequired, + }, + }, + }; + + tools = tools ? [...tools, responseTool] : [responseTool]; + } + + const apiResponse = await this.client.chat.completions.create({ + model: this.modelName.split("cerebras-")[1], + messages: [ + ...formattedMessages, + // Add explicit instruction to return JSON if we have a response model + ...(options.response_model + ? [ + { + role: "system" as const, + content: `IMPORTANT: Your response must be valid JSON that matches this schema: ${JSON.stringify(options.response_model.schema)}`, + }, + ] + : []), + ], + temperature: options.temperature || 0.7, + max_tokens: options.maxTokens, + tools: tools, + tool_choice: options.tool_choice || "auto", + }); + + // TODO: transform response to required format + // TODO: Validate response model + // TODO: Enable caching + + return apiResponse as T; + } + + async streamText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { logger = () => {}, retries = 3, ...chatOptions } = options; + + // Create a unique request ID if not provided + const requestId = options.requestId || Date.now().toString(); + + logger({ + category: "cerebras", + message: "Initiating text streaming", + level: 2, + auxiliary: { + options: { + value: JSON.stringify({ + prompt, + requestId, + }), + type: "object", + }, + modelName: { + value: this.modelName, + type: "string", + }, + }, + }); + + try { + // Create a chat completion with the prompt as a user message + const response = (await this.createChatCompletionStream({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId, + }, + logger, + retries, + })) as StreamingChatResponse; + + // Restructure the response to return a stream of text + const textStream = new ReadableStream({ + async start(controller) { + try { + for await (const chunk of response) { + const content = chunk.choices[0]?.delta?.content; + if (content !== undefined) { + controller.enqueue(content); + } + } + controller.close(); + } catch (error) { + controller.error(error); + } + }, + }); + + // Log successful streaming + logger({ + category: "cerebras", + message: "text streaming response", + level: 2, + auxiliary: { + response: { + value: JSON.stringify(textStream), + type: "object", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return { textStream: textStream } as T; + } catch (error) { + // Log the error with detailed information + logger({ + category: "cerebras", + message: "Text streaming failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } + + async generateText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { + logger = () => {}, + retries = 3, + requestId = Date.now().toString(), + ...chatOptions + } = options; + + // Log the generation attempt + logger({ + category: "cerebras", + message: "Initiating text generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + model: { + value: this.modelName, + type: "string", + }, + options: { + value: JSON.stringify(chatOptions), + type: "object", + }, + }, + }); + + try { + // Create chat completion with the provided prompt + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId, + }, + logger, + retries, + })) as LLMResponse; + + // Validate response structure + if ( + !response.choices || + response.choices.length === 0 || + response.choices[0].message.content === null || + response.choices[0].message.content === undefined + ) { + logger({ + category: "cerebras", + message: "Text generation failed", + level: 0, + auxiliary: { + error: { + value: "API response contains no valid choices", + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + throw new CreateChatCompletionResponseError( + "API response contains no valid choices", + ); + } + + // Extract and validate the generated text + const generatedContent = response.choices[0].message.content; + + // Construct the final response with additional metadata + const textResponse = { + ...response, + text: generatedContent, + modelName: this.modelName.split("cerebras-")[1], // Clean model name + timestamp: Date.now(), + metadata: { + provider: "cerebras", + originalPrompt: prompt, + requestId, + temperature: chatOptions.temperature || 0.7, + }, + } as T; + + // Log successful generation with detailed metrics + logger({ + category: "cerebras", + message: "Text generation successful", + level: 2, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + responseLength: { + value: generatedContent.length.toString(), + type: "string", + }, + usage: { + value: JSON.stringify(response.usage), + type: "object", + }, + finishReason: { + value: response.choices[0].finish_reason || "unknown", + type: "string", + }, + }, + }); + + return textResponse; + } catch (error) { + // Log the error with detailed information + logger({ + category: "cerebras", + message: "Text generation failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + errorType: { + value: error.constructor.name, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + model: { + value: this.modelName, + type: "string", + }, + }, + }); + + // If it's a known error type, throw it directly + if (error instanceof CreateChatCompletionResponseError) { + throw error; + } + + // Otherwise, wrap it in our custom error type with context + throw new CreateChatCompletionResponseError( + `Cerebras text generation failed: ${error.message}`, + ); + } + } + async generateObject({ + prompt, + schema, + options = {}, + }: GenerateObjectOptions): Promise { + // Destructure options with defaults + const { + logger = () => {}, + retries = 3, + requestId = Date.now().toString(), + ...chatOptions + } = options; + + // Log the generation attempt + logger({ + category: "cerebras", + message: "Initiating object generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + try { + // Create chat completion with the provided prompt + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + response_model: { + name: "object", + schema: schema, + }, + ...chatOptions, + requestId, + }, + logger, + retries, + })) as LLMObjectResponse; + + // Validate response structure + if ( + !response.data || + response.data.length === 0 || + response.data === undefined + ) { + logger({ + category: "cerebras", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: "API response contains no valid choices", + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + throw new CreateChatCompletionResponseError( + "API response contains no valid choices", + ); + } + + // Extract and validate the generated text + const generatedObject = response.data; + + // Construct the final response + const objResponse = { + object: generatedObject, + finishReason: response.response.choices[0].finish_reason, + usage: response.response.usage, + ...response, + } as T; + + // Log successful generation + logger({ + category: "cerebras", + message: "Object generation successful", + level: 2, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return objResponse; + } catch (error) { + // Log the error + logger({ + category: "cerebras", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } } diff --git a/lib/llm/GoogleClient.ts b/lib/llm/GoogleClient.ts index 968ce9e22..14155a9fe 100644 --- a/lib/llm/GoogleClient.ts +++ b/lib/llm/GoogleClient.ts @@ -22,6 +22,13 @@ import { LLMClient, LLMResponse, AnnotatedScreenshotText, + TextResponse, + GenerateTextOptions, + LLMObjectResponse, + GenerateObjectOptions, + ObjectResponse, + StreamingChatResponse, + StreamingTextResponse, } from "./LLMClient"; import { CreateChatCompletionResponseError, @@ -476,6 +483,7 @@ export class GoogleClient extends LLMClient { const extractionResult = { data: parsedData, usage: llmResponse.usage, + response: llmResponse, }; if (this.enableCaching) { @@ -530,4 +538,430 @@ export class GoogleClient extends LLMClient { ); } } + + async createChatCompletionStream({ + options, + logger, + retries = 3, + }: CreateChatCompletionOptions): Promise { + const { + image, + requestId, + response_model, + tools, + temperature, + top_p, + maxTokens, + } = options; + + logger({ + category: "google", + message: "creating chat completion stream", + level: 2, + auxiliary: { + options: { + value: JSON.stringify(options), + type: "object", + }, + }, + }); + + const cacheKeyOptions = { + model: this.modelName, + messages: options.messages, + temperature: temperature, + top_p: top_p, + // frequency_penalty and presence_penalty are not directly supported in Gemini API + image: image + ? { description: image.description, bufferLength: image.buffer.length } + : undefined, // Use buffer length for caching key stability + response_model: response_model + ? { + name: response_model.name, + schema: JSON.stringify(zodToJsonSchema(response_model.schema)), + } + : undefined, + tools: tools, + maxTokens: maxTokens, + retries: retries, + }; + + if (this.enableCaching) { + const cachedResponse = await this.cache.get( + cacheKeyOptions, + requestId, + ); + if (cachedResponse) { + logger({ + category: "llm_cache", + message: "LLM cache hit - returning cached response", + level: 1, + auxiliary: { requestId: { value: requestId, type: "string" } }, + }); + return cachedResponse; + } else { + logger({ + category: "llm_cache", + message: "LLM cache miss - proceeding with API call", + level: 1, + auxiliary: { requestId: { value: requestId, type: "string" } }, + }); + } + } + + const formattedMessages = this.formatMessages(options.messages, image); + const formattedTools = this.formatTools(tools); + + const generationConfig = { + maxOutputTokens: maxTokens, + temperature: temperature, + topP: top_p, + responseMimeType: response_model ? "application/json" : undefined, + }; + + // Handle JSON mode instructions + if (response_model) { + // Prepend instructions for JSON output if needed (similar to o1 handling) + const schemaString = JSON.stringify( + zodToJsonSchema(response_model.schema), + ); + formattedMessages.push({ + role: "user", + parts: [ + { + text: `Please respond ONLY with a valid JSON object that strictly adheres to the following JSON schema. Do not include any other text, explanations, or markdown formatting like \`\`\`json ... \`\`\`. Just the JSON object.\n\nSchema:\n${schemaString}`, + }, + ], + }); + formattedMessages.push({ role: "model", parts: [{ text: "{" }] }); // Prime the model + } + + // Construct the full request object + const requestPayload = { + model: this.modelName, + contents: formattedMessages, + config: { + ...generationConfig, + safetySettings: safetySettings, + tools: formattedTools, + }, + }; + + // Log the full payload safely + try { + logger({ + category: "google", + message: "Full request payload", + level: 2, + auxiliary: { + requestId: { value: requestId, type: "string" }, + fullPayload: { + value: JSON.stringify(requestPayload), + type: "object", + }, + }, + }); + } catch (e) { + logger({ + category: "google", + message: "Failed to stringify full request payload for logging", + level: 0, + auxiliary: { + requestId: { value: requestId, type: "string" }, + error: { value: e.message, type: "string" }, + }, + }); + } + + const result = + await this.client.models.generateContentStream(requestPayload); + + // TODO: transform response to required format + // TODO: Validate response model + // TODO: Enable caching + + return new ReadableStream({ + async start(controller) { + try { + for await (const chunk of result) { + controller.enqueue(chunk.candidates[0].content.parts[0].text); + } + controller.close(); + } catch (error) { + controller.error(error); + } + }, + }) as T; + } + + async streamText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { logger = () => {}, retries = 3, ...chatOptions } = options; + + // Create a unique request ID if not provided + const requestId = options.requestId || Date.now().toString(); + + logger({ + category: "google", + message: "Initiating text streaming", + level: 2, + auxiliary: { + options: { + value: JSON.stringify({ + prompt, + requestId, + }), + type: "object", + }, + modelName: { + value: this.modelName, + type: "string", + }, + }, + }); + + try { + // Create a chat completion with the prompt as a user message + const response = (await this.createChatCompletionStream({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId, + }, + logger, + retries, + })) as StreamingChatResponse; + + logger({ + category: "google", + message: "text streaming response", + level: 2, + auxiliary: { + response: { + value: JSON.stringify(response), + type: "object", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return { textStream: response } as T; + } catch (error) { + logger({ + category: "google", + message: "Text streaming failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } + + async generateText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { + logger = () => {}, // Default no-op logger + retries = 3, // Default retry attempts + ...chatOptions // All other chat-specific options + } = options; + + // Create a chat completion with a single user message + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId: options.requestId || Date.now().toString(), // Ensure unique request ID + }, + logger, + retries, + })) as LLMResponse; + + // Validate and extract the generated text from the response + if (response.choices && response.choices.length > 0) { + return { + text: response.choices[0].message.content, + finishReason: response.choices[0].finish_reason, + usage: response.usage, + response: response, + } as T; + } else { + logger({ + category: "google", + message: "text generation failed", + level: 0, + auxiliary: { + error: { + value: "No choices available in the response", + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + throw new CreateChatCompletionResponseError( + "No choices available in the response", + ); + } + } + async generateObject({ + prompt, + schema, + options = {}, + }: GenerateObjectOptions): Promise { + // Destructure options with defaults + const { + logger = () => {}, + retries = 3, + requestId = Date.now().toString(), + ...chatOptions + } = options; + + // Log the generation attempt + logger({ + category: "google", + message: "Initiating object generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + try { + // Create chat completion with the provided prompt + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + response_model: { + name: "object", + schema: schema, + }, + ...chatOptions, + requestId, + }, + logger, + retries, + })) as LLMObjectResponse; + // Validate response structure + if ( + !response.data || + response.data.length === 0 || + response.data === undefined + ) { + logger({ + category: "google", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: "API response contains no valid choices", + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + throw new CreateChatCompletionResponseError( + "API response contains no valid choices", + ); + } + + // Extract and validate the generated text + const generatedObject = response.data; + + // Construct the final response + const objResponse = { + object: generatedObject, + finishReason: response.response.choices[0].finish_reason, + usage: response.usage, + response: response, + } as T; + + // Log successful generation + logger({ + category: "google", + message: "Object generation successful", + level: 2, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return objResponse; + } catch (error) { + // Log the error + logger({ + category: "google", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } } diff --git a/lib/llm/GroqClient.ts b/lib/llm/GroqClient.ts index fe91d06ba..1ee47ad20 100644 --- a/lib/llm/GroqClient.ts +++ b/lib/llm/GroqClient.ts @@ -7,8 +7,15 @@ import { LLMCache } from "../cache/LLMCache"; import { ChatMessage, CreateChatCompletionOptions, + GenerateObjectOptions, + GenerateTextOptions, LLMClient, + LLMObjectResponse, LLMResponse, + ObjectResponse, + StreamingChatResponse, + StreamingTextResponse, + TextResponse, } from "./LLMClient"; import { CreateChatCompletionResponseError } from "@/types/stagehandErrors"; @@ -243,7 +250,7 @@ export class GroqClient extends LLMClient { if (this.enableCaching) { this.cache.set(cacheOptions, result, options.requestId); } - return result as T; + return { data: result, response: response } as T; } catch (e) { // If JSON parse fails, the model might be returning a different format logger({ @@ -271,7 +278,7 @@ export class GroqClient extends LLMClient { if (this.enableCaching) { this.cache.set(cacheOptions, result, options.requestId); } - return result as T; + return { data: result, response: response } as T; } } catch (e) { logger({ @@ -324,4 +331,537 @@ export class GroqClient extends LLMClient { throw error; } } + + async createChatCompletionStream({ + options, + logger, + retries = 3, + }: CreateChatCompletionOptions): Promise { + const optionsWithoutImage = { ...options }; + delete optionsWithoutImage.image; + + logger({ + category: "groq", + message: "creating chat completion stream", + level: 2, + auxiliary: { + options: { + value: JSON.stringify(optionsWithoutImage), + type: "object", + }, + }, + }); + + // Try to get cached response + const cacheOptions = { + model: this.modelName.split("groq-")[1], + messages: options.messages, + temperature: options.temperature, + response_model: options.response_model, + tools: options.tools, + retries: retries, + }; + + if (this.enableCaching) { + const cachedResponse = await this.cache.get( + cacheOptions, + options.requestId, + ); + if (cachedResponse) { + logger({ + category: "llm_cache", + message: "LLM cache hit - returning cached response", + level: 1, + auxiliary: { + cachedResponse: { + value: JSON.stringify(cachedResponse), + type: "object", + }, + requestId: { + value: options.requestId, + type: "string", + }, + cacheOptions: { + value: JSON.stringify(cacheOptions), + type: "object", + }, + }, + }); + return cachedResponse as T; + } + } + + // Format messages for Groq API (using OpenAI format) + const formattedMessages = options.messages.map((msg: ChatMessage) => { + const baseMessage = { + content: + typeof msg.content === "string" + ? msg.content + : Array.isArray(msg.content) && + msg.content.length > 0 && + "text" in msg.content[0] + ? msg.content[0].text + : "", + }; + + // Groq supports system, user, and assistant roles + if (msg.role === "system") { + return { ...baseMessage, role: "system" as const }; + } else if (msg.role === "assistant") { + return { ...baseMessage, role: "assistant" as const }; + } else { + // Default to user for any other role + return { ...baseMessage, role: "user" as const }; + } + }); + + // Format tools if provided + let tools = options.tools?.map((tool) => ({ + type: "function" as const, + function: { + name: tool.name, + description: tool.description, + parameters: { + type: "object", + properties: tool.parameters.properties, + required: tool.parameters.required, + }, + }, + })); + + // Add response model as a tool if provided + if (options.response_model) { + const jsonSchema = zodToJsonSchema(options.response_model.schema) as { + properties?: Record; + required?: string[]; + }; + const schemaProperties = jsonSchema.properties || {}; + const schemaRequired = jsonSchema.required || []; + + const responseTool = { + type: "function" as const, + function: { + name: "print_extracted_data", + description: + "Prints the extracted data based on the provided schema.", + parameters: { + type: "object", + properties: schemaProperties, + required: schemaRequired, + }, + }, + }; + + tools = tools ? [...tools, responseTool] : [responseTool]; + } + + // Use OpenAI client with Groq API + const apiResponse = await this.client.chat.completions.create({ + model: this.modelName.split("groq-")[1], + messages: [ + ...formattedMessages, + // Add explicit instruction to return JSON if we have a response model + ...(options.response_model + ? [ + { + role: "system" as const, + content: `IMPORTANT: Your response must be valid JSON that matches this schema: ${JSON.stringify(options.response_model.schema)}`, + }, + ] + : []), + ], + temperature: options.temperature || 0.7, + max_tokens: options.maxTokens, + tools: tools, + tool_choice: options.tool_choice || "auto", + stream: true, + }); + + // TODO: transform response to required format + // TODO: Validate response model + // TODO: Enable caching + + return apiResponse as T; + } + + async streamText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { logger = () => {}, retries = 3, ...chatOptions } = options; + + // Create a unique request ID if not provided + const requestId = options.requestId || Date.now().toString(); + + logger({ + category: "groq", + message: "Initiating text streaming", + level: 2, + auxiliary: { + options: { + value: JSON.stringify({ + prompt, + requestId, + }), + type: "object", + }, + modelName: { + value: this.modelName, + type: "string", + }, + }, + }); + + try { + // Create a chat completion with the prompt as a user message + const response = (await this.createChatCompletionStream({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId, + }, + logger, + retries, + })) as StreamingChatResponse; + + // Restructure the response to return a stream of text + const textStream = new ReadableStream({ + async start(controller) { + try { + for await (const chunk of response) { + const content = chunk.choices[0]?.delta?.content; + if (content !== undefined) { + controller.enqueue(content); + } + } + controller.close(); + } catch (error) { + controller.error(error); + } + }, + }); + + logger({ + category: "groq", + message: "text streaming response", + level: 2, + auxiliary: { + response: { + value: JSON.stringify(textStream), + type: "object", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return { textStream: textStream } as T; + } catch (error) { + logger({ + category: "groq", + message: "Text streaming failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } + + async generateText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { + logger = () => {}, + retries = 3, + requestId = Date.now().toString(), + ...chatOptions + } = options; + + // Log the generation attempt + logger({ + category: "groq", + message: "Initiating text generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + model: { + value: this.modelName, + type: "string", + }, + }, + }); + + try { + // Create chat completion with the provided prompt + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId, + }, + logger, + retries, + })) as LLMResponse; + + // Validate response structure + if ( + !response.choices || + response.choices.length === 0 || + response.choices[0].message.content == null || + response.choices[0].message.content === undefined + ) { + logger({ + category: "groq", + message: "Text generation failed", + level: 0, + auxiliary: { + error: { + value: "API response contains no valid choices", + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + throw new CreateChatCompletionResponseError( + "API response contains no valid choices", + ); + } + + // Extract and validate the generated text + const generatedContent = response.choices[0].message.content; + + // Construct the final response + const textResponse = { + text: generatedContent, + finishReason: response.choices[0].finish_reason, + usage: response.usage, + response: response, + // reasoning: response.reasoning, + // sources: response.sources + } as T; + + // Log successful generation + logger({ + category: "groq", + message: "Text generation successful", + level: 2, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + responseLength: { + value: generatedContent.length.toString(), + type: "string", + }, + usage: { + value: JSON.stringify(response.usage), + type: "object", + }, + }, + }); + + return textResponse; + } catch (error) { + // Log the error with detailed information + logger({ + category: "groq", + message: "Text generation failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + model: { + value: this.modelName, + type: "string", + }, + }, + }); + + // If it's a known error type, throw it directly + if (error instanceof CreateChatCompletionResponseError) { + throw error; + } + + // Otherwise, wrap it in our custom error type + throw new CreateChatCompletionResponseError( + `Failed to generate text: ${error.message}`, + ); + } + } + async generateObject({ + prompt, + schema, + options = {}, + }: GenerateObjectOptions): Promise { + // Destructure options with defaults + const { + logger = () => {}, + retries = 3, + requestId = Date.now().toString(), + ...chatOptions + } = options; + + // Log the generation attempt + logger({ + category: "groq", + message: "Initiating object generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + try { + // Create chat completion with the provided prompt + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + response_model: { + name: "object", + schema: schema, + }, + ...chatOptions, + requestId, + }, + logger, + retries, + })) as LLMObjectResponse; + // Validate response structure + if ( + !response.data || + response.data.length === 0 || + response.data === undefined + ) { + logger({ + category: "groq", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: "API response contains no valid choices", + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + throw new CreateChatCompletionResponseError( + "API response contains no valid choices", + ); + } + + // Extract and validate the generated text + const generatedObject = response.data; + + // Construct the final response + const objResponse = { + object: generatedObject, + finishReason: response.response.choices[0].finish_reason, + usage: response.response.usage, + ...response, + } as T; + + // Log successful generation + logger({ + category: "groq", + message: "Object generation successful", + level: 2, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return objResponse; + } catch (error) { + // Log the error + logger({ + category: "groq", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } } diff --git a/lib/llm/LLMClient.ts b/lib/llm/LLMClient.ts index 71690c387..31cb83fb1 100644 --- a/lib/llm/LLMClient.ts +++ b/lib/llm/LLMClient.ts @@ -51,33 +51,102 @@ export interface ChatCompletionOptions { requestId: string; } -export type LLMResponse = { +// Base response type for common fields +export interface BaseResponse { + id: string; + object: string; + created: number; + model: string; +} + +// Tool call type +export interface ToolCall { + id: string; + type: string; + function: { + name: string; + arguments: string; + }; +} + +// Message type +export interface LLMMessage { + role: string; + content: string | null; + tool_calls?: ToolCall[]; +} + +// Choice type +export interface LLMChoice { + index: number; + message: LLMMessage; + finish_reason: string; +} + +// Usage metrics +export interface UsageMetrics { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; +} + +// Main LLM Response type +export interface LLMResponse extends BaseResponse { + choices: LLMChoice[]; + usage: UsageMetrics; +} + +// Stream text response type +export interface StreamingTextResponse { + textStream: AsyncIterable; +} + +// Streaming chat chunk response type +export interface StreamingChatResponseChunk { id: string; object: string; created: number; model: string; choices: { index: number; - message: { - role: string; - content: string | null; - tool_calls: { - id: string; - type: string; - function: { - name: string; - arguments: string; - }; - }[]; + delta: { + content?: string; + role?: string; + function_call?: { + name?: string; + arguments?: string; + }; }; - finish_reason: string; + finish_reason: string | null; }[]; - usage: { - prompt_tokens: number; - completion_tokens: number; - total_tokens: number; - }; -}; +} + +// Streaming chat response type +export type StreamingChatResponse = AsyncIterable; + +// Main LLM Response type +export interface LLMObjectResponse extends BaseResponse { + data: Record; + usage: UsageMetrics; + response: LLMResponse; +} + +// Text Response type that can include LLM properties +export interface TextResponse extends BaseResponse { + text: string; + finishReason: string; + usage: UsageMetrics; + response: LLMResponse; + // reasoning: string; + // sources: string[]; +} + +// Object Response type that can include LLM properties +export interface ObjectResponse extends BaseResponse { + object: string; + choices?: LLMChoice[]; + usage?: UsageMetrics; +} export interface CreateChatCompletionOptions { options: ChatCompletionOptions; @@ -85,6 +154,23 @@ export interface CreateChatCompletionOptions { retries?: number; } +export interface GenerateTextOptions { + prompt: string; + options?: Partial> & { + logger?: (message: LogLine) => void; + retries?: number; + }; +} + +export interface GenerateObjectOptions { + prompt: string; + schema: ZodType; + options?: Partial> & { + logger?: (message: LogLine) => void; + retries?: number; + }; +} + export abstract class LLMClient { public type: "openai" | "anthropic" | "cerebras" | "groq" | (string & {}); public modelName: AvailableModel | (string & {}); @@ -102,4 +188,20 @@ export abstract class LLMClient { usage?: LLMResponse["usage"]; }, >(options: CreateChatCompletionOptions): Promise; + + abstract streamText( + input: GenerateTextOptions, + ): Promise; + + abstract generateText< + T = TextResponse & { + usage?: TextResponse["usage"]; + }, + >(input: GenerateTextOptions): Promise; + + abstract generateObject< + T = ObjectResponse & { + usage?: ObjectResponse["usage"]; + }, + >(input: GenerateObjectOptions): Promise; } diff --git a/lib/llm/OpenAIClient.ts b/lib/llm/OpenAIClient.ts index 5086fa41d..00a5d55b8 100644 --- a/lib/llm/OpenAIClient.ts +++ b/lib/llm/OpenAIClient.ts @@ -5,6 +5,7 @@ import { ChatCompletionContentPartImage, ChatCompletionContentPartText, ChatCompletionCreateParamsNonStreaming, + ChatCompletionCreateParamsStreaming, ChatCompletionMessageParam, ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam, @@ -18,8 +19,15 @@ import { ChatCompletionOptions, ChatMessage, CreateChatCompletionOptions, + GenerateObjectOptions, + GenerateTextOptions, LLMClient, + LLMObjectResponse, LLMResponse, + ObjectResponse, + StreamingChatResponse, + StreamingTextResponse, + TextResponse, } from "./LLMClient"; import { CreateChatCompletionResponseError, @@ -437,6 +445,7 @@ export class OpenAIClient extends LLMClient { return { data: parsedData, usage: response.usage, + response: response, } as T; } @@ -467,4 +476,640 @@ export class OpenAIClient extends LLMClient { // so we can safely cast here to T, which defaults to ChatCompletion return response as T; } + + async createChatCompletionStream({ + options: optionsInitial, + logger, + retries = 3, + }: CreateChatCompletionOptions): Promise { + let options: Partial = optionsInitial; + + // O1 models do not support most of the options. So we override them. + // For schema and tools, we add them as user messages. + // let isToolsOverridedForO1 = false; + if (this.modelName.startsWith("o1") || this.modelName.startsWith("o3")) { + /* eslint-disable */ + // Remove unsupported options + let { + tool_choice, + top_p, + frequency_penalty, + presence_penalty, + temperature, + } = options; + ({ + tool_choice, + top_p, + frequency_penalty, + presence_penalty, + temperature, + ...options + } = options); + /* eslint-enable */ + // Remove unsupported options + options.messages = options.messages.map((message) => ({ + ...message, + role: "user", + })); + if (options.tools && options.response_model) { + throw new StagehandError( + "Cannot use both tool and response_model for o1 models", + ); + } + + if (options.tools) { + // Remove unsupported options + let { tools } = options; + ({ tools, ...options } = options); + // isToolsOverridedForO1 = true; + options.messages.push({ + role: "user", + content: `You have the following tools available to you:\n${JSON.stringify( + tools, + )} + + Respond with the following zod schema format to use a method: { + "name": "", + "arguments": + } + + Do not include any other text or formattings like \`\`\` in your response. Just the JSON object.`, + }); + } + } + if ( + options.temperature && + (this.modelName.startsWith("o1") || this.modelName.startsWith("o3")) + ) { + throw new StagehandError("Temperature is not supported for o1 models"); + } + + const { image, requestId, ...optionsWithoutImageAndRequestId } = options; + + logger({ + category: "openai", + message: "creating chat completion stream", + level: 2, + auxiliary: { + options: { + value: JSON.stringify(optionsWithoutImageAndRequestId), + type: "object", + }, + }, + }); + + const cacheOptions = { + model: this.modelName, + messages: options.messages, + temperature: options.temperature, + top_p: options.top_p, + frequency_penalty: options.frequency_penalty, + presence_penalty: options.presence_penalty, + image: image, + response_model: options.response_model, + }; + + if (this.enableCaching) { + const cachedResponse = await this.cache.get(cacheOptions, requestId); + if (cachedResponse) { + logger({ + category: "llm_cache", + message: "LLM cache hit - returning cached response", + level: 1, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + cachedResponse: { + value: JSON.stringify(cachedResponse), + type: "object", + }, + }, + }); + return cachedResponse; + } else { + logger({ + category: "llm_cache", + message: "LLM cache miss - no cached response found", + level: 1, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + }, + }); + } + } + + if (options.image) { + const screenshotMessage: ChatMessage = { + role: "user", + content: [ + { + type: "image_url", + image_url: { + url: `data:image/jpeg;base64,${options.image.buffer.toString("base64")}`, + }, + }, + ...(options.image.description + ? [{ type: "text", text: options.image.description }] + : []), + ], + }; + + options.messages.push(screenshotMessage); + } + + let responseFormat = undefined; + if (options.response_model) { + // For O1 models, we need to add the schema as a user message. + if (this.modelName.startsWith("o1") || this.modelName.startsWith("o3")) { + try { + const parsedSchema = JSON.stringify( + zodToJsonSchema(options.response_model.schema), + ); + options.messages.push({ + role: "user", + content: `Respond in this zod schema format:\n${parsedSchema}\n + + Do not include any other text, formatting or markdown in your output. Do not include \`\`\` or \`\`\`json in your response. Only the JSON object itself.`, + }); + } catch (error) { + logger({ + category: "openai", + message: "Failed to parse response model schema", + level: 0, + }); + + if (retries > 0) { + // as-casting to account for o1 models not supporting all options + return this.createChatCompletion({ + options: options as ChatCompletionOptions, + logger, + retries: retries - 1, + }); + } + + throw error; + } + } else { + responseFormat = zodResponseFormat( + options.response_model.schema, + options.response_model.name, + ); + } + } + + /* eslint-disable */ + // Remove unsupported options + const { response_model, ...openAiOptions } = { + ...optionsWithoutImageAndRequestId, + model: this.modelName, + }; + /* eslint-enable */ + + const formattedMessages: ChatCompletionMessageParam[] = + options.messages.map((message) => { + if (Array.isArray(message.content)) { + const contentParts = message.content.map((content) => { + if ("image_url" in content) { + const imageContent: ChatCompletionContentPartImage = { + image_url: { + url: content.image_url.url, + }, + type: "image_url", + }; + return imageContent; + } else { + const textContent: ChatCompletionContentPartText = { + text: content.text, + type: "text", + }; + return textContent; + } + }); + + if (message.role === "system") { + const formattedMessage: ChatCompletionSystemMessageParam = { + ...message, + role: "system", + content: contentParts.filter( + (content): content is ChatCompletionContentPartText => + content.type === "text", + ), + }; + return formattedMessage; + } else if (message.role === "user") { + const formattedMessage: ChatCompletionUserMessageParam = { + ...message, + role: "user", + content: contentParts, + }; + return formattedMessage; + } else { + const formattedMessage: ChatCompletionAssistantMessageParam = { + ...message, + role: "assistant", + content: contentParts.filter( + (content): content is ChatCompletionContentPartText => + content.type === "text", + ), + }; + return formattedMessage; + } + } + + const formattedMessage: ChatCompletionUserMessageParam = { + role: "user", + content: message.content, + }; + + return formattedMessage; + }); + + const body: ChatCompletionCreateParamsStreaming = { + ...openAiOptions, + model: this.modelName, + messages: formattedMessages, + response_format: responseFormat, + stream: true, + tools: options.tools?.map((tool) => ({ + function: { + name: tool.name, + description: tool.description, + parameters: tool.parameters, + }, + type: "function", + })), + }; + const response = await this.client.chat.completions.create(body); + + // TODO: O1 models - parse the tool call response manually and add it to the response + // TODO: Response model validation + // TODO: Caching + + return response as T; + } + + async streamText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { logger = () => {}, retries = 3, ...chatOptions } = options; + + // Create a unique request ID if not provided + const requestId = options.requestId || Date.now().toString(); + + logger({ + category: "openai", + message: "Initiating text streaming", + level: 2, + auxiliary: { + options: { + value: JSON.stringify({ + prompt, + requestId, + }), + type: "object", + }, + modelName: { + value: this.modelName, + type: "string", + }, + }, + }); + + try { + // Create a chat completion with the prompt as a user message + const response = (await this.createChatCompletionStream({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId, + }, + logger, + retries, + })) as StreamingChatResponse; + + // Restructure the response to return a stream of text + const textStream = new ReadableStream({ + async start(controller) { + try { + for await (const chunk of response) { + const content = chunk.choices[0]?.delta?.content; + if (content !== undefined) { + controller.enqueue(content); + } + } + controller.close(); + } catch (error) { + controller.error(error); + } + }, + }); + + logger({ + category: "openai", + message: "text streaming response", + level: 2, + auxiliary: { + response: { + value: JSON.stringify(textStream), + type: "object", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return { + textStream: textStream, + } as T; + } catch (error) { + logger({ + category: "openai", + message: "Text streaming failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } + + async generateText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { logger = () => {}, retries = 3, ...chatOptions } = options; + + // Generate a unique request ID if not provided + const requestId = options.requestId || Date.now().toString(); + + // Log the generation attempt + logger({ + category: "openai", + message: "Initiating text generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + try { + // Create a chat completion with the prompt as a user message + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId, + }, + logger, + retries, + })) as LLMResponse; + + // Validate and extract the generated text from the response + if (response.choices && response.choices.length > 0) { + // Log successful generation + logger({ + category: "openai", + message: "Text generation successful", + level: 2, + auxiliary: { + response: { + value: JSON.stringify(response.choices[0].message.content), + type: "object", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return { + text: response.choices[0].message.content, + finishReason: response.choices[0].finish_reason, + usage: response.usage, + response: response, + // reasoning: response.reasoning, + // sources: response.sources + } as T; + } + + // Log the error if a logger is provided + logger({ + category: "openai", + message: "Text generation failed", + level: 0, + auxiliary: { + error: { + value: "No valid choices found in API response", + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + // Throw error if no valid response was generated + throw new CreateChatCompletionResponseError( + "No valid choices found in API response", + ); + } catch (error) { + // Log the error if a logger is provided + logger({ + category: "openai", + message: "Text generation failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } + + async generateObject({ + prompt, + schema, + options = {}, + }: GenerateObjectOptions): Promise { + // Destructure options with defaults + const { + logger = () => {}, + retries = 3, + requestId = Date.now().toString(), + ...chatOptions + } = options; + + // Log the generation attempt + logger({ + category: "openai", + message: "Initiating object generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + try { + // Create chat completion with the provided prompt + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + response_model: { + name: "object", + schema: schema, + }, + ...chatOptions, + requestId, + }, + logger, + retries, + })) as LLMObjectResponse; + // Validate response structure + if ( + !response.data || + response.data.length === 0 || + response.data == undefined + ) { + logger({ + category: "openai", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: "API response contains no valid choices", + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + throw new CreateChatCompletionResponseError( + "API response contains no valid choices", + ); + } + + // Extract and validate the generated text + const generatedObject = response.data; + + // Construct the final response + const objResponse = { + object: generatedObject, + finishReason: response.response.choices[0].finish_reason, + usage: response.usage, + response: response, + // reasoning: response.reasoning, + // sources: response.sources + } as T; + + // Log successful generation + logger({ + category: "openai", + message: "Object generation successful", + level: 2, + auxiliary: { + response: { + value: JSON.stringify(objResponse), + type: "object", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return objResponse; + } catch (error) { + // Log the error + logger({ + category: "openai", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } } diff --git a/package-lock.json b/package-lock.json index 53cd79748..7f95e2d76 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@browserbasehq/stagehand", - "version": "2.0.0", + "version": "2.1.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@browserbasehq/stagehand", - "version": "2.0.0", + "version": "2.1.0", "license": "MIT", "dependencies": { "@anthropic-ai/sdk": "0.39.0", @@ -16,7 +16,7 @@ "pino": "^9.6.0", "pino-pretty": "^13.0.0", "ws": "^8.18.0", - "zod-to-json-schema": "^3.23.5" + "zod-to-json-schema": "^3.24.5" }, "devDependencies": { "@ai-sdk/anthropic": "^1.2.6", @@ -10441,9 +10441,9 @@ } }, "node_modules/zod-to-json-schema": { - "version": "3.24.3", - "resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.24.3.tgz", - "integrity": "sha512-HIAfWdYIt1sssHfYZFCXp4rU1w2r8hVVXYIlmoa0r0gABLs5di3RCqPU5DDROogVz1pAdYBaz7HK5n9pSUNs3A==", + "version": "3.24.5", + "resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.24.5.tgz", + "integrity": "sha512-/AuWwMP+YqiPbsJx5D6TfgRTc4kTLjsh5SOcd4bLsfUg2RcEXrFMJl1DGgdHy2aCfsIA/cr/1JM0xcB2GZji8g==", "license": "ISC", "peerDependencies": { "zod": "^3.24.1" diff --git a/package.json b/package.json index fba455fe2..007920e4e 100644 --- a/package.json +++ b/package.json @@ -21,6 +21,7 @@ "form-filling-sensible-openai": "npm run build && tsx examples/form_filling_sensible_openai.ts", "google-enter": "npm run build && tsx examples/google_enter.ts", "try-wordle": "npm run build && tsx examples/try_wordle.ts", + "llm-usage-wordle": "npm run build && tsx examples/llm_usage_wordle.ts", "format": "prettier --write .", "prettier": "prettier --check .", "prettier:fix": "prettier --write .", @@ -95,7 +96,7 @@ "pino": "^9.6.0", "pino-pretty": "^13.0.0", "ws": "^8.18.0", - "zod-to-json-schema": "^3.23.5" + "zod-to-json-schema": "^3.24.5" }, "directories": { "doc": "docs",