Skip to content

Commit 7f8ba57

Browse files
committed
Allow selecting several tools
1 parent 9f272aa commit 7f8ba57

File tree

4 files changed

+29
-42
lines changed

4 files changed

+29
-42
lines changed

src/chat-handler.ts

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,12 @@ export class ChatHandler extends AbstractChatModel {
120120
/**
121121
* Get/set a tool, which will build an agent.
122122
*/
123-
get tool(): Tool | null {
124-
return this._tool;
123+
get tools(): Tool[] {
124+
return this._tools;
125125
}
126-
set tool(value: Tool | null) {
127-
this._tool = value;
128-
this._providerRegistry.buildAgent(this._tool);
126+
set tools(value: Tool[]) {
127+
this._tools = value;
128+
this._providerRegistry.buildAgent(this._tools);
129129
}
130130

131131
/**
@@ -338,7 +338,7 @@ export class ChatHandler extends AbstractChatModel {
338338
private _defaultErrorMessage = 'AI provider not configured';
339339
private _controller: AbortController | null = null;
340340
private _useTool: boolean = false;
341-
private _tool: Tool | null = null;
341+
private _tools: Tool[] = [];
342342
private _toolRegistry?: IToolRegistry;
343343
private _useToolChanged = new Signal<ChatHandler, boolean>(this);
344344
}
@@ -382,11 +382,11 @@ export namespace ChatHandler {
382382
/**
383383
* Getter/setter of the tool to use.
384384
*/
385-
get tool(): Tool | null {
386-
return (this._model as ChatHandler).tool;
385+
get tools(): Tool[] {
386+
return (this._model as ChatHandler).tools;
387387
}
388-
set tool(value: Tool | null) {
389-
(this._model as ChatHandler).tool = value;
388+
set tools(value: Tool[]) {
389+
(this._model as ChatHandler).tools = value;
390390
}
391391
}
392392

src/components/tool-select.tsx

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ export function toolSelect(
2424
const toolRegistry = chatContext.toolsRegistry;
2525

2626
const [useTool, setUseTool] = useState<boolean>(chatContext.useTool);
27-
const [selectedTool, setSelectedTool] = useState<Tool | null>(null);
27+
const [selectedTools, setSelectedTools] = useState<Tool[]>([]);
2828
const [tools, setTools] = useState<Tool[]>(toolRegistry?.tools || []);
2929
const [menuAnchorEl, setMenuAnchorEl] = useState<HTMLElement | null>(null);
3030
const [menuOpen, setMenuOpen] = useState(false);
@@ -38,13 +38,17 @@ export function toolSelect(
3838
setMenuOpen(false);
3939
}, []);
4040

41-
const onClick = useCallback(
42-
(tool: Tool | null) => {
43-
setSelectedTool(tool);
44-
chatContext.tool = tool;
45-
},
46-
[props.model]
47-
);
41+
const onClick = (tool: Tool) => {
42+
const currentTools = [...selectedTools];
43+
const index = currentTools.indexOf(tool);
44+
if (index !== -1) {
45+
currentTools.splice(index, 1);
46+
} else {
47+
currentTools.push(tool);
48+
}
49+
setSelectedTools(currentTools);
50+
chatContext.tools = currentTools;
51+
};
4852

4953
useEffect(() => {
5054
const updateTools = () => setTools(toolRegistry?.tools || []);
@@ -83,7 +87,7 @@ export function toolSelect(
8387
}
8488
}}
8589
sx={
86-
selectedTool === null
90+
selectedTools.length === 0
8791
? { backgroundColor: 'var(--jp-layout-color3)' }
8892
: {}
8993
}
@@ -109,23 +113,6 @@ export function toolSelect(
109113
}
110114
}}
111115
>
112-
<Tooltip title={'Do not use a tool'}>
113-
<MenuItem
114-
className={SELECT_ITEM_CLASS}
115-
onClick={e => {
116-
onClick(null);
117-
// prevent sending second message with no selection
118-
e.stopPropagation();
119-
}}
120-
>
121-
{selectedTool === null ? (
122-
<checkIcon.react className={'lm-Menu-itemIcon'} />
123-
) : (
124-
<div className={'lm-Menu-itemIcon'} />
125-
)}
126-
<Typography display="block">No tool</Typography>
127-
</MenuItem>
128-
</Tooltip>
129116
{tools.map(tool => (
130117
<Tooltip title={tool.description}>
131118
<MenuItem
@@ -136,7 +123,7 @@ export function toolSelect(
136123
e.stopPropagation();
137124
}}
138125
>
139-
{selectedTool === tool ? (
126+
{selectedTools.includes(tool) ? (
140127
<checkIcon.react className={'lm-Menu-itemIcon'} />
141128
) : (
142129
<div className={'lm-Menu-itemIcon'} />

src/provider.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -393,19 +393,19 @@ export class AIProviderRegistry implements IAIProviderRegistry {
393393
/**
394394
* Build an agent with a given tool.
395395
*/
396-
buildAgent(tool: Tool | null) {
397-
if (tool !== null) {
396+
buildAgent(tools: Tool[]) {
397+
if (tools.length) {
398398
const chatModel = Private.getChatModel();
399399
if (chatModel === null) {
400400
Private.setAgent(null);
401401
return;
402402
}
403-
chatModel.bindTools?.([tool], { tool_choice: tool.name });
403+
chatModel.bindTools?.(tools);
404404
Private.setChatModel(chatModel);
405405
Private.setAgent(
406406
createReactAgent({
407407
llm: chatModel,
408-
tools: [tool]
408+
tools
409409
})
410410
);
411411
} else {

src/tokens.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ export interface IAIProviderRegistry {
144144
/**
145145
* Build an agent with a given tool.
146146
*/
147-
buildAgent(tool: Tool | null): void;
147+
buildAgent(tools: Tool[]): void;
148148
/**
149149
* A signal emitting when the provider or its settings has changed.
150150
*/

0 commit comments

Comments
 (0)