diff --git a/.eslintignore b/.eslintignore index 6f8fb52..834d381 100644 --- a/.eslintignore +++ b/.eslintignore @@ -5,3 +5,4 @@ coverage tests **/__tests__ ui-tests +webpack.config.js diff --git a/package.json b/package.json index 4180e95..77ae803 100644 --- a/package.json +++ b/package.json @@ -69,6 +69,7 @@ "@langchain/community": "^0.3.48", "@langchain/core": "^0.3.62", "@langchain/google-genai": "^0.2.14", + "@langchain/langgraph": "^0.3.5", "@langchain/mistralai": "^0.2.1", "@langchain/ollama": "^0.2.3", "@langchain/openai": "^0.5.16", @@ -127,6 +128,7 @@ "jupyterlab": { "extension": true, "outputDir": "jupyterlite_ai/labextension", - "schemaDir": "schema" + "schemaDir": "schema", + "webpackConfig": "./webpack.config.js" } } diff --git a/schema/provider-registry.json b/schema/provider-registry.json index ad5cfe9..a58e244 100644 --- a/schema/provider-registry.json +++ b/schema/provider-registry.json @@ -11,6 +11,12 @@ "description": "Whether to use or not the secrets manager. If not, secrets will be stored in the browser (local storage)", "default": true }, + "AllowToolsUsage": { + "type": "boolean", + "title": "Allow tools usage", + "description": "Whether tools are available in chat or not", + "default": true + }, "UniqueProvider": { "type": "boolean", "title": "Use the same provider for chat and completer", diff --git a/src/chat-handler.ts b/src/chat-handler.ts index 4d10fe5..a3d0faf 100644 --- a/src/chat-handler.ts +++ b/src/chat-handler.ts @@ -13,10 +13,12 @@ import { IChatMessage, IChatModel, IInputModel, - INewMessage + INewMessage, + IUser } from '@jupyter/chat'; import { AIMessage, + BaseMessage, HumanMessage, mergeMessageRuns, SystemMessage @@ -25,7 +27,7 @@ import { UUID } from '@lumino/coreutils'; import { DEFAULT_CHAT_SYSTEM_PROMPT } from './default-prompts'; import { jupyternautLiteIcon } from './icons'; -import { IAIProviderRegistry } from './tokens'; +import { IAIProviderRegistry, IToolRegistry } from './tokens'; import { AIChatModel } from './types/ai-model'; /** @@ -56,13 +58,38 @@ export class ChatHandler extends AbstractChatModel { constructor(options: ChatHandler.IOptions) { super(options); this._providerRegistry = options.providerRegistry; + this._toolRegistry = options.toolRegistry; this._providerRegistry.providerChanged.connect(() => { this._errorMessage = this._providerRegistry.chatError; }); } - get provider(): AIChatModel | null { + /** + * The provider registry. + */ + get providerRegistry(): IAIProviderRegistry { + return this._providerRegistry; + } + + /** + * Get the tool registry. + */ + get toolRegistry(): IToolRegistry | undefined { + return this._toolRegistry; + } + + /** + * Get the agent from the provider registry. + */ + get agent(): AIChatModel | null { + return this._providerRegistry.currentAgent; + } + + /** + * Get the chat model from the provider registry. + */ + get chatModel(): AIChatModel | null { return this._providerRegistry.currentChatModel; } @@ -84,12 +111,15 @@ export class ChatHandler extends AbstractChatModel { } /** - * Get/set the system prompt for the chat. + * Get the system prompt for the chat. */ get systemPrompt(): string { - return ( - this._providerRegistry.chatSystemPrompt ?? DEFAULT_CHAT_SYSTEM_PROMPT - ); + let prompt = + this._providerRegistry.chatSystemPrompt ?? DEFAULT_CHAT_SYSTEM_PROMPT; + if (this.agent !== null) { + prompt = prompt.concat('\nPlease use the tool that is provided'); + } + return prompt; } async sendMessage(message: INewMessage): Promise { @@ -110,7 +140,9 @@ export class ChatHandler extends AbstractChatModel { }; this.messageAdded(msg); - if (this._providerRegistry.currentChatModel === null) { + const chatModel = this.chatModel; + + if (chatModel === null) { const errorMsg: IChatMessage = { id: UUID.uuid4(), body: `**${this._errorMessage ? this._errorMessage : this._defaultErrorMessage}**`, @@ -137,7 +169,39 @@ export class ChatHandler extends AbstractChatModel { const sender = { username: this._personaName, avatar_url: AI_AVATAR }; this.updateWriters([{ user: sender }]); - // create an empty message to be filled by the AI provider + if (this.agent !== null) { + return this._sendAgentMessage(this.agent, messages, sender); + } + + return this._sentChatMessage(chatModel, messages, sender); + } + + async getHistory(): Promise { + return this._history; + } + + dispose(): void { + super.dispose(); + } + + messageAdded(message: IChatMessage): void { + super.messageAdded(message); + } + + stopStreaming(): void { + this._controller?.abort(); + } + + createChatContext(): IChatContext { + return new ChatHandler.ChatContext({ model: this }); + } + + private async _sentChatMessage( + chatModel: AIChatModel, + messages: BaseMessage[], + sender: IUser + ): Promise { + // Create an empty message to be filled by the AI provider const botMsg: IChatMessage = { id: UUID.uuid4(), body: '', @@ -145,15 +209,12 @@ export class ChatHandler extends AbstractChatModel { time: Private.getTimestampMs(), type: 'msg' }; - let content = ''; - this._controller = new AbortController(); try { - for await (const chunk of await this._providerRegistry.currentChatModel.stream( - messages, - { signal: this._controller.signal } - )) { + for await (const chunk of await chatModel.stream(messages, { + signal: this._controller.signal + })) { content += chunk.content ?? chunk; botMsg.body = content; this.messageAdded(botMsg); @@ -177,24 +238,72 @@ export class ChatHandler extends AbstractChatModel { } } - async getHistory(): Promise { - return this._history; - } - - dispose(): void { - super.dispose(); - } - - messageAdded(message: IChatMessage): void { - super.messageAdded(message); - } - - stopStreaming(): void { - this._controller?.abort(); - } - - createChatContext(): IChatContext { - return new ChatHandler.ChatContext({ model: this }); + private async _sendAgentMessage( + agent: AIChatModel, + messages: BaseMessage[], + sender: IUser + ): Promise { + this._controller = new AbortController(); + try { + for await (const chunk of await agent.stream( + { messages }, + { + streamMode: 'updates', + signal: this._controller.signal + } + )) { + if ((chunk as any).agent) { + messages = (chunk as any).agent.messages; + messages.forEach(message => { + const contents: string[] = []; + if (typeof message.content === 'string') { + contents.push(message.content); + } else if (Array.isArray(message.content)) { + message.content.forEach(content => { + if (content.type === 'text') { + contents.push(content.text); + } + }); + } + contents.forEach(content => { + this.messageAdded({ + id: UUID.uuid4(), + body: content, + sender, + time: Private.getTimestampMs(), + type: 'msg' + }); + }); + }); + } else if ((chunk as any).tools) { + messages = (chunk as any).tools.messages; + messages.forEach(message => { + this.messageAdded({ + id: UUID.uuid4(), + body: message.content as string, + sender: { username: `Tool "${message.name}"` }, + time: Private.getTimestampMs(), + type: 'msg' + }); + }); + } + } + return true; + } catch (reason) { + const error = this._providerRegistry.formatErrorMessage(reason); + const errorMsg: IChatMessage = { + id: UUID.uuid4(), + body: `**${error}**`, + sender: { username: 'ERROR' }, + time: Private.getTimestampMs(), + type: 'msg' + }; + this.messageAdded(errorMsg); + return false; + } finally { + this.updateWriters([]); + this._controller = null; + } } private _providerRegistry: IAIProviderRegistry; @@ -203,6 +312,7 @@ export class ChatHandler extends AbstractChatModel { private _history: IChatHistory = { messages: [] }; private _defaultErrorMessage = 'AI provider not configured'; private _controller: AbortController | null = null; + private _toolRegistry?: IToolRegistry; } export namespace ChatHandler { @@ -211,13 +321,28 @@ export namespace ChatHandler { */ export interface IOptions extends IChatModel.IOptions { providerRegistry: IAIProviderRegistry; + toolRegistry?: IToolRegistry; } /** - * The minimal chat context. + * The chat context. */ export class ChatContext extends AbstractChatContext { users = []; + + /** + * The provider registry. + */ + get providerRegistry(): IAIProviderRegistry { + return (this._model as ChatHandler).providerRegistry; + } + + /** + * The tool registry. + */ + get toolsRegistry(): IToolRegistry | undefined { + return (this._model as ChatHandler).toolRegistry; + } } /** diff --git a/src/components/index.ts b/src/components/index.ts new file mode 100644 index 0000000..8a1ebb9 --- /dev/null +++ b/src/components/index.ts @@ -0,0 +1,7 @@ +/* + * Copyright (c) Jupyter Development Team. + * Distributed under the terms of the Modified BSD License. + */ + +export * from './stop-button'; +export * from './tool-select'; diff --git a/src/components/tool-select.tsx b/src/components/tool-select.tsx new file mode 100644 index 0000000..ab4d54b --- /dev/null +++ b/src/components/tool-select.tsx @@ -0,0 +1,160 @@ +/* + * Copyright (c) Jupyter Development Team. + * Distributed under the terms of the Modified BSD License. + */ + +import { InputToolbarRegistry, TooltippedButton } from '@jupyter/chat'; +import { checkIcon } from '@jupyterlab/ui-components'; +import BuildIcon from '@mui/icons-material/Build'; +import { Menu, MenuItem, Tooltip, Typography } from '@mui/material'; +import React, { useCallback, useEffect, useState } from 'react'; + +import { ChatHandler } from '../chat-handler'; +import { IAIProviderRegistry, Tool } from '../tokens'; + +const SELECT_ITEM_CLASS = 'jp-AIToolSelect-item'; + +/** + * The tool select component. + */ +export function toolSelect( + props: InputToolbarRegistry.IToolbarItemProps +): JSX.Element { + const chatContext = props.model.chatContext as ChatHandler.ChatContext; + const toolRegistry = chatContext.toolsRegistry; + const providerRegistry = chatContext.providerRegistry; + + const [allowTools, setAllowTools] = useState(true); + const [agentAvailable, setAgentAvailable] = useState(); + const [selectedTools, setSelectedTools] = useState([]); + const [tools, setTools] = useState(toolRegistry?.tools || []); + const [menuAnchorEl, setMenuAnchorEl] = useState(null); + const [menuOpen, setMenuOpen] = useState(false); + + const openMenu = useCallback((el: HTMLElement | null) => { + setMenuAnchorEl(el); + setMenuOpen(true); + }, []); + + const closeMenu = useCallback(() => { + setMenuOpen(false); + }, []); + + const onClick = (tool: Tool) => { + const currentTools = [...selectedTools]; + const index = currentTools.indexOf(tool); + if (index !== -1) { + currentTools.splice(index, 1); + } else { + currentTools.push(tool); + } + setSelectedTools(currentTools); + if (!providerRegistry.setTools(currentTools)) { + setSelectedTools([]); + } + }; + + useEffect(() => { + const updateTools = () => setTools(toolRegistry?.tools || []); + toolRegistry?.toolsChanged.connect(updateTools); + return () => { + toolRegistry?.toolsChanged.disconnect(updateTools); + }; + }, [toolRegistry]); + + useEffect(() => { + const updateAllowTools = (_: IAIProviderRegistry, value: boolean) => + setAllowTools(value); + + const updateAgentAvailable = () => + setAgentAvailable(providerRegistry.isAgentAvailable()); + + providerRegistry.allowToolsChanged.connect(updateAllowTools); + providerRegistry.providerChanged.connect(updateAgentAvailable); + + setAllowTools(providerRegistry.allowTools); + setAgentAvailable(providerRegistry.isAgentAvailable()); + return () => { + providerRegistry.allowToolsChanged.disconnect(updateAllowTools); + providerRegistry.providerChanged.disconnect(updateAgentAvailable); + }; + }, [providerRegistry]); + + return allowTools && tools.length ? ( + <> + { + openMenu(e.currentTarget); + }} + disabled={!agentAvailable} + tooltip={ + agentAvailable === undefined + ? 'The provider is not set' + : agentAvailable + ? 'Tools' + : 'The provider or model cannot use tools' + } + buttonProps={{ + variant: 'contained', + onKeyDown: e => { + if (e.key !== 'Enter' && e.key !== ' ') { + return; + } + openMenu(e.currentTarget); + // stopping propagation of this event prevents the prompt from being + // sent when the dropdown button is selected and clicked via 'Enter'. + e.stopPropagation(); + } + }} + sx={ + selectedTools.length === 0 + ? { backgroundColor: 'var(--jp-layout-color3)' } + : {} + } + > + + + + {tools.map(tool => ( + + { + onClick(tool); + // prevent sending second message with no selection + e.stopPropagation(); + }} + > + {selectedTools.includes(tool) ? ( + + ) : ( +
+ )} + {tool.name} + + + ))} +
+ + ) : ( + <> + ); +} diff --git a/src/index.ts b/src/index.ts index 06e870d..142fc51 100644 --- a/src/index.ts +++ b/src/index.ts @@ -23,11 +23,13 @@ import { ISecretsManager, SecretsManager } from 'jupyter-secrets-manager'; import { ChatHandler, welcomeMessage } from './chat-handler'; import { CompletionProvider } from './completion-provider'; +import { stopItem, toolSelect } from './components'; import { defaultProviderPlugins } from './default-providers'; import { AIProviderRegistry } from './provider'; import { aiSettingsRenderer, textArea } from './settings'; -import { IAIProviderRegistry, PLUGIN_IDS } from './tokens'; -import { stopItem } from './components/stop-button'; +import { IAIProviderRegistry, IToolRegistry, PLUGIN_IDS } from './tokens'; +import { ToolsRegistry } from './tool-registry'; +import { createNotebook } from './tools/create-notebook'; const chatCommandRegistryPlugin: JupyterFrontEndPlugin = { id: PLUGIN_IDS.chatCommandRegistry, @@ -50,7 +52,8 @@ const chatPlugin: JupyterFrontEndPlugin = { INotebookTracker, ISettingRegistry, IThemeManager, - ILayoutRestorer + ILayoutRestorer, + IToolRegistry ], activate: async ( app: JupyterFrontEnd, @@ -60,7 +63,8 @@ const chatPlugin: JupyterFrontEndPlugin = { notebookTracker: INotebookTracker | null, settingsRegistry: ISettingRegistry | null, themeManager: IThemeManager | null, - restorer: ILayoutRestorer | null + restorer: ILayoutRestorer | null, + toolRegistry?: IToolRegistry ) => { let activeCellManager: IActiveCellManager | null = null; if (notebookTracker) { @@ -72,7 +76,8 @@ const chatPlugin: JupyterFrontEndPlugin = { const chatHandler = new ChatHandler({ providerRegistry, - activeCellManager + activeCellManager, + toolRegistry }); let sendWithShiftEnter = false; @@ -113,6 +118,9 @@ const chatPlugin: JupyterFrontEndPlugin = { const stopButton = stopItem(() => chatHandler.stopStreaming()); inputToolbarRegistry.addItem('stop', stopButton); + // Add the tool select item. + inputToolbarRegistry.addItem('tools', { element: toolSelect, position: 1 }); + chatHandler.writersChanged.connect((_, writers) => { if ( writers.filter( @@ -197,15 +205,23 @@ const providerRegistryPlugin: JupyterFrontEndPlugin = }) ); + let allowToolsUsage = true; + settingRegistry .load(providerRegistryPlugin.id) .then(settings => { if (!secretsManager) { delete settings.schema.properties?.['UseSecretsManager']; } - const updateProvider = () => { + + const loadSetting = (setting: ISettingRegistry.ISettings) => { + // Allowing usage of tools in the chat. + allowToolsUsage = + (setting.get('AllowToolsUsage').composite as boolean) ?? false; + providerRegistry.allowTools = allowToolsUsage; + // Get the Ai provider settings. - const providerSettings = settings.get('AIproviders') + const providerSettings = setting.get('AIproviders') .composite as ReadonlyPartialJSONObject; // Update completer provider. @@ -227,8 +243,8 @@ const providerRegistryPlugin: JupyterFrontEndPlugin = } }; - settings.changed.connect(() => updateProvider()); - updateProvider(); + settings.changed.connect(loadSetting); + loadSetting(settings); }) .catch(reason => { console.error( @@ -295,12 +311,24 @@ const systemPromptsPlugin: JupyterFrontEndPlugin = { } }; +const toolRegistryPlugin: JupyterFrontEndPlugin = { + id: PLUGIN_IDS.toolRegistry, + autoStart: true, + provides: IToolRegistry, + activate: (app: JupyterFrontEnd): IToolRegistry => { + const registry = new ToolsRegistry(); + registry.add(createNotebook(app.commands)); + return registry; + } +}; + export default [ providerRegistryPlugin, chatCommandRegistryPlugin, chatPlugin, completerPlugin, systemPromptsPlugin, + toolRegistryPlugin, ...defaultProviderPlugins ]; diff --git a/src/provider.ts b/src/provider.ts index 7e6b372..84fa4f5 100644 --- a/src/provider.ts +++ b/src/provider.ts @@ -4,6 +4,8 @@ import { IInlineCompletionContext } from '@jupyterlab/completer'; import { BaseLanguageModel } from '@langchain/core/language_models/base'; +import { CompiledStateGraph } from '@langchain/langgraph'; +import { createReactAgent } from '@langchain/langgraph/prebuilt'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { ReadonlyPartialJSONObject } from '@lumino/coreutils'; import { Debouncer } from '@lumino/polling'; @@ -18,7 +20,8 @@ import { IAIProviderRegistry, IDict, ModelRole, - PLUGIN_IDS + PLUGIN_IDS, + Tool } from './tokens'; import { AIChatModel, AICompleter } from './types/ai-model'; @@ -31,6 +34,7 @@ export class AIProviderRegistry implements IAIProviderRegistry { */ constructor(options: AIProviderRegistry.IOptions) { this._secretsManager = options.secretsManager || null; + this._allowTools = true; Private.setToken(options.token); this._notifications = { @@ -130,6 +134,19 @@ export class AIProviderRegistry implements IAIProviderRegistry { }; } + /** + * Get the current agent. + */ + get currentAgent(): AIChatModel | null { + const agent = Private.getAgent(); + if (agent === null) { + return null; + } + return { + stream: (input: any, options?: any) => agent.stream(input, options) + }; + } + /** * Getter/setter for the chat system prompt. */ @@ -143,6 +160,16 @@ export class AIProviderRegistry implements IAIProviderRegistry { this._chatPrompt = value; } + /** + * Check if we can add tools to the chat model to build an agent. + */ + isAgentAvailable(): boolean | undefined { + if (Private.getChatModel() === null) { + return; + } + return Private.getChatModel()?.bindTools !== undefined; + } + /** * Get the settings schema of a given provider. */ @@ -327,6 +354,11 @@ export class AIProviderRegistry implements IAIProviderRegistry { ...fullSettings }) ); + if (this.isAgentAvailable() && this._allowTools) { + if (this._tools.length) { + this._buildAgent(); + } + } } catch (e: any) { this.chatError = e.message; Private.setChatModel(null); @@ -338,6 +370,38 @@ export class AIProviderRegistry implements IAIProviderRegistry { this._providerChanged.emit('chat'); } + /** + * Allowing the usage of tools from settings. + */ + get allowTools(): boolean { + return this._allowTools; + } + set allowTools(value: boolean) { + if (this._allowTools !== value) { + this._allowTools = value; + this._allowToolsChanged.emit(value); + } + } + + /** + * Set the tools to use with the chat. + */ + setTools(tools: Tool[]): boolean { + if (!this.isAgentAvailable()) { + this._tools = []; + return false; + } + this._tools = tools; + return this._buildAgent(); + } + + /** + * A signal triggered when the setting on tool usage has changed. + */ + get allowToolsChanged(): ISignal { + return this._allowToolsChanged; + } + /** * A signal emitting when the provider or its settings has changed. */ @@ -372,6 +436,32 @@ export class AIProviderRegistry implements IAIProviderRegistry { return fullSettings; } + /** + * Build an agent with given tools. + */ + private _buildAgent(): boolean { + console.log('Build Agent'); + if (this._tools.length) { + const chatModel = Private.getChatModel(); + if (chatModel === null || chatModel.bindTools === undefined) { + Private.setAgent(null); + this._tools = []; + return false; + } + chatModel.bindTools?.(this._tools); + Private.setChatModel(chatModel); + Private.setAgent( + createReactAgent({ + llm: chatModel, + tools: this._tools + }) + ); + } else { + Private.setAgent(null); + } + return true; + } + private _secretsManager: ISecretsManager | null; private _providerChanged = new Signal(this); private _chatError: string = ''; @@ -387,6 +477,9 @@ export class AIProviderRegistry implements IAIProviderRegistry { }; private _chatPrompt: string = ''; private _completerPrompt: string = ''; + private _allowTools: boolean; + private _allowToolsChanged = new Signal(this); + private _tools: Tool[] = []; } export namespace AIProviderRegistry { @@ -511,4 +604,15 @@ namespace Private { export function getCompleter(): IBaseCompleter | null { return completer; } + + /** + * The agent getter and setter. + */ + let agent: CompiledStateGraph | null = null; + export function setAgent(value: CompiledStateGraph | null): void { + agent = value; + } + export function getAgent(): CompiledStateGraph | null { + return agent; + } } diff --git a/src/tokens.ts b/src/tokens.ts index 367bfe5..9d02124 100644 --- a/src/tokens.ts +++ b/src/tokens.ts @@ -1,4 +1,5 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { StructuredToolInterface } from '@langchain/core/tools'; import { ReadonlyPartialJSONObject, Token } from '@lumino/coreutils'; import { ISignal } from '@lumino/signaling'; import { JSONSchema7 } from 'json-schema'; @@ -12,7 +13,8 @@ export const PLUGIN_IDS = { completer: '@jupyterlite/ai:completer', providerRegistry: '@jupyterlite/ai:provider-registry', settingsConnector: '@jupyterlite/ai:settings-connector', - systemPrompts: '@jupyterlite/ai:system-prompts' + systemPrompts: '@jupyterlite/ai:system-prompts', + toolRegistry: '@jupyterlite/ai:tool-registry' }; export type ModelRole = 'chat' | 'completer'; @@ -90,7 +92,7 @@ export interface IAIProviderRegistry { /** * Get the current completer of the completion provider. */ - currentCompleter: AICompleter | null; + readonly currentCompleter: AICompleter | null; /** * Getter/setter for the completer system prompt. */ @@ -98,11 +100,19 @@ export interface IAIProviderRegistry { /** * Get the current llm chat model. */ - currentChatModel: AIChatModel | null; + readonly currentChatModel: AIChatModel | null; + /** + * Get the current agent. + */ + readonly currentAgent: AIChatModel | null; /** * Getter/setter for the chat system prompt. */ chatSystemPrompt: string; + /** + * Check if tools can be added to the chat model, to build an agent. + */ + isAgentAvailable(): boolean | undefined; /** * Get the settings schema of a given provider. */ @@ -135,6 +145,18 @@ export interface IAIProviderRegistry { * @param options - An object with the name and the settings of the provider to use. */ setChatProvider(settings: ReadonlyPartialJSONObject): void; + /** + * Allowing the usage of tools from settings. + */ + allowTools: boolean; + /** + * Set the tools to use with the chat. + */ + setTools(tools: Tool[]): boolean; + /** + * A signal triggered when the ability to use tools changed. + */ + readonly allowToolsChanged: ISignal; /** * A signal emitting when the provider or its settings has changed. */ @@ -149,6 +171,35 @@ export interface IAIProviderRegistry { readonly completerError: string; } +/** + * The type describing a tool used in langgraph. + */ +export type Tool = StructuredToolInterface; + +/** + * The tool registry interface. + */ +export interface IToolRegistry { + /** + * The registered tools. + */ + readonly tools: Tool[]; + /** + * A signal triggered when the tools has changed; + */ + readonly toolsChanged: ISignal; + /** + * Add a new tool to the registry. + */ + add(provider: Tool): void; + /** + * Get a tool for a given name. + * Return null if the name is not provided or if there is no registered tool with the + * given name. + */ + get(name: string | null): Tool | null; +} + /** * The provider registry token. */ @@ -156,3 +207,11 @@ export const IAIProviderRegistry = new Token( '@jupyterlite/ai:provider-registry', 'Provider for chat and completion LLM provider' ); + +/** + * The tool registry token. + */ +export const IToolRegistry = new Token( + '@jupyterlite/ai:tool-registry', + 'Tool registry for AI agent' +); diff --git a/src/tool-registry.ts b/src/tool-registry.ts new file mode 100644 index 0000000..25a696f --- /dev/null +++ b/src/tool-registry.ts @@ -0,0 +1,44 @@ +import { ISignal, Signal } from '@lumino/signaling'; +import { IToolRegistry, Tool } from './tokens'; + +export class ToolsRegistry implements IToolRegistry { + /** + * The registered tools. + */ + get tools(): Tool[] { + return this._tools; + } + + /** + * A signal triggered when the tools has changed. + */ + get toolsChanged(): ISignal { + return this._toolsChanged; + } + + /** + * Add a new tool to the registry. + */ + add(tool: Tool): void { + const index = this._tools.findIndex(t => t.name === tool.name); + if (index === -1) { + this._tools.push(tool); + this._toolsChanged.emit(); + } + } + + /** + * Get a tool for a given name. + * Return null if the name is not provided or if there is no registered tool with the + * given name. + */ + get(name: string | null): Tool | null { + if (name === null) { + return null; + } + return this._tools.find(t => t.name === name) || null; + } + + private _tools: Tool[] = []; + private _toolsChanged = new Signal(this); +} diff --git a/src/tools/create-notebook.ts b/src/tools/create-notebook.ts new file mode 100644 index 0000000..7a8fec9 --- /dev/null +++ b/src/tools/create-notebook.ts @@ -0,0 +1,37 @@ +import { StructuredToolInterface, tool } from '@langchain/core/tools'; +import { CommandRegistry } from '@lumino/commands'; +import { ReadonlyPartialJSONObject } from '@lumino/coreutils'; +import { z } from 'zod'; + +export const createNotebook = ( + commands: CommandRegistry +): StructuredToolInterface => { + return tool( + async ({ command, args }) => { + let result: any = 'No command called'; + if (command === 'notebook:create-new') { + result = await commands.execute( + command, + args as ReadonlyPartialJSONObject + ); + } + const output = ` +The test tool has been called, with the following query: "${command}" +The args for the commands where ${JSON.stringify(args)} +The result of the command (if called) is "${result}" +`; + return output; + }, + { + name: 'createNotebook', + description: 'Run jupyterlab command to create a notebook', + schema: z.object({ + command: z.string().describe('The Jupyterlab command id to execute'), + args: z + .object({}) + .passthrough() + .describe('The argument for the command') + }) + } + ); +}; diff --git a/style/base.css b/style/base.css index c4925ee..03c1c32 100644 --- a/style/base.css +++ b/style/base.css @@ -20,6 +20,11 @@ min-height: 300px; } +.jp-AIToolSelect-item .lm-Menu-itemIcon { + display: flex; + align-items: center; +} + .jp-chat-welcome-message { text-align: center; max-width: 350px; diff --git a/webpack.config.js b/webpack.config.js new file mode 100644 index 0000000..e8031e3 --- /dev/null +++ b/webpack.config.js @@ -0,0 +1,9 @@ +module.exports = { + // Ignore source map warnings for @langchain/langgraph + ignoreWarnings: [ + { + module: /node_modules\/@langchain\/langgraph/ + }, + /Failed to parse source map.*@langchain/ + ] +}; diff --git a/yarn.lock b/yarn.lock index fcca9db..7d4bb46 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1577,6 +1577,7 @@ __metadata: "@langchain/community": ^0.3.48 "@langchain/core": ^0.3.62 "@langchain/google-genai": ^0.2.14 + "@langchain/langgraph": ^0.3.5 "@langchain/mistralai": ^0.2.1 "@langchain/ollama": ^0.2.3 "@langchain/openai": ^0.5.16 @@ -2063,6 +2064,58 @@ __metadata: languageName: node linkType: hard +"@langchain/langgraph-checkpoint@npm:~0.0.18": + version: 0.0.18 + resolution: "@langchain/langgraph-checkpoint@npm:0.0.18" + dependencies: + uuid: ^10.0.0 + peerDependencies: + "@langchain/core": ">=0.2.31 <0.4.0" + checksum: bd7ba56696bbfde0be8f2cef2f0dadf30826732cbb1d9bf92439f314a01219053c50ff6b624ef58a9d6723d6e1a8b872e725abc4d95e4ea9610a41c5f1e41047 + languageName: node + linkType: hard + +"@langchain/langgraph-sdk@npm:~0.0.90": + version: 0.0.92 + resolution: "@langchain/langgraph-sdk@npm:0.0.92" + dependencies: + "@types/json-schema": ^7.0.15 + p-queue: ^6.6.2 + p-retry: 4 + uuid: ^9.0.0 + peerDependencies: + "@langchain/core": ">=0.2.31 <0.4.0" + react: ^18 || ^19 + react-dom: ^18 || ^19 + peerDependenciesMeta: + "@langchain/core": + optional: true + react: + optional: true + react-dom: + optional: true + checksum: 905380f0785da27dd5638b37e6215f3f2655cd0cb9111f6178e4b72085e930f1e250dfbfa6bb3b0ce4b7dc7f129f988095fac2240b420c09f63d65c3a8d5b8d6 + languageName: node + linkType: hard + +"@langchain/langgraph@npm:^0.3.5": + version: 0.3.7 + resolution: "@langchain/langgraph@npm:0.3.7" + dependencies: + "@langchain/langgraph-checkpoint": ~0.0.18 + "@langchain/langgraph-sdk": ~0.0.90 + uuid: ^10.0.0 + zod: ^3.25.32 + peerDependencies: + "@langchain/core": ">=0.3.58 < 0.4.0" + zod-to-json-schema: ^3.x + peerDependenciesMeta: + zod-to-json-schema: + optional: true + checksum: 6a924940d92d4c0c97c18665a1d1f70bfa83c2d5f8b221ea6fe90d3d47bfa6518a6032058bb9b722199ce0c9fb511e9fe2818e2d877ffb7c43510715c1b82d84 + languageName: node + linkType: hard + "@langchain/mistralai@npm:^0.2.1": version: 0.2.1 resolution: "@langchain/mistralai@npm:0.2.1" @@ -8307,7 +8360,7 @@ __metadata: languageName: node linkType: hard -"uuid@npm:^9.0.1": +"uuid@npm:^9.0.0, uuid@npm:^9.0.1": version: 9.0.1 resolution: "uuid@npm:9.0.1" bin: