From 14440b0787761fa6695833035b293d49e3cf1b40 Mon Sep 17 00:00:00 2001 From: Hitesh Agarwal Date: Mon, 14 Apr 2025 22:32:07 +0800 Subject: [PATCH 1/5] feat: added generateText for llmClient --- examples/external_clients/aisdk.ts | 29 ++++ examples/external_clients/customOpenAI.ts | 37 +++++- examples/external_clients/langchain.ts | 41 ++++++ examples/llm_usage_wordle.ts | 21 +++ lib/llm/AnthropicClient.ts | 114 ++++++++++++++++ lib/llm/CerebrasClient.ts | 153 ++++++++++++++++++++++ lib/llm/GoogleClient.ts | 42 ++++++ lib/llm/GroqClient.ts | 135 +++++++++++++++++++ lib/llm/LLMClient.ts | 82 ++++++++---- lib/llm/OpenAIClient.ts | 81 ++++++++++++ package.json | 1 + 11 files changed, 713 insertions(+), 23 deletions(-) create mode 100644 examples/llm_usage_wordle.ts diff --git a/examples/external_clients/aisdk.ts b/examples/external_clients/aisdk.ts index 1d72d984f..7e2b99966 100644 --- a/examples/external_clients/aisdk.ts +++ b/examples/external_clients/aisdk.ts @@ -12,6 +12,7 @@ import { } from "ai"; import { CreateChatCompletionOptions, LLMClient, AvailableModel } from "@/dist"; import { ChatCompletion } from "openai/resources"; +import { GenerateTextOptions, TextResponse } from "@/lib"; export class AISdkClient extends LLMClient { public type = "aisdk" as const; @@ -119,4 +120,32 @@ export class AISdkClient extends LLMClient { }, } as T; } + async generateText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + const tools: Record = {}; + if (options.tools) { + for (const rawTool of options.tools) { + tools[rawTool.name] = { + description: rawTool.description, + parameters: rawTool.parameters, + }; + } + } + + const response = await generateText({ + model: this.model, + prompt: prompt, + tools, + }); + return { + text: response.text, + usage: { + prompt_tokens: response.usage.promptTokens ?? 0, + completion_tokens: response.usage.completionTokens ?? 0, + total_tokens: response.usage.totalTokens ?? 0, + }, + } as T; + } } diff --git a/examples/external_clients/customOpenAI.ts b/examples/external_clients/customOpenAI.ts index 6a6d70b3f..d6e4b0954 100644 --- a/examples/external_clients/customOpenAI.ts +++ b/examples/external_clients/customOpenAI.ts @@ -20,6 +20,7 @@ import type { } from "openai/resources/chat/completions"; import { z } from "zod"; import { CreateChatCompletionResponseError } from "@/types/stagehandErrors"; +import { GenerateTextOptions, LLMResponse, TextResponse } from "@/lib"; function validateZodSchema(schema: z.ZodTypeAny, data: unknown) { try { @@ -229,7 +230,7 @@ export class CustomOpenAIClient extends LLMClient { } return { - data: response.choices[0].message.content, + choices: response.choices, usage: { prompt_tokens: response.usage?.prompt_tokens ?? 0, completion_tokens: response.usage?.completion_tokens ?? 0, @@ -237,4 +238,38 @@ export class CustomOpenAIClient extends LLMClient { }, } as T; } + + async generateText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { logger = () => {}, retries = 3, ...chatOptions } = options; + + // Create chat completion with single user message + const res = await (this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + // Generate unique request ID if not provided + requestId: options.requestId || Date.now().toString(), + }, + logger, + retries, + }) as Promise); + // Validate response and extract generated text + if (res.choices && res.choices.length > 0) { + return { + ...res, + text: res.choices[0].message.content, + } as T; + } else { + throw new CreateChatCompletionResponseError("No choices in response"); + } + } } diff --git a/examples/external_clients/langchain.ts b/examples/external_clients/langchain.ts index 1d071a63b..63525699e 100644 --- a/examples/external_clients/langchain.ts +++ b/examples/external_clients/langchain.ts @@ -8,6 +8,12 @@ import { SystemMessage, } from "@langchain/core/messages"; import { ChatCompletion } from "openai/resources"; +import { + CreateChatCompletionResponseError, + GenerateTextOptions, + LLMResponse, + TextResponse, +} from "@/lib"; export class LangchainClient extends LLMClient { public type = "langchainClient" as const; @@ -84,4 +90,39 @@ export class LangchainClient extends LLMClient { }, } as T; } + + async generateText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { logger = () => {}, retries = 3, ...chatOptions } = options; + + // Create chat completion with single user message + const res = await (this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + // Generate unique request ID if not provided + requestId: options.requestId || Date.now().toString(), + }, + logger, + retries, + }) as Promise); + + // Validate and extract response + if (res.choices && res.choices.length > 0) { + return { + ...res, + text: res.choices[0].message.content, + } as T; + } else { + throw new CreateChatCompletionResponseError("No choices in response"); + } + } } diff --git a/examples/llm_usage_wordle.ts b/examples/llm_usage_wordle.ts new file mode 100644 index 000000000..d686edca0 --- /dev/null +++ b/examples/llm_usage_wordle.ts @@ -0,0 +1,21 @@ +import { Stagehand } from "@/dist"; +import StagehandConfig from "@/stagehand.config"; + +async function example() { + const stagehand = new Stagehand({ + ...StagehandConfig, + }); + + await stagehand.init(); + + const { text } = await stagehand.llmClient.generateText({ + prompt: + "you are playing wordle. Return the 5-letter word that would be the best guess", + }); + console.log(text); + await stagehand.close(); +} + +(async () => { + await example(); +})(); diff --git a/lib/llm/AnthropicClient.ts b/lib/llm/AnthropicClient.ts index 27d8f4c7c..7db3a3b12 100644 --- a/lib/llm/AnthropicClient.ts +++ b/lib/llm/AnthropicClient.ts @@ -11,8 +11,10 @@ import { AnthropicJsonSchemaObject, AvailableModel } from "../../types/model"; import { LLMCache } from "../cache/LLMCache"; import { CreateChatCompletionOptions, + GenerateTextOptions, LLMClient, LLMResponse, + TextResponse, } from "./LLMClient"; import { CreateChatCompletionResponseError } from "@/types/stagehandErrors"; @@ -373,6 +375,118 @@ export class AnthropicClient extends LLMClient { // so we can safely cast here to T, which defaults to AnthropicTransformedResponse return transformedResponse as T; } + + async generateText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { + logger = () => {}, + retries = 3, + requestId = Date.now().toString(), + ...chatOptions + } = options; + + try { + // Log the generation attempt + logger({ + category: "anthropic", + message: "Initiating text generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Create chat completion with the provided prompt + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId, + }, + logger, + retries, + })) as LLMResponse; + + // Validate response structure + if (!response.choices || response.choices.length === 0) { + throw new CreateChatCompletionResponseError( + "API response contains no valid choices", + ); + } + + // Extract and validate the generated text + const generatedText = response.choices[0].message.content; + if (generatedText === null || generatedText === undefined) { + throw new CreateChatCompletionResponseError( + "Generated text content is empty", + ); + } + + // Construct the final response + const textResponse = { + ...response, + text: generatedText, + } as T; + + // Log successful generation + logger({ + category: "anthropic", + message: "Text generation successful", + level: 2, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + responseLength: { + value: generatedText.length.toString(), + type: "string", + }, + }, + }); + + return textResponse; + } catch (error) { + // Log the error + logger({ + category: "anthropic", + message: "Text generation failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } } const extractSchemaProperties = (jsonSchema: AnthropicJsonSchemaObject) => { diff --git a/lib/llm/CerebrasClient.ts b/lib/llm/CerebrasClient.ts index 4d5a0daca..0c336bd87 100644 --- a/lib/llm/CerebrasClient.ts +++ b/lib/llm/CerebrasClient.ts @@ -7,8 +7,10 @@ import { LLMCache } from "../cache/LLMCache"; import { ChatMessage, CreateChatCompletionOptions, + GenerateTextOptions, LLMClient, LLMResponse, + TextResponse, } from "./LLMClient"; import { CreateChatCompletionResponseError } from "@/types/stagehandErrors"; @@ -324,4 +326,155 @@ export class CerebrasClient extends LLMClient { throw error; } } + + async generateText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { + logger = () => {}, + retries = 3, + requestId = Date.now().toString(), + ...chatOptions + } = options; + + try { + // Log the generation attempt + logger({ + category: "cerebras", + message: "Initiating text generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + model: { + value: this.modelName, + type: "string", + }, + options: { + value: JSON.stringify(chatOptions), + type: "object", + }, + }, + }); + + // Create chat completion with the provided prompt + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId, + }, + logger, + retries, + })) as LLMResponse; + + // Validate response structure + if (!response.choices || response.choices.length === 0) { + throw new CreateChatCompletionResponseError( + "API response contains no valid choices", + ); + } + + // Extract and validate the generated text + const generatedContent = response.choices[0].message.content; + if (generatedContent === null || generatedContent === undefined) { + throw new CreateChatCompletionResponseError( + "Generated text content is empty", + ); + } + + // Construct the final response with additional metadata + const textResponse = { + ...response, + text: generatedContent, + modelName: this.modelName.split("cerebras-")[1], // Clean model name + timestamp: Date.now(), + metadata: { + provider: "cerebras", + originalPrompt: prompt, + requestId, + temperature: chatOptions.temperature || 0.7, + }, + } as T; + + // Log successful generation with detailed metrics + logger({ + category: "cerebras", + message: "Text generation successful", + level: 2, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + responseLength: { + value: generatedContent.length.toString(), + type: "string", + }, + usage: { + value: JSON.stringify(response.usage), + type: "object", + }, + finishReason: { + value: response.choices[0].finish_reason || "unknown", + type: "string", + }, + }, + }); + + return textResponse; + } catch (error) { + // Log the error with detailed information + logger({ + category: "cerebras", + message: "Text generation failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + errorType: { + value: error.constructor.name, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + model: { + value: this.modelName, + type: "string", + }, + }, + }); + + // If it's a known error type, throw it directly + if (error instanceof CreateChatCompletionResponseError) { + throw error; + } + + // Otherwise, wrap it in our custom error type with context + throw new CreateChatCompletionResponseError( + `Cerebras text generation failed: ${error.message}`, + ); + } + } } diff --git a/lib/llm/GoogleClient.ts b/lib/llm/GoogleClient.ts index 968ce9e22..aa8e38675 100644 --- a/lib/llm/GoogleClient.ts +++ b/lib/llm/GoogleClient.ts @@ -22,6 +22,8 @@ import { LLMClient, LLMResponse, AnnotatedScreenshotText, + TextResponse, + GenerateTextOptions, } from "./LLMClient"; import { CreateChatCompletionResponseError, @@ -530,4 +532,44 @@ export class GoogleClient extends LLMClient { ); } } + + async generateText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { + logger = () => {}, // Default no-op logger + retries = 3, // Default retry attempts + ...chatOptions // All other chat-specific options + } = options; + + // Create a chat completion with a single user message + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId: options.requestId || Date.now().toString(), // Ensure unique request ID + }, + logger, + retries, + })) as LLMResponse; + + // Validate and extract the generated text from the response + if (response.choices && response.choices.length > 0) { + return { + ...response, + text: response.choices[0].message.content, + } as T; + } else { + throw new CreateChatCompletionResponseError( + "No choices available in the response", + ); + } + } } diff --git a/lib/llm/GroqClient.ts b/lib/llm/GroqClient.ts index fe91d06ba..deb9773b6 100644 --- a/lib/llm/GroqClient.ts +++ b/lib/llm/GroqClient.ts @@ -7,8 +7,10 @@ import { LLMCache } from "../cache/LLMCache"; import { ChatMessage, CreateChatCompletionOptions, + GenerateTextOptions, LLMClient, LLMResponse, + TextResponse, } from "./LLMClient"; import { CreateChatCompletionResponseError } from "@/types/stagehandErrors"; @@ -324,4 +326,137 @@ export class GroqClient extends LLMClient { throw error; } } + + async generateText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { + logger = () => {}, + retries = 3, + requestId = Date.now().toString(), + ...chatOptions + } = options; + + try { + // Log the generation attempt + logger({ + category: "groq", + message: "Initiating text generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + model: { + value: this.modelName, + type: "string", + }, + }, + }); + + // Create chat completion with the provided prompt + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId, + }, + logger, + retries, + })) as LLMResponse; + + // Validate response structure + if (!response.choices || response.choices.length === 0) { + throw new CreateChatCompletionResponseError( + "API response contains no valid choices", + ); + } + + // Extract and validate the generated text + const generatedContent = response.choices[0].message.content; + if (generatedContent === null || generatedContent === undefined) { + throw new CreateChatCompletionResponseError( + "Generated text content is empty", + ); + } + + // Construct the final response + const textResponse = { + ...response, + text: generatedContent, + modelName: this.modelName, + timestamp: Date.now(), + } as T; + + // Log successful generation + logger({ + category: "groq", + message: "Text generation successful", + level: 2, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + responseLength: { + value: generatedContent.length.toString(), + type: "string", + }, + usage: { + value: JSON.stringify(response.usage), + type: "object", + }, + }, + }); + + return textResponse; + } catch (error) { + // Log the error with detailed information + logger({ + category: "groq", + message: "Text generation failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + model: { + value: this.modelName, + type: "string", + }, + }, + }); + + // If it's a known error type, throw it directly + if (error instanceof CreateChatCompletionResponseError) { + throw error; + } + + // Otherwise, wrap it in our custom error type + throw new CreateChatCompletionResponseError( + `Failed to generate text: ${error.message}`, + ); + } + } } diff --git a/lib/llm/LLMClient.ts b/lib/llm/LLMClient.ts index 71690c387..3256ac8d8 100644 --- a/lib/llm/LLMClient.ts +++ b/lib/llm/LLMClient.ts @@ -51,33 +51,57 @@ export interface ChatCompletionOptions { requestId: string; } -export type LLMResponse = { +// Base response type for common fields +export interface BaseResponse { id: string; object: string; created: number; model: string; - choices: { - index: number; - message: { - role: string; - content: string | null; - tool_calls: { - id: string; - type: string; - function: { - name: string; - arguments: string; - }; - }[]; - }; - finish_reason: string; - }[]; - usage: { - prompt_tokens: number; - completion_tokens: number; - total_tokens: number; +} + +// Tool call type +export interface ToolCall { + id: string; + type: string; + function: { + name: string; + arguments: string; }; -}; +} + +// Message type +export interface LLMMessage { + role: string; + content: string | null; + tool_calls?: ToolCall[]; +} + +// Choice type +export interface LLMChoice { + index: number; + message: LLMMessage; + finish_reason: string; +} + +// Usage metrics +export interface UsageMetrics { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; +} + +// Main LLM Response type +export interface LLMResponse extends BaseResponse { + choices: LLMChoice[]; + usage: UsageMetrics; +} + +// Text Response type that can include LLM properties +export interface TextResponse extends BaseResponse { + text: string; + choices?: LLMChoice[]; + usage?: UsageMetrics; +} export interface CreateChatCompletionOptions { options: ChatCompletionOptions; @@ -85,6 +109,14 @@ export interface CreateChatCompletionOptions { retries?: number; } +export interface GenerateTextOptions { + prompt: string; + options?: Partial> & { + logger?: (message: LogLine) => void; + retries?: number; + }; +} + export abstract class LLMClient { public type: "openai" | "anthropic" | "cerebras" | "groq" | (string & {}); public modelName: AvailableModel | (string & {}); @@ -102,4 +134,10 @@ export abstract class LLMClient { usage?: LLMResponse["usage"]; }, >(options: CreateChatCompletionOptions): Promise; + + abstract generateText< + T = TextResponse & { + usage?: TextResponse["usage"]; + }, + >(input: GenerateTextOptions): Promise; } diff --git a/lib/llm/OpenAIClient.ts b/lib/llm/OpenAIClient.ts index 5086fa41d..492a20d0e 100644 --- a/lib/llm/OpenAIClient.ts +++ b/lib/llm/OpenAIClient.ts @@ -18,8 +18,10 @@ import { ChatCompletionOptions, ChatMessage, CreateChatCompletionOptions, + GenerateTextOptions, LLMClient, LLMResponse, + TextResponse, } from "./LLMClient"; import { CreateChatCompletionResponseError, @@ -467,4 +469,83 @@ export class OpenAIClient extends LLMClient { // so we can safely cast here to T, which defaults to ChatCompletion return response as T; } + + async generateText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { logger = () => {}, retries = 3, ...chatOptions } = options; + + // Generate a unique request ID if not provided + const requestId = options.requestId || Date.now().toString(); + + try { + // Log the generation attempt + logger({ + category: "openai", + message: "Initiating text generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Create a chat completion with the prompt as a user message + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId, + }, + logger, + retries, + })) as LLMResponse; + + // Validate and extract the generated text from the response + if (response.choices && response.choices.length > 0) { + return { + ...response, + text: response.choices[0].message.content, + } as T; + } + + // Throw error if no valid response was generated + throw new CreateChatCompletionResponseError( + "No valid choices found in API response", + ); + } catch (error) { + // Log the error if a logger is provided + logger({ + category: "openai", + message: "Text generation failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } } diff --git a/package.json b/package.json index fba455fe2..af2c20c94 100644 --- a/package.json +++ b/package.json @@ -21,6 +21,7 @@ "form-filling-sensible-openai": "npm run build && tsx examples/form_filling_sensible_openai.ts", "google-enter": "npm run build && tsx examples/google_enter.ts", "try-wordle": "npm run build && tsx examples/try_wordle.ts", + "llm-usage-wordle": "npm run build && tsx examples/llm_usage_wordle.ts", "format": "prettier --write .", "prettier": "prettier --check .", "prettier:fix": "prettier --write .", From 32006720c2efaebe37e21f406a53f445fe8b2218 Mon Sep 17 00:00:00 2001 From: Hitesh Agarwal Date: Wed, 16 Apr 2025 04:06:33 +0800 Subject: [PATCH 2/5] feat: add generateObject --- examples/external_clients/aisdk.ts | 28 ++++- examples/external_clients/customOpenAI.ts | 121 +++++++++++++++++++++- examples/external_clients/langchain.ts | 115 ++++++++++++++++++++ examples/llm_usage_wordle.ts | 12 +++ lib/llm/AnthropicClient.ts | 116 +++++++++++++++++++++ lib/llm/CerebrasClient.ts | 114 ++++++++++++++++++++ lib/llm/GoogleClient.ts | 114 ++++++++++++++++++++ lib/llm/GroqClient.ts | 118 ++++++++++++++++++++- lib/llm/LLMClient.ts | 28 +++++ lib/llm/OpenAIClient.ts | 115 ++++++++++++++++++++ package-lock.json | 12 +-- package.json | 2 +- 12 files changed, 884 insertions(+), 11 deletions(-) diff --git a/examples/external_clients/aisdk.ts b/examples/external_clients/aisdk.ts index 7e2b99966..9a73e63c3 100644 --- a/examples/external_clients/aisdk.ts +++ b/examples/external_clients/aisdk.ts @@ -12,7 +12,12 @@ import { } from "ai"; import { CreateChatCompletionOptions, LLMClient, AvailableModel } from "@/dist"; import { ChatCompletion } from "openai/resources"; -import { GenerateTextOptions, TextResponse } from "@/lib"; +import { + GenerateObjectOptions, + GenerateTextOptions, + ObjectResponse, + TextResponse, +} from "@/lib"; export class AISdkClient extends LLMClient { public type = "aisdk" as const; @@ -148,4 +153,25 @@ export class AISdkClient extends LLMClient { }, } as T; } + async generateObject({ + prompt, + schema, + options = {}, + }: GenerateObjectOptions): Promise { + console.log(options); + const response = await generateObject({ + model: this.model, + prompt: prompt, + schema: schema, + ...options, + }); + return { + object: response.object, + usage: { + prompt_tokens: response.usage.promptTokens ?? 0, + completion_tokens: response.usage.completionTokens ?? 0, + total_tokens: response.usage.totalTokens ?? 0, + }, + } as T; + } } diff --git a/examples/external_clients/customOpenAI.ts b/examples/external_clients/customOpenAI.ts index d6e4b0954..ee017ba46 100644 --- a/examples/external_clients/customOpenAI.ts +++ b/examples/external_clients/customOpenAI.ts @@ -20,7 +20,14 @@ import type { } from "openai/resources/chat/completions"; import { z } from "zod"; import { CreateChatCompletionResponseError } from "@/types/stagehandErrors"; -import { GenerateTextOptions, LLMResponse, TextResponse } from "@/lib"; +import { + GenerateObjectOptions, + GenerateTextOptions, + LLMObjectResponse, + LLMResponse, + ObjectResponse, + TextResponse, +} from "@/lib"; function validateZodSchema(schema: z.ZodTypeAny, data: unknown) { try { @@ -272,4 +279,116 @@ export class CustomOpenAIClient extends LLMClient { throw new CreateChatCompletionResponseError("No choices in response"); } } + async generateObject({ + prompt, + schema, + options = {}, + }: GenerateObjectOptions): Promise { + // Destructure options with defaults + const { + logger = () => {}, + retries = 3, + requestId = Date.now().toString(), + ...chatOptions + } = options; + + try { + // Log the generation attempt + logger({ + category: "anthropic", + message: "Initiating object generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Create chat completion with the provided prompt + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + response_model: { + name: "object", + schema: schema, + }, + ...chatOptions, + requestId, + }, + logger, + retries, + })) as LLMObjectResponse; + + // Validate response structure + if (!response.data || response.data.length === 0) { + throw new CreateChatCompletionResponseError( + "API response contains no valid choices", + ); + } + + // Extract and validate the generated text + const generatedObject = response.data; + if (generatedObject === null || generatedObject === undefined) { + throw new CreateChatCompletionResponseError( + "Generated text content is empty", + ); + } + + // Construct the final response + const objResponse = { + ...response, + object: generatedObject, + } as T; + + // Log successful generation + logger({ + category: "anthropic", + message: "Text generation successful", + level: 2, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return objResponse; + } catch (error) { + // Log the error + logger({ + category: "anthropic", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } } diff --git a/examples/external_clients/langchain.ts b/examples/external_clients/langchain.ts index 63525699e..227d99fae 100644 --- a/examples/external_clients/langchain.ts +++ b/examples/external_clients/langchain.ts @@ -10,8 +10,11 @@ import { import { ChatCompletion } from "openai/resources"; import { CreateChatCompletionResponseError, + GenerateObjectOptions, GenerateTextOptions, + LLMObjectResponse, LLMResponse, + ObjectResponse, TextResponse, } from "@/lib"; @@ -125,4 +128,116 @@ export class LangchainClient extends LLMClient { throw new CreateChatCompletionResponseError("No choices in response"); } } + async generateObject({ + prompt, + schema, + options = {}, + }: GenerateObjectOptions): Promise { + // Destructure options with defaults + const { + logger = () => {}, + retries = 3, + requestId = Date.now().toString(), + ...chatOptions + } = options; + + try { + // Log the generation attempt + logger({ + category: "anthropic", + message: "Initiating object generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Create chat completion with the provided prompt + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + response_model: { + name: "object", + schema: schema, + }, + ...chatOptions, + requestId, + }, + logger, + retries, + })) as LLMObjectResponse; + + // Validate response structure + if (!response.data || response.data.length === 0) { + throw new CreateChatCompletionResponseError( + "API response contains no valid choices", + ); + } + + // Extract and validate the generated text + const generatedObject = response.data; + if (generatedObject === null || generatedObject === undefined) { + throw new CreateChatCompletionResponseError( + "Generated text content is empty", + ); + } + + // Construct the final response + const objResponse = { + ...response, + object: generatedObject, + } as T; + + // Log successful generation + logger({ + category: "anthropic", + message: "Text generation successful", + level: 2, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return objResponse; + } catch (error) { + // Log the error + logger({ + category: "anthropic", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } } diff --git a/examples/llm_usage_wordle.ts b/examples/llm_usage_wordle.ts index d686edca0..caee50544 100644 --- a/examples/llm_usage_wordle.ts +++ b/examples/llm_usage_wordle.ts @@ -1,5 +1,6 @@ import { Stagehand } from "@/dist"; import StagehandConfig from "@/stagehand.config"; +import { z } from "zod"; async function example() { const stagehand = new Stagehand({ @@ -13,6 +14,17 @@ async function example() { "you are playing wordle. Return the 5-letter word that would be the best guess", }); console.log(text); + const { object } = await stagehand.llmClient.generateObject({ + prompt: + "you are playing wordle. Return the 5-letter word that would be the best guess", + schema: z.object({ + guess: z + .string() + .length(5) + .describe("The 5-letter word that would be the best guess"), + }), + }); + console.log(object); await stagehand.close(); } diff --git a/lib/llm/AnthropicClient.ts b/lib/llm/AnthropicClient.ts index 7db3a3b12..9fa034089 100644 --- a/lib/llm/AnthropicClient.ts +++ b/lib/llm/AnthropicClient.ts @@ -11,9 +11,12 @@ import { AnthropicJsonSchemaObject, AvailableModel } from "../../types/model"; import { LLMCache } from "../cache/LLMCache"; import { CreateChatCompletionOptions, + GenerateObjectOptions, GenerateTextOptions, LLMClient, + LLMObjectResponse, LLMResponse, + ObjectResponse, TextResponse, } from "./LLMClient"; import { CreateChatCompletionResponseError } from "@/types/stagehandErrors"; @@ -487,6 +490,119 @@ export class AnthropicClient extends LLMClient { throw error; } } + + async generateObject({ + prompt, + schema, + options = {}, + }: GenerateObjectOptions): Promise { + // Destructure options with defaults + const { + logger = () => {}, + retries = 3, + requestId = Date.now().toString(), + ...chatOptions + } = options; + + try { + // Log the generation attempt + logger({ + category: "anthropic", + message: "Initiating object generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Create chat completion with the provided prompt + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + response_model: { + name: "object", + schema: schema, + }, + ...chatOptions, + requestId, + }, + logger, + retries, + })) as LLMObjectResponse; + + // Validate response structure + if (!response.data || response.data.length === 0) { + throw new CreateChatCompletionResponseError( + "API response contains no valid choices", + ); + } + + // Extract and validate the generated text + const generatedObject = response.data; + if (generatedObject === null || generatedObject === undefined) { + throw new CreateChatCompletionResponseError( + "Generated text content is empty", + ); + } + + // Construct the final response + const objResponse = { + ...response, + object: generatedObject, + } as T; + + // Log successful generation + logger({ + category: "anthropic", + message: "Text generation successful", + level: 2, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return objResponse; + } catch (error) { + // Log the error + logger({ + category: "anthropic", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } } const extractSchemaProperties = (jsonSchema: AnthropicJsonSchemaObject) => { diff --git a/lib/llm/CerebrasClient.ts b/lib/llm/CerebrasClient.ts index 0c336bd87..cf637c3dc 100644 --- a/lib/llm/CerebrasClient.ts +++ b/lib/llm/CerebrasClient.ts @@ -7,9 +7,12 @@ import { LLMCache } from "../cache/LLMCache"; import { ChatMessage, CreateChatCompletionOptions, + GenerateObjectOptions, GenerateTextOptions, LLMClient, + LLMObjectResponse, LLMResponse, + ObjectResponse, TextResponse, } from "./LLMClient"; import { CreateChatCompletionResponseError } from "@/types/stagehandErrors"; @@ -477,4 +480,115 @@ export class CerebrasClient extends LLMClient { ); } } + async generateObject({ + prompt, + schema, + options = {}, + }: GenerateObjectOptions): Promise { + // Destructure options with defaults + const { + logger = () => {}, + retries = 3, + requestId = Date.now().toString(), + ...chatOptions + } = options; + + try { + // Log the generation attempt + logger({ + category: "anthropic", + message: "Initiating object generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Create chat completion with the provided prompt + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + response_model: { + name: "object", + schema: schema, + }, + ...chatOptions, + requestId, + }, + logger, + retries, + })) as LLMObjectResponse; + // Validate response structure + if (!response.data || response.data.length === 0) { + throw new CreateChatCompletionResponseError( + "API response contains no valid choices", + ); + } + + // Extract and validate the generated text + const generatedObject = response.data; + if (generatedObject === null || generatedObject === undefined) { + throw new CreateChatCompletionResponseError( + "Generated text content is empty", + ); + } + + // Construct the final response + const objResponse = { + ...response, + object: generatedObject, + } as T; + + // Log successful generation + logger({ + category: "anthropic", + message: "Text generation successful", + level: 2, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return objResponse; + } catch (error) { + // Log the error + logger({ + category: "anthropic", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } } diff --git a/lib/llm/GoogleClient.ts b/lib/llm/GoogleClient.ts index aa8e38675..71188e556 100644 --- a/lib/llm/GoogleClient.ts +++ b/lib/llm/GoogleClient.ts @@ -24,6 +24,9 @@ import { AnnotatedScreenshotText, TextResponse, GenerateTextOptions, + LLMObjectResponse, + GenerateObjectOptions, + ObjectResponse, } from "./LLMClient"; import { CreateChatCompletionResponseError, @@ -572,4 +575,115 @@ export class GoogleClient extends LLMClient { ); } } + async generateObject({ + prompt, + schema, + options = {}, + }: GenerateObjectOptions): Promise { + // Destructure options with defaults + const { + logger = () => {}, + retries = 3, + requestId = Date.now().toString(), + ...chatOptions + } = options; + + try { + // Log the generation attempt + logger({ + category: "anthropic", + message: "Initiating object generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Create chat completion with the provided prompt + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + response_model: { + name: "object", + schema: schema, + }, + ...chatOptions, + requestId, + }, + logger, + retries, + })) as LLMObjectResponse; + // Validate response structure + if (!response.data || response.data.length === 0) { + throw new CreateChatCompletionResponseError( + "API response contains no valid choices", + ); + } + + // Extract and validate the generated text + const generatedObject = response.data; + if (generatedObject === null || generatedObject === undefined) { + throw new CreateChatCompletionResponseError( + "Generated text content is empty", + ); + } + + // Construct the final response + const objResponse = { + ...response, + object: generatedObject, + } as T; + + // Log successful generation + logger({ + category: "anthropic", + message: "Text generation successful", + level: 2, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return objResponse; + } catch (error) { + // Log the error + logger({ + category: "anthropic", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } } diff --git a/lib/llm/GroqClient.ts b/lib/llm/GroqClient.ts index deb9773b6..3f02c3c47 100644 --- a/lib/llm/GroqClient.ts +++ b/lib/llm/GroqClient.ts @@ -7,9 +7,12 @@ import { LLMCache } from "../cache/LLMCache"; import { ChatMessage, CreateChatCompletionOptions, + GenerateObjectOptions, GenerateTextOptions, LLMClient, + LLMObjectResponse, LLMResponse, + ObjectResponse, TextResponse, } from "./LLMClient"; import { CreateChatCompletionResponseError } from "@/types/stagehandErrors"; @@ -245,7 +248,7 @@ export class GroqClient extends LLMClient { if (this.enableCaching) { this.cache.set(cacheOptions, result, options.requestId); } - return result as T; + return { data: result } as T; } catch (e) { // If JSON parse fails, the model might be returning a different format logger({ @@ -273,7 +276,7 @@ export class GroqClient extends LLMClient { if (this.enableCaching) { this.cache.set(cacheOptions, result, options.requestId); } - return result as T; + return { data: result } as T; } } catch (e) { logger({ @@ -459,4 +462,115 @@ export class GroqClient extends LLMClient { ); } } + async generateObject({ + prompt, + schema, + options = {}, + }: GenerateObjectOptions): Promise { + // Destructure options with defaults + const { + logger = () => {}, + retries = 3, + requestId = Date.now().toString(), + ...chatOptions + } = options; + + try { + // Log the generation attempt + logger({ + category: "anthropic", + message: "Initiating object generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Create chat completion with the provided prompt + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + response_model: { + name: "object", + schema: schema, + }, + ...chatOptions, + requestId, + }, + logger, + retries, + })) as LLMObjectResponse; + // Validate response structure + if (!response.data || response.data.length === 0) { + throw new CreateChatCompletionResponseError( + "API response contains no valid choices", + ); + } + + // Extract and validate the generated text + const generatedObject = response.data; + if (generatedObject === null || generatedObject === undefined) { + throw new CreateChatCompletionResponseError( + "Generated text content is empty", + ); + } + + // Construct the final response + const objResponse = { + ...response, + object: generatedObject, + } as T; + + // Log successful generation + logger({ + category: "anthropic", + message: "Text generation successful", + level: 2, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return objResponse; + } catch (error) { + // Log the error + logger({ + category: "anthropic", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } } diff --git a/lib/llm/LLMClient.ts b/lib/llm/LLMClient.ts index 3256ac8d8..d04caff1d 100644 --- a/lib/llm/LLMClient.ts +++ b/lib/llm/LLMClient.ts @@ -96,6 +96,12 @@ export interface LLMResponse extends BaseResponse { usage: UsageMetrics; } +// Main LLM Response type +export interface LLMObjectResponse extends BaseResponse { + data: Record; + usage: UsageMetrics; +} + // Text Response type that can include LLM properties export interface TextResponse extends BaseResponse { text: string; @@ -103,6 +109,13 @@ export interface TextResponse extends BaseResponse { usage?: UsageMetrics; } +// Object Response type that can include LLM properties +export interface ObjectResponse extends BaseResponse { + object: string; + choices?: LLMChoice[]; + usage?: UsageMetrics; +} + export interface CreateChatCompletionOptions { options: ChatCompletionOptions; logger: (message: LogLine) => void; @@ -117,6 +130,15 @@ export interface GenerateTextOptions { }; } +export interface GenerateObjectOptions { + prompt: string; + schema: ZodType; + options?: Partial> & { + logger?: (message: LogLine) => void; + retries?: number; + }; +} + export abstract class LLMClient { public type: "openai" | "anthropic" | "cerebras" | "groq" | (string & {}); public modelName: AvailableModel | (string & {}); @@ -140,4 +162,10 @@ export abstract class LLMClient { usage?: TextResponse["usage"]; }, >(input: GenerateTextOptions): Promise; + + abstract generateObject< + T = ObjectResponse & { + usage?: ObjectResponse["usage"]; + }, + >(input: GenerateObjectOptions): Promise; } diff --git a/lib/llm/OpenAIClient.ts b/lib/llm/OpenAIClient.ts index 492a20d0e..2958ea9e5 100644 --- a/lib/llm/OpenAIClient.ts +++ b/lib/llm/OpenAIClient.ts @@ -18,9 +18,12 @@ import { ChatCompletionOptions, ChatMessage, CreateChatCompletionOptions, + GenerateObjectOptions, GenerateTextOptions, LLMClient, + LLMObjectResponse, LLMResponse, + ObjectResponse, TextResponse, } from "./LLMClient"; import { @@ -544,6 +547,118 @@ export class OpenAIClient extends LLMClient { }, }); + // Re-throw the error to be handled by the caller + throw error; + } + } + async generateObject({ + prompt, + schema, + options = {}, + }: GenerateObjectOptions): Promise { + // Destructure options with defaults + const { + logger = () => {}, + retries = 3, + requestId = Date.now().toString(), + ...chatOptions + } = options; + + try { + // Log the generation attempt + logger({ + category: "anthropic", + message: "Initiating object generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + // Create chat completion with the provided prompt + const response = (await this.createChatCompletion({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + response_model: { + name: "object", + schema: schema, + }, + ...chatOptions, + requestId, + }, + logger, + retries, + })) as LLMObjectResponse; + + // Validate response structure + if (!response.data || response.data.length === 0) { + throw new CreateChatCompletionResponseError( + "API response contains no valid choices", + ); + } + + // Extract and validate the generated text + const generatedObject = response.data; + if (generatedObject === null || generatedObject === undefined) { + throw new CreateChatCompletionResponseError( + "Generated text content is empty", + ); + } + + // Construct the final response + const objResponse = { + ...response, + object: generatedObject, + } as T; + + // Log successful generation + logger({ + category: "anthropic", + message: "Text generation successful", + level: 2, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return objResponse; + } catch (error) { + // Log the error + logger({ + category: "anthropic", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + // Re-throw the error to be handled by the caller throw error; } diff --git a/package-lock.json b/package-lock.json index 53cd79748..7f95e2d76 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@browserbasehq/stagehand", - "version": "2.0.0", + "version": "2.1.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@browserbasehq/stagehand", - "version": "2.0.0", + "version": "2.1.0", "license": "MIT", "dependencies": { "@anthropic-ai/sdk": "0.39.0", @@ -16,7 +16,7 @@ "pino": "^9.6.0", "pino-pretty": "^13.0.0", "ws": "^8.18.0", - "zod-to-json-schema": "^3.23.5" + "zod-to-json-schema": "^3.24.5" }, "devDependencies": { "@ai-sdk/anthropic": "^1.2.6", @@ -10441,9 +10441,9 @@ } }, "node_modules/zod-to-json-schema": { - "version": "3.24.3", - "resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.24.3.tgz", - "integrity": "sha512-HIAfWdYIt1sssHfYZFCXp4rU1w2r8hVVXYIlmoa0r0gABLs5di3RCqPU5DDROogVz1pAdYBaz7HK5n9pSUNs3A==", + "version": "3.24.5", + "resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.24.5.tgz", + "integrity": "sha512-/AuWwMP+YqiPbsJx5D6TfgRTc4kTLjsh5SOcd4bLsfUg2RcEXrFMJl1DGgdHy2aCfsIA/cr/1JM0xcB2GZji8g==", "license": "ISC", "peerDependencies": { "zod": "^3.24.1" diff --git a/package.json b/package.json index af2c20c94..007920e4e 100644 --- a/package.json +++ b/package.json @@ -96,7 +96,7 @@ "pino": "^9.6.0", "pino-pretty": "^13.0.0", "ws": "^8.18.0", - "zod-to-json-schema": "^3.23.5" + "zod-to-json-schema": "^3.24.5" }, "directories": { "doc": "docs", From d678c253d8cdd5698ce13be81d0a138f14a6b3c0 Mon Sep 17 00:00:00 2001 From: Hitesh Agarwal Date: Thu, 17 Apr 2025 14:27:25 +0800 Subject: [PATCH 3/5] feat: add streamText --- examples/external_clients/aisdk.ts | 25 ++ examples/external_clients/customOpenAI.ts | 238 +++++++++++++ examples/external_clients/langchain.ts | 92 +++++ examples/llm_usage_wordle.ts | 21 +- lib/llm/AnthropicClient.ts | 308 +++++++++++++++++ lib/llm/CerebrasClient.ts | 252 ++++++++++++++ lib/llm/GoogleClient.ts | 252 ++++++++++++++ lib/llm/GroqClient.ts | 252 ++++++++++++++ lib/llm/LLMClient.ts | 32 ++ lib/llm/OpenAIClient.ts | 390 +++++++++++++++++++++- 10 files changed, 1844 insertions(+), 18 deletions(-) diff --git a/examples/external_clients/aisdk.ts b/examples/external_clients/aisdk.ts index 9a73e63c3..8efa246c9 100644 --- a/examples/external_clients/aisdk.ts +++ b/examples/external_clients/aisdk.ts @@ -8,6 +8,7 @@ import { generateText, ImagePart, LanguageModel, + streamText, TextPart, } from "ai"; import { CreateChatCompletionOptions, LLMClient, AvailableModel } from "@/dist"; @@ -16,6 +17,7 @@ import { GenerateObjectOptions, GenerateTextOptions, ObjectResponse, + StreamingTextResponse, TextResponse, } from "@/lib"; @@ -125,6 +127,29 @@ export class AISdkClient extends LLMClient { }, } as T; } + + async streamText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + const tools: Record = {}; + if (options.tools) { + for (const rawTool of options.tools) { + tools[rawTool.name] = { + description: rawTool.description, + parameters: rawTool.parameters, + }; + } + } + + const response = await streamText({ + model: this.model, + prompt: prompt, + tools, + }); + return response as T; + } + async generateText({ prompt, options = {}, diff --git a/examples/external_clients/customOpenAI.ts b/examples/external_clients/customOpenAI.ts index ee017ba46..b5af6ad6a 100644 --- a/examples/external_clients/customOpenAI.ts +++ b/examples/external_clients/customOpenAI.ts @@ -14,6 +14,7 @@ import type { ChatCompletionContentPartImage, ChatCompletionContentPartText, ChatCompletionCreateParamsNonStreaming, + ChatCompletionCreateParamsStreaming, ChatCompletionMessageParam, ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam, @@ -26,6 +27,8 @@ import { LLMObjectResponse, LLMResponse, ObjectResponse, + StreamingChatResponse, + StreamingTextResponse, TextResponse, } from "@/lib"; @@ -246,6 +249,241 @@ export class CustomOpenAIClient extends LLMClient { } as T; } + async createChatCompletionStream({ + options, + logger, + retries = 3, + }: CreateChatCompletionOptions): Promise { + const { image, requestId, ...optionsWithoutImageAndRequestId } = options; + + // TODO: Implement vision support + if (image) { + console.warn( + "Image provided. Vision is not currently supported for openai", + ); + } + + logger({ + category: "openai", + message: "creating chat completion stream", + level: 1, + auxiliary: { + options: { + value: JSON.stringify({ + ...optionsWithoutImageAndRequestId, + requestId, + }), + type: "object", + }, + modelName: { + value: this.modelName, + type: "string", + }, + }, + }); + + if (options.image) { + console.warn( + "Image provided. Vision is not currently supported for openai", + ); + } + + let responseFormat = undefined; + if (options.response_model) { + responseFormat = zodResponseFormat( + options.response_model.schema, + options.response_model.name, + ); + } + + /* eslint-disable */ + // Remove unsupported options + const { response_model, ...openaiOptions } = { + ...optionsWithoutImageAndRequestId, + model: this.modelName, + }; + + const formattedMessages: ChatCompletionMessageParam[] = + options.messages.map((message) => { + if (Array.isArray(message.content)) { + const contentParts = message.content.map((content) => { + if ("image_url" in content) { + const imageContent: ChatCompletionContentPartImage = { + image_url: { + url: content.image_url.url, + }, + type: "image_url", + }; + return imageContent; + } else { + const textContent: ChatCompletionContentPartText = { + text: content.text, + type: "text", + }; + return textContent; + } + }); + + if (message.role === "system") { + const formattedMessage: ChatCompletionSystemMessageParam = { + ...message, + role: "system", + content: contentParts.filter( + (content): content is ChatCompletionContentPartText => + content.type === "text", + ), + }; + return formattedMessage; + } else if (message.role === "user") { + const formattedMessage: ChatCompletionUserMessageParam = { + ...message, + role: "user", + content: contentParts, + }; + return formattedMessage; + } else { + const formattedMessage: ChatCompletionAssistantMessageParam = { + ...message, + role: "assistant", + content: contentParts.filter( + (content): content is ChatCompletionContentPartText => + content.type === "text", + ), + }; + return formattedMessage; + } + } + + const formattedMessage: ChatCompletionUserMessageParam = { + role: "user", + content: message.content, + }; + + return formattedMessage; + }); + + const body: ChatCompletionCreateParamsStreaming = { + ...openaiOptions, + model: this.modelName, + messages: formattedMessages, + response_format: responseFormat, + stream: true, + tools: options.tools?.map((tool) => ({ + function: { + name: tool.name, + description: tool.description, + parameters: tool.parameters, + }, + type: "function", + })), + }; + + const response = await this.client.chat.completions.create(body); + return response as T; + } + + async streamText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { logger = () => {}, retries = 3, ...chatOptions } = options; + + // Create a unique request ID if not provided + const requestId = options.requestId || Date.now().toString(); + + logger({ + category: "openai", + message: "Initiating text streaming", + level: 2, + auxiliary: { + options: { + value: JSON.stringify({ + prompt, + requestId, + }), + type: "object", + }, + modelName: { + value: this.modelName, + type: "string", + }, + }, + }); + + try { + // Create a chat completion with the prompt as a user message + const response = (await this.createChatCompletionStream({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId, + }, + logger, + retries, + })) as StreamingChatResponse; + + // Restructure the response to return a stream of text + const textStream = new ReadableStream({ + async start(controller) { + try { + for await (const chunk of response) { + const content = chunk.choices[0]?.delta?.content; + if (content !== undefined) { + controller.enqueue(content); + } + } + controller.close(); + } catch (error) { + controller.error(error); + } + }, + }); + + logger({ + category: "openai", + message: "text streaming response", + level: 2, + auxiliary: { + response: { + value: JSON.stringify(textStream), + type: "object", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return { textStream: textStream } as T; + } catch (error) { + logger({ + category: "openai", + message: "Text streaming failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } + async generateText({ prompt, options = {}, diff --git a/examples/external_clients/langchain.ts b/examples/external_clients/langchain.ts index 227d99fae..5faefc34e 100644 --- a/examples/external_clients/langchain.ts +++ b/examples/external_clients/langchain.ts @@ -15,6 +15,8 @@ import { LLMObjectResponse, LLMResponse, ObjectResponse, + StreamingChatResponse, + StreamingTextResponse, TextResponse, } from "@/lib"; @@ -94,6 +96,96 @@ export class LangchainClient extends LLMClient { } as T; } + async createChatCompletionStream({ + options, + logger, + retries = 3, + }: CreateChatCompletionOptions): Promise { + console.log(logger, retries); + const formattedMessages: BaseMessageLike[] = options.messages.map( + (message) => { + if (Array.isArray(message.content)) { + if (message.role === "system") { + return new SystemMessage( + message.content + .map((c) => ("text" in c ? c.text : "")) + .join("\n"), + ); + } + + const content = message.content.map((content) => + "image_url" in content + ? { type: "image", image: content.image_url.url } + : { type: "text", text: content.text }, + ); + + if (message.role === "user") return new HumanMessage({ content }); + + const textOnlyParts = content.map((part) => ({ + type: "text" as const, + text: part.type === "image" ? "[Image]" : part.text, + })); + + return new AIMessage({ content: textOnlyParts }); + } + + return { + role: message.role, + content: message.content, + }; + }, + ); + const modelWithTools = this.model.bindTools(options.tools); + const response = await modelWithTools._streamIterator(formattedMessages); + return response as T; + } + + async streamText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { logger = () => {}, retries = 3, ...chatOptions } = options; + + // Create a unique request ID if not provided + const requestId = options.requestId || Date.now().toString(); + + // Create a chat completion with the prompt as a user message + const response = (await this.createChatCompletionStream({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId, + }, + logger, + retries, + })) as StreamingChatResponse; + + // Restructure the response to return a stream of text + const textStream = new ReadableStream({ + async start(controller) { + try { + for await (const chunk of response) { + const content = chunk.choices[0]?.delta?.content; + if (content !== undefined) { + controller.enqueue(content); + } + } + controller.close(); + } catch (error) { + controller.error(error); + } + }, + }); + + return { textStream: textStream } as T; + } + async generateText({ prompt, options = {}, diff --git a/examples/llm_usage_wordle.ts b/examples/llm_usage_wordle.ts index caee50544..61d261171 100644 --- a/examples/llm_usage_wordle.ts +++ b/examples/llm_usage_wordle.ts @@ -1,6 +1,6 @@ import { Stagehand } from "@/dist"; import StagehandConfig from "@/stagehand.config"; -import { z } from "zod"; +// import { z } from "zod"; async function example() { const stagehand = new Stagehand({ @@ -9,22 +9,15 @@ async function example() { await stagehand.init(); - const { text } = await stagehand.llmClient.generateText({ + const { textStream } = await stagehand.llmClient.streamText({ prompt: "you are playing wordle. Return the 5-letter word that would be the best guess", }); - console.log(text); - const { object } = await stagehand.llmClient.generateObject({ - prompt: - "you are playing wordle. Return the 5-letter word that would be the best guess", - schema: z.object({ - guess: z - .string() - .length(5) - .describe("The 5-letter word that would be the best guess"), - }), - }); - console.log(object); + + for await (const textPart of textStream) { + process.stdout.write(textPart); + } + await stagehand.close(); } diff --git a/lib/llm/AnthropicClient.ts b/lib/llm/AnthropicClient.ts index 9fa034089..16aaeef3f 100644 --- a/lib/llm/AnthropicClient.ts +++ b/lib/llm/AnthropicClient.ts @@ -17,6 +17,8 @@ import { LLMObjectResponse, LLMResponse, ObjectResponse, + StreamingChatResponse, + StreamingTextResponse, TextResponse, } from "./LLMClient"; import { CreateChatCompletionResponseError } from "@/types/stagehandErrors"; @@ -379,6 +381,312 @@ export class AnthropicClient extends LLMClient { return transformedResponse as T; } + async createChatCompletionStream({ + options, + retries, + logger, + }: CreateChatCompletionOptions): Promise { + console.log(options, logger, retries); + const optionsWithoutImage = { ...options }; + delete optionsWithoutImage.image; + + logger({ + category: "anthropic", + message: "creating chat completion stream", + level: 2, + auxiliary: { + options: { + value: JSON.stringify(optionsWithoutImage), + type: "object", + }, + }, + }); + + // Try to get cached response + const cacheOptions = { + model: this.modelName, + messages: options.messages, + temperature: options.temperature, + image: options.image, + response_model: options.response_model, + tools: options.tools, + retries: retries, + }; + + if (this.enableCaching) { + const cachedResponse = await this.cache.get( + cacheOptions, + options.requestId, + ); + if (cachedResponse) { + logger({ + category: "llm_cache", + message: "LLM cache hit - returning cached response", + level: 1, + auxiliary: { + cachedResponse: { + value: JSON.stringify(cachedResponse), + type: "object", + }, + requestId: { + value: options.requestId, + type: "string", + }, + cacheOptions: { + value: JSON.stringify(cacheOptions), + type: "object", + }, + }, + }); + return cachedResponse as T; + } else { + logger({ + category: "llm_cache", + message: "LLM cache miss - no cached response found", + level: 1, + auxiliary: { + cacheOptions: { + value: JSON.stringify(cacheOptions), + type: "object", + }, + requestId: { + value: options.requestId, + type: "string", + }, + }, + }); + } + } + + const systemMessage = options.messages.find((msg) => { + if (msg.role === "system") { + if (typeof msg.content === "string") { + return true; + } else if (Array.isArray(msg.content)) { + return msg.content.every((content) => content.type !== "image_url"); + } + } + return false; + }); + + const userMessages = options.messages.filter( + (msg) => msg.role !== "system", + ); + + const formattedMessages: MessageParam[] = userMessages.map((msg) => { + if (typeof msg.content === "string") { + return { + role: msg.role as "user" | "assistant", // ensure its not checking for system types + content: msg.content, + }; + } else { + return { + role: msg.role as "user" | "assistant", + content: msg.content.map((content) => { + if ("image_url" in content) { + const formattedContent: ImageBlockParam = { + type: "image", + source: { + type: "base64", + media_type: "image/jpeg", + data: content.image_url.url, + }, + }; + + return formattedContent; + } else { + return { type: "text", text: content.text }; + } + }), + }; + } + }); + + if (options.image) { + const screenshotMessage: MessageParam = { + role: "user", + content: [ + { + type: "image", + source: { + type: "base64", + media_type: "image/jpeg", + data: options.image.buffer.toString("base64"), + }, + }, + ], + }; + if ( + options.image.description && + Array.isArray(screenshotMessage.content) + ) { + screenshotMessage.content.push({ + type: "text", + text: options.image.description, + }); + } + + formattedMessages.push(screenshotMessage); + } + + let anthropicTools: Tool[] = options.tools?.map((tool) => { + return { + name: tool.name, + description: tool.description, + input_schema: { + type: "object", + properties: tool.parameters.properties, + required: tool.parameters.required, + }, + }; + }); + + let toolDefinition: Tool | undefined; + + // Check if a response model is provided + if (options.response_model) { + const jsonSchema = zodToJsonSchema(options.response_model.schema); + const { properties: schemaProperties, required: schemaRequired } = + extractSchemaProperties(jsonSchema); + + toolDefinition = { + name: "print_extracted_data", + description: "Prints the extracted data based on the provided schema.", + input_schema: { + type: "object", + properties: schemaProperties, + required: schemaRequired, + }, + }; + } + + // Add the tool definition to the tools array if it exists + if (toolDefinition) { + anthropicTools = anthropicTools ?? []; + anthropicTools.push(toolDefinition); + } + + // Create the chat completion stream with the provided messages + const response = await this.client.messages.create({ + model: this.modelName, + max_tokens: options.maxTokens || 8192, + messages: formattedMessages, + tools: anthropicTools, + system: systemMessage + ? (systemMessage.content as string | TextBlockParam[]) + : undefined, + temperature: options.temperature, + stream: true, + }); + + // Restructure the response to match the expected format + return new ReadableStream({ + async start(controller) { + try { + for await (const chunk of response) { + if ( + chunk.type === "content_block_delta" && + chunk.delta.type === "text_delta" + ) { + controller.enqueue(chunk.delta.text); + } + } + controller.close(); + } catch (error) { + controller.error(error); + } + }, + }) as T; + } + + async streamText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { logger = () => {}, retries = 3, ...chatOptions } = options; + + // Create a unique request ID if not provided + const requestId = options.requestId || Date.now().toString(); + + // Log the generation attempt + logger({ + category: "anthropic", + message: "Initiating text streaming", + level: 2, + auxiliary: { + options: { + value: JSON.stringify({ + prompt, + requestId, + }), + type: "object", + }, + modelName: { + value: this.modelName, + type: "string", + }, + }, + }); + + try { + // Create a chat completion stream with the prompt as a user message + const response = (await this.createChatCompletionStream({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId, + }, + logger, + retries, + })) as StreamingChatResponse; + + logger({ + category: "anthropic", + message: "text streaming response", + level: 2, + auxiliary: { + response: { + value: JSON.stringify(response), + type: "object", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return { + textStream: response, + } as T; + } catch (error) { + logger({ + category: "anthropic", + message: "Text streaming failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } + async generateText({ prompt, options = {}, diff --git a/lib/llm/CerebrasClient.ts b/lib/llm/CerebrasClient.ts index cf637c3dc..305c48ec7 100644 --- a/lib/llm/CerebrasClient.ts +++ b/lib/llm/CerebrasClient.ts @@ -13,6 +13,8 @@ import { LLMObjectResponse, LLMResponse, ObjectResponse, + StreamingChatResponse, + StreamingTextResponse, TextResponse, } from "./LLMClient"; import { CreateChatCompletionResponseError } from "@/types/stagehandErrors"; @@ -330,6 +332,256 @@ export class CerebrasClient extends LLMClient { } } + async createChatCompletionStream({ + options, + logger, + retries = 3, + }: CreateChatCompletionOptions): Promise { + const optionsWithoutImage = { ...options }; + delete optionsWithoutImage.image; + + logger({ + category: "cerebras", + message: "creating chat completion stream", + level: 2, + auxiliary: { + options: { + value: JSON.stringify(optionsWithoutImage), + type: "object", + }, + }, + }); + + // Try to get cached response + const cacheOptions = { + model: this.modelName.split("cerebras-")[1], + messages: options.messages, + temperature: options.temperature, + response_model: options.response_model, + tools: options.tools, + retries: retries, + }; + + if (this.enableCaching) { + const cachedResponse = await this.cache.get( + cacheOptions, + options.requestId, + ); + if (cachedResponse) { + logger({ + category: "llm_cache", + message: "LLM cache hit - returning cached response", + level: 1, + auxiliary: { + cachedResponse: { + value: JSON.stringify(cachedResponse), + type: "object", + }, + requestId: { + value: options.requestId, + type: "string", + }, + cacheOptions: { + value: JSON.stringify(cacheOptions), + type: "object", + }, + }, + }); + return cachedResponse as T; + } + } + + // Format messages for Cerebras API (using OpenAI format) + const formattedMessages = options.messages.map((msg: ChatMessage) => { + const baseMessage = { + content: + typeof msg.content === "string" + ? msg.content + : Array.isArray(msg.content) && + msg.content.length > 0 && + "text" in msg.content[0] + ? msg.content[0].text + : "", + }; + + // Cerebras only supports system, user, and assistant roles + if (msg.role === "system") { + return { ...baseMessage, role: "system" as const }; + } else if (msg.role === "assistant") { + return { ...baseMessage, role: "assistant" as const }; + } else { + // Default to user for any other role + return { ...baseMessage, role: "user" as const }; + } + }); + + // Format tools if provided + let tools = options.tools?.map((tool) => ({ + type: "function" as const, + function: { + name: tool.name, + description: tool.description, + parameters: { + type: "object", + properties: tool.parameters.properties, + required: tool.parameters.required, + }, + }, + })); + + // Add response model as a tool if provided + if (options.response_model) { + const jsonSchema = zodToJsonSchema(options.response_model.schema) as { + properties?: Record; + required?: string[]; + }; + const schemaProperties = jsonSchema.properties || {}; + const schemaRequired = jsonSchema.required || []; + + const responseTool = { + type: "function" as const, + function: { + name: "print_extracted_data", + description: + "Prints the extracted data based on the provided schema.", + parameters: { + type: "object", + properties: schemaProperties, + required: schemaRequired, + }, + }, + }; + + tools = tools ? [...tools, responseTool] : [responseTool]; + } + + const apiResponse = await this.client.chat.completions.create({ + model: this.modelName.split("cerebras-")[1], + messages: [ + ...formattedMessages, + // Add explicit instruction to return JSON if we have a response model + ...(options.response_model + ? [ + { + role: "system" as const, + content: `IMPORTANT: Your response must be valid JSON that matches this schema: ${JSON.stringify(options.response_model.schema)}`, + }, + ] + : []), + ], + temperature: options.temperature || 0.7, + max_tokens: options.maxTokens, + tools: tools, + tool_choice: options.tool_choice || "auto", + }); + + return apiResponse as T; + } + + async streamText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { logger = () => {}, retries = 3, ...chatOptions } = options; + + // Create a unique request ID if not provided + const requestId = options.requestId || Date.now().toString(); + + logger({ + category: "cerebras", + message: "Initiating text streaming", + level: 2, + auxiliary: { + options: { + value: JSON.stringify({ + prompt, + requestId, + }), + type: "object", + }, + modelName: { + value: this.modelName, + type: "string", + }, + }, + }); + + try { + // Create a chat completion with the prompt as a user message + const response = (await this.createChatCompletionStream({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId, + }, + logger, + retries, + })) as StreamingChatResponse; + + // Restructure the response to return a stream of text + const textStream = new ReadableStream({ + async start(controller) { + try { + for await (const chunk of response) { + const content = chunk.choices[0]?.delta?.content; + if (content !== undefined) { + controller.enqueue(content); + } + } + controller.close(); + } catch (error) { + controller.error(error); + } + }, + }); + + // Log successful streaming + logger({ + category: "cerebras", + message: "text streaming response", + level: 2, + auxiliary: { + response: { + value: JSON.stringify(textStream), + type: "object", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return { textStream: textStream } as T; + } catch (error) { + // Log the error with detailed information + logger({ + category: "cerebras", + message: "Text streaming failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } + async generateText({ prompt, options = {}, diff --git a/lib/llm/GoogleClient.ts b/lib/llm/GoogleClient.ts index 71188e556..0f26edf78 100644 --- a/lib/llm/GoogleClient.ts +++ b/lib/llm/GoogleClient.ts @@ -27,6 +27,8 @@ import { LLMObjectResponse, GenerateObjectOptions, ObjectResponse, + StreamingChatResponse, + StreamingTextResponse, } from "./LLMClient"; import { CreateChatCompletionResponseError, @@ -536,6 +538,256 @@ export class GoogleClient extends LLMClient { } } + async createChatCompletionStream({ + options, + logger, + retries = 3, + }: CreateChatCompletionOptions): Promise { + const { + image, + requestId, + response_model, + tools, + temperature, + top_p, + maxTokens, + } = options; + console.log(retries); + + logger({ + category: "google", + message: "creating chat completion stream", + level: 2, + auxiliary: { + options: { + value: JSON.stringify(options), + type: "object", + }, + }, + }); + + const cacheKeyOptions = { + model: this.modelName, + messages: options.messages, + temperature: temperature, + top_p: top_p, + // frequency_penalty and presence_penalty are not directly supported in Gemini API + image: image + ? { description: image.description, bufferLength: image.buffer.length } + : undefined, // Use buffer length for caching key stability + response_model: response_model + ? { + name: response_model.name, + schema: JSON.stringify(zodToJsonSchema(response_model.schema)), + } + : undefined, + tools: tools, + maxTokens: maxTokens, + }; + + if (this.enableCaching) { + const cachedResponse = await this.cache.get( + cacheKeyOptions, + requestId, + ); + if (cachedResponse) { + logger({ + category: "llm_cache", + message: "LLM cache hit - returning cached response", + level: 1, + auxiliary: { requestId: { value: requestId, type: "string" } }, + }); + return cachedResponse; + } else { + logger({ + category: "llm_cache", + message: "LLM cache miss - proceeding with API call", + level: 1, + auxiliary: { requestId: { value: requestId, type: "string" } }, + }); + } + } + + const formattedMessages = this.formatMessages(options.messages, image); + const formattedTools = this.formatTools(tools); + + const generationConfig = { + maxOutputTokens: maxTokens, + temperature: temperature, + topP: top_p, + responseMimeType: response_model ? "application/json" : undefined, + }; + + // Handle JSON mode instructions + if (response_model) { + // Prepend instructions for JSON output if needed (similar to o1 handling) + const schemaString = JSON.stringify( + zodToJsonSchema(response_model.schema), + ); + formattedMessages.push({ + role: "user", + parts: [ + { + text: `Please respond ONLY with a valid JSON object that strictly adheres to the following JSON schema. Do not include any other text, explanations, or markdown formatting like \`\`\`json ... \`\`\`. Just the JSON object.\n\nSchema:\n${schemaString}`, + }, + ], + }); + formattedMessages.push({ role: "model", parts: [{ text: "{" }] }); // Prime the model + } + + logger({ + category: "google", + message: "creating chat completion", + level: 2, + auxiliary: { + modelName: { value: this.modelName, type: "string" }, + requestId: { value: requestId, type: "string" }, + requestPayloadSummary: { + value: `Model: ${this.modelName}, Messages: ${formattedMessages.length}, Config Keys: ${Object.keys(generationConfig).join(", ")}, Tools: ${formattedTools ? formattedTools.length : 0}, Safety Categories: ${safetySettings.map((s) => s.category).join(", ")}`, + type: "string", + }, + }, + }); + + // Construct the full request object + const requestPayload = { + model: this.modelName, + contents: formattedMessages, + config: { + ...generationConfig, + safetySettings: safetySettings, + tools: formattedTools, + }, + }; + + // Log the full payload safely + try { + logger({ + category: "google", + message: "Full request payload", + level: 2, + auxiliary: { + requestId: { value: requestId, type: "string" }, + fullPayload: { + value: JSON.stringify(requestPayload), + type: "object", + }, + }, + }); + } catch (e) { + logger({ + category: "google", + message: "Failed to stringify full request payload for logging", + level: 0, + auxiliary: { + requestId: { value: requestId, type: "string" }, + error: { value: e.message, type: "string" }, + }, + }); + } + + const result = + await this.client.models.generateContentStream(requestPayload); + + return new ReadableStream({ + async start(controller) { + try { + for await (const chunk of result) { + controller.enqueue(chunk.candidates[0].content.parts[0].text); + } + controller.close(); + } catch (error) { + controller.error(error); + } + }, + }) as T; + } + + async streamText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { logger = () => {}, retries = 3, ...chatOptions } = options; + + // Create a unique request ID if not provided + const requestId = options.requestId || Date.now().toString(); + + logger({ + category: "google", + message: "Initiating text streaming", + level: 2, + auxiliary: { + options: { + value: JSON.stringify({ + prompt, + requestId, + }), + type: "object", + }, + modelName: { + value: this.modelName, + type: "string", + }, + }, + }); + + try { + // Create a chat completion with the prompt as a user message + const response = (await this.createChatCompletionStream({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId, + }, + logger, + retries, + })) as StreamingChatResponse; + + logger({ + category: "google", + message: "text streaming response", + level: 2, + auxiliary: { + response: { + value: JSON.stringify(response), + type: "object", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return { textStream: response } as T; + } catch (error) { + logger({ + category: "google", + message: "Text streaming failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } + async generateText({ prompt, options = {}, diff --git a/lib/llm/GroqClient.ts b/lib/llm/GroqClient.ts index 3f02c3c47..cd011f967 100644 --- a/lib/llm/GroqClient.ts +++ b/lib/llm/GroqClient.ts @@ -13,6 +13,8 @@ import { LLMObjectResponse, LLMResponse, ObjectResponse, + StreamingChatResponse, + StreamingTextResponse, TextResponse, } from "./LLMClient"; import { CreateChatCompletionResponseError } from "@/types/stagehandErrors"; @@ -330,6 +332,256 @@ export class GroqClient extends LLMClient { } } + async createChatCompletionStream({ + options, + logger, + retries = 3, + }: CreateChatCompletionOptions): Promise { + const optionsWithoutImage = { ...options }; + delete optionsWithoutImage.image; + + logger({ + category: "groq", + message: "creating chat completion stream", + level: 2, + auxiliary: { + options: { + value: JSON.stringify(optionsWithoutImage), + type: "object", + }, + }, + }); + + // Try to get cached response + const cacheOptions = { + model: this.modelName.split("groq-")[1], + messages: options.messages, + temperature: options.temperature, + response_model: options.response_model, + tools: options.tools, + retries: retries, + }; + + if (this.enableCaching) { + const cachedResponse = await this.cache.get( + cacheOptions, + options.requestId, + ); + if (cachedResponse) { + logger({ + category: "llm_cache", + message: "LLM cache hit - returning cached response", + level: 1, + auxiliary: { + cachedResponse: { + value: JSON.stringify(cachedResponse), + type: "object", + }, + requestId: { + value: options.requestId, + type: "string", + }, + cacheOptions: { + value: JSON.stringify(cacheOptions), + type: "object", + }, + }, + }); + return cachedResponse as T; + } + } + + // Format messages for Groq API (using OpenAI format) + const formattedMessages = options.messages.map((msg: ChatMessage) => { + const baseMessage = { + content: + typeof msg.content === "string" + ? msg.content + : Array.isArray(msg.content) && + msg.content.length > 0 && + "text" in msg.content[0] + ? msg.content[0].text + : "", + }; + + // Groq supports system, user, and assistant roles + if (msg.role === "system") { + return { ...baseMessage, role: "system" as const }; + } else if (msg.role === "assistant") { + return { ...baseMessage, role: "assistant" as const }; + } else { + // Default to user for any other role + return { ...baseMessage, role: "user" as const }; + } + }); + + // Format tools if provided + let tools = options.tools?.map((tool) => ({ + type: "function" as const, + function: { + name: tool.name, + description: tool.description, + parameters: { + type: "object", + properties: tool.parameters.properties, + required: tool.parameters.required, + }, + }, + })); + + // Add response model as a tool if provided + if (options.response_model) { + const jsonSchema = zodToJsonSchema(options.response_model.schema) as { + properties?: Record; + required?: string[]; + }; + const schemaProperties = jsonSchema.properties || {}; + const schemaRequired = jsonSchema.required || []; + + const responseTool = { + type: "function" as const, + function: { + name: "print_extracted_data", + description: + "Prints the extracted data based on the provided schema.", + parameters: { + type: "object", + properties: schemaProperties, + required: schemaRequired, + }, + }, + }; + + tools = tools ? [...tools, responseTool] : [responseTool]; + } + + // Use OpenAI client with Groq API + const apiResponse = await this.client.chat.completions.create({ + model: this.modelName.split("groq-")[1], + messages: [ + ...formattedMessages, + // Add explicit instruction to return JSON if we have a response model + ...(options.response_model + ? [ + { + role: "system" as const, + content: `IMPORTANT: Your response must be valid JSON that matches this schema: ${JSON.stringify(options.response_model.schema)}`, + }, + ] + : []), + ], + temperature: options.temperature || 0.7, + max_tokens: options.maxTokens, + tools: tools, + tool_choice: options.tool_choice || "auto", + stream: true, + }); + + return apiResponse as T; + } + + async streamText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { logger = () => {}, retries = 3, ...chatOptions } = options; + + // Create a unique request ID if not provided + const requestId = options.requestId || Date.now().toString(); + + logger({ + category: "groq", + message: "Initiating text streaming", + level: 2, + auxiliary: { + options: { + value: JSON.stringify({ + prompt, + requestId, + }), + type: "object", + }, + modelName: { + value: this.modelName, + type: "string", + }, + }, + }); + + try { + // Create a chat completion with the prompt as a user message + const response = (await this.createChatCompletionStream({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId, + }, + logger, + retries, + })) as StreamingChatResponse; + + // Restructure the response to return a stream of text + const textStream = new ReadableStream({ + async start(controller) { + try { + for await (const chunk of response) { + const content = chunk.choices[0]?.delta?.content; + if (content !== undefined) { + controller.enqueue(content); + } + } + controller.close(); + } catch (error) { + controller.error(error); + } + }, + }); + + logger({ + category: "groq", + message: "text streaming response", + level: 2, + auxiliary: { + response: { + value: JSON.stringify(textStream), + type: "object", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return { textStream: textStream } as T; + } catch (error) { + logger({ + category: "groq", + message: "Text streaming failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } + async generateText({ prompt, options = {}, diff --git a/lib/llm/LLMClient.ts b/lib/llm/LLMClient.ts index d04caff1d..10f000233 100644 --- a/lib/llm/LLMClient.ts +++ b/lib/llm/LLMClient.ts @@ -96,6 +96,34 @@ export interface LLMResponse extends BaseResponse { usage: UsageMetrics; } +// Stream text response type +export interface StreamingTextResponse { + textStream: AsyncIterable; +} + +// Streaming chat chunk response type +export interface StreamingChatResponseChunk { + id: string; + object: string; + created: number; + model: string; + choices: { + index: number; + delta: { + content?: string; + role?: string; + function_call?: { + name?: string; + arguments?: string; + }; + }; + finish_reason: string | null; + }[]; +} + +// Streaming chat response type +export type StreamingChatResponse = AsyncIterable; + // Main LLM Response type export interface LLMObjectResponse extends BaseResponse { data: Record; @@ -157,6 +185,10 @@ export abstract class LLMClient { }, >(options: CreateChatCompletionOptions): Promise; + abstract streamText( + input: GenerateTextOptions, + ): Promise; + abstract generateText< T = TextResponse & { usage?: TextResponse["usage"]; diff --git a/lib/llm/OpenAIClient.ts b/lib/llm/OpenAIClient.ts index 2958ea9e5..4bd49124e 100644 --- a/lib/llm/OpenAIClient.ts +++ b/lib/llm/OpenAIClient.ts @@ -5,6 +5,7 @@ import { ChatCompletionContentPartImage, ChatCompletionContentPartText, ChatCompletionCreateParamsNonStreaming, + ChatCompletionCreateParamsStreaming, ChatCompletionMessageParam, ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam, @@ -24,6 +25,8 @@ import { LLMObjectResponse, LLMResponse, ObjectResponse, + StreamingChatResponse, + StreamingTextResponse, TextResponse, } from "./LLMClient"; import { @@ -473,6 +476,381 @@ export class OpenAIClient extends LLMClient { return response as T; } + async createChatCompletionStream({ + options: optionsInitial, + logger, + retries = 3, + }: CreateChatCompletionOptions): Promise { + let options: Partial = optionsInitial; + + // O1 models do not support most of the options. So we override them. + // For schema and tools, we add them as user messages. + // let isToolsOverridedForO1 = false; + if (this.modelName.startsWith("o1") || this.modelName.startsWith("o3")) { + /* eslint-disable */ + // Remove unsupported options + let { + tool_choice, + top_p, + frequency_penalty, + presence_penalty, + temperature, + } = options; + ({ + tool_choice, + top_p, + frequency_penalty, + presence_penalty, + temperature, + ...options + } = options); + /* eslint-enable */ + // Remove unsupported options + options.messages = options.messages.map((message) => ({ + ...message, + role: "user", + })); + if (options.tools && options.response_model) { + throw new StagehandError( + "Cannot use both tool and response_model for o1 models", + ); + } + + if (options.tools) { + // Remove unsupported options + let { tools } = options; + ({ tools, ...options } = options); + // isToolsOverridedForO1 = true; + options.messages.push({ + role: "user", + content: `You have the following tools available to you:\n${JSON.stringify( + tools, + )} + + Respond with the following zod schema format to use a method: { + "name": "", + "arguments": + } + + Do not include any other text or formattings like \`\`\` in your response. Just the JSON object.`, + }); + } + } + if ( + options.temperature && + (this.modelName.startsWith("o1") || this.modelName.startsWith("o3")) + ) { + throw new StagehandError("Temperature is not supported for o1 models"); + } + + const { image, requestId, ...optionsWithoutImageAndRequestId } = options; + + logger({ + category: "openai", + message: "creating chat completion stream", + level: 2, + auxiliary: { + options: { + value: JSON.stringify(optionsWithoutImageAndRequestId), + type: "object", + }, + }, + }); + + const cacheOptions = { + model: this.modelName, + messages: options.messages, + temperature: options.temperature, + top_p: options.top_p, + frequency_penalty: options.frequency_penalty, + presence_penalty: options.presence_penalty, + image: image, + response_model: options.response_model, + }; + + if (this.enableCaching) { + const cachedResponse = await this.cache.get(cacheOptions, requestId); + if (cachedResponse) { + logger({ + category: "llm_cache", + message: "LLM cache hit - returning cached response", + level: 1, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + cachedResponse: { + value: JSON.stringify(cachedResponse), + type: "object", + }, + }, + }); + return cachedResponse; + } else { + logger({ + category: "llm_cache", + message: "LLM cache miss - no cached response found", + level: 1, + auxiliary: { + requestId: { + value: requestId, + type: "string", + }, + }, + }); + } + } + + if (options.image) { + const screenshotMessage: ChatMessage = { + role: "user", + content: [ + { + type: "image_url", + image_url: { + url: `data:image/jpeg;base64,${options.image.buffer.toString("base64")}`, + }, + }, + ...(options.image.description + ? [{ type: "text", text: options.image.description }] + : []), + ], + }; + + options.messages.push(screenshotMessage); + } + + let responseFormat = undefined; + if (options.response_model) { + // For O1 models, we need to add the schema as a user message. + if (this.modelName.startsWith("o1") || this.modelName.startsWith("o3")) { + try { + const parsedSchema = JSON.stringify( + zodToJsonSchema(options.response_model.schema), + ); + options.messages.push({ + role: "user", + content: `Respond in this zod schema format:\n${parsedSchema}\n + + Do not include any other text, formatting or markdown in your output. Do not include \`\`\` or \`\`\`json in your response. Only the JSON object itself.`, + }); + } catch (error) { + logger({ + category: "openai", + message: "Failed to parse response model schema", + level: 0, + }); + + if (retries > 0) { + // as-casting to account for o1 models not supporting all options + return this.createChatCompletion({ + options: options as ChatCompletionOptions, + logger, + retries: retries - 1, + }); + } + + throw error; + } + } else { + responseFormat = zodResponseFormat( + options.response_model.schema, + options.response_model.name, + ); + } + } + + /* eslint-disable */ + // Remove unsupported options + const { response_model, ...openAiOptions } = { + ...optionsWithoutImageAndRequestId, + model: this.modelName, + }; + /* eslint-enable */ + + const formattedMessages: ChatCompletionMessageParam[] = + options.messages.map((message) => { + if (Array.isArray(message.content)) { + const contentParts = message.content.map((content) => { + if ("image_url" in content) { + const imageContent: ChatCompletionContentPartImage = { + image_url: { + url: content.image_url.url, + }, + type: "image_url", + }; + return imageContent; + } else { + const textContent: ChatCompletionContentPartText = { + text: content.text, + type: "text", + }; + return textContent; + } + }); + + if (message.role === "system") { + const formattedMessage: ChatCompletionSystemMessageParam = { + ...message, + role: "system", + content: contentParts.filter( + (content): content is ChatCompletionContentPartText => + content.type === "text", + ), + }; + return formattedMessage; + } else if (message.role === "user") { + const formattedMessage: ChatCompletionUserMessageParam = { + ...message, + role: "user", + content: contentParts, + }; + return formattedMessage; + } else { + const formattedMessage: ChatCompletionAssistantMessageParam = { + ...message, + role: "assistant", + content: contentParts.filter( + (content): content is ChatCompletionContentPartText => + content.type === "text", + ), + }; + return formattedMessage; + } + } + + const formattedMessage: ChatCompletionUserMessageParam = { + role: "user", + content: message.content, + }; + + return formattedMessage; + }); + + const body: ChatCompletionCreateParamsStreaming = { + ...openAiOptions, + model: this.modelName, + messages: formattedMessages, + response_format: responseFormat, + stream: true, + tools: options.tools?.map((tool) => ({ + function: { + name: tool.name, + description: tool.description, + parameters: tool.parameters, + }, + type: "function", + })), + }; + const response = await this.client.chat.completions.create(body); + return response as T; + } + + async streamText({ + prompt, + options = {}, + }: GenerateTextOptions): Promise { + // Destructure options with defaults + const { logger = () => {}, retries = 3, ...chatOptions } = options; + + // Create a unique request ID if not provided + const requestId = options.requestId || Date.now().toString(); + + logger({ + category: "openai", + message: "Initiating text streaming", + level: 2, + auxiliary: { + options: { + value: JSON.stringify({ + prompt, + requestId, + }), + type: "object", + }, + modelName: { + value: this.modelName, + type: "string", + }, + }, + }); + + try { + // Create a chat completion with the prompt as a user message + const response = (await this.createChatCompletionStream({ + options: { + messages: [ + { + role: "user", + content: prompt, + }, + ], + ...chatOptions, + requestId, + }, + logger, + retries, + })) as StreamingChatResponse; + + // Restructure the response to return a stream of text + const textStream = new ReadableStream({ + async start(controller) { + try { + for await (const chunk of response) { + const content = chunk.choices[0]?.delta?.content; + if (content !== undefined) { + controller.enqueue(content); + } + } + controller.close(); + } catch (error) { + controller.error(error); + } + }, + }); + + logger({ + category: "openai", + message: "text streaming response", + level: 2, + auxiliary: { + response: { + value: JSON.stringify(textStream), + type: "object", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + + return { + textStream: textStream, + } as T; + } catch (error) { + logger({ + category: "openai", + message: "Text streaming failed", + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + + // Re-throw the error to be handled by the caller + throw error; + } + } + async generateText({ prompt, options = {}, @@ -567,7 +945,7 @@ export class OpenAIClient extends LLMClient { try { // Log the generation attempt logger({ - category: "anthropic", + category: "openai", message: "Initiating object generation", level: 2, auxiliary: { @@ -625,10 +1003,14 @@ export class OpenAIClient extends LLMClient { // Log successful generation logger({ - category: "anthropic", - message: "Text generation successful", + category: "openai", + message: "Object generation successful", level: 2, auxiliary: { + response: { + value: JSON.stringify(generatedObject), + type: "object", + }, requestId: { value: requestId, type: "string", @@ -640,7 +1022,7 @@ export class OpenAIClient extends LLMClient { } catch (error) { // Log the error logger({ - category: "anthropic", + category: "openai", message: "Object generation failed", level: 0, auxiliary: { From d28ec13afac3859614ba0a56288fc892d52eae8d Mon Sep 17 00:00:00 2001 From: Hitesh Agarwal Date: Thu, 17 Apr 2025 17:26:42 +0800 Subject: [PATCH 4/5] fix: logging and refactor --- examples/external_clients/aisdk.ts | 1 - examples/llm_usage_wordle.ts | 25 ++++- lib/llm/AnthropicClient.ts | 124 +++++++++++++++---------- lib/llm/CerebrasClient.ts | 142 ++++++++++++++++++----------- lib/llm/GoogleClient.ts | 95 +++++++++++-------- lib/llm/GroqClient.ts | 133 ++++++++++++++++----------- lib/llm/OpenAIClient.ts | 132 +++++++++++++++++++-------- 7 files changed, 416 insertions(+), 236 deletions(-) diff --git a/examples/external_clients/aisdk.ts b/examples/external_clients/aisdk.ts index 8efa246c9..a20bc9558 100644 --- a/examples/external_clients/aisdk.ts +++ b/examples/external_clients/aisdk.ts @@ -183,7 +183,6 @@ export class AISdkClient extends LLMClient { schema, options = {}, }: GenerateObjectOptions): Promise { - console.log(options); const response = await generateObject({ model: this.model, prompt: prompt, diff --git a/examples/llm_usage_wordle.ts b/examples/llm_usage_wordle.ts index 61d261171..666eda7dc 100644 --- a/examples/llm_usage_wordle.ts +++ b/examples/llm_usage_wordle.ts @@ -1,17 +1,34 @@ import { Stagehand } from "@/dist"; import StagehandConfig from "@/stagehand.config"; -// import { z } from "zod"; +import { z } from "zod"; async function example() { const stagehand = new Stagehand({ ...StagehandConfig, }); - + const prompt = + "you are playing wordle. Return the 5-letter word that would be the best guess"; await stagehand.init(); + console.log("---Generating Text---"); + const { text } = await stagehand.llmClient.generateText({ + prompt: prompt, + }); + console.log(text); + + console.log("---Generating Object---"); + const { object } = await stagehand.llmClient.generateObject({ + prompt: prompt, + schema: z.object({ + guess: z + .string() + .describe("The 5-letter word that would be the best guess"), + }), + }); + console.log(object); + console.log("---Streaming Text---"); const { textStream } = await stagehand.llmClient.streamText({ - prompt: - "you are playing wordle. Return the 5-letter word that would be the best guess", + prompt: prompt, }); for await (const textPart of textStream) { diff --git a/lib/llm/AnthropicClient.ts b/lib/llm/AnthropicClient.ts index 16aaeef3f..6bd9327c3 100644 --- a/lib/llm/AnthropicClient.ts +++ b/lib/llm/AnthropicClient.ts @@ -386,7 +386,6 @@ export class AnthropicClient extends LLMClient { retries, logger, }: CreateChatCompletionOptions): Promise { - console.log(options, logger, retries); const optionsWithoutImage = { ...options }; delete optionsWithoutImage.image; @@ -648,7 +647,7 @@ export class AnthropicClient extends LLMClient { logger({ category: "anthropic", - message: "text streaming response", + message: "Text streaming response", level: 2, auxiliary: { response: { @@ -699,24 +698,24 @@ export class AnthropicClient extends LLMClient { ...chatOptions } = options; - try { - // Log the generation attempt - logger({ - category: "anthropic", - message: "Initiating text generation", - level: 2, - auxiliary: { - prompt: { - value: prompt, - type: "string", - }, - requestId: { - value: requestId, - type: "string", - }, + // Log the generation attempt + logger({ + category: "anthropic", + message: "Initiating text generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", }, - }); + requestId: { + value: requestId, + type: "string", + }, + }, + }); + try { // Create chat completion with the provided prompt const response = (await this.createChatCompletion({ options: { @@ -734,7 +733,28 @@ export class AnthropicClient extends LLMClient { })) as LLMResponse; // Validate response structure - if (!response.choices || response.choices.length === 0) { + if ( + !response.choices || + response.choices.length === 0 || + response.choices[0].message.content == undefined || + response.choices[0].message.content == null + ) { + logger({ + category: "anthropic", + message: "Text generation failed", + level: 0, + auxiliary: { + error: { + value: "API response contains no valid choices", + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + throw new CreateChatCompletionResponseError( "API response contains no valid choices", ); @@ -742,11 +762,6 @@ export class AnthropicClient extends LLMClient { // Extract and validate the generated text const generatedText = response.choices[0].message.content; - if (generatedText === null || generatedText === undefined) { - throw new CreateChatCompletionResponseError( - "Generated text content is empty", - ); - } // Construct the final response const textResponse = { @@ -812,24 +827,24 @@ export class AnthropicClient extends LLMClient { ...chatOptions } = options; - try { - // Log the generation attempt - logger({ - category: "anthropic", - message: "Initiating object generation", - level: 2, - auxiliary: { - prompt: { - value: prompt, - type: "string", - }, - requestId: { - value: requestId, - type: "string", - }, + // Log the generation attempt + logger({ + category: "anthropic", + message: "Initiating object generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", }, - }); + requestId: { + value: requestId, + type: "string", + }, + }, + }); + try { // Create chat completion with the provided prompt const response = (await this.createChatCompletion({ options: { @@ -851,7 +866,27 @@ export class AnthropicClient extends LLMClient { })) as LLMObjectResponse; // Validate response structure - if (!response.data || response.data.length === 0) { + if ( + !response.data || + response.data.length === 0 || + response.data === undefined + ) { + logger({ + category: "anthropic", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: "API response contains no valid choices", + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + throw new CreateChatCompletionResponseError( "API response contains no valid choices", ); @@ -859,11 +894,6 @@ export class AnthropicClient extends LLMClient { // Extract and validate the generated text const generatedObject = response.data; - if (generatedObject === null || generatedObject === undefined) { - throw new CreateChatCompletionResponseError( - "Generated text content is empty", - ); - } // Construct the final response const objResponse = { @@ -874,7 +904,7 @@ export class AnthropicClient extends LLMClient { // Log successful generation logger({ category: "anthropic", - message: "Text generation successful", + message: "Object generation successful", level: 2, auxiliary: { requestId: { diff --git a/lib/llm/CerebrasClient.ts b/lib/llm/CerebrasClient.ts index 305c48ec7..ebd2345f6 100644 --- a/lib/llm/CerebrasClient.ts +++ b/lib/llm/CerebrasClient.ts @@ -594,32 +594,32 @@ export class CerebrasClient extends LLMClient { ...chatOptions } = options; - try { - // Log the generation attempt - logger({ - category: "cerebras", - message: "Initiating text generation", - level: 2, - auxiliary: { - prompt: { - value: prompt, - type: "string", - }, - requestId: { - value: requestId, - type: "string", - }, - model: { - value: this.modelName, - type: "string", - }, - options: { - value: JSON.stringify(chatOptions), - type: "object", - }, + // Log the generation attempt + logger({ + category: "cerebras", + message: "Initiating text generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", }, - }); + requestId: { + value: requestId, + type: "string", + }, + model: { + value: this.modelName, + type: "string", + }, + options: { + value: JSON.stringify(chatOptions), + type: "object", + }, + }, + }); + try { // Create chat completion with the provided prompt const response = (await this.createChatCompletion({ options: { @@ -637,7 +637,28 @@ export class CerebrasClient extends LLMClient { })) as LLMResponse; // Validate response structure - if (!response.choices || response.choices.length === 0) { + if ( + !response.choices || + response.choices.length === 0 || + response.choices[0].message.content === null || + response.choices[0].message.content === undefined + ) { + logger({ + category: "cerebras", + message: "Text generation failed", + level: 0, + auxiliary: { + error: { + value: "API response contains no valid choices", + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + throw new CreateChatCompletionResponseError( "API response contains no valid choices", ); @@ -645,11 +666,6 @@ export class CerebrasClient extends LLMClient { // Extract and validate the generated text const generatedContent = response.choices[0].message.content; - if (generatedContent === null || generatedContent === undefined) { - throw new CreateChatCompletionResponseError( - "Generated text content is empty", - ); - } // Construct the final response with additional metadata const textResponse = { @@ -745,24 +761,24 @@ export class CerebrasClient extends LLMClient { ...chatOptions } = options; - try { - // Log the generation attempt - logger({ - category: "anthropic", - message: "Initiating object generation", - level: 2, - auxiliary: { - prompt: { - value: prompt, - type: "string", - }, - requestId: { - value: requestId, - type: "string", - }, + // Log the generation attempt + logger({ + category: "cerebras", + message: "Initiating object generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", }, - }); + requestId: { + value: requestId, + type: "string", + }, + }, + }); + try { // Create chat completion with the provided prompt const response = (await this.createChatCompletion({ options: { @@ -782,8 +798,29 @@ export class CerebrasClient extends LLMClient { logger, retries, })) as LLMObjectResponse; + // Validate response structure - if (!response.data || response.data.length === 0) { + if ( + !response.data || + response.data.length === 0 || + response.data === undefined + ) { + logger({ + category: "cerebras", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: "API response contains no valid choices", + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + throw new CreateChatCompletionResponseError( "API response contains no valid choices", ); @@ -791,11 +828,6 @@ export class CerebrasClient extends LLMClient { // Extract and validate the generated text const generatedObject = response.data; - if (generatedObject === null || generatedObject === undefined) { - throw new CreateChatCompletionResponseError( - "Generated text content is empty", - ); - } // Construct the final response const objResponse = { @@ -805,8 +837,8 @@ export class CerebrasClient extends LLMClient { // Log successful generation logger({ - category: "anthropic", - message: "Text generation successful", + category: "cerebras", + message: "Object generation successful", level: 2, auxiliary: { requestId: { @@ -820,7 +852,7 @@ export class CerebrasClient extends LLMClient { } catch (error) { // Log the error logger({ - category: "anthropic", + category: "cerebras", message: "Object generation failed", level: 0, auxiliary: { diff --git a/lib/llm/GoogleClient.ts b/lib/llm/GoogleClient.ts index 0f26edf78..9ded4cc76 100644 --- a/lib/llm/GoogleClient.ts +++ b/lib/llm/GoogleClient.ts @@ -635,20 +635,6 @@ export class GoogleClient extends LLMClient { formattedMessages.push({ role: "model", parts: [{ text: "{" }] }); // Prime the model } - logger({ - category: "google", - message: "creating chat completion", - level: 2, - auxiliary: { - modelName: { value: this.modelName, type: "string" }, - requestId: { value: requestId, type: "string" }, - requestPayloadSummary: { - value: `Model: ${this.modelName}, Messages: ${formattedMessages.length}, Config Keys: ${Object.keys(generationConfig).join(", ")}, Tools: ${formattedTools ? formattedTools.length : 0}, Safety Categories: ${safetySettings.map((s) => s.category).join(", ")}`, - type: "string", - }, - }, - }); - // Construct the full request object const requestPayload = { model: this.modelName, @@ -822,6 +808,22 @@ export class GoogleClient extends LLMClient { text: response.choices[0].message.content, } as T; } else { + logger({ + category: "google", + message: "text generation failed", + level: 0, + auxiliary: { + error: { + value: "No choices available in the response", + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + throw new CreateChatCompletionResponseError( "No choices available in the response", ); @@ -840,24 +842,24 @@ export class GoogleClient extends LLMClient { ...chatOptions } = options; - try { - // Log the generation attempt - logger({ - category: "anthropic", - message: "Initiating object generation", - level: 2, - auxiliary: { - prompt: { - value: prompt, - type: "string", - }, - requestId: { - value: requestId, - type: "string", - }, + // Log the generation attempt + logger({ + category: "google", + message: "Initiating object generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", }, - }); + requestId: { + value: requestId, + type: "string", + }, + }, + }); + try { // Create chat completion with the provided prompt const response = (await this.createChatCompletion({ options: { @@ -878,7 +880,27 @@ export class GoogleClient extends LLMClient { retries, })) as LLMObjectResponse; // Validate response structure - if (!response.data || response.data.length === 0) { + if ( + !response.data || + response.data.length === 0 || + response.data === undefined + ) { + logger({ + category: "google", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: "API response contains no valid choices", + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + throw new CreateChatCompletionResponseError( "API response contains no valid choices", ); @@ -886,11 +908,6 @@ export class GoogleClient extends LLMClient { // Extract and validate the generated text const generatedObject = response.data; - if (generatedObject === null || generatedObject === undefined) { - throw new CreateChatCompletionResponseError( - "Generated text content is empty", - ); - } // Construct the final response const objResponse = { @@ -900,8 +917,8 @@ export class GoogleClient extends LLMClient { // Log successful generation logger({ - category: "anthropic", - message: "Text generation successful", + category: "google", + message: "Object generation successful", level: 2, auxiliary: { requestId: { @@ -915,7 +932,7 @@ export class GoogleClient extends LLMClient { } catch (error) { // Log the error logger({ - category: "anthropic", + category: "google", message: "Object generation failed", level: 0, auxiliary: { diff --git a/lib/llm/GroqClient.ts b/lib/llm/GroqClient.ts index cd011f967..a1a9fdb31 100644 --- a/lib/llm/GroqClient.ts +++ b/lib/llm/GroqClient.ts @@ -594,28 +594,28 @@ export class GroqClient extends LLMClient { ...chatOptions } = options; - try { - // Log the generation attempt - logger({ - category: "groq", - message: "Initiating text generation", - level: 2, - auxiliary: { - prompt: { - value: prompt, - type: "string", - }, - requestId: { - value: requestId, - type: "string", - }, - model: { - value: this.modelName, - type: "string", - }, + // Log the generation attempt + logger({ + category: "groq", + message: "Initiating text generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", }, - }); + requestId: { + value: requestId, + type: "string", + }, + model: { + value: this.modelName, + type: "string", + }, + }, + }); + try { // Create chat completion with the provided prompt const response = (await this.createChatCompletion({ options: { @@ -633,7 +633,28 @@ export class GroqClient extends LLMClient { })) as LLMResponse; // Validate response structure - if (!response.choices || response.choices.length === 0) { + if ( + !response.choices || + response.choices.length === 0 || + response.choices[0].message.content == null || + response.choices[0].message.content === undefined + ) { + logger({ + category: "groq", + message: "Text generation failed", + level: 0, + auxiliary: { + error: { + value: "API response contains no valid choices", + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + throw new CreateChatCompletionResponseError( "API response contains no valid choices", ); @@ -641,11 +662,6 @@ export class GroqClient extends LLMClient { // Extract and validate the generated text const generatedContent = response.choices[0].message.content; - if (generatedContent === null || generatedContent === undefined) { - throw new CreateChatCompletionResponseError( - "Generated text content is empty", - ); - } // Construct the final response const textResponse = { @@ -727,24 +743,24 @@ export class GroqClient extends LLMClient { ...chatOptions } = options; - try { - // Log the generation attempt - logger({ - category: "anthropic", - message: "Initiating object generation", - level: 2, - auxiliary: { - prompt: { - value: prompt, - type: "string", - }, - requestId: { - value: requestId, - type: "string", - }, + // Log the generation attempt + logger({ + category: "groq", + message: "Initiating object generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", }, - }); + requestId: { + value: requestId, + type: "string", + }, + }, + }); + try { // Create chat completion with the provided prompt const response = (await this.createChatCompletion({ options: { @@ -765,7 +781,27 @@ export class GroqClient extends LLMClient { retries, })) as LLMObjectResponse; // Validate response structure - if (!response.data || response.data.length === 0) { + if ( + !response.data || + response.data.length === 0 || + response.data === undefined + ) { + logger({ + category: "groq", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: "API response contains no valid choices", + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + throw new CreateChatCompletionResponseError( "API response contains no valid choices", ); @@ -773,11 +809,6 @@ export class GroqClient extends LLMClient { // Extract and validate the generated text const generatedObject = response.data; - if (generatedObject === null || generatedObject === undefined) { - throw new CreateChatCompletionResponseError( - "Generated text content is empty", - ); - } // Construct the final response const objResponse = { @@ -787,8 +818,8 @@ export class GroqClient extends LLMClient { // Log successful generation logger({ - category: "anthropic", - message: "Text generation successful", + category: "groq", + message: "Object generation successful", level: 2, auxiliary: { requestId: { @@ -802,7 +833,7 @@ export class GroqClient extends LLMClient { } catch (error) { // Log the error logger({ - category: "anthropic", + category: "groq", message: "Object generation failed", level: 0, auxiliary: { diff --git a/lib/llm/OpenAIClient.ts b/lib/llm/OpenAIClient.ts index 4bd49124e..39094a98b 100644 --- a/lib/llm/OpenAIClient.ts +++ b/lib/llm/OpenAIClient.ts @@ -861,24 +861,24 @@ export class OpenAIClient extends LLMClient { // Generate a unique request ID if not provided const requestId = options.requestId || Date.now().toString(); - try { - // Log the generation attempt - logger({ - category: "openai", - message: "Initiating text generation", - level: 2, - auxiliary: { - prompt: { - value: prompt, - type: "string", - }, - requestId: { - value: requestId, - type: "string", - }, + // Log the generation attempt + logger({ + category: "openai", + message: "Initiating text generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", }, - }); + requestId: { + value: requestId, + type: "string", + }, + }, + }); + try { // Create a chat completion with the prompt as a user message const response = (await this.createChatCompletion({ options: { @@ -897,12 +897,46 @@ export class OpenAIClient extends LLMClient { // Validate and extract the generated text from the response if (response.choices && response.choices.length > 0) { + // Log successful generation + logger({ + category: "openai", + message: "Text generation successful", + level: 2, + auxiliary: { + response: { + value: JSON.stringify(response.choices[0].message.content), + type: "object", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + return { ...response, text: response.choices[0].message.content, } as T; } + // Log the error if a logger is provided + logger({ + category: "openai", + message: "Text generation failed", + level: 0, + auxiliary: { + error: { + value: "No valid choices found in API response", + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + }, + }); + // Throw error if no valid response was generated throw new CreateChatCompletionResponseError( "No valid choices found in API response", @@ -929,6 +963,7 @@ export class OpenAIClient extends LLMClient { throw error; } } + async generateObject({ prompt, schema, @@ -942,24 +977,24 @@ export class OpenAIClient extends LLMClient { ...chatOptions } = options; - try { - // Log the generation attempt - logger({ - category: "openai", - message: "Initiating object generation", - level: 2, - auxiliary: { - prompt: { - value: prompt, - type: "string", - }, - requestId: { - value: requestId, - type: "string", - }, + // Log the generation attempt + logger({ + category: "openai", + message: "Initiating object generation", + level: 2, + auxiliary: { + prompt: { + value: prompt, + type: "string", }, - }); + requestId: { + value: requestId, + type: "string", + }, + }, + }); + try { // Create chat completion with the provided prompt const response = (await this.createChatCompletion({ options: { @@ -981,7 +1016,31 @@ export class OpenAIClient extends LLMClient { })) as LLMObjectResponse; // Validate response structure - if (!response.data || response.data.length === 0) { + if ( + !response.data || + response.data.length === 0 || + response.data == undefined + ) { + logger({ + category: "openai", + message: "Object generation failed", + level: 0, + auxiliary: { + error: { + value: "API response contains no valid choices", + type: "string", + }, + prompt: { + value: prompt, + type: "string", + }, + requestId: { + value: requestId, + type: "string", + }, + }, + }); + throw new CreateChatCompletionResponseError( "API response contains no valid choices", ); @@ -989,11 +1048,6 @@ export class OpenAIClient extends LLMClient { // Extract and validate the generated text const generatedObject = response.data; - if (generatedObject === null || generatedObject === undefined) { - throw new CreateChatCompletionResponseError( - "Generated text content is empty", - ); - } // Construct the final response const objResponse = { @@ -1008,7 +1062,7 @@ export class OpenAIClient extends LLMClient { level: 2, auxiliary: { response: { - value: JSON.stringify(generatedObject), + value: JSON.stringify(objResponse), type: "object", }, requestId: { From 00fe3958ad29fac0d1b6b088090e8448ae6dcc33 Mon Sep 17 00:00:00 2001 From: Hitesh Agarwal Date: Fri, 18 Apr 2025 01:22:37 +0800 Subject: [PATCH 5/5] refactor --- examples/external_clients/aisdk.ts | 18 ++------------ examples/external_clients/customOpenAI.ts | 29 +++++++++-------------- examples/llm_usage_wordle.ts | 8 +++---- lib/llm/AnthropicClient.ts | 19 +++++++++++---- lib/llm/CerebrasClient.ts | 12 +++++++--- lib/llm/GoogleClient.ts | 15 +++++++++--- lib/llm/GroqClient.ts | 20 +++++++++++----- lib/llm/LLMClient.ts | 8 +++++-- lib/llm/OpenAIClient.ts | 19 ++++++++++++--- 9 files changed, 88 insertions(+), 60 deletions(-) diff --git a/examples/external_clients/aisdk.ts b/examples/external_clients/aisdk.ts index a20bc9558..491cb277d 100644 --- a/examples/external_clients/aisdk.ts +++ b/examples/external_clients/aisdk.ts @@ -169,14 +169,7 @@ export class AISdkClient extends LLMClient { prompt: prompt, tools, }); - return { - text: response.text, - usage: { - prompt_tokens: response.usage.promptTokens ?? 0, - completion_tokens: response.usage.completionTokens ?? 0, - total_tokens: response.usage.totalTokens ?? 0, - }, - } as T; + return response as T; } async generateObject({ prompt, @@ -189,13 +182,6 @@ export class AISdkClient extends LLMClient { schema: schema, ...options, }); - return { - object: response.object, - usage: { - prompt_tokens: response.usage.promptTokens ?? 0, - completion_tokens: response.usage.completionTokens ?? 0, - total_tokens: response.usage.totalTokens ?? 0, - }, - } as T; + return response as T; } } diff --git a/examples/external_clients/customOpenAI.ts b/examples/external_clients/customOpenAI.ts index b5af6ad6a..37d913ee5 100644 --- a/examples/external_clients/customOpenAI.ts +++ b/examples/external_clients/customOpenAI.ts @@ -231,22 +231,11 @@ export class CustomOpenAIClient extends LLMClient { return { data: parsedData, - usage: { - prompt_tokens: response.usage?.prompt_tokens ?? 0, - completion_tokens: response.usage?.completion_tokens ?? 0, - total_tokens: response.usage?.total_tokens ?? 0, - }, + response: response, } as T; } - return { - choices: response.choices, - usage: { - prompt_tokens: response.usage?.prompt_tokens ?? 0, - completion_tokens: response.usage?.completion_tokens ?? 0, - total_tokens: response.usage?.total_tokens ?? 0, - }, - } as T; + return response as T; } async createChatCompletionStream({ @@ -510,8 +499,10 @@ export class CustomOpenAIClient extends LLMClient { // Validate response and extract generated text if (res.choices && res.choices.length > 0) { return { - ...res, text: res.choices[0].message.content, + finishReason: res.choices[0].finish_reason, + usage: res.usage, + response: res, } as T; } else { throw new CreateChatCompletionResponseError("No choices in response"); @@ -533,7 +524,7 @@ export class CustomOpenAIClient extends LLMClient { try { // Log the generation attempt logger({ - category: "anthropic", + category: "openai", message: "Initiating object generation", level: 2, auxiliary: { @@ -585,13 +576,15 @@ export class CustomOpenAIClient extends LLMClient { // Construct the final response const objResponse = { - ...response, object: generatedObject, + finishReason: response.response.choices[0].finish_reason, + usage: response.response.usage, + ...response, } as T; // Log successful generation logger({ - category: "anthropic", + category: "openai", message: "Text generation successful", level: 2, auxiliary: { @@ -606,7 +599,7 @@ export class CustomOpenAIClient extends LLMClient { } catch (error) { // Log the error logger({ - category: "anthropic", + category: "openai", message: "Object generation failed", level: 0, auxiliary: { diff --git a/examples/llm_usage_wordle.ts b/examples/llm_usage_wordle.ts index 666eda7dc..73681cfb6 100644 --- a/examples/llm_usage_wordle.ts +++ b/examples/llm_usage_wordle.ts @@ -10,13 +10,13 @@ async function example() { "you are playing wordle. Return the 5-letter word that would be the best guess"; await stagehand.init(); console.log("---Generating Text---"); - const { text } = await stagehand.llmClient.generateText({ + const responseText = await stagehand.llmClient.generateText({ prompt: prompt, }); - console.log(text); + console.log(responseText); console.log("---Generating Object---"); - const { object } = await stagehand.llmClient.generateObject({ + const responseObj = await stagehand.llmClient.generateObject({ prompt: prompt, schema: z.object({ guess: z @@ -24,7 +24,7 @@ async function example() { .describe("The 5-letter word that would be the best guess"), }), }); - console.log(object); + console.log(responseObj); console.log("---Streaming Text---"); const { textStream } = await stagehand.llmClient.streamText({ diff --git a/lib/llm/AnthropicClient.ts b/lib/llm/AnthropicClient.ts index 6bd9327c3..749ba2151 100644 --- a/lib/llm/AnthropicClient.ts +++ b/lib/llm/AnthropicClient.ts @@ -296,7 +296,6 @@ export class AnthropicClient extends LLMClient { ], usage: usageData, }; - logger({ category: "anthropic", message: "transformed response", @@ -321,12 +320,12 @@ export class AnthropicClient extends LLMClient { const finalParsedResponse = { data: result, usage: usageData, + ...response, } as unknown as T; if (this.enableCaching) { this.cache.set(cacheOptions, finalParsedResponse, options.requestId); } - return finalParsedResponse; } else { if (!retries || retries < 5) { @@ -578,7 +577,11 @@ export class AnthropicClient extends LLMClient { stream: true, }); - // Restructure the response to match the expected format + // TODO: Transform response stream to preferred format + // TODO: Response model validation + // TODO: Enable caching + + // Temporarily restructure the response to match the expected format return new ReadableStream({ async start(controller) { try { @@ -765,8 +768,12 @@ export class AnthropicClient extends LLMClient { // Construct the final response const textResponse = { - ...response, text: generatedText, + finishReason: response.choices[0].finish_reason, + usage: response.usage, + response: response, + // reasoning: response.reasoning, + // sources: response.sources } as T; // Log successful generation @@ -897,8 +904,10 @@ export class AnthropicClient extends LLMClient { // Construct the final response const objResponse = { - ...response, object: generatedObject, + // finishReason: response.stop_reason, + // usage: response.response.usage, + response: response, } as T; // Log successful generation diff --git a/lib/llm/CerebrasClient.ts b/lib/llm/CerebrasClient.ts index ebd2345f6..419787d7f 100644 --- a/lib/llm/CerebrasClient.ts +++ b/lib/llm/CerebrasClient.ts @@ -250,7 +250,7 @@ export class CerebrasClient extends LLMClient { if (this.enableCaching) { this.cache.set(cacheOptions, result, options.requestId); } - return result as T; + return { data: result, response: response } as T; } catch (e) { // If JSON parse fails, the model might be returning a different format logger({ @@ -278,7 +278,7 @@ export class CerebrasClient extends LLMClient { if (this.enableCaching) { this.cache.set(cacheOptions, result, options.requestId); } - return result as T; + return { data: result, response: response } as T; } } catch (e) { logger({ @@ -475,6 +475,10 @@ export class CerebrasClient extends LLMClient { tool_choice: options.tool_choice || "auto", }); + // TODO: transform response to required format + // TODO: Validate response model + // TODO: Enable caching + return apiResponse as T; } @@ -831,8 +835,10 @@ export class CerebrasClient extends LLMClient { // Construct the final response const objResponse = { - ...response, object: generatedObject, + finishReason: response.response.choices[0].finish_reason, + usage: response.response.usage, + ...response, } as T; // Log successful generation diff --git a/lib/llm/GoogleClient.ts b/lib/llm/GoogleClient.ts index 9ded4cc76..14155a9fe 100644 --- a/lib/llm/GoogleClient.ts +++ b/lib/llm/GoogleClient.ts @@ -483,6 +483,7 @@ export class GoogleClient extends LLMClient { const extractionResult = { data: parsedData, usage: llmResponse.usage, + response: llmResponse, }; if (this.enableCaching) { @@ -552,7 +553,6 @@ export class GoogleClient extends LLMClient { top_p, maxTokens, } = options; - console.log(retries); logger({ category: "google", @@ -583,6 +583,7 @@ export class GoogleClient extends LLMClient { : undefined, tools: tools, maxTokens: maxTokens, + retries: retries, }; if (this.enableCaching) { @@ -675,6 +676,10 @@ export class GoogleClient extends LLMClient { const result = await this.client.models.generateContentStream(requestPayload); + // TODO: transform response to required format + // TODO: Validate response model + // TODO: Enable caching + return new ReadableStream({ async start(controller) { try { @@ -804,8 +809,10 @@ export class GoogleClient extends LLMClient { // Validate and extract the generated text from the response if (response.choices && response.choices.length > 0) { return { - ...response, text: response.choices[0].message.content, + finishReason: response.choices[0].finish_reason, + usage: response.usage, + response: response, } as T; } else { logger({ @@ -911,8 +918,10 @@ export class GoogleClient extends LLMClient { // Construct the final response const objResponse = { - ...response, object: generatedObject, + finishReason: response.response.choices[0].finish_reason, + usage: response.usage, + response: response, } as T; // Log successful generation diff --git a/lib/llm/GroqClient.ts b/lib/llm/GroqClient.ts index a1a9fdb31..1ee47ad20 100644 --- a/lib/llm/GroqClient.ts +++ b/lib/llm/GroqClient.ts @@ -250,7 +250,7 @@ export class GroqClient extends LLMClient { if (this.enableCaching) { this.cache.set(cacheOptions, result, options.requestId); } - return { data: result } as T; + return { data: result, response: response } as T; } catch (e) { // If JSON parse fails, the model might be returning a different format logger({ @@ -278,7 +278,7 @@ export class GroqClient extends LLMClient { if (this.enableCaching) { this.cache.set(cacheOptions, result, options.requestId); } - return { data: result } as T; + return { data: result, response: response } as T; } } catch (e) { logger({ @@ -477,6 +477,10 @@ export class GroqClient extends LLMClient { stream: true, }); + // TODO: transform response to required format + // TODO: Validate response model + // TODO: Enable caching + return apiResponse as T; } @@ -665,10 +669,12 @@ export class GroqClient extends LLMClient { // Construct the final response const textResponse = { - ...response, text: generatedContent, - modelName: this.modelName, - timestamp: Date.now(), + finishReason: response.choices[0].finish_reason, + usage: response.usage, + response: response, + // reasoning: response.reasoning, + // sources: response.sources } as T; // Log successful generation @@ -812,8 +818,10 @@ export class GroqClient extends LLMClient { // Construct the final response const objResponse = { - ...response, object: generatedObject, + finishReason: response.response.choices[0].finish_reason, + usage: response.response.usage, + ...response, } as T; // Log successful generation diff --git a/lib/llm/LLMClient.ts b/lib/llm/LLMClient.ts index 10f000233..31cb83fb1 100644 --- a/lib/llm/LLMClient.ts +++ b/lib/llm/LLMClient.ts @@ -128,13 +128,17 @@ export type StreamingChatResponse = AsyncIterable; export interface LLMObjectResponse extends BaseResponse { data: Record; usage: UsageMetrics; + response: LLMResponse; } // Text Response type that can include LLM properties export interface TextResponse extends BaseResponse { text: string; - choices?: LLMChoice[]; - usage?: UsageMetrics; + finishReason: string; + usage: UsageMetrics; + response: LLMResponse; + // reasoning: string; + // sources: string[]; } // Object Response type that can include LLM properties diff --git a/lib/llm/OpenAIClient.ts b/lib/llm/OpenAIClient.ts index 39094a98b..00a5d55b8 100644 --- a/lib/llm/OpenAIClient.ts +++ b/lib/llm/OpenAIClient.ts @@ -445,6 +445,7 @@ export class OpenAIClient extends LLMClient { return { data: parsedData, usage: response.usage, + response: response, } as T; } @@ -744,6 +745,11 @@ export class OpenAIClient extends LLMClient { })), }; const response = await this.client.chat.completions.create(body); + + // TODO: O1 models - parse the tool call response manually and add it to the response + // TODO: Response model validation + // TODO: Caching + return response as T; } @@ -915,8 +921,12 @@ export class OpenAIClient extends LLMClient { }); return { - ...response, text: response.choices[0].message.content, + finishReason: response.choices[0].finish_reason, + usage: response.usage, + response: response, + // reasoning: response.reasoning, + // sources: response.sources } as T; } @@ -1014,7 +1024,6 @@ export class OpenAIClient extends LLMClient { logger, retries, })) as LLMObjectResponse; - // Validate response structure if ( !response.data || @@ -1051,8 +1060,12 @@ export class OpenAIClient extends LLMClient { // Construct the final response const objResponse = { - ...response, object: generatedObject, + finishReason: response.response.choices[0].finish_reason, + usage: response.usage, + response: response, + // reasoning: response.reasoning, + // sources: response.sources } as T; // Log successful generation