diff --git a/evals/args.ts b/evals/args.ts index 453f59d47..6f64d3253 100644 --- a/evals/args.ts +++ b/evals/args.ts @@ -93,7 +93,7 @@ function buildUsage(detailed = false): string { providerDefault, )}) [${chalk.yellow("OPENAI")}, ${chalk.yellow( "ANTHROPIC", - )}, ${chalk.yellow("GOOGLE")}, ${chalk.yellow("TOGETHER")}, ${chalk.yellow( + )}, ${chalk.yellow("GOOGLE")}, ${chalk.yellow("OPENROUTER")}, ${chalk.yellow("TOGETHER")}, ${chalk.yellow( "GROQ", )}, ${chalk.yellow("CEREBRAS")}] diff --git a/evals/index.eval.ts b/evals/index.eval.ts index c66ad3ffb..78496dbf5 100644 --- a/evals/index.eval.ts +++ b/evals/index.eval.ts @@ -39,6 +39,7 @@ import { groq } from "@ai-sdk/groq"; import { cerebras } from "@ai-sdk/cerebras"; import { openai } from "@ai-sdk/openai"; import { AISdkClient } from "@/examples/external_clients/aisdk"; +import { OpenRouterClient } from "@/lib/llm/OpenRouterClient"; dotenv.config(); /** @@ -351,6 +352,17 @@ const generateFilteredTestcases = (): Testcase[] => { ), ), }); + } else if (input.modelName.startsWith("x-ai/")) { + // Handle OpenRouter models (e.g., x-ai/grok-4) using native OpenRouterClient + llmClient = new OpenRouterClient({ + logger: logger.log.bind(logger), + enableCaching: false, + cache: undefined, + modelName: input.modelName as AvailableModel, + clientOptions: { + apiKey: process.env.OPENROUTER_API_KEY, + }, + }); } else if (input.modelName.includes("/")) { llmClient = new CustomOpenAIClient({ modelName: input.modelName as AvailableModel, @@ -361,6 +373,10 @@ const generateFilteredTestcases = (): Testcase[] => { }), ), }); + } else { + throw new StagehandEvalError( + `Unsupported model: ${input.modelName}. Please add support for this model in the evals.`, + ); } const taskInput = await initStagehand({ logger, diff --git a/evals/taskConfig.ts b/evals/taskConfig.ts index e8207403a..e9eade9c9 100644 --- a/evals/taskConfig.ts +++ b/evals/taskConfig.ts @@ -35,6 +35,8 @@ const ALL_EVAL_MODELS = [ "o3", "o3-mini", "o4-mini", + // OPENROUTER + "x-ai/grok-4", // TOGETHER - META "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", "meta-llama/Llama-3.3-70B-Instruct-Turbo", @@ -136,6 +138,8 @@ const filterModelByProvider = (model: string, provider: string): boolean => { return modelLower.startsWith("claude"); } else if (provider === "google") { return modelLower.startsWith("gemini"); + } else if (provider === "openrouter") { + return modelLower.startsWith("x-ai/"); } else if (provider === "together") { return ( modelLower.startsWith("meta-llama") || diff --git a/lib/llm/LLMProvider.ts b/lib/llm/LLMProvider.ts index fc11c5753..4df3e396f 100644 --- a/lib/llm/LLMProvider.ts +++ b/lib/llm/LLMProvider.ts @@ -17,6 +17,7 @@ import { GoogleClient } from "./GoogleClient"; import { GroqClient } from "./GroqClient"; import { LLMClient } from "./LLMClient"; import { OpenAIClient } from "./OpenAIClient"; +import { OpenRouterClient } from "./OpenRouterClient"; import { openai, createOpenAI } from "@ai-sdk/openai"; import { anthropic, createAnthropic } from "@ai-sdk/anthropic"; import { google, createGoogleGenerativeAI } from "@ai-sdk/google"; @@ -91,6 +92,7 @@ const modelToProviderMap: { [key in AvailableModel]: ModelProvider } = { "gemini-2.0-flash": "google", "gemini-2.5-flash-preview-04-17": "google", "gemini-2.5-pro-preview-03-25": "google", + "x-ai/grok-4": "openrouter", }; function getAISDKLanguageModel( @@ -221,6 +223,14 @@ export class LLMProvider { modelName: availableModel, clientOptions, }); + case "openrouter": + return new OpenRouterClient({ + logger: this.logger, + enableCaching: this.enableCaching, + cache: this.cache, + modelName: availableModel, + clientOptions, + }); default: throw new UnsupportedModelProviderError([ ...new Set(Object.values(modelToProviderMap)), diff --git a/lib/llm/OpenRouterClient.ts b/lib/llm/OpenRouterClient.ts new file mode 100644 index 000000000..708259807 --- /dev/null +++ b/lib/llm/OpenRouterClient.ts @@ -0,0 +1,560 @@ +import OpenAI from "openai"; +import type { ClientOptions } from "openai"; +import { + ChatCompletionContentPartImage, + ChatCompletionContentPartText, + ChatCompletionMessageParam, + ChatCompletionSystemMessageParam, + ChatCompletionUserMessageParam, + ChatCompletionAssistantMessageParam, +} from "openai/resources/chat"; +import zodToJsonSchema from "zod-to-json-schema"; +import { LogLine } from "../../types/log"; +import { AvailableModel } from "../../types/model"; +import { LLMCache } from "../cache/LLMCache"; +import { validateZodSchema } from "../utils"; +import { + CreateChatCompletionOptions, + LLMClient, + LLMResponse, +} from "./LLMClient"; +import { + CreateChatCompletionResponseError, + ZodSchemaValidationError, +} from "@/types/stagehandErrors"; + +export class OpenRouterClient extends LLMClient { + public type = "openrouter" as const; + private client: OpenAI; + private cache: LLMCache | undefined; + private enableCaching: boolean; + public clientOptions: ClientOptions; + public hasVision = true; // Grok 4 supports vision + + constructor({ + enableCaching = false, + cache, + modelName, + clientOptions, + userProvidedInstructions, + }: { + logger: (message: LogLine) => void; + enableCaching?: boolean; + cache?: LLMCache; + modelName: AvailableModel; + clientOptions?: ClientOptions; + userProvidedInstructions?: string; + }) { + super(modelName, userProvidedInstructions); + + // Create OpenAI client with OpenRouter API + this.client = new OpenAI({ + baseURL: "https://openrouter.ai/api/v1", + apiKey: clientOptions?.apiKey || process.env.OPENROUTER_API_KEY, + defaultHeaders: { + "HTTP-Referer": "https://stagehand.dev", + "X-Title": "Stagehand", + }, + ...clientOptions, + }); + + this.cache = cache; + this.enableCaching = enableCaching; + this.modelName = modelName; + this.clientOptions = clientOptions; + } + + async createChatCompletion({ + options, + retries, + logger, + }: CreateChatCompletionOptions): Promise { + logger({ + category: "openrouter", + message: "creating chat completion", + level: 2, + auxiliary: { + options: { + value: JSON.stringify(options), + type: "object", + }, + }, + }); + + const cacheOptions = { + model: this.modelName, + messages: options.messages, + temperature: options.temperature, + top_p: options.top_p, + frequency_penalty: options.frequency_penalty, + presence_penalty: options.presence_penalty, + image: options.image, + response_model: options.response_model, + tools: options.tools, + tool_choice: options.tool_choice, + maxTokens: options.maxTokens, + }; + + if (this.enableCaching) { + const cachedResponse = await this.cache.get( + cacheOptions, + options.requestId, + ); + if (cachedResponse) { + logger({ + category: "llm_cache", + message: "LLM cache hit - returning cached response", + level: 1, + auxiliary: { + requestId: { + value: options.requestId, + type: "string", + }, + cachedResponse: { + value: JSON.stringify(cachedResponse), + type: "object", + }, + }, + }); + return cachedResponse; + } else { + logger({ + category: "llm_cache", + message: "LLM cache miss - no cached response found", + level: 1, + auxiliary: { + requestId: { + value: options.requestId, + type: "string", + }, + }, + }); + } + } + + // Format messages for OpenRouter API (using OpenAI format) + const formattedMessages: ChatCompletionMessageParam[] = + options.messages.map((message) => { + if (Array.isArray(message.content)) { + const contentParts = message.content.map((content) => { + if ("image_url" in content) { + const imageContent: ChatCompletionContentPartImage = { + image_url: { + url: content.image_url.url, + }, + type: "image_url", + }; + return imageContent; + } else { + const textContent: ChatCompletionContentPartText = { + text: content.text, + type: "text", + }; + return textContent; + } + }); + + if (message.role === "system") { + const formattedMessage: ChatCompletionSystemMessageParam = { + ...message, + role: "system", + content: contentParts.filter( + (content): content is ChatCompletionContentPartText => + content.type === "text", + ), + }; + return formattedMessage; + } else if (message.role === "user") { + const formattedMessage: ChatCompletionUserMessageParam = { + ...message, + role: "user", + content: contentParts, + }; + return formattedMessage; + } else { + const formattedMessage: ChatCompletionAssistantMessageParam = { + ...message, + role: "assistant", + content: contentParts.filter( + (content): content is ChatCompletionContentPartText => + content.type === "text", + ), + }; + return formattedMessage; + } + } + + if (message.role === "system") { + const formattedMessage: ChatCompletionSystemMessageParam = { + role: "system", + content: message.content, + }; + return formattedMessage; + } else if (message.role === "assistant") { + const formattedMessage: ChatCompletionAssistantMessageParam = { + role: "assistant", + content: message.content, + }; + return formattedMessage; + } else { + const formattedMessage: ChatCompletionUserMessageParam = { + role: "user", + content: message.content, + }; + return formattedMessage; + } + }); + + // Add image if provided + if (options.image) { + const base64Image = options.image.buffer.toString("base64"); + const imageMessage = { + role: "user" as const, + content: [ + { + type: "text" as const, + text: options.image.description || "Please analyze this image.", + }, + { + type: "image_url" as const, + image_url: { + url: `data:image/png;base64,${base64Image}`, + }, + }, + ], + }; + formattedMessages.push(imageMessage); + } + + // Format tools if provided (only user-defined tools, not response models) + const tools = options.tools?.map((tool) => ({ + type: "function" as const, + function: { + name: tool.name, + description: tool.description, + parameters: { + type: "object", + properties: tool.parameters.properties, + required: tool.parameters.required, + }, + }, + })); + + // Handle response format for structured outputs + // OpenRouter only supports response_format for certain models (OpenAI, Nitro, etc.) + // Since we can't reliably detect which models support it, always use instruction-based approach + if (options.response_model) { + logger({ + category: "openrouter", + message: "Using instruction-based approach for response model", + level: 1, + auxiliary: { + requestId: { + value: options.requestId, + type: "string", + }, + }, + }); + + const parsedSchema = JSON.stringify( + zodToJsonSchema(options.response_model.schema), + ); + formattedMessages.push({ + role: "user", + content: `Respond with ONLY valid JSON that matches this exact schema:\n${parsedSchema}\n\nIMPORTANT: Your response must be valid JSON with no additional text, explanations, or markdown formatting. Start your response with '{' and end with '}'. Do not use \`\`\`json or any other formatting.`, + }); + } + + try { + // Use OpenAI client with OpenRouter API + const apiResponse = await this.client.chat.completions.create({ + model: this.modelName, + messages: formattedMessages, + temperature: options.temperature || 0.7, + max_tokens: options.maxTokens, + tools: tools, + tool_choice: + tools && tools.length > 0 ? options.tool_choice || "auto" : undefined, + top_p: options.top_p, + frequency_penalty: options.frequency_penalty, + presence_penalty: options.presence_penalty, + }); + + // Format the response to match the expected LLMResponse format + const response: LLMResponse = { + id: apiResponse.id, + object: "chat.completion", + created: apiResponse.created, + model: this.modelName, + choices: [ + { + index: 0, + message: { + role: "assistant", + content: apiResponse.choices[0]?.message?.content || null, + tool_calls: apiResponse.choices[0]?.message?.tool_calls || [], + }, + finish_reason: apiResponse.choices[0]?.finish_reason || "stop", + }, + ], + usage: { + prompt_tokens: apiResponse.usage?.prompt_tokens || 0, + completion_tokens: apiResponse.usage?.completion_tokens || 0, + total_tokens: apiResponse.usage?.total_tokens || 0, + }, + }; + + logger({ + category: "openrouter", + message: "OpenRouter chat completion finished", + level: 1, + auxiliary: { + response: { + value: JSON.stringify(response), + type: "object", + }, + requestId: { + value: options.requestId, + type: "string", + }, + }, + }); + + // Handle response_model extraction + if (options.response_model) { + let extractedData; + + // Check if we have content in the message + if (response.choices[0].message.content) { + extractedData = response.choices[0].message.content; + } + // Check if the model returned a tool call instead (some models do this) + else if ( + response.choices[0].message.tool_calls && + response.choices[0].message.tool_calls.length > 0 + ) { + // Try to extract from tool calls as a fallback + const toolCall = response.choices[0].message.tool_calls[0]; + extractedData = toolCall.function.arguments; + } else { + logger({ + category: "openrouter", + message: "No content or tool calls found in OpenRouter response", + level: 0, + auxiliary: { + response: { + value: JSON.stringify(response), + type: "object", + }, + requestId: { + value: options.requestId, + type: "string", + }, + }, + }); + + if (retries > 0) { + return this.createChatCompletion({ + options, + retries: retries - 1, + logger, + }); + } + + throw new CreateChatCompletionResponseError( + "No content or tool calls found in OpenRouter response for response_model", + ); + } + + let parsedData; + try { + parsedData = JSON.parse(extractedData); + } catch (parseError) { + // Try to extract JSON from the response if it's wrapped in text or markdown + let cleanedData = extractedData; + + // Remove markdown code blocks + cleanedData = cleanedData + .replace(/```json\n?/g, "") + .replace(/```\n?/g, ""); + + // Try to find JSON object in the text + const jsonMatch = cleanedData.match(/\{[\s\S]*\}/); + if (jsonMatch) { + try { + parsedData = JSON.parse(jsonMatch[0]); + } catch (secondParseError) { + logger({ + category: "openrouter", + message: "Failed to parse cleaned response as JSON", + level: 0, + auxiliary: { + originalData: { + value: extractedData, + type: "string", + }, + cleanedData: { + value: jsonMatch[0], + type: "string", + }, + parseError: { + value: secondParseError.message, + type: "string", + }, + requestId: { + value: options.requestId, + type: "string", + }, + }, + }); + + if (retries > 0) { + return this.createChatCompletion({ + options, + retries: retries - 1, + logger, + }); + } + + throw new CreateChatCompletionResponseError( + `Failed to parse OpenRouter response as JSON: ${secondParseError.message}`, + ); + } + } else { + logger({ + category: "openrouter", + message: "No JSON found in response", + level: 0, + auxiliary: { + extractedData: { + value: extractedData, + type: "string", + }, + parseError: { + value: parseError.message, + type: "string", + }, + requestId: { + value: options.requestId, + type: "string", + }, + }, + }); + + if (retries > 0) { + return this.createChatCompletion({ + options, + retries: retries - 1, + logger, + }); + } + + throw new CreateChatCompletionResponseError( + `No JSON found in OpenRouter response: ${parseError.message}`, + ); + } + } + + try { + validateZodSchema(options.response_model.schema, parsedData); + } catch (e) { + logger({ + category: "openrouter", + message: "Response failed Zod schema validation", + level: 0, + auxiliary: { + parsedData: { + value: JSON.stringify(parsedData), + type: "object", + }, + requestId: { + value: options.requestId, + type: "string", + }, + }, + }); + if (retries > 0) { + return this.createChatCompletion({ + options, + retries: retries - 1, + logger, + }); + } + + if (e instanceof ZodSchemaValidationError) { + logger({ + category: "openrouter", + message: `Error during OpenRouter chat completion: ${e.message}`, + level: 0, + auxiliary: { + errorDetails: { + value: `Message: ${e.message}${e.stack ? "\nStack: " + e.stack : ""}`, + type: "string", + }, + requestId: { value: options.requestId, type: "string" }, + }, + }); + throw new CreateChatCompletionResponseError(e.message); + } + throw e; + } + + const result = { + data: parsedData, + usage: response.usage, + } as T; + + if (this.enableCaching) { + await this.cache.set(cacheOptions, result, options.requestId); + } + + return result; + } + + if (this.enableCaching) { + await this.cache.set(cacheOptions, response, options.requestId); + } + + return response as T; + } catch (error) { + logger({ + category: "openrouter", + message: `OpenRouter request failed: ${error.message}`, + level: 0, + auxiliary: { + error: { + value: error.message, + type: "string", + }, + requestId: { + value: options.requestId, + type: "string", + }, + }, + }); + + if (retries > 0) { + logger({ + category: "openrouter", + message: `retrying OpenRouter request, ${retries} attempts remaining`, + level: 1, + auxiliary: { + requestId: { + value: options.requestId, + type: "string", + }, + }, + }); + return this.createChatCompletion({ + options, + retries: retries - 1, + logger, + }); + } + + throw new CreateChatCompletionResponseError( + `OpenRouter request failed: ${error.message}`, + ); + } + } +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index b81dabcd6..044dac7f8 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -204,6 +204,9 @@ importers: specifier: workspace:* version: link:.. devDependencies: + jszip: + specifier: ^3.10.1 + version: 3.10.1 tsx: specifier: ^4.10.5 version: 4.19.4 @@ -3153,6 +3156,9 @@ packages: resolution: {integrity: sha512-hsBTNUqQTDwkWtcdYI2i06Y/nUBEsNEDJKjWdigLvegy8kDuJAS8uRlpkkcQpyEXL0Z/pjDy5HBmMjRCJ2gq+g==} engines: {node: '>= 4'} + immediate@3.0.6: + resolution: {integrity: sha512-XXOFtyqDjNDAQxVfYxuF7g9Il/IbWmmlQg2MYKOH8ExIT1qg6xc4zyS3HaEEATgs1btfzxq15ciUiY7gjSXRGQ==} + immer@9.0.21: resolution: {integrity: sha512-bc4NBHqOqSfRW7POMkHd51LvClaeMXpm8dx0e8oE2GORbq5aRK7Bxl4FyzVLdGtLmvLKL7BTDBG5ACQm4HWjTA==} @@ -3464,6 +3470,9 @@ packages: resolution: {integrity: sha512-p/nXbhSEcu3pZRdkW1OfJhpsVtW1gd4Wa1fnQc9YLiTfAjn0312eMKimbdIQzuZl9aa9xUGaRlP9T/CJE/ditQ==} engines: {node: '>=0.10.0'} + jszip@3.10.1: + resolution: {integrity: sha512-xXDvecyTpGLrqFrvkrUSoxxfJI5AH7U8zxxtVclpsUtMCq4JQ290LY8AW5c7Ggnr/Y/oK+bQMbqK2qmtk3pN4g==} + jwa@2.0.0: resolution: {integrity: sha512-jrZ2Qx916EA+fq9cEAeCROWPTfCwi1IVHqT2tapuqLEVVDKFDENFw1oL+MwrTvH6msKxsd1YTDVw6uKEcsrLEA==} @@ -3504,6 +3513,9 @@ packages: resolution: {integrity: sha512-+bT2uH4E5LGE7h/n3evcS/sQlJXCpIp6ym8OWJ5eV6+67Dsql/LaaT7qJBAt2rzfoa/5QBGBhxDix1dMt2kQKQ==} engines: {node: '>= 0.8.0'} + lie@3.3.0: + resolution: {integrity: sha512-UaiMJzeWRlEujzAuw5LokY1L5ecNQYZKfmyZ9L7wDHb/p5etKaxXhohBcrw0EYby+G/NA52vRSN4N39dxHAIwQ==} + lilconfig@3.1.3: resolution: {integrity: sha512-/vlFKAoH5Cgt3Ie+JLhRbwOsCQePABiU3tJ1egGvyQ+33R/vcwM2Zl2QR/LzjsBeItPt3oSVXapn+m4nQDvpzw==} engines: {node: '>=14'} @@ -3870,6 +3882,7 @@ packages: multer@1.4.5-lts.2: resolution: {integrity: sha512-VzGiVigcG9zUAoCNU+xShztrlr1auZOlurXynNvO9GiWD1/mTBbUljOKY+qMeazBqXgRnjzeEgJI/wyjJUHg9A==} engines: {node: '>= 6.0.0'} + deprecated: Multer 1.x is impacted by a number of vulnerabilities, which have been patched in 2.x. You should upgrade to the latest 2.x version. mustache@4.2.0: resolution: {integrity: sha512-71ippSywq5Yb7/tVYyGbkBggbU8H3u5Rz56fH60jGFgr8uHwxs+aSKeqmluIVzM0m0kB7xQjKS6qPfd0b2ZoqQ==} @@ -4109,6 +4122,9 @@ packages: package-manager-detector@0.2.11: resolution: {integrity: sha512-BEnLolu+yuz22S56CU1SUKq3XC3PkwD5wv4ikR4MfGvnRVcmzXR9DwSlW2fEamyTPyXHomBJRzgapeuBvRNzJQ==} + pako@1.0.11: + resolution: {integrity: sha512-4hLB8Py4zZce5s4yd9XzopqwVv/yGNhV1Bl8NTmCq1763HeK2+EwVTv+leGeL13Dnh2wfbqowVPXCIO0z4taYw==} + parent-module@1.0.1: resolution: {integrity: sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g==} engines: {node: '>=6'} @@ -4560,6 +4576,9 @@ packages: resolution: {integrity: sha512-RJRdvCo6IAnPdsvP/7m6bsQqNnn1FCBX5ZNtFL98MmFF/4xAIJTIg1YbHW5DC2W5SKZanrC6i4HsJqlajw/dZw==} engines: {node: '>= 0.4'} + setimmediate@1.0.5: + resolution: {integrity: sha512-MATJdZp8sLqDl/68LfQmbP8zKPLQNV6BIZoIgrscFDQ+RsvK/BxeDQOgyxKKoh0y/8h3BqVFnCqQ/gd+reiIXA==} + setprototypeof@1.2.0: resolution: {integrity: sha512-E5LDX7Wrp85Kil5bhZv46j8jOeboKq5JMmYM3gVGdGH8xFpPWXUMsNrlODCrkoxMEeNi/XZIwuRvY4XNwYMJpw==} @@ -8825,6 +8844,8 @@ snapshots: ignore@5.3.2: {} + immediate@3.0.6: {} + immer@9.0.21: {} import-fresh@3.3.1: @@ -9111,6 +9132,13 @@ snapshots: jsonpointer@5.0.1: {} + jszip@3.10.1: + dependencies: + lie: 3.3.0 + pako: 1.0.11 + readable-stream: 2.3.8 + setimmediate: 1.0.5 + jwa@2.0.0: dependencies: buffer-equal-constant-time: 1.0.1 @@ -9157,6 +9185,10 @@ snapshots: prelude-ls: 1.2.1 type-check: 0.4.0 + lie@3.3.0: + dependencies: + immediate: 3.0.6 + lilconfig@3.1.3: {} linear-sum-assignment@1.0.7: @@ -10068,6 +10100,8 @@ snapshots: dependencies: quansync: 0.2.10 + pako@1.0.11: {} + parent-module@1.0.1: dependencies: callsites: 3.1.0 @@ -10694,6 +10728,8 @@ snapshots: es-errors: 1.3.0 es-object-atoms: 1.1.1 + setimmediate@1.0.5: {} + setprototypeof@1.2.0: {} sharp@0.33.5: diff --git a/types/model.ts b/types/model.ts index bdd9324b2..d4fe819b9 100644 --- a/types/model.ts +++ b/types/model.ts @@ -32,6 +32,7 @@ export const AvailableModelSchema = z.enum([ "gemini-2.0-flash", "gemini-2.5-flash-preview-04-17", "gemini-2.5-pro-preview-03-25", + "x-ai/grok-4", ]); export type AvailableModel = z.infer | string; @@ -42,6 +43,7 @@ export type ModelProvider = | "cerebras" | "groq" | "google" + | "openrouter" | "aisdk"; export type ClientOptions = OpenAIClientOptions | AnthropicClientOptions;