Skip to content

feat: add SambaNova provider integration #6080

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ export const providerNames = [
"groq",
"chutes",
"litellm",
"sambanova",
] as const

export const providerNamesSchema = z.enum(providerNames)
Expand Down Expand Up @@ -229,6 +230,10 @@ const litellmSchema = baseProviderSettingsSchema.extend({
litellmModelId: z.string().optional(),
})

const sambaNovaSchema = apiModelIdProviderModelSchema.extend({
sambaNovaApiKey: z.string().optional(),
})

const defaultSchema = z.object({
apiProvider: z.undefined(),
})
Expand Down Expand Up @@ -258,6 +263,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
groqSchema.merge(z.object({ apiProvider: z.literal("groq") })),
chutesSchema.merge(z.object({ apiProvider: z.literal("chutes") })),
litellmSchema.merge(z.object({ apiProvider: z.literal("litellm") })),
sambaNovaSchema.merge(z.object({ apiProvider: z.literal("sambanova") })),
defaultSchema,
])

Expand Down Expand Up @@ -287,6 +293,7 @@ export const providerSettingsSchema = z.object({
...groqSchema.shape,
...chutesSchema.shape,
...litellmSchema.shape,
...sambaNovaSchema.shape,
...codebaseIndexProviderSchema.shape,
})

Expand Down
1 change: 1 addition & 0 deletions packages/types/src/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export * from "./ollama.js"
export * from "./openai.js"
export * from "./openrouter.js"
export * from "./requesty.js"
export * from "./sambanova.js"
export * from "./unbound.js"
export * from "./vertex.js"
export * from "./vscode-llm.js"
Expand Down
141 changes: 141 additions & 0 deletions packages/types/src/providers/sambanova.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import type { ModelInfo } from "../model.js"

// https://docs.sambanova.ai/cloud/docs/get-started/supported-models
export type SambaNovaModelId =
| "Meta-Llama-3.1-8B-Instruct"
| "Meta-Llama-3.1-70B-Instruct"
| "Meta-Llama-3.1-405B-Instruct"
| "Meta-Llama-3.2-1B-Instruct"
| "Meta-Llama-3.2-3B-Instruct"
| "Meta-Llama-3.3-70B-Instruct"
| "Llama-3.2-11B-Vision-Instruct"
| "Llama-3.2-90B-Vision-Instruct"
| "QwQ-32B-Preview"
| "Qwen2.5-72B-Instruct"
| "Qwen2.5-Coder-32B-Instruct"
| "deepseek-r1"
| "deepseek-r1-distill-llama-70b"

export const sambaNovaDefaultModelId: SambaNovaModelId = "Meta-Llama-3.3-70B-Instruct"

