Skip to content

Commit 2928c80

Browse files
authored
Merge pull request RooCodeInc#977 from d-oit/mistral
Additional models for mistral api provider
2 parents c815ea7 + a9d8a1d commit 2928c80

File tree

5 files changed

+219
-10
lines changed

5 files changed

+219
-10
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import { MistralHandler } from "../mistral"
2+
import { ApiHandlerOptions, mistralDefaultModelId } from "../../../shared/api"
3+
import { Anthropic } from "@anthropic-ai/sdk"
4+
import { ApiStreamTextChunk } from "../../transform/stream"
5+
6+
// Mock Mistral client
7+
const mockCreate = jest.fn()
8+
jest.mock("@mistralai/mistralai", () => {
9+
return {
10+
Mistral: jest.fn().mockImplementation(() => ({
11+
chat: {
12+
stream: mockCreate.mockImplementation(async (options) => {
13+
const stream = {
14+
[Symbol.asyncIterator]: async function* () {
15+
yield {
16+
data: {
17+
choices: [
18+
{
19+
delta: { content: "Test response" },
20+
index: 0,
21+
},
22+
],
23+
},
24+
}
25+
},
26+
}
27+
return stream
28+
}),
29+
},
30+
})),
31+
}
32+
})
33+
34+
describe("MistralHandler", () => {
35+
let handler: MistralHandler
36+
let mockOptions: ApiHandlerOptions
37+
38+
beforeEach(() => {
39+
mockOptions = {
40+
apiModelId: "codestral-latest", // Update to match the actual model ID
41+
mistralApiKey: "test-api-key",
42+
includeMaxTokens: true,
43+
modelTemperature: 0,
44+
}
45+
handler = new MistralHandler(mockOptions)
46+
mockCreate.mockClear()
47+
})
48+
49+
describe("constructor", () => {
50+
it("should initialize with provided options", () => {
51+
expect(handler).toBeInstanceOf(MistralHandler)
52+
expect(handler.getModel().id).toBe(mockOptions.apiModelId)
53+
})
54+
55+
it("should throw error if API key is missing", () => {
56+
expect(() => {
57+
new MistralHandler({
58+
...mockOptions,
59+
mistralApiKey: undefined,
60+
})
61+
}).toThrow("Mistral API key is required")
62+
})
63+
64+
it("should use custom base URL if provided", () => {
65+
const customBaseUrl = "https://custom.mistral.ai/v1"
66+
const handlerWithCustomUrl = new MistralHandler({
67+
...mockOptions,
68+
mistralCodestralUrl: customBaseUrl,
69+
})
70+
expect(handlerWithCustomUrl).toBeInstanceOf(MistralHandler)
71+
})
72+
})
73+
74+
describe("getModel", () => {
75+
it("should return correct model info", () => {
76+
const model = handler.getModel()
77+
expect(model.id).toBe(mockOptions.apiModelId)
78+
expect(model.info).toBeDefined()
79+
expect(model.info.supportsPromptCache).toBe(false)
80+
})
81+
})
82+
83+
describe("createMessage", () => {
84+
const systemPrompt = "You are a helpful assistant."
85+
const messages: Anthropic.Messages.MessageParam[] = [
86+
{
87+
role: "user",
88+
content: [{ type: "text", text: "Hello!" }],
89+
},
90+
]
91+
92+
it("should create message successfully", async () => {
93+
const iterator = handler.createMessage(systemPrompt, messages)
94+
const result = await iterator.next()
95+
96+
expect(mockCreate).toHaveBeenCalledWith({
97+
model: mockOptions.apiModelId,
98+
messages: expect.any(Array),
99+
maxTokens: expect.any(Number),
100+
temperature: 0,
101+
})
102+
103+
expect(result.value).toBeDefined()
104+
expect(result.done).toBe(false)
105+
})
106+
107+
it("should handle streaming response correctly", async () => {
108+
const iterator = handler.createMessage(systemPrompt, messages)
109+
const results: ApiStreamTextChunk[] = []
110+
111+
for await (const chunk of iterator) {
112+
if ("text" in chunk) {
113+
results.push(chunk as ApiStreamTextChunk)
114+
}
115+
}
116+
117+
expect(results.length).toBeGreaterThan(0)
118+
expect(results[0].text).toBe("Test response")
119+
})
120+
121+
it("should handle errors gracefully", async () => {
122+
mockCreate.mockRejectedValueOnce(new Error("API Error"))
123+
await expect(handler.createMessage(systemPrompt, messages).next()).rejects.toThrow("API Error")
124+
})
125+
})
126+
})

