Skip to content

Commit f0637eb

Browse files
authored
Making the LLM providers more generics (#10)
* WIP on making the LLM providers more generics * Update changes in settings to the chat and completion LLM * Cleaning and lint * Provides only one completion provider * Rename 'client' to 'provider' and LlmProvider to AIProvider for better readability
1 parent f68be53 commit f0637eb

14 files changed

+416
-153
lines changed

package.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@
6363
"@langchain/core": "^0.3.13",
6464
"@langchain/mistralai": "^0.1.1",
6565
"@lumino/coreutils": "^2.1.2",
66-
"@lumino/polling": "^2.1.2"
66+
"@lumino/polling": "^2.1.2",
67+
"@lumino/signaling": "^2.1.2"
6768
},
6869
"devDependencies": {
6970
"@jupyterlab/builder": "^4.0.0",

schema/ai-provider.json

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"title": "AI provider",
3+
"description": "Provider settings",
4+
"type": "object",
5+
"properties": {
6+
"provider": {
7+
"type": "string",
8+
"title": "The AI provider",
9+
"description": "The AI provider to use for chat and completion",
10+
"default": "None",
11+
"enum": ["None", "MistralAI"]
12+
},
13+
"apiKey": {
14+
"type": "string",
15+
"title": "The Codestral API key",
16+
"description": "The API key to use for Codestral",
17+
"default": ""
18+
}
19+
},
20+
"additionalProperties": false
21+
}

schema/inline-provider.json

Lines changed: 0 additions & 14 deletions
This file was deleted.

src/handler.ts renamed to src/chat-handler.ts

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,30 @@ import {
99
IChatMessage,
1010
INewMessage
1111
} from '@jupyter/chat';
12-
import { UUID } from '@lumino/coreutils';
13-
import type { ChatMistralAI } from '@langchain/mistralai';
12+
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
1413
import {
1514
AIMessage,
1615
HumanMessage,
1716
mergeMessageRuns
1817
} from '@langchain/core/messages';
18+
import { UUID } from '@lumino/coreutils';
1919

2020
export type ConnectionMessage = {
2121
type: 'connection';
2222
client_id: string;
2323
};
2424

