Skip to content

Commit 8043bf9

Browse files
authored
Refactoring AIProvider and handling errors (#15)
* Better handling of error with providers * Add a method to get the error message when catching an error
1 parent f0637eb commit 8043bf9

File tree

8 files changed

+146
-67
lines changed

8 files changed

+146
-67
lines changed

src/chat-handler.ts

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ import {
1616
mergeMessageRuns
1717
} from '@langchain/core/messages';
1818
import { UUID } from '@lumino/coreutils';
19+
import { getErrorMessage } from './llm-models';
20+
import { IAIProvider } from './token';
1921

2022
export type ConnectionMessage = {
2123
type: 'connection';
@@ -25,14 +27,14 @@ export type ConnectionMessage = {
2527
export class ChatHandler extends ChatModel {
2628
constructor(options: ChatHandler.IOptions) {
2729
super(options);
28-
this._provider = options.provider;
30+
this._aiProvider = options.aiProvider;
31+
this._aiProvider.modelChange.connect(() => {
32+
this._errorMessage = this._aiProvider.chatError;
33+
});
2934
}
3035

3136
get provider(): BaseChatModel | null {
32-
return this._provider;
33-
}
34-
set provider(provider: BaseChatModel | null) {
35-
this._provider = provider;
37+
return this._aiProvider.chatModel;
3638
}
3739

3840
async sendMessage(message: INewMessage): Promise<boolean> {
@@ -46,15 +48,15 @@ export class ChatHandler extends ChatModel {
4648
};
4749
this.messageAdded(msg);
4850

49-
if (this._provider === null) {
50-
const botMsg: IChatMessage = {
51+
if (this._aiProvider.chatModel === null) {
52+
const errorMsg: IChatMessage = {
5153
id: UUID.uuid4(),
52-
body: '**AI provider not configured for the chat**',
54+
body: `**${this._errorMessage ? this._errorMessage : this._defaultErrorMessage}**`,
5355
sender: { username: 'ERROR' },
5456
time: Date.now(),
5557
type: 'msg'
5658
};
57-
this.messageAdded(botMsg);
59+
this.messageAdded(errorMsg);
5860
return false;
5961
}
6062

@@ -69,19 +71,37 @@ export class ChatHandler extends ChatModel {
6971
})
7072
);
7173

72-
const response = await this._provider.invoke(messages);
73-
// TODO: fix deprecated response.text
74-
const content = response.text;
75-
const botMsg: IChatMessage = {
76-
id: UUID.uuid4(),
77-
body: content,
78-
sender: { username: 'Bot' },
79-
time: Date.now(),
80-
type: 'msg'
81-
};
82-
this.messageAdded(botMsg);
83-
this._history.messages.push(botMsg);
84-
return true;
74+
this.updateWriters([{ username: 'AI' }]);
75+
return this._aiProvider.chatModel
76+
.invoke(messages)
77+
.then(response => {
78+
const content = response.content;
79+
const botMsg: IChatMessage = {
80+
id: UUID.uuid4(),
81+
body: content.toString(),
82+
sender: { username: 'AI' },
83+
time: Date.now(),
84+
type: 'msg'
85+
};
86+
this.messageAdded(botMsg);
87+
this._history.messages.push(botMsg);
88+
return true;
89+
})
90+
.catch(reason => {
91+
const error = getErrorMessage(this._aiProvider.name, reason);
92+
const errorMsg: IChatMessage = {
93+
id: UUID.uuid4(),
94+
body: `**${error}**`,
95+
sender: { username: 'ERROR' },
96+
time: Date.now(),
97+
type: 'msg'
98+
};
99+
this.messageAdded(errorMsg);
100+
return false;
101+
})
102+
.finally(() => {
103+
this.updateWriters([]);
104+
});
85105
}
86106

87107
async getHistory(): Promise<IChatHistory> {
@@ -96,12 +116,14 @@ export class ChatHandler extends ChatModel {
96116
super.messageAdded(message);
97117
}
98118

99-
private _provider: BaseChatModel | null;
119+
private _aiProvider: IAIProvider;
120+
private _errorMessage: string = '';
100121
private _history: IChatHistory = { messages: [] };
122+
private _defaultErrorMessage = 'AI provider not configured';
101123
}
102124

103125
export namespace ChatHandler {
104126
export interface IOptions extends ChatModel.IOptions {
105-
provider: BaseChatModel | null;
127+
aiProvider: IAIProvider;
106128
}
107129
}

src/completion-provider.ts

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ import {
55
} from '@jupyterlab/completer';
66
import { LLM } from '@langchain/core/language_models/llms';
77

8-
import { getCompleter, IBaseCompleter } from './llm-models';
8+
import { getCompleter, IBaseCompleter, BaseCompleter } from './llm-models';
9+
import { ReadonlyPartialJSONObject } from '@lumino/coreutils';
910

1011
/**
1112
* The generic completion provider to register to the completion provider manager.
@@ -14,23 +15,36 @@ export class CompletionProvider implements IInlineCompletionProvider {
1415
readonly identifier = '@jupyterlite/ai';
1516

1617
constructor(options: CompletionProvider.IOptions) {
17-
this.name = options.name;
18+
const { name, settings } = options;
19+
this.setCompleter(name, settings);
1820
}
1921

2022
/**
21-
* Getter and setter of the name.
22-
* The setter will create the appropriate completer, accordingly to the name.
23+
* Set the completer.
24+
*
25+
* @param name - the name of the completer.
26+
* @param settings - The settings associated to the completer.
27+
*/
28+
setCompleter(name: string, settings: ReadonlyPartialJSONObject) {
29+
try {
30+
this._completer = getCompleter(name, settings);
31+
this._name = this._completer === null ? 'None' : name;
32+
} catch (e: any) {
33+
this._completer = null;
34+
this._name = 'None';
35+
throw e;
36+
}
37+
}
38+
39+
/**
40+
* Get the current completer name.
2341
*/
2442
get name(): string {
2543
return this._name;
2644
}
27-
set name(name: string) {
28-
this._name = name;
29-
this._completer = getCompleter(name);
30-
}
3145

3246
/**
33-
* get the current completer.
47+
* Get the current completer.
3448
*/
3549
get completer(): IBaseCompleter | null {
3650
return this._completer;
@@ -55,7 +69,7 @@ export class CompletionProvider implements IInlineCompletionProvider {
5569
}
5670

5771
export namespace CompletionProvider {
58-
export interface IOptions {
72+
export interface IOptions extends BaseCompleter.IOptions {
5973
name: string;
6074
}
6175
}

src/index.ts

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,10 @@ const chatPlugin: JupyterFrontEndPlugin<void> = {
4141
}
4242

4343
const chatHandler = new ChatHandler({
44-
provider: aiProvider.chatModel,
44+
aiProvider: aiProvider,
4545
activeCellManager: activeCellManager
4646
});
4747

48-
aiProvider.modelChange.connect(() => {
49-
chatHandler.provider = aiProvider.chatModel;
50-
});
51-
5248
let sendWithShiftEnter = false;
5349
let enableCodeToolbar = true;
5450

src/llm-models/base-completer.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import {
33
IInlineCompletionContext
44
} from '@jupyterlab/completer';
55
import { LLM } from '@langchain/core/language_models/llms';
6+
import { ReadonlyPartialJSONObject } from '@lumino/coreutils';
67

78
export interface IBaseCompleter {
89
/**
@@ -18,3 +19,15 @@ export interface IBaseCompleter {
1819
context: IInlineCompletionContext
1920
): Promise<any>;
2021
}
22+
23+
/**
24+
* The namespace for the base completer.
25+
*/
26+
export namespace BaseCompleter {
27+
/**
28+
* The options for the constructor of a completer.
29+
*/
30+
export interface IOptions {
31+
settings: ReadonlyPartialJSONObject;
32+
}
33+
}

src/llm-models/codestral-completer.ts

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,16 @@ import { MistralAI } from '@langchain/mistralai';
77
import { Throttler } from '@lumino/polling';
88
import { CompletionRequest } from '@mistralai/mistralai';
99

10-
import { IBaseCompleter } from './base-completer';
10+
import { BaseCompleter, IBaseCompleter } from './base-completer';
1111

1212
/*
1313
* The Mistral API has a rate limit of 1 request per second
1414
*/
1515
const INTERVAL = 1000;
1616

1717
export class CodestralCompleter implements IBaseCompleter {
18-
constructor() {
19-
this._mistralProvider = new MistralAI({
20-
apiKey: 'TMP',
21-
model: 'codestral-latest'
22-
});
18+
constructor(options: BaseCompleter.IOptions) {
19+
this._mistralProvider = new MistralAI({ ...options.settings });
2320
this._throttler = new Throttler(async (data: CompletionRequest) => {
2421
const response = await this._mistralProvider.completionWithRetry(
2522
data,

src/llm-models/utils.ts

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,40 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models';
22
import { ChatMistralAI } from '@langchain/mistralai';
33
import { IBaseCompleter } from './base-completer';
44
import { CodestralCompleter } from './codestral-completer';
5+
import { ReadonlyPartialJSONObject } from '@lumino/coreutils';
56

67
/**
78
* Get an LLM completer from the name.
89
*/
9-
export function getCompleter(name: string): IBaseCompleter | null {
10+
export function getCompleter(
11+
name: string,
12+
settings: ReadonlyPartialJSONObject
13+
): IBaseCompleter | null {
1014
if (name === 'MistralAI') {
11-
return new CodestralCompleter();
15+
return new CodestralCompleter({ settings });
1216
}
1317
return null;
1418
}
1519

1620
/**
1721
* Get an LLM chat model from the name.
1822
*/
19-
export function getChatModel(name: string): BaseChatModel | null {
23+
export function getChatModel(
24+
name: string,
25+
settings: ReadonlyPartialJSONObject
26+
): BaseChatModel | null {
2027
if (name === 'MistralAI') {
21-
return new ChatMistralAI({ apiKey: 'TMP' });
28+
return new ChatMistralAI({ ...settings });
2229
}
2330
return null;
2431
}
32+
33+
/**
34+
* Get the error message from provider.
35+
*/
36+
export function getErrorMessage(name: string, error: any): string {
37+
if (name === 'MistralAI') {
38+
return error.message;
39+
}
40+
return 'Unknown provider';
41+
}

src/provider.ts

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ import { IAIProvider } from './token';
1010

1111
export class AIProvider implements IAIProvider {
1212
constructor(options: AIProvider.IOptions) {
13-
this._completionProvider = new CompletionProvider({ name: 'None' });
13+
this._completionProvider = new CompletionProvider({
14+
name: 'None',
15+
settings: {}
16+
});
1417
options.completionProviderManager.registerInlineProvider(
1518
this._completionProvider
1619
);
@@ -21,7 +24,7 @@ export class AIProvider implements IAIProvider {
2124
}
2225

2326
/**
24-
* get the current completer of the completion provider.
27+
* Get the current completer of the completion provider.
2528
*/
2629
get completer(): IBaseCompleter | null {
2730
if (this._name === null) {
@@ -31,7 +34,7 @@ export class AIProvider implements IAIProvider {
3134
}
3235

3336
/**
34-
* get the current llm chat model.
37+
* Get the current llm chat model.
3538
*/
3639
get chatModel(): BaseChatModel | null {
3740
if (this._name === null) {
@@ -40,6 +43,20 @@ export class AIProvider implements IAIProvider {
4043
return this._llmChatModel;
4144
}
4245

46+
/**
47+
* Get the current chat error;
48+
*/
49+
get chatError(): string {
50+
return this._chatError;
51+
}
52+
53+
/**
54+
* get the current completer error.
55+
*/
56+
get completerError(): string {
57+
return this._completerError;
58+
}
59+
4360
/**
4461
* Set the models (chat model and completer).
4562
* Creates the models if the name has changed, otherwise only updates their config.
@@ -48,22 +65,21 @@ export class AIProvider implements IAIProvider {
4865
* @param settings - the settings for the models.
4966
*/
5067
setModels(name: string, settings: ReadonlyPartialJSONObject) {
51-
if (name !== this._name) {
52-
this._name = name;
53-
this._completionProvider.name = name;
54-
this._llmChatModel = getChatModel(name);
55-
this._modelChange.emit();
68+
try {
69+
this._completionProvider.setCompleter(name, settings);
70+
this._completerError = '';
71+
} catch (e: any) {
72+
this._completerError = e.message;
5673
}
57-
58-
// Update the inline completion provider settings.
59-
if (this._completionProvider.llmCompleter) {
60-
AIProvider.updateConfig(this._completionProvider.llmCompleter, settings);
61-
}
62-
63-
// Update the chat LLM settings.
64-
if (this._llmChatModel) {
65-
AIProvider.updateConfig(this._llmChatModel, settings);
74+
try {
75+
this._llmChatModel = getChatModel(name, settings);
76+
this._chatError = '';
77+
} catch (e: any) {
78+
this._chatError = e.message;
79+
this._llmChatModel = null;
6680
}
81+
this._name = name;
82+
this._modelChange.emit();
6783
}
6884

6985
get modelChange(): ISignal<IAIProvider, void> {
@@ -74,6 +90,8 @@ export class AIProvider implements IAIProvider {
7490
private _llmChatModel: BaseChatModel | null = null;
7591
private _name: string = 'None';
7692
private _modelChange = new Signal<IAIProvider, void>(this);
93+
private _chatError: string = '';
94+
private _completerError: string = '';
7795
}
7896

7997
export namespace AIProvider {

src/token.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@ import { ISignal } from '@lumino/signaling';
55
import { IBaseCompleter } from './llm-models';
66

77
export interface IAIProvider {
8-
name: string | null;
8+
name: string;
99
completer: IBaseCompleter | null;
1010
chatModel: BaseChatModel | null;
1111
modelChange: ISignal<IAIProvider, void>;
12+
chatError: string;
13+
completerError: string;
1214
}
1315

1416
export const IAIProvider = new Token<IAIProvider>(

0 commit comments

Comments
 (0)