Skip to content

Commit 1b482ad

Browse files
authored
Merge pull request #19 from brichet/openAI
Add OpenAI provider
2 parents 6f14cfb + 83465be commit 1b482ad

File tree

6 files changed

+112
-25
lines changed

6 files changed

+112
-25
lines changed

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
"@langchain/community": "^0.3.31",
6666
"@langchain/core": "^0.3.40",
6767
"@langchain/mistralai": "^0.1.1",
68+
"@langchain/openai": "^0.4.4",
6869
"@lumino/coreutils": "^2.1.2",
6970
"@lumino/polling": "^2.1.2",
7071
"@lumino/signaling": "^2.1.2",

schema/ai-provider.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"title": "The AI provider",
1111
"description": "The AI provider to use for chat and completion",
1212
"default": "None",
13-
"enum": ["None", "Anthropic", "ChromeAI", "MistralAI"]
13+
"enum": ["None", "Anthropic", "ChromeAI", "MistralAI", "OpenAI"]
1414
}
1515
},
1616
"additionalProperties": true

scripts/settings-generator.js

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ const providers = {
4343
path: 'node_modules/@langchain/anthropic/dist/chat_models.d.ts',
4444
type: 'AnthropicInput',
4545
excludedProps: ['clientOptions']
46+
},
47+
openAI: {
48+
path: 'node_modules/@langchain/openai/dist/chat_models.d.ts',
49+
type: 'ChatOpenAIFields',
50+
excludedProps: ['configuration']
4651
}
4752
};
4853