src/api/providers/mistral.ts

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,36 @@ export class MistralHandler implements ApiHandler {
2121
private client: Mistral
2222

2323
constructor(options: ApiHandlerOptions) {
24+
if (!options.mistralApiKey) {
25+
throw new Error("Mistral API key is required")
26+
}
27+
2428
this.options = options
29+
const baseUrl = this.getBaseUrl()
30+
console.debug(`[Roo Code] MistralHandler using baseUrl: ${baseUrl}`)
2531
this.client = new Mistral({
26-
serverURL: "https://codestral.mistral.ai",
32+
serverURL: baseUrl,
2733
apiKey: this.options.mistralApiKey,
2834
})
2935
}
3036

37+
private getBaseUrl(): string {
38+
const modelId = this.options.apiModelId
39+
if (modelId?.startsWith("codestral-")) {
40+
return this.options.mistralCodestralUrl || "https://codestral.mistral.ai"
41+
}
42+
return "https://api.mistral.ai"
43+
}
44+
3145
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
32-
const stream = await this.client.chat.stream({
33-
model: this.getModel().id,
34-
// max_completion_tokens: this.getModel().info.maxTokens,
46+
const response = await this.client.chat.stream({
47+
model: this.options.apiModelId || mistralDefaultModelId,
48+
messages: convertToMistralMessages(messages),
49+
maxTokens: this.options.includeMaxTokens ? this.getModel().info.maxTokens : undefined,
3550
temperature: this.options.modelTemperature ?? MISTRAL_DEFAULT_TEMPERATURE,
36-
messages: [{ role: "system", content: systemPrompt }, ...convertToMistralMessages(messages)],
37-
stream: true,
3851
})
3952

40-
for await (const chunk of stream) {
53+
for await (const chunk of response) {
4154
const delta = chunk.data.choices[0]?.delta
4255
if (delta?.content) {
4356
let content: string = ""

src/core/webview/ClineProvider.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ type GlobalStateKey =
127127
| "requestyModelInfo"
128128
| "unboundModelInfo"
129129
| "modelTemperature"
130+
| "mistralCodestralUrl"
130131
| "maxOpenTabsContext"
131132

132133
export const GlobalFileNames = {
@@ -1637,6 +1638,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
16371638
openRouterUseMiddleOutTransform,
16381639
vsCodeLmModelSelector,
16391640
mistralApiKey,
1641+
mistralCodestralUrl,
16401642
unboundApiKey,
16411643
unboundModelId,
16421644
unboundModelInfo,
@@ -1682,6 +1684,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
16821684
await this.updateGlobalState("openRouterUseMiddleOutTransform", openRouterUseMiddleOutTransform)
16831685
await this.updateGlobalState("vsCodeLmModelSelector", vsCodeLmModelSelector)
16841686
await this.storeSecret("mistralApiKey", mistralApiKey)
1687+
await this.updateGlobalState("mistralCodestralUrl", mistralCodestralUrl)
16851688
await this.storeSecret("unboundApiKey", unboundApiKey)
16861689
await this.updateGlobalState("unboundModelId", unboundModelId)
16871690
await this.updateGlobalState("unboundModelInfo", unboundModelInfo)
@@ -2521,6 +2524,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
25212524
openAiNativeApiKey,
25222525
deepSeekApiKey,
25232526
mistralApiKey,
2527+
mistralCodestralUrl,
25242528
azureApiVersion,
25252529
openAiStreamingEnabled,
25262530
openRouterModelId,
@@ -2602,6 +2606,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
26022606
this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
26032607
this.getSecret("deepSeekApiKey") as Promise<string | undefined>,
26042608
this.getSecret("mistralApiKey") as Promise<string | undefined>,
2609+
this.getGlobalState("mistralCodestralUrl") as Promise<string | undefined>,
26052610
this.getGlobalState("azureApiVersion") as Promise<string | undefined>,
26062611
this.getGlobalState("openAiStreamingEnabled") as Promise<boolean | undefined>,
26072612
this.getGlobalState("openRouterModelId") as Promise<string | undefined>,
@@ -2700,6 +2705,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
27002705
openAiNativeApiKey,
27012706
deepSeekApiKey,
27022707
mistralApiKey,
2708+
mistralCodestralUrl,
27032709
azureApiVersion,
27042710
openAiStreamingEnabled,
27052711
openRouterModelId,

src/shared/api.ts

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ export interface ApiHandlerOptions {
5252
geminiApiKey?: string
5353
openAiNativeApiKey?: string
5454
mistralApiKey?: string
55+
mistralCodestralUrl?: string // New option for Codestral URL
5556
azureApiVersion?: string
5657
openRouterUseMiddleOutTransform?: boolean
5758
openAiStreamingEnabled?: boolean
@@ -670,13 +671,53 @@ export type MistralModelId = keyof typeof mistralModels
670671
export const mistralDefaultModelId: MistralModelId = "codestral-latest"
671672
export const mistralModels = {
672673
"codestral-latest": {
673-
maxTokens: 32_768,
674+
maxTokens: 256_000,
674675
contextWindow: 256_000,
675676
supportsImages: false,
676677
supportsPromptCache: false,
677678
inputPrice: 0.3,
678679
outputPrice: 0.9,
679680
},
681+
"mistral-large-latest": {
682+
maxTokens: 131_000,
683+
contextWindow: 131_000,
684+
supportsImages: false,
685+
supportsPromptCache: false,
686+
inputPrice: 2.0,
687+
outputPrice: 6.0,
688+
},
689+
"ministral-8b-latest": {
690+
maxTokens: 131_000,
691+
contextWindow: 131_000,
692+
supportsImages: false,
693+
supportsPromptCache: false,
694+
inputPrice: 0.1,
695+
outputPrice: 0.1,
696+
},
697+
"ministral-3b-latest": {
698+
maxTokens: 131_000,
699+
contextWindow: 131_000,
700+
supportsImages: false,
701+
supportsPromptCache: false,
702+
inputPrice: 0.04,
703+
outputPrice: 0.04,
704+
},
705+
"mistral-small-latest": {
706+
maxTokens: 32_000,
707+
contextWindow: 32_000,
708+
supportsImages: false,
709+
supportsPromptCache: false,
710+
inputPrice: 0.2,
711+
outputPrice: 0.6,
712+
},
713+
"pixtral-large-latest": {
714+
maxTokens: 131_000,
715+
contextWindow: 131_000,
716+
supportsImages: true,
717+
supportsPromptCache: false,
718+
inputPrice: 2.0,
719+
outputPrice: 6.0,
720+
},
680721
} as const satisfies Record<string, ModelInfo>
681722

682723
// Unbound Security

webview-ui/src/components/settings/ApiOptions.tsx

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage, fromWelcomeView }: A
314314
placeholder="Enter API Key...">
315315
<span style={{ fontWeight: 500 }}>Mistral API Key</span>
316316
</VSCodeTextField>
317+
317318
<p
318319
style={{
319320
fontSize: "12px",
@@ -323,15 +324,37 @@ const ApiOptions = ({ apiErrorMessage, modelIdErrorMessage, fromWelcomeView }: A
323324
This key is stored locally and only used to make API requests from this extension.
324325
{!apiConfiguration?.mistralApiKey && (
325326
<VSCodeLink
326-
href="https://console.mistral.ai/codestral/"
327+
href="https://console.mistral.ai/"
327328
style={{
328329
display: "inline",
329330
fontSize: "inherit",
330331
}}>
331-
You can get a Mistral API key by signing up here.
332+
You can get a La Plateforme (api.mistral.ai) / Codestral (codestral.mistral.ai) API key
333+
by signing up here.
332334
</VSCodeLink>
333335
)}
334336
</p>
337+
338+
{apiConfiguration?.apiModelId?.startsWith("codestral-") && (
339+
<div>
340+
<VSCodeTextField
341+
value={apiConfiguration?.mistralCodestralUrl || ""}
342+
style={{ width: "100%", marginTop: "10px" }}
343+
type="url"
344+
onBlur={handleInputChange("mistralCodestralUrl")}
345+
placeholder="Default: https://codestral.mistral.ai">
346+
<span style={{ fontWeight: 500 }}>Codestral Base URL (Optional)</span>
347+
</VSCodeTextField>
348+
<p
349+
style={{
350+
fontSize: "12px",
351+
marginTop: 3,
352+
color: "var(--vscode-descriptionForeground)",
353+
}}>
354+
Set alternative URL for Codestral model: https://api.mistral.ai
355+
</p>
356+
</div>
357+
)}
335358
</div>
336359
)}
337360

0 commit comments

Comments
 (0)