Skip to content

Commit c135e93

Browse files
evalstatensarrazin
andauthored
Anthropic Tool Support (#1594)
* support anthropic PDF beta * upstream merge, remove commented out console log line * Fixing type errors. the anthropic API does not yet include a "DocumentBlock" for support PDFs, so an extended type has been added to the endpoint. * changed document processor to async (matching image processor) * use the beta api types rather than custom extension * rudimentary tool testing * interim commit (tool re-passing, file handling) * remove merge error * tidy up, isolate beta classes to utils * anthropic tool calling support. * improve handling of directlyAnswer tool * fix streaming * slight tidy up to tools flow handling * fix: dont pass tools in final generation, instead deduce tools from tool results --------- Co-authored-by: Nathan Sarrazin <sarrazin.nathan@gmail.com>
1 parent 18e264a commit c135e93

File tree

5 files changed

+196
-37
lines changed

5 files changed

+196
-37
lines changed

src/lib/server/endpoints/anthropic/endpointAnthropic.ts

Lines changed: 128 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,19 @@ import type { Endpoint } from "../endpoints";
33
import { env } from "$env/dynamic/private";
44
import type { TextGenerationStreamOutput } from "@huggingface/inference";
55
import { createImageProcessorOptionsValidator } from "../images";
6-
import { endpointMessagesToAnthropicMessages } from "./utils";
6+
import { endpointMessagesToAnthropicMessages, addToolResults } from "./utils";
77
import { createDocumentProcessorOptionsValidator } from "../document";
8+
import type {
9+
Tool,
10+
ToolCall,
11+
ToolInput,
12+
ToolInputFile,
13+
ToolInputFixed,
14+
ToolInputOptional,
15+
} from "$lib/types/Tool";
16+
import type Anthropic from "@anthropic-ai/sdk";
817
import type { MessageParam } from "@anthropic-ai/sdk/resources/messages.mjs";
18+
import directlyAnswer from "$lib/server/tools/directlyAnswer";
919

1020
export const endpointAnthropicParametersSchema = z.object({
1121
weight: z.number().int().positive().default(1),
@@ -52,23 +62,41 @@ export async function endpointAnthropic(
5262
defaultQuery,
5363
});
5464

55-
return async ({ messages, preprompt, generateSettings }) => {
65+
return async ({
66+
messages,
67+
preprompt,
68+
generateSettings,
69+
conversationId,
70+
tools = [],
71+
toolResults = [],
72+
}) => {
5673
let system = preprompt;
5774
if (messages?.[0]?.from === "system") {
5875
system = messages[0].content;
5976
}
6077

6178
let tokenId = 0;
79+
if (tools.length === 0 && toolResults.length > 0) {
80+
const toolNames = new Set(toolResults.map((tool) => tool.call.name));
81+
tools = Array.from(toolNames).map((name) => ({
82+
name,
83+
description: "",
84+
inputs: [],
85+
})) as unknown as Tool[];
86+
}
6287

6388
const parameters = { ...model.parameters, ...generateSettings };
6489

6590
return (async function* () {
6691
const stream = anthropic.messages.stream({
6792
model: model.id ?? model.name,
68-
messages: (await endpointMessagesToAnthropicMessages(
69-
messages,
70-
multimodal
71-
)) as MessageParam[],
93+
tools: createAnthropicTools(tools),
94+
tool_choice:
95+
tools.length > 0 ? { type: "auto", disable_parallel_tool_use: false } : undefined,
96+
messages: addToolResults(
97+
await endpointMessagesToAnthropicMessages(messages, multimodal, conversationId),
98+
toolResults
99+
) as MessageParam[],
72100
max_tokens: parameters?.max_new_tokens,
73101
temperature: parameters?.temperature,
74102
top_p: parameters?.top_p,
@@ -79,21 +107,40 @@ export async function endpointAnthropic(
79107
while (true) {
80108
const result = await Promise.race([stream.emitted("text"), stream.emitted("end")]);
81109

82-
// Stream end
83110
if (result === undefined) {
84-
yield {
85-
token: {
86-
id: tokenId++,
87-
text: "",
88-
logprob: 0,
89-
special: true,
90-
},
91-
generated_text: await stream.finalText(),
92-
details: null,
93-
} satisfies TextGenerationStreamOutput;
111+
if ("tool_use" === stream.receivedMessages[0].stop_reason) {
112+
// this should really create a new "Assistant" message with the tool id in it.
113+
const toolCalls: ToolCall[] = stream.receivedMessages[0].content
114+
.filter(
115+
(block): block is Anthropic.Messages.ContentBlock & { type: "tool_use" } =>
116+
block.type === "tool_use"
117+
)
118+
.map((block) => ({
119+
name: block.name,
120+
parameters: block.input as Record<string, string | number | boolean>,
121+
id: block.id,
122+
}));
123+
124+
yield {
125+
token: { id: tokenId, text: "", logprob: 0, special: false, toolCalls },
126+
generated_text: null,
127+
details: null,
128+
};
129+
} else {
130+
yield {
131+
token: {
132+
id: tokenId++,
133+
text: "",
134+
logprob: 0,
135+
special: true,
136+
},
137+
generated_text: await stream.finalText(),
138+
details: null,
139+
} satisfies TextGenerationStreamOutput;
140+
}
141+
94142
return;
95143
}
96-
97144
// Text delta
98145
yield {
99146
token: {
@@ -109,3 +156,66 @@ export async function endpointAnthropic(
109156
})();
110157
};
111158
}
159+
160+
function createAnthropicTools(tools: Tool[]): Anthropic.Messages.Tool[] {
161+
return tools
162+
.filter((tool) => tool.name !== directlyAnswer.name)
163+
.map((tool) => {
164+
const properties = tool.inputs.reduce((acc, input) => {
165+
acc[input.name] = convertToolInputToJSONSchema(input);
166+
return acc;
167+
}, {} as Record<string, unknown>);
168+
169+
const required = tool.inputs
170+
.filter((input) => input.paramType === "required")
171+
.map((input) => input.name);
172+
173+
return {
174+
name: tool.name,
175+
description: tool.description,
176+
input_schema: {
177+
type: "object",
178+
properties,
179+
required: required.length > 0 ? required : undefined,
180+
},
181+
};
182+
});
183+
}
184+
185+
function convertToolInputToJSONSchema(input: ToolInput): Record<string, unknown> {
186+
const baseSchema: Record<string, unknown> = {};
187+
if ("description" in input) {
188+
baseSchema["description"] = input.description || "";
189+
}
190+
switch (input.paramType) {
191+
case "optional":
192+
baseSchema["default"] = (input as ToolInputOptional).default;
193+
break;
194+
case "fixed":
195+
baseSchema["const"] = (input as ToolInputFixed).value;
196+
break;
197+
}
198+
199+
if (input.type === "file") {
200+
baseSchema["type"] = "string";
201+
baseSchema["format"] = "binary";
202+
baseSchema["mimeTypes"] = (input as ToolInputFile).mimeTypes;
203+
} else {
204+
switch (input.type) {
205+
case "str":
206+
baseSchema["type"] = "string";
207+
break;
208+
case "int":
209+
baseSchema["type"] = "integer";
210+
break;
211+
case "float":
212+
baseSchema["type"] = "number";
213+
break;
214+
case "bool":
215+
baseSchema["type"] = "boolean";
216+
break;
217+
}
218+
}
219+
220+
return baseSchema;
221+
}

src/lib/server/endpoints/anthropic/utils.ts

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,16 @@ import type {
77
BetaMessageParam,
88
BetaBase64PDFBlock,
99
} from "@anthropic-ai/sdk/resources/beta/messages/messages.mjs";
10+
import type { ToolResult } from "$lib/types/Tool";
11+
import { downloadFile } from "$lib/server/files/downloadFile";
12+
import type { ObjectId } from "mongodb";
1013

1114
export async function fileToImageBlock(
1215
file: MessageFile,
1316
opts: ImageProcessorOptions<"image/png" | "image/jpeg" | "image/webp">
1417
): Promise<BetaImageBlockParam> {
1518
const processor = makeImageProcessor(opts);
19+
1620
const { image, mime } = await processor(file);
1721

1822
return {
@@ -48,7 +52,8 @@ export async function endpointMessagesToAnthropicMessages(
4852
multimodal: {
4953
image: ImageProcessorOptions<"image/png" | "image/jpeg" | "image/webp">;
5054
document?: FileProcessorOptions<"application/pdf">;
51-
}
55+
},
56+
conversationId?: ObjectId | undefined
5257
): Promise<BetaMessageParam[]> {
5358
return await Promise.all(
5459
messages
@@ -57,20 +62,59 @@ export async function endpointMessagesToAnthropicMessages(
5762
return {
5863
role: message.from,
5964
content: [
60-
...(await Promise.all(
61-
(message.files ?? []).map(async (file) => {
62-
if (file.mime.startsWith("image/")) {
63-
return fileToImageBlock(file, multimodal.image);
64-
} else if (file.mime === "application/pdf" && multimodal.document) {
65-
return fileToDocumentBlock(file, multimodal.document);
66-
} else {
67-
throw new Error(`Unsupported file type: ${file.mime}`);
68-
}
69-
})
70-
)),
65+
...(message.from === "user"
66+
? await Promise.all(
67+
(message.files ?? []).map(async (file) => {
68+
if (file.type === "hash" && conversationId) {
69+
file = await downloadFile(file.value, conversationId);
70+
}
71+
72+
if (file.mime.startsWith("image/")) {
73+
return fileToImageBlock(file, multimodal.image);
74+
} else if (file.mime === "application/pdf" && multimodal.document) {
75+
return fileToDocumentBlock(file, multimodal.document);
76+
} else {
77+
throw new Error(`Unsupported file type: ${file.mime}`);
78+
}
79+
})
80+
)
81+
: []),
7182
{ type: "text", text: message.content },
7283
],
7384
};
7485
})
7586
);
7687
}
88+
89+
export function addToolResults(
90+
messages: BetaMessageParam[],
91+
toolResults: ToolResult[]
92+
): BetaMessageParam[] {
93+
const id = crypto.randomUUID();
94+
if (toolResults.length === 0) {
95+
return messages;
96+
}
97+
return [
98+
...messages,
99+
{
100+
role: "assistant",
101+
content: toolResults.map((result, index) => ({
102+
type: "tool_use",
103+
id: `tool_${index}_${id}`,
104+
name: result.call.name,
105+
input: result.call.parameters,
106+
})),
107+
},
108+
{
109+
role: "user",
110+
content: toolResults.map((result, index) => ({
111+
type: "tool_result",
112+
tool_use_id: `tool_${index}_${id}`,
113+
is_error: result.status === "error",
114+
content: JSON.stringify(
115+
result.status === "error" ? result.message : "outputs" in result ? result.outputs : ""
116+
),
117+
})),
118+
},
119+
];
120+
}

src/lib/server/textGeneration/generate.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import type { ToolResult } from "$lib/types/Tool";
1+
import type { ToolResult, Tool } from "$lib/types/Tool";
22
import {
33
MessageReasoningUpdateType,
44
MessageUpdateType,
@@ -16,7 +16,8 @@ type GenerateContext = Omit<TextGenerationContext, "messages"> & { messages: End
1616
export async function* generate(
1717
{ model, endpoint, conv, messages, assistant, isContinue, promptedAt }: GenerateContext,
1818
toolResults: ToolResult[],
19-
preprompt?: string
19+
preprompt?: string,
20+
tools?: Tool[]
2021
): AsyncIterable<MessageUpdate> {
2122
// reasoning mode is false by default
2223
let reasoning = false;
@@ -43,6 +44,7 @@ export async function* generate(
4344
preprompt,
4445
continueMessage: isContinue,
4546
generateSettings: assistant?.generateSettings,
47+
tools,
4648
toolResults,
4749
isMultimodal: model.multimodal,
4850
conversationId: conv._id,

src/lib/server/textGeneration/index.ts

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import { mergeAsyncGenerators } from "$lib/utils/mergeAsyncGenerators";
2020
import type { TextGenerationContext } from "./types";
2121
import type { ToolResult } from "$lib/types/Tool";
2222
import { toolHasName } from "../tools/utils";
23+
import directlyAnswer from "../tools/directlyAnswer";
2324

2425
async function* keepAlive(done: AbortSignal): AsyncGenerator<MessageUpdate, undefined, undefined> {
2526
while (!done.aborted) {
@@ -73,11 +74,13 @@ async function* textGenerationWithoutTitle(
7374
}
7475

7576
let toolResults: ToolResult[] = [];
77+
let tools = model.tools ? await getTools(toolsPreference, ctx.assistant) : undefined;
7678

77-
if (model.tools) {
78-
const tools = await getTools(toolsPreference, ctx.assistant);
79-
const toolCallsRequired = tools.some((tool) => !toolHasName("directly_answer", tool));
80-
if (toolCallsRequired) toolResults = yield* runTools(ctx, tools, preprompt);
79+
if (tools) {
80+
const toolCallsRequired = tools.some((tool) => !toolHasName(directlyAnswer.name, tool));
81+
if (toolCallsRequired) {
82+
toolResults = yield* runTools(ctx, tools, preprompt);
83+
} else tools = undefined;
8184
}
8285

8386
const processedMessages = await preprocessMessages(messages, webSearchResult, convId);

src/lib/server/textGeneration/tools.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ export async function* runTools(
213213
}
214214

215215
// if we dont see a tool call in the first 25 chars, something is going wrong and we abort
216-
if (rawText.length > 25 && !(rawText.includes("```json") || rawText.includes("{"))) {
216+
if (rawText.length > 100 && !(rawText.includes("```json") || rawText.includes("{"))) {
217217
return [];
218218
}
219219

0 commit comments

Comments
 (0)