Skip to content

Commit 50dff08

Browse files
feat: add OpenAI provider with structured output support (#28)
* feat: add OpenAI provider with structured output support Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: add @ai-sdk/openai dependency and fix modelConfigs access Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: correct indentation in agent.ts Co-Authored-By: Han Xiao <han.xiao@jina.ai> * refactor: centralize model initialization in config.ts Co-Authored-By: Han Xiao <han.xiao@jina.ai> * refactor: improve model config access patterns Co-Authored-By: Han Xiao <han.xiao@jina.ai> * fix: remove unused imports Co-Authored-By: Han Xiao <han.xiao@jina.ai> * refactor: clean up --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Han Xiao <han.xiao@jina.ai>
1 parent f1c7ada commit 50dff08

15 files changed

+271
-100
lines changed

Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ COPY . .
1515

1616
# Set environment variables
1717
ENV GEMINI_API_KEY=${GEMINI_API_KEY}
18+
ENV OPENAI_API_KEY=${OPENAI_API_KEY}
1819
ENV JINA_API_KEY=${JINA_API_KEY}
1920
ENV BRAVE_API_KEY=${BRAVE_API_KEY}
2021

README.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,7 @@ flowchart LR
2525

2626
## Install
2727

28-
We use gemini for llm, [jina reader](https://jina.ai/reader) for searching and reading webpages.
29-
3028
```bash
31-
export GEMINI_API_KEY=... # for gemini api, ask han
32-
export JINA_API_KEY=jina_... # free jina api key, get from https://jina.ai/reader
33-
3429
git clone https://github.com/jina-ai/node-DeepResearch.git
3530
cd node-DeepResearch
3631
npm install
@@ -39,7 +34,14 @@ npm install
3934

4035
## Usage
4136

37+
We use Gemini/OpenAI for reasoning, [Jina Reader](https://jina.ai/reader) for searching and reading webpages, you can get a free API key with 1M tokens from jina.ai.
38+
4239
```bash
40+
export GEMINI_API_KEY=... # for gemini
41+
# export OPENAI_API_KEY=... # for openai
42+
# export LLM_PROVIDER=openai # for openai
43+
export JINA_API_KEY=jina_... # free jina api key, get from https://jina.ai/reader
44+
4345
npm run dev $QUERY
4446
```
4547

docker-compose.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ services:
77
dockerfile: Dockerfile
88
environment:
99
- GEMINI_API_KEY=${GEMINI_API_KEY}
10+
- OPENAI_API_KEY=${OPENAI_API_KEY}
1011
- JINA_API_KEY=${JINA_API_KEY}
1112
- BRAVE_API_KEY=${BRAVE_API_KEY}
1213
ports:

package-lock.json

Lines changed: 20 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"description": "",
2020
"dependencies": {
2121
"@ai-sdk/google": "^1.0.0",
22+
"@ai-sdk/openai": "^1.1.9",
2223
"@types/cors": "^2.8.17",
2324
"@types/express": "^5.0.0",
2425
"@types/node-fetch": "^2.6.12",

src/agent.ts

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import {createGoogleGenerativeAI} from '@ai-sdk/google';
21
import {z} from 'zod';
32
import {generateObject} from 'ai';
3+
import {getModel, getMaxTokens, SEARCH_PROVIDER, STEP_SLEEP} from "./config";
44
import {readUrl} from "./tools/read";
55
import {handleGenerateObjectError} from './utils/error-handling';
66
import fs from 'fs/promises';
@@ -10,7 +10,6 @@ import {rewriteQuery} from "./tools/query-rewriter";
1010
import {dedupQueries} from "./tools/dedup";
1111
import {evaluateAnswer} from "./tools/evaluator";
1212
import {analyzeSteps} from "./tools/error-analyzer";
13-
import {SEARCH_PROVIDER, STEP_SLEEP, modelConfigs} from "./config";
1413
import {TokenTracker} from "./utils/token-tracker";
1514
import {ActionTracker} from "./utils/action-tracker";
1615
import {StepAction, AnswerAction} from "./types";
@@ -325,15 +324,15 @@ export async function getResponse(question: string, tokenBudget: number = 1_000_
325324
false
326325
);
327326

328-
const model = createGoogleGenerativeAI({apiKey: process.env.GEMINI_API_KEY})(modelConfigs.agent.model);
327+
const model = getModel('agent');
329328
let object;
330329
let totalTokens = 0;
331330
try {
332331
const result = await generateObject({
333332
model,
334333
schema: getSchema(allowReflect, allowRead, allowAnswer, allowSearch),
335334
prompt,
336-
maxTokens: modelConfigs.agent.maxTokens
335+
maxTokens: getMaxTokens('agent')
337336
});
338337
object = result.object;
339338
totalTokens = result.usage?.totalTokens || 0;
@@ -671,15 +670,15 @@ You decided to think out of the box or cut from a completely different angle.`);
671670
true
672671
);
673672

