Skip to content

Use langgraph (agent) to call tools #110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .eslintignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ coverage
tests
**/__tests__
ui-tests
webpack.config.js
4 changes: 3 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
"@langchain/community": "^0.3.48",
"@langchain/core": "^0.3.62",
"@langchain/google-genai": "^0.2.14",
"@langchain/langgraph": "^0.3.5",
"@langchain/mistralai": "^0.2.1",
"@langchain/ollama": "^0.2.3",
"@langchain/openai": "^0.5.16",
Expand Down Expand Up @@ -127,6 +128,7 @@
"jupyterlab": {
"extension": true,
"outputDir": "jupyterlite_ai/labextension",
"schemaDir": "schema"
"schemaDir": "schema",
"webpackConfig": "./webpack.config.js"
}
}
6 changes: 6 additions & 0 deletions schema/provider-registry.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
"description": "Whether to use or not the secrets manager. If not, secrets will be stored in the browser (local storage)",
"default": true
},
"AllowToolsUsage": {
"type": "boolean",
"title": "Allow tools usage",
"description": "Whether tools are available in chat or not",
"default": true
},
"UniqueProvider": {
"type": "boolean",
"title": "Use the same provider for chat and completer",
Expand Down
193 changes: 159 additions & 34 deletions src/chat-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ import {
IChatMessage,
IChatModel,
IInputModel,
INewMessage
INewMessage,
IUser
} from '@jupyter/chat';
import {
AIMessage,
BaseMessage,
HumanMessage,
mergeMessageRuns,
SystemMessage
Expand All @@ -25,7 +27,7 @@ import { UUID } from '@lumino/coreutils';

import { DEFAULT_CHAT_SYSTEM_PROMPT } from './default-prompts';
import { jupyternautLiteIcon } from './icons';
import { IAIProviderRegistry } from './tokens';
import { IAIProviderRegistry, IToolRegistry } from './tokens';
import { AIChatModel } from './types/ai-model';

/**
Expand Down Expand Up @@ -56,13 +58,38 @@ export class ChatHandler extends AbstractChatModel {
constructor(options: ChatHandler.IOptions) {
super(options);
this._providerRegistry = options.providerRegistry;
this._toolRegistry = options.toolRegistry;

this._providerRegistry.providerChanged.connect(() => {
this._errorMessage = this._providerRegistry.chatError;
});
}

get provider(): AIChatModel | null {
/**
* The provider registry.
*/
get providerRegistry(): IAIProviderRegistry {
return this._providerRegistry;
}

/**
* Get the tool registry.
*/
get toolRegistry(): IToolRegistry | undefined {
return this._toolRegistry;
}

/**
* Get the agent from the provider registry.
*/
get agent(): AIChatModel | null {
return this._providerRegistry.currentAgent;
}

/**
* Get the chat model from the provider registry.
*/
get chatModel(): AIChatModel | null {
return this._providerRegistry.currentChatModel;
}

Expand All @@ -84,12 +111,15 @@ export class ChatHandler extends AbstractChatModel {
}

/**
* Get/set the system prompt for the chat.
* Get the system prompt for the chat.
*/
get systemPrompt(): string {
return (
this._providerRegistry.chatSystemPrompt ?? DEFAULT_CHAT_SYSTEM_PROMPT
);
let prompt =
this._providerRegistry.chatSystemPrompt ?? DEFAULT_CHAT_SYSTEM_PROMPT;
if (this.agent !== null) {
prompt = prompt.concat('\nPlease use the tool that is provided');
}
return prompt;
}

async sendMessage(message: INewMessage): Promise<boolean> {
Expand All @@ -110,7 +140,9 @@ export class ChatHandler extends AbstractChatModel {
};
this.messageAdded(msg);

if (this._providerRegistry.currentChatModel === null) {
const chatModel = this.chatModel;

if (chatModel === null) {
const errorMsg: IChatMessage = {
id: UUID.uuid4(),
body: `**${this._errorMessage ? this._errorMessage : this._defaultErrorMessage}**`,
Expand All @@ -137,23 +169,52 @@ export class ChatHandler extends AbstractChatModel {
const sender = { username: this._personaName, avatar_url: AI_AVATAR };
this.updateWriters([{ user: sender }]);

// create an empty message to be filled by the AI provider
if (this.agent !== null) {
return this._sendAgentMessage(this.agent, messages, sender);
}

return this._sentChatMessage(chatModel, messages, sender);
}

async getHistory(): Promise<IChatHistory> {
return this._history;
}

dispose(): void {
super.dispose();
}

messageAdded(message: IChatMessage): void {
super.messageAdded(message);
}

stopStreaming(): void {
this._controller?.abort();
}

createChatContext(): IChatContext {
return new ChatHandler.ChatContext({ model: this });
}

private async _sentChatMessage(
chatModel: AIChatModel,
messages: BaseMessage[],
sender: IUser
): Promise<boolean> {
// Create an empty message to be filled by the AI provider
const botMsg: IChatMessage = {
id: UUID.uuid4(),
body: '',
sender,
time: Private.getTimestampMs(),
type: 'msg'
};

let content = '';

this._controller = new AbortController();
try {
for await (const chunk of await this._providerRegistry.currentChatModel.stream(
messages,
{ signal: this._controller.signal }
)) {
for await (const chunk of await chatModel.stream(messages, {
signal: this._controller.signal
})) {
content += chunk.content ?? chunk;
botMsg.body = content;
this.messageAdded(botMsg);
Expand All @@ -177,24 +238,72 @@ export class ChatHandler extends AbstractChatModel {
}
}

async getHistory(): Promise<IChatHistory> {
return this._history;
}

dispose(): void {
super.dispose();
}

messageAdded(message: IChatMessage): void {
super.messageAdded(message);
}

stopStreaming(): void {
this._controller?.abort();
}

createChatContext(): IChatContext {
return new ChatHandler.ChatContext({ model: this });
private async _sendAgentMessage(
agent: AIChatModel,
messages: BaseMessage[],
sender: IUser
): Promise<boolean> {
this._controller = new AbortController();
try {
for await (const chunk of await agent.stream(
{ messages },
{
streamMode: 'updates',
signal: this._controller.signal
}
)) {
if ((chunk as any).agent) {
messages = (chunk as any).agent.messages;
messages.forEach(message => {
const contents: string[] = [];
if (typeof message.content === 'string') {
contents.push(message.content);
} else if (Array.isArray(message.content)) {
message.content.forEach(content => {
if (content.type === 'text') {
contents.push(content.text);
}
});
}
contents.forEach(content => {
this.messageAdded({
id: UUID.uuid4(),
body: content,
sender,
time: Private.getTimestampMs(),
type: 'msg'
});
});
});
} else if ((chunk as any).tools) {
messages = (chunk as any).tools.messages;
messages.forEach(message => {
this.messageAdded({
id: UUID.uuid4(),
body: message.content as string,
sender: { username: `Tool "${message.name}"` },
time: Private.getTimestampMs(),
type: 'msg'
});
});
}
}
return true;
} catch (reason) {
const error = this._providerRegistry.formatErrorMessage(reason);
const errorMsg: IChatMessage = {
id: UUID.uuid4(),
body: `**${error}**`,
sender: { username: 'ERROR' },
time: Private.getTimestampMs(),
type: 'msg'
};
this.messageAdded(errorMsg);
return false;
} finally {
this.updateWriters([]);
this._controller = null;
}
}

private _providerRegistry: IAIProviderRegistry;
Expand All @@ -203,6 +312,7 @@ export class ChatHandler extends AbstractChatModel {
private _history: IChatHistory = { messages: [] };
private _defaultErrorMessage = 'AI provider not configured';
private _controller: AbortController | null = null;
private _toolRegistry?: IToolRegistry;
}

export namespace ChatHandler {
Expand All @@ -211,13 +321,28 @@ export namespace ChatHandler {
*/
export interface IOptions extends IChatModel.IOptions {
providerRegistry: IAIProviderRegistry;
toolRegistry?: IToolRegistry;
}

/**
* The minimal chat context.
* The chat context.
*/
export class ChatContext extends AbstractChatContext {
users = [];

/**
* The provider registry.
*/
get providerRegistry(): IAIProviderRegistry {
return (this._model as ChatHandler).providerRegistry;
}

/**
* The tool registry.
*/
get toolsRegistry(): IToolRegistry | undefined {
return (this._model as ChatHandler).toolRegistry;
}
}

/**
Expand Down
7 changes: 7 additions & 0 deletions src/components/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
/*
* Copyright (c) Jupyter Development Team.
* Distributed under the terms of the Modified BSD License.
*/

export * from './stop-button';
export * from './tool-select';
Loading
Loading