diff --git a/.changeset/major-goats-follow.md b/.changeset/major-goats-follow.md new file mode 100644 index 000000000..d4781ce52 --- /dev/null +++ b/.changeset/major-goats-follow.md @@ -0,0 +1,5 @@ +--- +"@browserbasehq/stagehand": minor +--- + +Support overriding clientOptions for CUA Client diff --git a/lib/agent/AgentProvider.ts b/lib/agent/AgentProvider.ts index cc8110fa4..105d32f83 100644 --- a/lib/agent/AgentProvider.ts +++ b/lib/agent/AgentProvider.ts @@ -7,6 +7,7 @@ import { UnsupportedModelError, UnsupportedModelProviderError, } from "@/types/stagehandErrors"; +import { ClientOptions } from "@/types/model"; // Map model names to their provider types const modelToAgentProviderMap: Record = { @@ -32,7 +33,7 @@ export class AgentProvider { getClient( modelName: string, - clientOptions?: Record, + clientOptions?: ClientOptions & Record, userProvidedInstructions?: string, ): AgentClient { const type = AgentProvider.getAgentProvider(modelName); diff --git a/lib/agent/AnthropicCUAClient.ts b/lib/agent/AnthropicCUAClient.ts index fbbc97d82..d7027de43 100644 --- a/lib/agent/AnthropicCUAClient.ts +++ b/lib/agent/AnthropicCUAClient.ts @@ -1,4 +1,6 @@ -import Anthropic from "@anthropic-ai/sdk"; +import Anthropic, { + type ClientOptions as AnthropicClientOptions, +} from "@anthropic-ai/sdk"; import { LogLine } from "@/types/log"; import { AgentAction, @@ -35,7 +37,7 @@ export class AnthropicCUAClient extends AgentClient { type: AgentType, modelName: string, userProvidedInstructions?: string, - clientOptions?: Record, + clientOptions?: AnthropicClientOptions & Record, ) { super(type, modelName, userProvidedInstructions); @@ -53,13 +55,7 @@ export class AnthropicCUAClient extends AgentClient { } // Store client options for reference - this.clientOptions = { - apiKey: this.apiKey, - }; - - if (this.baseURL) { - this.clientOptions.baseUrl = this.baseURL; - } + this.clientOptions = clientOptions; // Initialize the Anthropic client this.client = new Anthropic(this.clientOptions); diff --git a/lib/agent/OpenAICUAClient.ts b/lib/agent/OpenAICUAClient.ts index 6a494300b..7e2df4596 100644 --- a/lib/agent/OpenAICUAClient.ts +++ b/lib/agent/OpenAICUAClient.ts @@ -1,4 +1,4 @@ -import OpenAI from "openai"; +import OpenAI, { type ClientOptions as OpenAIClientOptions } from "openai"; import { LogLine } from "../../types/log"; import { AgentAction, @@ -34,7 +34,7 @@ export class OpenAICUAClient extends AgentClient { type: AgentType, modelName: string, userProvidedInstructions?: string, - clientOptions?: Record, + clientOptions?: OpenAIClientOptions & Record, ) { super(type, modelName, userProvidedInstructions); @@ -43,6 +43,7 @@ export class OpenAICUAClient extends AgentClient { (clientOptions?.apiKey as string) || process.env.OPENAI_API_KEY || ""; this.organization = (clientOptions?.organization as string) || process.env.OPENAI_ORG; + this.baseURL = (clientOptions?.baseURL as string) || undefined; // Get environment if specified if ( @@ -53,9 +54,11 @@ export class OpenAICUAClient extends AgentClient { } // Store client options for reference - this.clientOptions = { - apiKey: this.apiKey, - }; + this.clientOptions = clientOptions; + + if (this.baseURL) { + this.clientOptions.baseURL = this.baseURL; + } // Initialize the OpenAI client this.client = new OpenAI(this.clientOptions); diff --git a/types/agent.ts b/types/agent.ts index 8cc062012..cd5f830e3 100644 --- a/types/agent.ts +++ b/types/agent.ts @@ -1,4 +1,5 @@ import { LogLine } from "./log"; +import { ClientOptions } from "@/types/model"; export interface AgentAction { type: string; @@ -44,7 +45,7 @@ export interface AgentExecutionOptions { export interface AgentHandlerOptions { modelName: string; - clientOptions?: Record; + clientOptions?: ClientOptions & Record; userProvidedInstructions?: string; agentType: AgentType; } diff --git a/types/stagehand.ts b/types/stagehand.ts index 814ab0886..2233ac41f 100644 --- a/types/stagehand.ts +++ b/types/stagehand.ts @@ -255,7 +255,7 @@ export interface AgentConfig { /** * Additional options to pass to the agent client */ - options?: Record; + options?: ClientOptions & Record; } export enum StagehandFunctionName {