export const sambaNovaModels = {
"Meta-Llama-3.1-8B-Instruct": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.1,
outputPrice: 0.2,
description: "Meta Llama 3.1 8B Instruct model with 128K context window.",
},
"Meta-Llama-3.1-70B-Instruct": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.64,
outputPrice: 0.8,
description: "Meta Llama 3.1 70B Instruct model with 128K context window.",
},
"Meta-Llama-3.1-405B-Instruct": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 3.0,
outputPrice: 15.0,
description: "Meta Llama 3.1 405B Instruct model with 128K context window.",
},
"Meta-Llama-3.2-1B-Instruct": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.04,
outputPrice: 0.04,
description: "Meta Llama 3.2 1B Instruct model with 128K context window.",
},
"Meta-Llama-3.2-3B-Instruct": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.06,
outputPrice: 0.06,
description: "Meta Llama 3.2 3B Instruct model with 128K context window.",
},
"Meta-Llama-3.3-70B-Instruct": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.64,
outputPrice: 0.8,
description: "Meta Llama 3.3 70B Instruct model with 128K context window.",
},
"Llama-3.2-11B-Vision-Instruct": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: true,
supportsPromptCache: false,
inputPrice: 0.18,
outputPrice: 0.2,
description: "Meta Llama 3.2 11B Vision Instruct model with image support.",
},
"Llama-3.2-90B-Vision-Instruct": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: true,
supportsPromptCache: false,
inputPrice: 0.9,
outputPrice: 1.1,
description: "Meta Llama 3.2 90B Vision Instruct model with image support.",
},
"QwQ-32B-Preview": {
maxTokens: 32768,
contextWindow: 32768,
supportsImages: false,
supportsPromptCache: false,
supportsReasoningBudget: true,
inputPrice: 0.15,
outputPrice: 0.15,
description: "Alibaba QwQ 32B Preview reasoning model.",
},
"Qwen2.5-72B-Instruct": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.59,
outputPrice: 0.79,
description: "Alibaba Qwen 2.5 72B Instruct model with 128K context window.",
},
"Qwen2.5-Coder-32B-Instruct": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.29,
outputPrice: 0.39,
description: "Alibaba Qwen 2.5 Coder 32B Instruct model optimized for coding tasks.",
},
"deepseek-r1": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: false,
supportsReasoningBudget: true,
inputPrice: 0.55,
outputPrice: 2.19,
description: "DeepSeek R1 reasoning model with 128K context window.",
},
"deepseek-r1-distill-llama-70b": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.27,
outputPrice: 1.08,
description: "DeepSeek R1 distilled Llama 70B model with 128K context window.",
},
} as const satisfies Record<string, ModelInfo>
3 changes: 3 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import {
ChutesHandler,
LiteLLMHandler,
ClaudeCodeHandler,
SambaNovaHandler,
} from "./providers"

export interface SingleCompletionHandler {
Expand Down Expand Up @@ -112,6 +113,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
return new ChutesHandler(options)
case "litellm":
return new LiteLLMHandler(options)
case "sambanova":
return new SambaNovaHandler(options)
default:
apiProvider satisfies "gemini-cli" | undefined
return new AnthropicHandler(options)
Expand Down
163 changes: 163 additions & 0 deletions src/api/providers/__tests__/sambanova.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import { describe, it, expect, vi, beforeEach } from "vitest"
import OpenAI from "openai"
import { Anthropic } from "@anthropic-ai/sdk"

import { type SambaNovaModelId, sambaNovaModels } from "@roo-code/types"

import { SambaNovaHandler } from "../sambanova"

// Mock OpenAI
vi.mock("openai", () => {
const mockCreate = vi.fn()
return {
default: vi.fn(() => ({
chat: {
completions: {
create: mockCreate,
},
},
})),
}
})

describe("SambaNovaHandler", () => {
let handler: SambaNovaHandler
let mockCreate: any

beforeEach(() => {
vi.clearAllMocks()
mockCreate = (OpenAI as unknown as any)().chat.completions.create
handler = new SambaNovaHandler({ sambaNovaApiKey: "test-sambanova-api-key" })
})

it("should use the correct SambaNova base URL", () => {
new SambaNovaHandler({ sambaNovaApiKey: "test-sambanova-api-key" })
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ baseURL: "https://api.sambanova.ai/v1" }))
})

it("should use the provided API key", () => {
const sambaNovaApiKey = "test-sambanova-api-key"
new SambaNovaHandler({ sambaNovaApiKey })
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: sambaNovaApiKey }))
})

it("should throw an error if API key is not provided", () => {
expect(() => new SambaNovaHandler({} as any)).toThrow("API key is required")
})

it("should use the specified model when provided", () => {
const testModelId: SambaNovaModelId = "Meta-Llama-3.3-70B-Instruct"
const handlerWithModel = new SambaNovaHandler({
apiModelId: testModelId,
sambaNovaApiKey: "test-sambanova-api-key",
})
const model = handlerWithModel.getModel()
expect(model.id).toBe(testModelId)
expect(model.info).toEqual(sambaNovaModels[testModelId])
})