@@ -53,7 +58,8 @@ Object.entries(providers).forEach(([name, desc], index) => {
5358
path: desc.path,
5459
tsconfig: './tsconfig.json',
5560
type: desc.type,
56-
functions: 'hide'
61+
functions: 'hide',
62+
topRef: false
5763
};
5864

5965
const outputPath = path.join(outputDir, `${name}.json`);
@@ -81,33 +87,35 @@ Object.entries(providers).forEach(([name, desc], index) => {
8187
}
8288

8389
// Remove the properties from extended class.
84-
const providerKeys = Object.keys(schema.definitions[desc.type]['properties']);
85-
90+
const providerKeys = Object.keys(schema.properties);
8691
Object.keys(
8792
schemaBase.definitions?.['BaseLanguageModelParams']['properties']
8893
).forEach(key => {
8994
if (providerKeys.includes(key)) {
90-
delete schema.definitions?.[desc.type]['properties'][key];
95+
delete schema.properties?.[key];
9196
}
9297
});
9398

94-
// Remove the useless definitions.
95-
let change = true;
96-
while (change) {
97-
change = false;
98-
const temporarySchemaString = JSON.stringify(schema);
99-
100-
Object.keys(schema.definitions).forEach(key => {
101-
const index = temporarySchemaString.indexOf(`#/definitions/${key}`);
102-
if (index === -1) {
103-
delete schema.definitions?.[key];
104-
change = true;
105-
}
106-
});
99+
// Replace all references by their value, and remove the useless definitions.
100+
const defKeys = Object.keys(schema.definitions);
101+
for (let i = defKeys.length - 1; i >= 0; i--) {
102+
let schemaString = JSON.stringify(schema);
103+
const key = defKeys[i];
104+
const reference = `"$ref":"#/definitions/${key}"`;
105+
106+
// Replace all the references to the definition by the content (after removal of the brace).
107+
const replacement = JSON.stringify(schema.definitions?.[key]).slice(1, -1);
108+
temporarySchemaString = schemaString.replaceAll(reference, replacement);
109+
// Build again the schema from the string representation if it change.
110+
if (schemaString !== temporarySchemaString) {
111+
schema = JSON.parse(temporarySchemaString);
112+
}
113+
// Remove the definition
114+
delete schema.definitions?.[key];
107115
}
108116

109117
// Transform the default values.
110-
Object.values(schema.definitions[desc.type]['properties']).forEach(value => {
118+
Object.values(schema.properties).forEach(value => {
111119
const defaultValue = value.default;
112120
if (!defaultValue) {
113121
return;

src/llm-models/openai-completer.ts

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import {
2+
CompletionHandler,
3+
IInlineCompletionContext
4+
} from '@jupyterlab/completer';
5+
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
6+
import { AIMessage, SystemMessage } from '@langchain/core/messages';
7+
import { ChatOpenAI } from '@langchain/openai';
8+
9+
import { BaseCompleter, IBaseCompleter } from './base-completer';
10+
import { COMPLETION_SYSTEM_PROMPT } from '../provider';
11+
12+
export class OpenAICompleter implements IBaseCompleter {
13+
constructor(options: BaseCompleter.IOptions) {
14+
this._openAIProvider = new ChatOpenAI({ ...options.settings });
15+
}
16+
17+
get provider(): BaseChatModel {
18+
return this._openAIProvider;
19+
}
20+
21+
/**
22+
* Getter and setter for the initial prompt.
23+
*/
24+
get prompt(): string {
25+
return this._prompt;
26+
}
27+
set prompt(value: string) {
28+
this._prompt = value;
29+
}
30+
31+
async fetch(
32+
request: CompletionHandler.IRequest,
33+
context: IInlineCompletionContext
34+
) {
35+
const { text, offset: cursorOffset } = request;
36+
const prompt = text.slice(0, cursorOffset);
37+
38+
const messages = [new SystemMessage(this._prompt), new AIMessage(prompt)];
39+
40+
try {
41+
const response = await this._openAIProvider.invoke(messages);
42+
const items = [];
43+
if (typeof response.content === 'string') {
44+
items.push({
45+
insertText: response.content
46+
});
47+
} else {
48+
response.content.forEach(content => {
49+
if (content.type !== 'text') {
50+
return;
51+
}
52+
items.push({
53+
insertText: content.text,
54+
filterText: prompt.substring(prompt.length)
55+
});
56+
});
57+
}
58+
return { items };
59+
} catch (error) {
60+
console.error('Error fetching completions', error);
61+
return { items: [] };
62+
}
63+
}
64+
65+
private _openAIProvider: ChatOpenAI;
66+
private _prompt: string = COMPLETION_SYSTEM_PROMPT;
67+
}

src/llm-models/utils.ts

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@ import { ChatAnthropic } from '@langchain/anthropic';
22
import { ChromeAI } from '@langchain/community/experimental/llms/chrome_ai';
33
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
44
import { ChatMistralAI } from '@langchain/mistralai';
5-
import { JSONObject } from '@lumino/coreutils';
5+
import { ChatOpenAI } from '@langchain/openai';
66

77
import { IBaseCompleter } from './base-completer';
88
import { AnthropicCompleter } from './anthropic-completer';
99
import { CodestralCompleter } from './codestral-completer';
1010
import { ReadonlyPartialJSONObject } from '@lumino/coreutils';
1111
import { ChromeCompleter } from './chrome-completer';
12+
import { OpenAICompleter } from './openai-completer';
1213

1314
import chromeAI from '../_provider-settings/chromeAI.json';
1415
import mistralAI from '../_provider-settings/mistralAI.json';
1516
import anthropic from '../_provider-settings/anthropic.json';
17+
import openAI from '../_provider-settings/openAI.json';
1618

1719
/**
1820
* Get an LLM completer from the name.
@@ -27,6 +29,8 @@ export function getCompleter(
2729
return new AnthropicCompleter({ settings });
2830
} else if (name === 'ChromeAI') {
2931
return new ChromeCompleter({ settings });
32+
} else if (name === 'OpenAI') {
33+
return new OpenAICompleter({ settings });
3034
}
3135
return null;
3236
}
@@ -46,6 +50,8 @@ export function getChatModel(
4650
// TODO: fix
4751
// @ts-expect-error: missing properties
4852
return new ChromeAI({ ...settings });
53+
} else if (name === 'OpenAI') {
54+
return new ChatOpenAI({ ...settings });
4955
}
5056
return null;
5157
}
@@ -60,20 +66,24 @@ export function getErrorMessage(name: string, error: any): string {
6066
return error.error.error.message;
6167
} else if (name === 'ChromeAI') {
6268
return error.message;
69+
} else if (name === 'OpenAI') {
70+
return error.message;
6371
}
6472
return 'Unknown provider';
6573
}
6674

6775
/*
6876
* Get an LLM completer from the name.
6977
*/
70-
export function getSettings(name: string): JSONObject | null {
78+
export function getSettings(name: string): any {
7179
if (name === 'MistralAI') {
72-
return mistralAI.definitions.ChatMistralAIInput.properties;
80+
return mistralAI.properties;
7381
} else if (name === 'Anthropic') {
74-
return anthropic.definitions.AnthropicInput.properties;
82+
return anthropic.properties;
7583
} else if (name === 'ChromeAI') {
76-
return chromeAI.definitions.ChromeAIInputs.properties;
84+
return chromeAI.properties;
85+
} else if (name === 'OpenAI') {
86+
return openAI.properties;
7787
}
7888

7989
return null;

yarn.lock

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1969,6 +1969,7 @@ __metadata:
19691969
"@langchain/community": ^0.3.31
19701970
"@langchain/core": ^0.3.40
19711971
"@langchain/mistralai": ^0.1.1
1972+
"@langchain/openai": ^0.4.4
19721973
"@lumino/coreutils": ^2.1.2
19731974
"@lumino/polling": ^2.1.2
19741975
"@lumino/signaling": ^2.1.2
@@ -2449,7 +2450,7 @@ __metadata:
24492450
languageName: node
24502451
linkType: hard
24512452

2452-
"@langchain/openai@npm:>=0.2.0 <0.5.0":
2453+
"@langchain/openai@npm:>=0.2.0 <0.5.0, @langchain/openai@npm:^0.4.4":
24532454
version: 0.4.4
24542455
resolution: "@langchain/openai@npm:0.4.4"
24552456
dependencies:

0 commit comments

Comments
 (0)