Skip to content

Commit f2f7b5a

Browse files
committed
Fix: tool call
1 parent c33b445 commit f2f7b5a

File tree

3 files changed

+67
-19
lines changed

3 files changed

+67
-19
lines changed

src/services/core/ai-service-provider.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ export interface CompletionOptions {
2727
signal?: AbortSignal; // AbortSignal for cancellation
2828
mcpTools?: string[]; // IDs of MCP servers to use as tools
2929
tools?: Record<string, unknown>; // Pre-configured tools for the AI
30+
toolChoice?: Record<string, unknown>; // Pre-configured tool Choice for the AI
3031
}
3132

3233
/**
@@ -79,7 +80,7 @@ export interface AiServiceProvider {
7980
getChatCompletion(
8081
messages: Message[],
8182
options: CompletionOptions,
82-
streamController: StreamControlHandler
83+
streamController: StreamControlHandler,
8384
): Promise<Message>;
8485

8586
/**

src/services/providers/common-provider-service.ts

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { generateText, streamText, Provider, type LanguageModelUsage, type ToolSet } from 'ai';
1+
import { generateText, streamText, Provider, type LanguageModelUsage, type ToolSet, tool } from 'ai';
22
import { Message, MessageRole } from '../../types/chat';
33
import { AiServiceProvider, CompletionOptions } from '../core/ai-service-provider';
44
import { SettingsService } from '../settings-service';
@@ -7,7 +7,8 @@ import { v4 as uuidv4 } from 'uuid';
77
import { MessageHelper } from '../message-helper';
88
import { AIServiceCapability, mapModelCapabilities } from '../../types/capabilities';
99
import { ModelSettings } from '../../types/settings';
10-
import { LanguageModelV1, ToolChoice } from 'ai';
10+
import { LanguageModelV1 } from 'ai';
11+
import { z } from 'zod';
1112

1213
// Define an interface for tool results to fix the 'never' type errors
1314
interface ToolResult {
@@ -159,23 +160,64 @@ export class CommonProviderHelper implements AiServiceProvider {
159160
modelInstance: LanguageModelV1,
160161
messages: Message[],
161162
options: CompletionOptions,
162-
streamController: StreamControlHandler,
163-
tools: ToolSet | undefined = undefined,
164-
toolChoice: ToolChoice<ToolSet> | undefined = undefined
163+
streamController: StreamControlHandler
165164
): Promise<Message> {
166165
try {
167166
const formattedMessages = await MessageHelper.MessagesContentToOpenAIFormat(messages);
168167

169168
console.log('formattedMessages: ', formattedMessages);
170169

170+
// Build ToolSet & ToolChoice for getChatCompletionByModel API
171+
const rawTools = options.tools;
172+
173+
// Convert raw tools to AI SDK format
174+
const formattedTools: ToolSet = {};
175+
176+
if (rawTools && typeof rawTools === 'object') {
177+
for (const [toolName, toolConfig] of Object.entries(rawTools)) {
178+
if (toolConfig && typeof toolConfig === 'object') {
179+
// Special case for image generation
180+
if (toolName === 'generate_image') {
181+
formattedTools[toolName] = tool({
182+
description: 'Generate an image from a text prompt',
183+
parameters: z.object({
184+
prompt: z.string().describe('The text prompt to generate an image from'),
185+
size: z.string().optional().describe('The size of the image to generate'),
186+
style: z.enum(['vivid', 'natural']).optional().describe('The style of the image to generate')
187+
}),
188+
execute: async (args) => {
189+
// Execute is handled later in the tool call handler
190+
return (toolConfig as ToolWithExecute).execute(args);
191+
}
192+
});
193+
} else {
194+
// For other tools, try to extract description and parameters
195+
const toolWithExecute = toolConfig as ToolWithExecute;
196+
const description = (toolConfig as {description?: string}).description || `Execute ${toolName} tool`;
197+
198+
// Create a fallback schema if not provided
199+
const parameters = z.object({}).catchall(z.unknown());
200+
201+
formattedTools[toolName] = tool({
202+
description,
203+
parameters,
204+
execute: async (args) => {
205+
if (typeof toolWithExecute.execute === 'function') {
206+
return toolWithExecute.execute(args);
207+
}
208+
throw new Error(`Tool ${toolName} does not have an execute function`);
209+
}
210+
});
211+
}
212+
}
213+
}
214+
}
215+
171216
let fullText = '';
172217

173218
if (options.stream) {
174219
console.log(`Streaming ${options.provider}/${options.model} response`);
175220

176-
// Prepare tools for AI SDK format if they exist
177-
const toolsForStream = tools || options.tools as unknown as ToolSet;
178-
179221
const result = streamText({
180222
model: modelInstance,
181223
abortSignal: streamController.getAbortSignal(),
@@ -185,8 +227,7 @@ export class CommonProviderHelper implements AiServiceProvider {
185227
topP: options.top_p,
186228
frequencyPenalty: options.frequency_penalty,
187229
presencePenalty: options.presence_penalty,
188-
tools: toolsForStream,
189-
toolChoice: toolChoice,
230+
tools: Object.keys(formattedTools).length > 0 ? formattedTools : undefined,
190231
toolCallStreaming: true,
191232
onFinish: (result: { usage: LanguageModelUsage }) => {
192233
console.log('OpenAI streaming chat completion finished');
@@ -266,9 +307,9 @@ export class CommonProviderHelper implements AiServiceProvider {
266307
error instanceof Error ? error : new Error('Unknown error in image generation')
267308
);
268309
}
269-
} else if (options.tools) {
310+
} else if (rawTools) {
270311
// Use a safer way to check for and execute tools
271-
const toolsMap = options.tools as Record<string, unknown>;
312+
const toolsMap = rawTools as Record<string, unknown>;
272313
const tool = toolsMap[toolName] as ToolWithExecute | undefined;
273314

274315
if (tool && typeof tool.execute === 'function') {
@@ -300,9 +341,6 @@ export class CommonProviderHelper implements AiServiceProvider {
300341
else {
301342
console.log(`Generating ${options.provider}/${options.model} response`);
302343

303-
// Prepare tools for AI SDK format if they exist
304-
const toolsForGenerate = tools || options.tools as unknown as ToolSet;
305-
306344
const { text, usage, toolResults } = await generateText({
307345
model: modelInstance,
308346
messages: formattedMessages,
@@ -311,8 +349,7 @@ export class CommonProviderHelper implements AiServiceProvider {
311349
topP: options.top_p,
312350
frequencyPenalty: options.frequency_penalty,
313351
presencePenalty: options.presence_penalty,
314-
tools: toolsForGenerate,
315-
toolChoice: toolChoice,
352+
tools: Object.keys(formattedTools).length > 0 ? formattedTools : undefined,
316353
maxSteps: 3, // Allow multiple steps for tool calls
317354
});
318355

src/services/providers/openai-service.ts

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,17 @@ export class OpenAIService implements AiServiceProvider {
152152

153153
options.stream = false;
154154

155-
return CommonProviderHelper.getChatCompletionByModel(modelInstance, messages, options, streamController, tools, toolChoice);
155+
options.tools = {
156+
...options.tools,
157+
tools
158+
}
159+
160+
options.toolChoice = {
161+
...options.toolChoice,
162+
toolChoice
163+
}
164+
165+
return CommonProviderHelper.getChatCompletionByModel(modelInstance, messages, options, streamController);
156166
}
157167

158168
/**

0 commit comments

Comments
 (0)