Skip to content

Commit 1e49dee

Browse files
authored
fix chat completion not working for custom llm clients (#618)
* fix chat completion not working for custom llm clients * throw stagehandevalerror
1 parent 7b0b996 commit 1e49dee

File tree

8 files changed

+305
-54
lines changed

8 files changed

+305
-54
lines changed

evals/evals.config.json

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,14 @@
8585
"name": "wichita",
8686
"categories": ["combination", "regression_dom_extract"]
8787
},
88-
88+
{
89+
"name": "hn_aisdk",
90+
"categories": ["combination", "regression_dom_extract"]
91+
},
92+
{
93+
"name": "hn_langchain",
94+
"categories": ["combination", "regression_dom_extract"]
95+
},
8996
{
9097
"name": "apple",
9198
"categories": ["experimental"]

evals/initStagehand.ts

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,15 @@
1111
*/
1212

1313
import { enableCaching, env } from "./env";
14-
import { AvailableModel, ConstructorParams, LogLine, Stagehand } from "@/dist";
14+
import {
15+
AvailableModel,
16+
ConstructorParams,
17+
LLMClient,
18+
LogLine,
19+
Stagehand,
20+
} from "@/dist";
1521
import { EvalLogger } from "./logger";
22+
import { StagehandEvalError } from "@/types/stagehandErrors";
1623

1724
/**
1825
* StagehandConfig:
@@ -51,26 +58,37 @@ const StagehandConfig = {
5158
* - initResponse: Any response data returned by Stagehand initialization
5259
*/
5360
export const initStagehand = async ({
61+
llmClient,
5462
modelName,
5563
domSettleTimeoutMs,
5664
logger,
5765
configOverrides,
5866
actTimeoutMs,
5967
}: {
60-
modelName: AvailableModel;
68+
llmClient?: LLMClient;
69+
modelName?: AvailableModel;
6170
domSettleTimeoutMs?: number;
6271
logger: EvalLogger;
6372
configOverrides?: Partial<ConstructorParams>;
6473
actTimeoutMs?: number;
6574
}) => {
75+
if (llmClient && modelName) {
76+
throw new StagehandEvalError("Cannot provide both llmClient and modelName");
77+
}
78+
79+
if (!llmClient && !modelName) {
80+
throw new StagehandEvalError("Must provide either llmClient or modelName");
81+
}
82+
6683
let chosenApiKey: string | undefined = process.env.OPENAI_API_KEY;
67-
if (modelName.startsWith("claude")) {
84+
if (modelName?.startsWith("claude")) {
6885
chosenApiKey = process.env.ANTHROPIC_API_KEY;
6986
}
7087

7188
const config = {
7289
...StagehandConfig,
7390
modelName,
91+
llmClient,
7492
...(domSettleTimeoutMs && { domSettleTimeoutMs }),
7593
modelClientOptions: {
7694
apiKey: chosenApiKey,

evals/tasks/hn_aisdk.ts

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import { EvalFunction } from "@/types/evals";
2+
import { initStagehand } from "@/evals/initStagehand";
3+
import { z } from "zod";
4+
import { openai } from "@ai-sdk/openai/dist";
5+
import { AISdkClient } from "@/examples/external_clients/aisdk";
6+
7+
export const hn_aisdk: EvalFunction = async ({ logger }) => {
8+
const { stagehand, initResponse } = await initStagehand({
9+
logger,
10+
llmClient: new AISdkClient({
11+
model: openai("gpt-4o-mini"),
12+
}),
13+
});
14+
15+
const { debugUrl, sessionUrl } = initResponse;
16+
17+
await stagehand.page.goto("https://news.ycombinator.com");
18+
19+
let { story } = await stagehand.page.extract({
20+
schema: z.object({
21+
story: z.string().describe("the title of the top story on the page"),
22+
}),
23+
});
24+
// remove the (url) part of the story title
25+
story = story.split(" (")[0];
26+
27+
const expectedStoryElement = await stagehand.page.$(
28+
"xpath=/html/body/center/table/tbody/tr[3]/td/table/tbody/tr[1]/td[3]/span/a",
29+
);
30+
// remove the (url) part of the story title
31+
const expectedStory = (await expectedStoryElement?.textContent())?.split(
32+
" (",
33+
)?.[0];
34+
35+
if (!expectedStory) {
36+
logger.error({
37+
message: "Could not find expected story element",
38+
level: 0,
39+
});
40+
return {
41+
_success: false,
42+
error: "Could not find expected story element",
43+
debugUrl,
44+
sessionUrl,
45+
logs: logger.getLogs(),
46+
};
47+
}
48+
49+
if (story !== expectedStory) {
50+
logger.error({
51+
message: "Extracted story does not match expected story",
52+
level: 0,
53+
auxiliary: {
54+
expected: {
55+
value: expectedStory,
56+
type: "string",
57+
},
58+
actual: {
59+
value: story,
60+
type: "string",
61+
},
62+
},
63+
});
64+
return {
65+
_success: false,
66+
error: "Extracted story does not match expected story",
67+
expectedStory,
68+
actualStory: story,
69+
debugUrl,
70+
sessionUrl,
71+
logs: logger.getLogs(),
72+
};
73+
}
74+
75+
await stagehand.page.act("Click on the 'new' tab");
76+
77+
if (stagehand.page.url() !== "https://news.ycombinator.com/newest") {
78+
logger.error({
79+
message: "Page did not navigate to the 'new' tab",
80+
level: 0,
81+
auxiliary: {
82+
expected: {
83+
value: "https://news.ycombinator.com/newest",
84+
type: "string",
85+
},
86+
actual: {
87+
value: stagehand.page.url(),
88+
type: "string",
89+
},
90+
},
91+
});
92+
return {
93+
_success: false,
94+
error: "Page did not navigate to the 'new' tab",
95+
expectedUrl: "https://news.ycombinator.com/newest",
96+
actualUrl: stagehand.page.url(),
97+
debugUrl,
98+
sessionUrl,
99+
logs: logger.getLogs(),
100+
};
101+
}
102+
103+
await stagehand.close();
104+
105+
return {
106+
_success: true,
107+
expectedStory,
108+
actualStory: story,
109+
debugUrl,
110+
sessionUrl,
111+
logs: logger.getLogs(),
112+
};
113+
};

evals/tasks/hn_langchain.ts

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import { EvalFunction } from "@/types/evals";
2+
import { initStagehand } from "@/evals/initStagehand";
3+
import { z } from "zod";
4+
import { LangchainClient } from "@/examples/external_clients/langchain";
5+
import { ChatOpenAI } from "@langchain/openai";
6+
7+
export const hn_langchain: EvalFunction = async ({ logger }) => {
8+
const { stagehand, initResponse } = await initStagehand({
9+
logger,
10+
llmClient: new LangchainClient(
11+
new ChatOpenAI({
12+
model: "gpt-4o",
13+
}),
14+
),
15+
});
16+
17+
const { debugUrl, sessionUrl } = initResponse;
18+
19+
await stagehand.page.goto("https://news.ycombinator.com");
20+
21+
let { story } = await stagehand.page.extract({
22+
schema: z.object({
23+
story: z.string().describe("the title of the top story on the page"),
24+
}),
25+
});
26+
// remove the (url) part of the story title
27+
story = story.split(" (")[0];
28+
29+
const expectedStoryElement = await stagehand.page.$(
30+
"xpath=/html/body/center/table/tbody/tr[3]/td/table/tbody/tr[1]/td[3]/span/a",
31+
);
32+
// remove the (url) part of the story title
33+
const expectedStory = (await expectedStoryElement?.textContent())?.split(
34+
" (",
35+
)?.[0];
36+
37+
if (!expectedStory) {
38+
logger.error({
39+
message: "Could not find expected story element",
40+
level: 0,
41+
});
42+
return {
43+
_success: false,
44+
error: "Could not find expected story element",
45+
debugUrl,
46+
sessionUrl,
47+
logs: logger.getLogs(),
48+
};
49+
}
50+
51+
if (story !== expectedStory) {
52+
logger.error({
53+
message: "Extracted story does not match expected story",
54+
level: 0,
55+
auxiliary: {
56+
expected: {
57+
value: expectedStory,
58+
type: "string",
59+
},
60+
actual: {
61+
value: story,
62+
type: "string",
63+
},
64+
},
65+
});
66+
return {
67+
_success: false,
68+
error: "Extracted story does not match expected story",
69+
expectedStory,
70+
actualStory: story,
71+
debugUrl,
72+
sessionUrl,
73+
logs: logger.getLogs(),
74+
};
75+
}
76+
77+
await stagehand.page.act("Click on the 'new' tab");
78+
79+
if (stagehand.page.url() !== "https://news.ycombinator.com/newest") {
80+
logger.error({
81+
message: "Page did not navigate to the 'new' tab",
82+
level: 0,
83+
auxiliary: {
84+
expected: {
85+
value: "https://news.ycombinator.com/newest",
86+
type: "string",
87+
},
88+
actual: {
89+
value: stagehand.page.url(),
90+
type: "string",
91+
},
92+
},
93+
});
94+
return {
95+
_success: false,
96+
error: "Page did not navigate to the 'new' tab",
97+
expectedUrl: "https://news.ycombinator.com/newest",
98+
actualUrl: stagehand.page.url(),
99+
debugUrl,
100+
sessionUrl,
101+
logs: logger.getLogs(),
102+
};
103+
}
104+
105+
await stagehand.close();
106+
107+
return {
108+
_success: true,
109+
expectedStory,
110+
actualStory: story,
111+
debugUrl,
112+
sessionUrl,
113+
logs: logger.getLogs(),
114+
};
115+
};

examples/ai_sdk_example.ts

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,29 @@
1-
import { google } from "@ai-sdk/google";
2-
import { z } from "zod";
1+
import { openai } from "@ai-sdk/openai";
32
import { Stagehand } from "@/dist";
43
import { AISdkClient } from "./external_clients/aisdk";
54
import StagehandConfig from "@/stagehand.config";
5+
import { z } from "zod";
66

77
async function example() {
88
const stagehand = new Stagehand({
99
...StagehandConfig,
1010
llmClient: new AISdkClient({
11-
model: google("gemini-1.5-flash-latest"),
11+
model: openai("gpt-4o"),
1212
}),
1313
});
1414

1515
await stagehand.init();
1616
await stagehand.page.goto("https://news.ycombinator.com");
1717

18-
const headlines = await stagehand.page.extract({
19-
instruction: "Extract only 3 stories from the Hacker News homepage.",
18+
const { story } = await stagehand.page.extract({
2019
schema: z.object({
21-
stories: z
22-
.array(
23-
z.object({
24-
title: z.string(),
25-
url: z.string(),
26-
points: z.number(),
27-
}),
28-
)
29-
.length(3),
20+
story: z.string().describe("the top story on the page"),
3021
}),
3122
});
3223

33-
console.log(headlines);
24+
console.log("The top story is:", story);
25+
26+
await stagehand.page.act("click the first story");
3427

3528
await stagehand.close();
3629
}

examples/external_clients/aisdk.ts

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ import {
1010
LanguageModel,
1111
TextPart,
1212
} from "ai";
13-
import { ChatCompletion } from "openai/resources/chat/completions";
1413
import { CreateChatCompletionOptions, LLMClient, AvailableModel } from "@/dist";
14+
import { ChatCompletion } from "openai/resources";
1515

1616
export class AISdkClient extends LLMClient {
1717
public type = "aisdk" as const;
@@ -85,7 +85,14 @@ export class AISdkClient extends LLMClient {
8585
schema: options.response_model.schema,
8686
});
8787

88-
return response.object;
88+
return {
89+
data: response.object,
90+
usage: {
91+
prompt_tokens: response.usage.promptTokens ?? 0,
92+
completion_tokens: response.usage.completionTokens ?? 0,
93+
total_tokens: response.usage.totalTokens ?? 0,
94+
},
95+
} as T;
8996
}
9097

9198
const tools: Record<string, CoreTool> = {};
@@ -103,6 +110,13 @@ export class AISdkClient extends LLMClient {
103110
tools,
104111
});
105112

106-
return response as T;
113+
return {
114+
data: response.text,
115+
usage: {
116+
prompt_tokens: response.usage.promptTokens ?? 0,
117+
completion_tokens: response.usage.completionTokens ?? 0,
118+
total_tokens: response.usage.totalTokens ?? 0,
119+
},
120+
} as T;
107121
}
108122
}

0 commit comments

Comments
 (0)