Skip to content

feat: add HuggingFace provider support #6127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ export const providerNames = [
"groq",
"chutes",
"litellm",
"huggingface",
] as const

export const providerNamesSchema = z.enum(providerNames)
Expand Down Expand Up @@ -219,6 +220,12 @@ const groqSchema = apiModelIdProviderModelSchema.extend({
groqApiKey: z.string().optional(),
})

const huggingFaceSchema = baseProviderSettingsSchema.extend({
huggingFaceApiKey: z.string().optional(),
huggingFaceModelId: z.string().optional(),
huggingFaceInferenceProvider: z.string().optional(),
})

const chutesSchema = apiModelIdProviderModelSchema.extend({
chutesApiKey: z.string().optional(),
})
Expand Down Expand Up @@ -256,6 +263,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
fakeAiSchema.merge(z.object({ apiProvider: z.literal("fake-ai") })),
xaiSchema.merge(z.object({ apiProvider: z.literal("xai") })),
groqSchema.merge(z.object({ apiProvider: z.literal("groq") })),
huggingFaceSchema.merge(z.object({ apiProvider: z.literal("huggingface") })),
chutesSchema.merge(z.object({ apiProvider: z.literal("chutes") })),
litellmSchema.merge(z.object({ apiProvider: z.literal("litellm") })),
defaultSchema,
Expand Down Expand Up @@ -285,6 +293,7 @@ export const providerSettingsSchema = z.object({
...fakeAiSchema.shape,
...xaiSchema.shape,
...groqSchema.shape,
...huggingFaceSchema.shape,
...chutesSchema.shape,
...litellmSchema.shape,
...codebaseIndexProviderSchema.shape,
Expand All @@ -304,6 +313,7 @@ export const MODEL_ID_KEYS: Partial<keyof ProviderSettings>[] = [
"unboundModelId",
"requestyModelId",
"litellmModelId",
"huggingFaceModelId",
]

export const getModelId = (settings: ProviderSettings): string | undefined => {
Expand Down
61 changes: 61 additions & 0 deletions packages/types/src/providers/huggingface.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import { z } from "zod"
import { modelInfoSchema } from "../model.js"

export const huggingFaceDefaultModelId = "meta-llama/Llama-3.3-70B-Instruct"

export const huggingFaceModels = {
"meta-llama/Llama-3.3-70B-Instruct": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: false,
},
"meta-llama/Llama-3.2-11B-Vision-Instruct": {
maxTokens: 4096,
contextWindow: 131072,
supportsImages: true,
supportsPromptCache: false,
},
"Qwen/Qwen2.5-72B-Instruct": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: false,
},
"mistralai/Mistral-7B-Instruct-v0.3": {
maxTokens: 8192,
contextWindow: 32768,
supportsImages: false,
supportsPromptCache: false,
},
} as const

export type HuggingFaceModelId = keyof typeof huggingFaceModels

export const huggingFaceModelSchema = z.enum(
Object.keys(huggingFaceModels) as [HuggingFaceModelId, ...HuggingFaceModelId[]],
)

export const huggingFaceModelInfoSchema = z
.discriminatedUnion("id", [
z.object({
id: z.literal("meta-llama/Llama-3.3-70B-Instruct"),
info: modelInfoSchema.optional(),
}),
z.object({
id: z.literal("meta-llama/Llama-3.2-11B-Vision-Instruct"),
info: modelInfoSchema.optional(),
}),
z.object({
id: z.literal("Qwen/Qwen2.5-72B-Instruct"),
info: modelInfoSchema.optional(),
}),
z.object({
id: z.literal("mistralai/Mistral-7B-Instruct-v0.3"),
info: modelInfoSchema.optional(),
}),
])
.transform(({ id, info }) => ({
id,
info: { ...huggingFaceModels[id], ...info },
}))
1 change: 1 addition & 0 deletions packages/types/src/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export * from "./deepseek.js"
export * from "./gemini.js"
export * from "./glama.js"
export * from "./groq.js"
export * from "./huggingface.js"
export * from "./lite-llm.js"
export * from "./lm-studio.js"
export * from "./mistral.js"
Expand Down
17 changes: 17 additions & 0 deletions src/api/huggingface-models.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { fetchHuggingFaceModels, type HuggingFaceModel } from "../services/huggingface-models"

export interface HuggingFaceModelsResponse {
models: HuggingFaceModel[]
cached: boolean
timestamp: number
}

export async function getHuggingFaceModels(): Promise<HuggingFaceModelsResponse> {
const models = await fetchHuggingFaceModels()

return {
models,
cached: false, // We could enhance this to track if data came from cache
timestamp: Date.now(),
}
}
3 changes: 3 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import {
FakeAIHandler,
XAIHandler,
GroqHandler,
HuggingFaceHandler,
ChutesHandler,
LiteLLMHandler,
ClaudeCodeHandler,
Expand Down Expand Up @@ -108,6 +109,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
return new XAIHandler(options)
case "groq":
return new GroqHandler(options)
case "huggingface":
return new HuggingFaceHandler(options)
case "chutes":
return new ChutesHandler(options)
case "litellm":
Expand Down
99 changes: 99 additions & 0 deletions src/api/providers/huggingface.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import OpenAI from "openai"
import { Anthropic } from "@anthropic-ai/sdk"

import type { ApiHandlerOptions } from "../../shared/api"
import { ApiStream } from "../transform/stream"
import { convertToOpenAiMessages } from "../transform/openai-format"
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
import { DEFAULT_HEADERS } from "./constants"
import { BaseProvider } from "./base-provider"

export class HuggingFaceHandler extends BaseProvider implements SingleCompletionHandler {
private client: OpenAI
private options: ApiHandlerOptions

constructor(options: ApiHandlerOptions) {
super()
this.options = options

if (!this.options.huggingFaceApiKey) {
throw new Error("Hugging Face API key is required")
}

this.client = new OpenAI({
baseURL: "https://router.huggingface.co/v1",
apiKey: this.options.huggingFaceApiKey,
defaultHeaders: DEFAULT_HEADERS,
})
}

override async *createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
metadata?: ApiHandlerCreateMessageMetadata,
): ApiStream {
const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct"
const temperature = this.options.modelTemperature ?? 0.7

const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model: modelId,
temperature,
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
stream: true,
stream_options: { include_usage: true },
}

const stream = await this.client.chat.completions.create(params)

for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta

if (delta?.content) {
yield {
type: "text",
text: delta.content,
}
}

if (chunk.usage) {
yield {
type: "usage",
inputTokens: chunk.usage.prompt_tokens || 0,
outputTokens: chunk.usage.completion_tokens || 0,
}
}
}
}