25-
export class CodestralHandler extends ChatModel {
26-
constructor(options: CodestralHandler.IOptions) {
25+
export class ChatHandler extends ChatModel {
26+
constructor(options: ChatHandler.IOptions) {
2727
super(options);
28-
this._mistralClient = options.mistralClient;
28+
this._provider = options.provider;
29+
}
30+
31+
get provider(): BaseChatModel | null {
32+
return this._provider;
33+
}
34+
set provider(provider: BaseChatModel | null) {
35+
this._provider = provider;
2936
}
3037

3138
async sendMessage(message: INewMessage): Promise<boolean> {
@@ -38,6 +45,19 @@ export class CodestralHandler extends ChatModel {
3845
type: 'msg'
3946
};
4047
this.messageAdded(msg);
48+
49+
if (this._provider === null) {
50+
const botMsg: IChatMessage = {
51+
id: UUID.uuid4(),
52+
body: '**AI provider not configured for the chat**',
53+
sender: { username: 'ERROR' },
54+
time: Date.now(),
55+
type: 'msg'
56+
};
57+
this.messageAdded(botMsg);
58+
return false;
59+
}
60+
4161
this._history.messages.push(msg);
4262

4363
const messages = mergeMessageRuns(
@@ -48,13 +68,14 @@ export class CodestralHandler extends ChatModel {
4868
return new AIMessage(msg.body);
4969
})
5070
);
51-
const response = await this._mistralClient.invoke(messages);
71+
72+
const response = await this._provider.invoke(messages);
5273
// TODO: fix deprecated response.text
5374
const content = response.text;
5475
const botMsg: IChatMessage = {
5576
id: UUID.uuid4(),
5677
body: content,
57-
sender: { username: 'Codestral' },
78+
sender: { username: 'Bot' },
5879
time: Date.now(),
5980
type: 'msg'
6081
};
@@ -75,12 +96,12 @@ export class CodestralHandler extends ChatModel {
7596
super.messageAdded(message);
7697
}
7798

78-
private _mistralClient: ChatMistralAI;
99+
private _provider: BaseChatModel | null;
79100
private _history: IChatHistory = { messages: [] };
80101
}
81102

82-
export namespace CodestralHandler {
103+
export namespace ChatHandler {
83104
export interface IOptions extends ChatModel.IOptions {
84-
mistralClient: ChatMistralAI;
105+
provider: BaseChatModel | null;
85106
}
86107
}

src/completion-provider.ts

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import {
2+
CompletionHandler,
3+
IInlineCompletionContext,
4+
IInlineCompletionProvider
5+
} from '@jupyterlab/completer';
6+
import { LLM } from '@langchain/core/language_models/llms';
7+
8+
import { getCompleter, IBaseCompleter } from './llm-models';
9+
10+
/**
11+
* The generic completion provider to register to the completion provider manager.
12+
*/
13+
export class CompletionProvider implements IInlineCompletionProvider {
14+
readonly identifier = '@jupyterlite/ai';
15+
16+
constructor(options: CompletionProvider.IOptions) {
17+
this.name = options.name;
18+
}
19+
20+
/**
21+
* Getter and setter of the name.
22+
* The setter will create the appropriate completer, accordingly to the name.
23+
*/
24+
get name(): string {
25+
return this._name;
26+
}
27+
set name(name: string) {
28+
this._name = name;
29+
this._completer = getCompleter(name);
30+
}
31+
32+
/**
33+
* get the current completer.
34+
*/
35+
get completer(): IBaseCompleter | null {
36+
return this._completer;
37+
}
38+
39+
/**
40+
* Get the LLM completer.
41+
*/
42+
get llmCompleter(): LLM | null {
43+
return this._completer?.provider || null;
44+
}
45+
46+
async fetch(
47+
request: CompletionHandler.IRequest,
48+
context: IInlineCompletionContext
49+
) {
50+
return this._completer?.fetch(request, context);
51+
}
52+
53+
private _name: string = 'None';
54+
private _completer: IBaseCompleter | null = null;
55+
}
56+
57+
export namespace CompletionProvider {
58+
export interface IOptions {
59+
name: string;
60+
}
61+
}

src/index.ts

Lines changed: 47 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -13,55 +13,20 @@ import { ICompletionProviderManager } from '@jupyterlab/completer';
1313
import { INotebookTracker } from '@jupyterlab/notebook';
1414
import { IRenderMimeRegistry } from '@jupyterlab/rendermime';
1515
import { ISettingRegistry } from '@jupyterlab/settingregistry';
16-
import { ChatMistralAI, MistralAI } from '@langchain/mistralai';
1716

18-
import { CodestralHandler } from './handler';
19-
import { CodestralProvider } from './provider';
20-
21-
const inlineProviderPlugin: JupyterFrontEndPlugin<void> = {
22-
id: 'jupyterlab-codestral:inline-provider',
23-
autoStart: true,
24-
requires: [ICompletionProviderManager, ISettingRegistry],
25-
activate: (
26-
app: JupyterFrontEnd,
27-
manager: ICompletionProviderManager,
28-
settingRegistry: ISettingRegistry
29-
): void => {
30-
const mistralClient = new MistralAI({
31-
model: 'codestral-latest',
32-
apiKey: 'TMP'
33-
});
34-
const provider = new CodestralProvider({ mistralClient });
35-
manager.registerInlineProvider(provider);
36-
37-
settingRegistry
38-
.load(inlineProviderPlugin.id)
39-
.then(settings => {
40-
const updateKey = () => {
41-
const apiKey = settings.get('apiKey').composite as string;
42-
mistralClient.apiKey = apiKey;
43-
};
44-
45-
settings.changed.connect(() => updateKey());
46-
updateKey();
47-
})
48-
.catch(reason => {
49-
console.error(
50-
`Failed to load settings for ${inlineProviderPlugin.id}`,
51-
reason
52-
);
53-
});
54-
}
55-
};
17+
import { ChatHandler } from './chat-handler';
18+
import { AIProvider } from './provider';
19+
import { IAIProvider } from './token';
5620

5721
const chatPlugin: JupyterFrontEndPlugin<void> = {
5822
id: 'jupyterlab-codestral:chat',
59-
description: 'Codestral chat extension',
23+
description: 'LLM chat extension',
6024
autoStart: true,
6125
optional: [INotebookTracker, ISettingRegistry, IThemeManager],
62-
requires: [IRenderMimeRegistry],
26+
requires: [IAIProvider, IRenderMimeRegistry],
6327
activate: async (
6428
app: JupyterFrontEnd,
29+
aiProvider: IAIProvider,
6530
rmRegistry: IRenderMimeRegistry,
6631
notebookTracker: INotebookTracker | null,
6732
settingsRegistry: ISettingRegistry | null,
@@ -75,15 +40,15 @@ const chatPlugin: JupyterFrontEndPlugin<void> = {
7540
});
7641
}
7742

78-
const mistralClient = new ChatMistralAI({
79-
model: 'codestral-latest',
80-
apiKey: 'TMP'
81-
});
82-
const chatHandler = new CodestralHandler({
83-
mistralClient,
43+
const chatHandler = new ChatHandler({
44+
provider: aiProvider.chatModel,
8445
activeCellManager: activeCellManager
8546
});
8647

48+
aiProvider.modelChange.connect(() => {
49+
chatHandler.provider = aiProvider.chatModel;
50+
});
51+
8752
let sendWithShiftEnter = false;
8853
let enableCodeToolbar = true;
8954

@@ -94,25 +59,6 @@ const chatPlugin: JupyterFrontEndPlugin<void> = {
9459
chatHandler.config = { sendWithShiftEnter, enableCodeToolbar };
9560
}
9661

97-
// TODO: handle the apiKey better
98-
settingsRegistry
99-
?.load(inlineProviderPlugin.id)
100-
.then(settings => {
101-
const updateKey = () => {
102-
const apiKey = settings.get('apiKey').composite as string;
103-
mistralClient.apiKey = apiKey;
104-
};
105-
106-
settings.changed.connect(() => updateKey());
107-
updateKey();
108-
})
109-
.catch(reason => {
110-
console.error(
111-
`Failed to load settings for ${inlineProviderPlugin.id}`,
112-
reason
113-
);
114-
});
115-
11662
Promise.all([app.restored, settingsRegistry?.load(chatPlugin.id)])
11763
.then(([, settings]) => {
11864
if (!settings) {
@@ -148,4 +94,38 @@ const chatPlugin: JupyterFrontEndPlugin<void> = {
14894
}
14995
};
15096

151-
export default [inlineProviderPlugin, chatPlugin];
97+
const aiProviderPlugin: JupyterFrontEndPlugin<IAIProvider> = {
98+
id: 'jupyterlab-codestral:ai-provider',
99+
autoStart: true,
100+
requires: [ICompletionProviderManager, ISettingRegistry],
101+
provides: IAIProvider,
102+
activate: (
103+
app: JupyterFrontEnd,
104+
manager: ICompletionProviderManager,
105+
settingRegistry: ISettingRegistry
106+
): IAIProvider => {
107+
const aiProvider = new AIProvider({ completionProviderManager: manager });
108+
109+
settingRegistry
110+
.load(aiProviderPlugin.id)
111+
.then(settings => {
112+
const updateProvider = () => {
113+
const provider = settings.get('provider').composite as string;
114+
aiProvider.setModels(provider, settings.composite);
115+
};
116+
117+
settings.changed.connect(() => updateProvider());
118+
updateProvider();
119+
})
120+
.catch(reason => {
121+
console.error(
122+
`Failed to load settings for ${aiProviderPlugin.id}`,
123+
reason
124+
);
125+
});
126+
127+
return aiProvider;
128+
}
129+
};
130+
131+
export default [chatPlugin, aiProviderPlugin];

src/llm-models/base-completer.ts

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import {
2+
CompletionHandler,
3+
IInlineCompletionContext
4+
} from '@jupyterlab/completer';
5+
import { LLM } from '@langchain/core/language_models/llms';
6+
7+
export interface IBaseCompleter {
8+
/**
9+
* The LLM completer.
10+
*/
11+
provider: LLM;
12+
13+
/**
14+
* The fetch request for the LLM completer.
15+
*/
16+
fetch(
17+
request: CompletionHandler.IRequest,
18+
context: IInlineCompletionContext
19+
): Promise<any>;
20+
}

0 commit comments

Comments
 (0)