it("should use the default model when no model is specified", () => {
const model = handler.getModel()
expect(model.id).toBe("Meta-Llama-3.3-70B-Instruct")
expect(model.info).toEqual(sambaNovaModels["Meta-Llama-3.3-70B-Instruct"])
})

describe("createMessage", () => {
it("should create a streaming chat completion with correct parameters", async () => {
const systemPrompt = "You are a helpful assistant"
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello",
},
]

mockCreate.mockImplementation(() => {
const chunks = [
{
choices: [{ delta: { content: "Hi there!" } }],
},
{
choices: [{ delta: {} }],
usage: { prompt_tokens: 10, completion_tokens: 5 },
},
]

return {
[Symbol.asyncIterator]: async function* () {
for (const chunk of chunks) {
yield chunk
}
},
}
})

const stream = handler.createMessage(systemPrompt, messages)
const results = []
for await (const chunk of stream) {
results.push(chunk)
}

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: "Meta-Llama-3.3-70B-Instruct",
max_tokens: 8192,
temperature: 0.7,
messages: [
{ role: "system", content: systemPrompt },
{ role: "user", content: "Hello" },
],
stream: true,
stream_options: { include_usage: true },
}),
)

expect(results).toEqual([
{ type: "text", text: "Hi there!" },
{ type: "usage", inputTokens: 10, outputTokens: 5 },
])
})
})

describe("completePrompt", () => {
it("should complete a prompt successfully", async () => {
const prompt = "Test prompt"
const expectedResponse = "Test response"

mockCreate.mockResolvedValue({
choices: [{ message: { content: expectedResponse } }],
})

const result = await handler.completePrompt(prompt)

expect(mockCreate).toHaveBeenCalledWith({
model: "Meta-Llama-3.3-70B-Instruct",
messages: [{ role: "user", content: prompt }],
})
expect(result).toBe(expectedResponse)
})

it("should handle errors properly", async () => {
const prompt = "Test prompt"
const errorMessage = "API Error"

mockCreate.mockRejectedValue(new Error(errorMessage))

await expect(handler.completePrompt(prompt)).rejects.toThrow(`SambaNova completion error: ${errorMessage}`)
})
})

describe("model selection", () => {
it.each(Object.keys(sambaNovaModels) as SambaNovaModelId[])("should correctly handle model %s", (modelId) => {
const modelInfo = sambaNovaModels[modelId]
const handlerWithModel = new SambaNovaHandler({
apiModelId: modelId,
sambaNovaApiKey: "test-sambanova-api-key",
})

const model = handlerWithModel.getModel()
expect(model.id).toBe(modelId)
expect(model.info).toEqual(modelInfo)
})
})
})
1 change: 1 addition & 0 deletions src/api/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export { OpenAiNativeHandler } from "./openai-native"
export { OpenAiHandler } from "./openai"
export { OpenRouterHandler } from "./openrouter"
export { RequestyHandler } from "./requesty"
export { SambaNovaHandler } from "./sambanova"
export { UnboundHandler } from "./unbound"
export { VertexHandler } from "./vertex"
export { VsCodeLmHandler } from "./vscode-lm"
Expand Down
19 changes: 19 additions & 0 deletions src/api/providers/sambanova.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { type SambaNovaModelId, sambaNovaDefaultModelId, sambaNovaModels } from "@roo-code/types"

import type { ApiHandlerOptions } from "../../shared/api"

import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"

export class SambaNovaHandler extends BaseOpenAiCompatibleProvider<SambaNovaModelId> {
constructor(options: ApiHandlerOptions) {
super({
...options,
providerName: "SambaNova",
baseURL: "https://api.sambanova.ai/v1",
apiKey: options.sambaNovaApiKey,
defaultProviderModelId: sambaNovaDefaultModelId,
providerModels: sambaNovaModels,
defaultTemperature: 0.7,
})
}
}
Loading