async completePrompt(prompt: string): Promise<string> {
const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct"

try {
const response = await this.client.chat.completions.create({
model: modelId,
messages: [{ role: "user", content: prompt }],
})

return response.choices[0]?.message.content || ""
} catch (error) {
if (error instanceof Error) {
throw new Error(`Hugging Face completion error: ${error.message}`)
}

throw error
}
}

override getModel() {
const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct"
return {
id: modelId,
info: {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: false,
},
}
}
}
1 change: 1 addition & 0 deletions src/api/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export { FakeAIHandler } from "./fake-ai"
export { GeminiHandler } from "./gemini"
export { GlamaHandler } from "./glama"
export { GroqHandler } from "./groq"
export { HuggingFaceHandler } from "./huggingface"
export { HumanRelayHandler } from "./human-relay"
export { LiteLLMHandler } from "./lite-llm"
export { LmStudioHandler } from "./lm-studio"
Expand Down
16 changes: 16 additions & 0 deletions src/core/webview/webviewMessageHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,22 @@ export const webviewMessageHandler = async (
// TODO: Cache like we do for OpenRouter, etc?
provider.postMessageToWebview({ type: "vsCodeLmModels", vsCodeLmModels })
break
case "requestHuggingFaceModels":
try {
const { getHuggingFaceModels } = await import("../../api/huggingface-models")
const huggingFaceModelsResponse = await getHuggingFaceModels()
provider.postMessageToWebview({
type: "huggingFaceModels",
huggingFaceModels: huggingFaceModelsResponse.models,
})
} catch (error) {
console.error("Failed to fetch Hugging Face models:", error)
provider.postMessageToWebview({
type: "huggingFaceModels",
huggingFaceModels: [],
})
}
break
case "openImage":
openImage(message.text!, { values: message.values })
break
Expand Down
Loading
Loading