674-
const model = createGoogleGenerativeAI({apiKey: process.env.GEMINI_API_KEY})(modelConfigs.agentBeastMode.model);
673+
const model = getModel('agentBeastMode');
675674
let object;
676675
let totalTokens = 0;
677676
try {
678677
const result = await generateObject({
679678
model,
680679
schema: getSchema(false, false, allowAnswer, false),
681680
prompt,
682-
maxTokens: modelConfigs.agentBeastMode.maxTokens
681+
maxTokens: getMaxTokens('agentBeastMode')
683682
});
684683
object = result.object;
685684
totalTokens = result.usage?.totalTokens || 0;

src/config.ts

Lines changed: 93 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,35 @@
11
import dotenv from 'dotenv';
22
import { ProxyAgent, setGlobalDispatcher } from 'undici';
3+
import { createGoogleGenerativeAI } from '@ai-sdk/google';
4+
import { createOpenAI } from '@ai-sdk/openai';
35

4-
interface ModelConfig {
6+
export type LLMProvider = 'openai' | 'gemini';
7+
export type ToolName = keyof ToolConfigs;
8+
9+
function isValidProvider(provider: string): provider is LLMProvider {
10+
return provider === 'openai' || provider === 'gemini';
11+
}
12+
13+
function validateModelConfig(config: ModelConfig, toolName: string): ModelConfig {
14+
if (typeof config.model !== 'string' || config.model.length === 0) {
15+
throw new Error(`Invalid model name for ${toolName}`);
16+
}
17+
if (typeof config.temperature !== 'number' || config.temperature < 0 || config.temperature > 1) {
18+
throw new Error(`Invalid temperature for ${toolName}`);
19+
}
20+
if (typeof config.maxTokens !== 'number' || config.maxTokens <= 0) {
21+
throw new Error(`Invalid maxTokens for ${toolName}`);
22+
}
23+
return config;
24+
}
25+
26+
export interface ModelConfig {
527
model: string;
628
temperature: number;
729
maxTokens: number;
830
}
931

10-
interface ToolConfigs {
32+
export interface ToolConfigs {
1133
dedup: ModelConfig;
1234
evaluator: ModelConfig;
1335
errorAnalyzer: ModelConfig;
@@ -31,44 +53,87 @@ if (process.env.https_proxy) {
3153
}
3254

3355
export const GEMINI_API_KEY = process.env.GEMINI_API_KEY as string;
56+
export const OPENAI_API_KEY = process.env.OPENAI_API_KEY as string;
3457
export const JINA_API_KEY = process.env.JINA_API_KEY as string;
3558
export const BRAVE_API_KEY = process.env.BRAVE_API_KEY as string;
36-
export const SEARCH_PROVIDER: 'brave' | 'jina' | 'duck' = 'jina'
59+
export const SEARCH_PROVIDER: 'brave' | 'jina' | 'duck' = 'jina';
60+
export const LLM_PROVIDER: LLMProvider = (() => {
61+
const provider = process.env.LLM_PROVIDER || 'gemini';
62+
if (!isValidProvider(provider)) {
63+
throw new Error(`Invalid LLM provider: ${provider}`);
64+
}
65+
return provider;
66+
})();
3767

38-
const DEFAULT_MODEL = 'gemini-1.5-flash';
68+
const DEFAULT_GEMINI_MODEL = 'gemini-1.5-flash';
69+
const DEFAULT_OPENAI_MODEL = 'gpt-4o-mini';
3970

40-
const defaultConfig: ModelConfig = {
41-
model: DEFAULT_MODEL,
71+
const defaultGeminiConfig: ModelConfig = {
72+
model: DEFAULT_GEMINI_MODEL,
4273
temperature: 0,
4374
maxTokens: 1000
4475
};
4576

46-
export const modelConfigs: ToolConfigs = {
47-
dedup: {
48-
...defaultConfig,
49-
temperature: 0.1
50-
},
51-
evaluator: {
52-
...defaultConfig
53-
},
54-
errorAnalyzer: {
55-
...defaultConfig
56-
},
57-
queryRewriter: {
58-
...defaultConfig,
59-
temperature: 0.1
60-
},
61-
agent: {
62-
...defaultConfig,
63-
temperature: 0.7
77+
const defaultOpenAIConfig: ModelConfig = {
78+
model: DEFAULT_OPENAI_MODEL,
79+
temperature: 0,
80+
maxTokens: 1000
81+
};
82+
83+
export const modelConfigs: Record<LLMProvider, ToolConfigs> = {
84+
gemini: {
85+
dedup: validateModelConfig({ ...defaultGeminiConfig, temperature: 0.1 }, 'dedup'),
86+
evaluator: validateModelConfig({ ...defaultGeminiConfig }, 'evaluator'),
87+
errorAnalyzer: validateModelConfig({ ...defaultGeminiConfig }, 'errorAnalyzer'),
88+
queryRewriter: validateModelConfig({ ...defaultGeminiConfig, temperature: 0.1 }, 'queryRewriter'),
89+
agent: validateModelConfig({ ...defaultGeminiConfig, temperature: 0.7 }, 'agent'),
90+
agentBeastMode: validateModelConfig({ ...defaultGeminiConfig, temperature: 0.7 }, 'agentBeastMode')
6491
},
65-
agentBeastMode: {
66-
...defaultConfig,
67-
temperature: 0.7
92+
openai: {
93+
dedup: validateModelConfig({ ...defaultOpenAIConfig, temperature: 0.1 }, 'dedup'),
94+
evaluator: validateModelConfig({ ...defaultOpenAIConfig }, 'evaluator'),
95+
errorAnalyzer: validateModelConfig({ ...defaultOpenAIConfig }, 'errorAnalyzer'),
96+
queryRewriter: validateModelConfig({ ...defaultOpenAIConfig, temperature: 0.1 }, 'queryRewriter'),
97+
agent: validateModelConfig({ ...defaultOpenAIConfig, temperature: 0.7 }, 'agent'),
98+
agentBeastMode: validateModelConfig({ ...defaultOpenAIConfig, temperature: 0.7 }, 'agentBeastMode')
6899
}
69100
};
70101

102+
export function getToolConfig(toolName: ToolName): ModelConfig {
103+
if (!modelConfigs[LLM_PROVIDER][toolName]) {
104+
throw new Error(`Invalid tool name: ${toolName}`);
105+
}
106+
return modelConfigs[LLM_PROVIDER][toolName];
107+
}
108+
109+
export function getMaxTokens(toolName: ToolName): number {
110+
return getToolConfig(toolName).maxTokens;
111+
}
112+
113+
114+
export function getModel(toolName: ToolName) {
115+
const config = getToolConfig(toolName);
116+
117+
if (LLM_PROVIDER === 'openai') {
118+
if (!OPENAI_API_KEY) {
119+
throw new Error('OPENAI_API_KEY not found');
120+
}
121+
return createOpenAI({
122+
apiKey: OPENAI_API_KEY,
123+
compatibility: 'strict'
124+
})(config.model);
125+
}
126+
127+
if (!GEMINI_API_KEY) {
128+
throw new Error('GEMINI_API_KEY not found');
129+
}
130+
return createGoogleGenerativeAI({ apiKey: GEMINI_API_KEY })(config.model);
131+
}
132+
71133
export const STEP_SLEEP = 1000;
72134

73-
if (!GEMINI_API_KEY) throw new Error("GEMINI_API_KEY not found");
135+
if (LLM_PROVIDER === 'gemini' && !GEMINI_API_KEY) throw new Error("GEMINI_API_KEY not found");
136+
if (LLM_PROVIDER === 'openai' && !OPENAI_API_KEY) throw new Error("OPENAI_API_KEY not found");
74137
if (!JINA_API_KEY) throw new Error("JINA_API_KEY not found");
138+
139+
console.log('LLM Provider:', LLM_PROVIDER)

src/tools/__tests__/dedup.test.ts

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,37 @@
11
import { dedupQueries } from '../dedup';
2+
import { LLMProvider } from '../../config';
23

34
describe('dedupQueries', () => {
4-
it('should remove duplicate queries', async () => {
5-
jest.setTimeout(10000); // Increase timeout to 10s
6-
const queries = ['typescript tutorial', 'typescript tutorial', 'javascript basics'];
7-
const { unique_queries } = await dedupQueries(queries, []);
8-
expect(unique_queries).toHaveLength(2);
9-
expect(unique_queries).toContain('javascript basics');
5+
const providers: Array<LLMProvider> = ['openai', 'gemini'];
6+
const originalEnv = process.env;
7+
8+
beforeEach(() => {
9+
jest.resetModules();
10+
process.env = { ...originalEnv };
11+
});
12+
13+
afterEach(() => {
14+
process.env = originalEnv;
1015
});
1116

12-
it('should handle empty input', async () => {
13-
const { unique_queries } = await dedupQueries([], []);
14-
expect(unique_queries).toHaveLength(0);
17+
providers.forEach(provider => {
18+
describe(`with ${provider} provider`, () => {
19+
beforeEach(() => {
20+
process.env.LLM_PROVIDER = provider;
21+
});
22+
23+
it('should remove duplicate queries', async () => {
24+
jest.setTimeout(10000);
25+
const queries = ['typescript tutorial', 'typescript tutorial', 'javascript basics'];
26+
const { unique_queries } = await dedupQueries(queries, []);
27+
expect(unique_queries).toHaveLength(2);
28+
expect(unique_queries).toContain('javascript basics');
29+
});
30+
31+
it('should handle empty input', async () => {
32+
const { unique_queries } = await dedupQueries([], []);
33+
expect(unique_queries).toHaveLength(0);
34+
});
35+
});
1536
});
1637
});
Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,31 @@
11
import { analyzeSteps } from '../error-analyzer';
2+
import { LLMProvider } from '../../config';
23

34
describe('analyzeSteps', () => {
4-
it('should analyze error steps', async () => {
5-
const { response } = await analyzeSteps(['Step 1: Search failed', 'Step 2: Invalid query']);
6-
expect(response).toHaveProperty('recap');
7-
expect(response).toHaveProperty('blame');
8-
expect(response).toHaveProperty('improvement');
5+
const providers: Array<LLMProvider> = ['openai', 'gemini'];
6+
const originalEnv = process.env;
7+
8+
beforeEach(() => {
9+
jest.resetModules();
10+
process.env = { ...originalEnv };
11+
});
12+
13+
afterEach(() => {
14+
process.env = originalEnv;
15+
});
16+
17+
providers.forEach(provider => {
18+
describe(`with ${provider} provider`, () => {
19+
beforeEach(() => {
20+
process.env.LLM_PROVIDER = provider;
21+
});
22+
23+
it('should analyze error steps', async () => {
24+
const { response } = await analyzeSteps(['Step 1: Search failed', 'Step 2: Invalid query']);
25+
expect(response).toHaveProperty('recap');
26+
expect(response).toHaveProperty('blame');
27+
expect(response).toHaveProperty('improvement');
28+
});
29+
});
930
});
1031
});

0 commit comments

Comments
 (0)