Skip to content

Commit a5a3bd6

Browse files
authored
Improves the relevance of codestral completion (#18)
* Improves the relevances of codestral completion * Add a timeout to avoid endless requests * Remove unused dependency * Fetch again if the prompt has changed between the request and the response * lint
1 parent 8043bf9 commit a5a3bd6

File tree

5 files changed

+78
-17
lines changed

5 files changed

+78
-17
lines changed

src/completion-provider.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ export class CompletionProvider implements IInlineCompletionProvider {
1616

1717
constructor(options: CompletionProvider.IOptions) {
1818
const { name, settings } = options;
19+
this._requestCompletion = options.requestCompletion;
1920
this.setCompleter(name, settings);
2021
}
2122

@@ -28,6 +29,9 @@ export class CompletionProvider implements IInlineCompletionProvider {
2829
setCompleter(name: string, settings: ReadonlyPartialJSONObject) {
2930
try {
3031
this._completer = getCompleter(name, settings);
32+
if (this._completer) {
33+
this._completer.requestCompletion = this._requestCompletion;
34+
}
3135
this._name = this._completer === null ? 'None' : name;
3236
} catch (e: any) {
3337
this._completer = null;
@@ -65,11 +69,13 @@ export class CompletionProvider implements IInlineCompletionProvider {
6569
}
6670

6771
private _name: string = 'None';
72+
private _requestCompletion: () => void;
6873
private _completer: IBaseCompleter | null = null;
6974
}
7075

7176
export namespace CompletionProvider {
7277
export interface IOptions extends BaseCompleter.IOptions {
7378
name: string;
79+
requestCompletion: () => void;
7480
}
7581
}

src/index.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,10 @@ const aiProviderPlugin: JupyterFrontEndPlugin<IAIProvider> = {
100100
manager: ICompletionProviderManager,
101101
settingRegistry: ISettingRegistry
102102
): IAIProvider => {
103-
const aiProvider = new AIProvider({ completionProviderManager: manager });
103+
const aiProvider = new AIProvider({
104+
completionProviderManager: manager,
105+
requestCompletion: () => app.commands.execute('inline-completer:invoke')
106+
});
104107

105108
settingRegistry
106109
.load(aiProviderPlugin.id)

src/llm-models/base-completer.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ export interface IBaseCompleter {
1111
*/
1212
provider: LLM;
1313

14+
/**
15+
* The function to fetch a new completion.
16+
*/
17+
requestCompletion?: () => void;
18+
1419
/**
1520
* The fetch request for the LLM completer.
1621
*/

src/llm-models/codestral-completer.ts

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,34 +9,67 @@ import { CompletionRequest } from '@mistralai/mistralai';
99

1010
import { BaseCompleter, IBaseCompleter } from './base-completer';
1111

12-
/*
12+
/**
1313
* The Mistral API has a rate limit of 1 request per second
1414
*/
1515
const INTERVAL = 1000;
1616

17+
/**
18+
* Timeout to avoid endless requests
19+
*/
20+
const REQUEST_TIMEOUT = 3000;
21+
1722
export class CodestralCompleter implements IBaseCompleter {
1823
constructor(options: BaseCompleter.IOptions) {
24+
// this._requestCompletion = options.requestCompletion;
1925
this._mistralProvider = new MistralAI({ ...options.settings });
20-
this._throttler = new Throttler(async (data: CompletionRequest) => {
21-
const response = await this._mistralProvider.completionWithRetry(
22-
data,
23-
{},
24-
false
25-
);
26-
const items = response.choices.map((choice: any) => {
27-
return { insertText: choice.message.content as string };
28-
});
26+
this._throttler = new Throttler(
27+
async (data: CompletionRequest) => {
28+
const invokedData = data;
29+
30+
// Request completion.
31+
const request = this._mistralProvider.completionWithRetry(
32+
data,
33+
{},
34+
false
35+
);
36+
const timeoutPromise = new Promise<null>(resolve => {
37+
return setTimeout(() => resolve(null), REQUEST_TIMEOUT);
38+
});
39+
40+
// Fetch again if the request is too long or if the prompt has changed.
41+
const response = await Promise.race([request, timeoutPromise]);
42+
if (
43+
response === null ||
44+
invokedData.prompt !== this._currentData?.prompt
45+
) {
46+
return {
47+
items: [],
48+
fetchAgain: true
49+
};
50+
}
2951

30-
return {
31-
items
32-
};
33-
}, INTERVAL);
52+
// Extract results of completion request.
53+
const items = response.choices.map((choice: any) => {
54+
return { insertText: choice.message.content as string };
55+
});
56+
57+
return {
58+
items
59+
};
60+
},
61+
{ limit: INTERVAL }
62+
);
3463
}
3564

3665
get provider(): LLM {
3766
return this._mistralProvider;
3867
}
3968

69+
set requestCompletion(value: () => void) {
70+
this._requestCompletion = value;
71+
}
72+
4073
async fetch(
4174
request: CompletionHandler.IRequest,
4275
context: IInlineCompletionContext
@@ -59,13 +92,22 @@ export class CodestralCompleter implements IBaseCompleter {
5992
};
6093

6194
try {
62-
return this._throttler.invoke(data);
95+
this._currentData = data;
96+
const completionResult = await this._throttler.invoke(data);
97+
if (completionResult.fetchAgain) {
98+
if (this._requestCompletion) {
99+
this._requestCompletion();
100+
}
101+
}
102+
return { items: completionResult.items };
63103
} catch (error) {
64104
console.error('Error fetching completions', error);
65105
return { items: [] };
66106
}
67107
}
68108

109+
private _requestCompletion?: () => void;
69110
private _throttler: Throttler;
70111
private _mistralProvider: MistralAI;
112+
private _currentData: CompletionRequest | null = null;
71113
}

src/provider.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ export class AIProvider implements IAIProvider {
1212
constructor(options: AIProvider.IOptions) {
1313
this._completionProvider = new CompletionProvider({
1414
name: 'None',
15-
settings: {}
15+
settings: {},
16+
requestCompletion: options.requestCompletion
1617
});
1718
options.completionProviderManager.registerInlineProvider(
1819
this._completionProvider
@@ -103,6 +104,10 @@ export namespace AIProvider {
103104
* The completion provider manager in which register the LLM completer.
104105
*/
105106
completionProviderManager: ICompletionProviderManager;
107+
/**
108+
* The application commands registry.
109+
*/
110+
requestCompletion: () => void;
106111
}
107112

108113
/**

0 commit comments

Comments
 (0)