diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index 884337767fe..ea7089a81ea 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -32,6 +32,7 @@ export const providerNames = [ "groq", "chutes", "litellm", + "huggingface", ] as const export const providerNamesSchema = z.enum(providerNames) @@ -219,6 +220,12 @@ const groqSchema = apiModelIdProviderModelSchema.extend({ groqApiKey: z.string().optional(), }) +const huggingFaceSchema = baseProviderSettingsSchema.extend({ + huggingFaceApiKey: z.string().optional(), + huggingFaceModelId: z.string().optional(), + huggingFaceInferenceProvider: z.string().optional(), +}) + const chutesSchema = apiModelIdProviderModelSchema.extend({ chutesApiKey: z.string().optional(), }) @@ -256,6 +263,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv fakeAiSchema.merge(z.object({ apiProvider: z.literal("fake-ai") })), xaiSchema.merge(z.object({ apiProvider: z.literal("xai") })), groqSchema.merge(z.object({ apiProvider: z.literal("groq") })), + huggingFaceSchema.merge(z.object({ apiProvider: z.literal("huggingface") })), chutesSchema.merge(z.object({ apiProvider: z.literal("chutes") })), litellmSchema.merge(z.object({ apiProvider: z.literal("litellm") })), defaultSchema, @@ -285,6 +293,7 @@ export const providerSettingsSchema = z.object({ ...fakeAiSchema.shape, ...xaiSchema.shape, ...groqSchema.shape, + ...huggingFaceSchema.shape, ...chutesSchema.shape, ...litellmSchema.shape, ...codebaseIndexProviderSchema.shape, @@ -304,6 +313,7 @@ export const MODEL_ID_KEYS: Partial[] = [ "unboundModelId", "requestyModelId", "litellmModelId", + "huggingFaceModelId", ] export const getModelId = (settings: ProviderSettings): string | undefined => { diff --git a/packages/types/src/providers/huggingface.ts b/packages/types/src/providers/huggingface.ts new file mode 100644 index 00000000000..79455f10333 --- /dev/null +++ b/packages/types/src/providers/huggingface.ts @@ -0,0 +1,61 @@ +import { z } from "zod" +import { modelInfoSchema } from "../model.js" + +export const huggingFaceDefaultModelId = "meta-llama/Llama-3.3-70B-Instruct" + +export const huggingFaceModels = { + "meta-llama/Llama-3.3-70B-Instruct": { + maxTokens: 8192, + contextWindow: 131072, + supportsImages: false, + supportsPromptCache: false, + }, + "meta-llama/Llama-3.2-11B-Vision-Instruct": { + maxTokens: 4096, + contextWindow: 131072, + supportsImages: true, + supportsPromptCache: false, + }, + "Qwen/Qwen2.5-72B-Instruct": { + maxTokens: 8192, + contextWindow: 131072, + supportsImages: false, + supportsPromptCache: false, + }, + "mistralai/Mistral-7B-Instruct-v0.3": { + maxTokens: 8192, + contextWindow: 32768, + supportsImages: false, + supportsPromptCache: false, + }, +} as const + +export type HuggingFaceModelId = keyof typeof huggingFaceModels + +export const huggingFaceModelSchema = z.enum( + Object.keys(huggingFaceModels) as [HuggingFaceModelId, ...HuggingFaceModelId[]], +) + +export const huggingFaceModelInfoSchema = z + .discriminatedUnion("id", [ + z.object({ + id: z.literal("meta-llama/Llama-3.3-70B-Instruct"), + info: modelInfoSchema.optional(), + }), + z.object({ + id: z.literal("meta-llama/Llama-3.2-11B-Vision-Instruct"), + info: modelInfoSchema.optional(), + }), + z.object({ + id: z.literal("Qwen/Qwen2.5-72B-Instruct"), + info: modelInfoSchema.optional(), + }), + z.object({ + id: z.literal("mistralai/Mistral-7B-Instruct-v0.3"), + info: modelInfoSchema.optional(), + }), + ]) + .transform(({ id, info }) => ({ + id, + info: { ...huggingFaceModels[id], ...info }, + })) diff --git a/packages/types/src/providers/index.ts b/packages/types/src/providers/index.ts index e4e506b8a7b..f5061f152c0 100644 --- a/packages/types/src/providers/index.ts +++ b/packages/types/src/providers/index.ts @@ -6,6 +6,7 @@ export * from "./deepseek.js" export * from "./gemini.js" export * from "./glama.js" export * from "./groq.js" +export * from "./huggingface.js" export * from "./lite-llm.js" export * from "./lm-studio.js" export * from "./mistral.js" diff --git a/src/api/huggingface-models.ts b/src/api/huggingface-models.ts new file mode 100644 index 00000000000..ec1915d0e3d --- /dev/null +++ b/src/api/huggingface-models.ts @@ -0,0 +1,17 @@ +import { fetchHuggingFaceModels, type HuggingFaceModel } from "../services/huggingface-models" + +export interface HuggingFaceModelsResponse { + models: HuggingFaceModel[] + cached: boolean + timestamp: number +} + +export async function getHuggingFaceModels(): Promise { + const models = await fetchHuggingFaceModels() + + return { + models, + cached: false, // We could enhance this to track if data came from cache + timestamp: Date.now(), + } +} diff --git a/src/api/index.ts b/src/api/index.ts index 4598a711b2a..bda390848cd 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -26,6 +26,7 @@ import { FakeAIHandler, XAIHandler, GroqHandler, + HuggingFaceHandler, ChutesHandler, LiteLLMHandler, ClaudeCodeHandler, @@ -108,6 +109,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler { return new XAIHandler(options) case "groq": return new GroqHandler(options) + case "huggingface": + return new HuggingFaceHandler(options) case "chutes": return new ChutesHandler(options) case "litellm": diff --git a/src/api/providers/huggingface.ts b/src/api/providers/huggingface.ts new file mode 100644 index 00000000000..913605bd929 --- /dev/null +++ b/src/api/providers/huggingface.ts @@ -0,0 +1,99 @@ +import OpenAI from "openai" +import { Anthropic } from "@anthropic-ai/sdk" + +import type { ApiHandlerOptions } from "../../shared/api" +import { ApiStream } from "../transform/stream" +import { convertToOpenAiMessages } from "../transform/openai-format" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import { DEFAULT_HEADERS } from "./constants" +import { BaseProvider } from "./base-provider" + +export class HuggingFaceHandler extends BaseProvider implements SingleCompletionHandler { + private client: OpenAI + private options: ApiHandlerOptions + + constructor(options: ApiHandlerOptions) { + super() + this.options = options + + if (!this.options.huggingFaceApiKey) { + throw new Error("Hugging Face API key is required") + } + + this.client = new OpenAI({ + baseURL: "https://router.huggingface.co/v1", + apiKey: this.options.huggingFaceApiKey, + defaultHeaders: DEFAULT_HEADERS, + }) + } + + override async *createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct" + const temperature = this.options.modelTemperature ?? 0.7 + + const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { + model: modelId, + temperature, + messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)], + stream: true, + stream_options: { include_usage: true }, + } + + const stream = await this.client.chat.completions.create(params) + + for await (const chunk of stream) { + const delta = chunk.choices[0]?.delta + + if (delta?.content) { + yield { + type: "text", + text: delta.content, + } + } + + if (chunk.usage) { + yield { + type: "usage", + inputTokens: chunk.usage.prompt_tokens || 0, + outputTokens: chunk.usage.completion_tokens || 0, + } + } + } + } + + async completePrompt(prompt: string): Promise { + const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct" + + try { + const response = await this.client.chat.completions.create({ + model: modelId, + messages: [{ role: "user", content: prompt }], + }) + + return response.choices[0]?.message.content || "" + } catch (error) { + if (error instanceof Error) { + throw new Error(`Hugging Face completion error: ${error.message}`) + } + + throw error + } + } + + override getModel() { + const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct" + return { + id: modelId, + info: { + maxTokens: 8192, + contextWindow: 131072, + supportsImages: false, + supportsPromptCache: false, + }, + } + } +} diff --git a/src/api/providers/index.ts b/src/api/providers/index.ts index 89d4c203adf..1cefd0616b4 100644 --- a/src/api/providers/index.ts +++ b/src/api/providers/index.ts @@ -9,6 +9,7 @@ export { FakeAIHandler } from "./fake-ai" export { GeminiHandler } from "./gemini" export { GlamaHandler } from "./glama" export { GroqHandler } from "./groq" +export { HuggingFaceHandler } from "./huggingface" export { HumanRelayHandler } from "./human-relay" export { LiteLLMHandler } from "./lite-llm" export { LmStudioHandler } from "./lm-studio" diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 780d40df891..ebe95530f25 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -674,6 +674,22 @@ export const webviewMessageHandler = async ( // TODO: Cache like we do for OpenRouter, etc? provider.postMessageToWebview({ type: "vsCodeLmModels", vsCodeLmModels }) break + case "requestHuggingFaceModels": + try { + const { getHuggingFaceModels } = await import("../../api/huggingface-models") + const huggingFaceModelsResponse = await getHuggingFaceModels() + provider.postMessageToWebview({ + type: "huggingFaceModels", + huggingFaceModels: huggingFaceModelsResponse.models, + }) + } catch (error) { + console.error("Failed to fetch Hugging Face models:", error) + provider.postMessageToWebview({ + type: "huggingFaceModels", + huggingFaceModels: [], + }) + } + break case "openImage": openImage(message.text!, { values: message.values }) break diff --git a/src/services/huggingface-models.ts b/src/services/huggingface-models.ts new file mode 100644 index 00000000000..9c0bc406f93 --- /dev/null +++ b/src/services/huggingface-models.ts @@ -0,0 +1,171 @@ +export interface HuggingFaceModel { + _id: string + id: string + inferenceProviderMapping: InferenceProviderMapping[] + trendingScore: number + config: ModelConfig + tags: string[] + pipeline_tag: "text-generation" | "image-text-to-text" + library_name?: string +} + +export interface InferenceProviderMapping { + provider: string + providerId: string + status: "live" | "staging" | "error" + task: "conversational" +} + +export interface ModelConfig { + architectures: string[] + model_type: string + tokenizer_config?: { + chat_template?: string | Array<{ name: string; template: string }> + model_max_length?: number + } +} + +interface HuggingFaceApiParams { + pipeline_tag?: "text-generation" | "image-text-to-text" + filter: string + inference_provider: string + limit: number + expand: string[] +} + +const DEFAULT_PARAMS: HuggingFaceApiParams = { + filter: "conversational", + inference_provider: "all", + limit: 100, + expand: [ + "inferenceProviderMapping", + "config", + "library_name", + "pipeline_tag", + "tags", + "mask_token", + "trendingScore", + ], +} + +const BASE_URL = "https://huggingface.co/api/models" +const CACHE_DURATION = 1000 * 60 * 60 // 1 hour + +interface CacheEntry { + data: HuggingFaceModel[] + timestamp: number + status: "success" | "partial" | "error" +} + +let cache: CacheEntry | null = null + +function buildApiUrl(params: HuggingFaceApiParams): string { + const url = new URL(BASE_URL) + + // Add simple params + Object.entries(params).forEach(([key, value]) => { + if (!Array.isArray(value)) { + url.searchParams.append(key, String(value)) + } + }) + + // Handle array params specially + params.expand.forEach((item) => { + url.searchParams.append("expand[]", item) + }) + + return url.toString() +} + +const headers: HeadersInit = { + "Upgrade-Insecure-Requests": "1", + "Sec-Fetch-Dest": "document", + "Sec-Fetch-Mode": "navigate", + "Sec-Fetch-Site": "none", + "Sec-Fetch-User": "?1", + Priority: "u=0, i", + Pragma: "no-cache", + "Cache-Control": "no-cache", +} + +const requestInit: RequestInit = { + credentials: "include", + headers, + method: "GET", + mode: "cors", +} + +export async function fetchHuggingFaceModels(): Promise { + const now = Date.now() + + // Check cache + if (cache && now - cache.timestamp < CACHE_DURATION) { + console.log("Using cached Hugging Face models") + return cache.data + } + + try { + console.log("Fetching Hugging Face models from API...") + + // Fetch both text-generation and image-text-to-text models in parallel + const [textGenResponse, imgTextResponse] = await Promise.allSettled([ + fetch(buildApiUrl({ ...DEFAULT_PARAMS, pipeline_tag: "text-generation" }), requestInit), + fetch(buildApiUrl({ ...DEFAULT_PARAMS, pipeline_tag: "image-text-to-text" }), requestInit), + ]) + + let textGenModels: HuggingFaceModel[] = [] + let imgTextModels: HuggingFaceModel[] = [] + let hasErrors = false + + // Process text-generation models + if (textGenResponse.status === "fulfilled" && textGenResponse.value.ok) { + textGenModels = await textGenResponse.value.json() + } else { + console.error("Failed to fetch text-generation models:", textGenResponse) + hasErrors = true + } + + // Process image-text-to-text models + if (imgTextResponse.status === "fulfilled" && imgTextResponse.value.ok) { + imgTextModels = await imgTextResponse.value.json() + } else { + console.error("Failed to fetch image-text-to-text models:", imgTextResponse) + hasErrors = true + } + + // Combine and filter models + const allModels = [...textGenModels, ...imgTextModels] + .filter((model) => model.inferenceProviderMapping.length > 0) + .sort((a, b) => a.id.toLowerCase().localeCompare(b.id.toLowerCase())) + + // Update cache + cache = { + data: allModels, + timestamp: now, + status: hasErrors ? "partial" : "success", + } + + console.log(`Fetched ${allModels.length} Hugging Face models (status: ${cache.status})`) + return allModels + } catch (error) { + console.error("Error fetching Hugging Face models:", error) + + // Return cached data if available + if (cache) { + console.log("Using stale cached data due to fetch error") + cache.status = "error" + return cache.data + } + + // No cache available, return empty array + return [] + } +} + +export function getCachedModels(): HuggingFaceModel[] | null { + return cache?.data || null +} + +export function clearCache(): void { + cache = null +} diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index 4f2aa2da159..827a4a956dc 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -18,6 +18,7 @@ import { McpServer } from "./mcp" import { Mode } from "./modes" import { RouterModels } from "./api" import type { MarketplaceItem } from "@roo-code/types" +import type { HuggingFaceModel } from "../services/huggingface-models" // Type for marketplace installed metadata export interface MarketplaceInstalledMetadata { @@ -67,6 +68,7 @@ export interface ExtensionMessage { | "ollamaModels" | "lmStudioModels" | "vsCodeLmModels" + | "huggingFaceModels" | "vsCodeLmApiAvailable" | "updatePrompt" | "systemPrompt" @@ -135,6 +137,7 @@ export interface ExtensionMessage { ollamaModels?: string[] lmStudioModels?: string[] vsCodeLmModels?: { vendor?: string; family?: string; version?: string; id?: string }[] + huggingFaceModels?: HuggingFaceModel[] mcpServers?: McpServer[] commits?: GitCommit[] listApiConfig?: ProviderSettingsEntry[] diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index 1f56829f7b3..b0529c5a276 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -67,6 +67,7 @@ export interface WebviewMessage { | "requestOllamaModels" | "requestLmStudioModels" | "requestVsCodeLmModels" + | "requestHuggingFaceModels" | "openImage" | "saveImage" | "openFile" diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 6c6c621956c..155f694929d 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -25,6 +25,7 @@ import { chutesDefaultModelId, bedrockDefaultModelId, vertexDefaultModelId, + huggingFaceDefaultModelId, } from "@roo-code/types" import { vscode } from "@src/utils/vscode" @@ -59,6 +60,7 @@ import { Gemini, Glama, Groq, + HuggingFace, LMStudio, LiteLLM, Mistral, @@ -296,6 +298,7 @@ const ApiOptions = ({ chutes: { field: "apiModelId", default: chutesDefaultModelId }, bedrock: { field: "apiModelId", default: bedrockDefaultModelId }, vertex: { field: "apiModelId", default: vertexDefaultModelId }, + huggingface: { field: "huggingFaceModelId", default: huggingFaceDefaultModelId }, openai: { field: "openAiModelId" }, ollama: { field: "ollamaModelId" }, lmstudio: { field: "lmStudioModelId" }, @@ -500,6 +503,10 @@ const ApiOptions = ({ /> )} + {selectedProvider === "huggingface" && ( + + )} + {selectedProvider === "human-relay" && ( <>
diff --git a/webview-ui/src/components/settings/constants.ts b/webview-ui/src/components/settings/constants.ts index 1140e4c0bcf..a24d9dccfb3 100644 --- a/webview-ui/src/components/settings/constants.ts +++ b/webview-ui/src/components/settings/constants.ts @@ -53,4 +53,5 @@ export const PROVIDERS = [ { value: "groq", label: "Groq" }, { value: "chutes", label: "Chutes AI" }, { value: "litellm", label: "LiteLLM" }, + { value: "huggingface", label: "HuggingFace" }, ].sort((a, b) => a.label.localeCompare(b.label)) diff --git a/webview-ui/src/components/settings/providers/HuggingFace.tsx b/webview-ui/src/components/settings/providers/HuggingFace.tsx new file mode 100644 index 00000000000..a87c59ec72b --- /dev/null +++ b/webview-ui/src/components/settings/providers/HuggingFace.tsx @@ -0,0 +1,210 @@ +import { useCallback, useState, useEffect, useMemo } from "react" +import { useEvent } from "react-use" +import { VSCodeTextField } from "@vscode/webview-ui-toolkit/react" + +import type { ProviderSettings } from "@roo-code/types" + +import { ExtensionMessage } from "@roo/ExtensionMessage" +import { vscode } from "@src/utils/vscode" +import { useAppTranslation } from "@src/i18n/TranslationContext" +import { VSCodeButtonLink } from "@src/components/common/VSCodeButtonLink" +import { SearchableSelect, type SearchableSelectOption } from "@src/components/ui" + +import { inputEventTransform } from "../transforms" + +type HuggingFaceModel = { + _id: string + id: string + inferenceProviderMapping: Array<{ + provider: string + providerId: string + status: "live" | "staging" | "error" + task: "conversational" + }> + trendingScore: number + config: { + architectures: string[] + model_type: string + tokenizer_config?: { + chat_template?: string | Array<{ name: string; template: string }> + model_max_length?: number + } + } + tags: string[] + pipeline_tag: "text-generation" | "image-text-to-text" + library_name?: string +} + +type HuggingFaceProps = { + apiConfiguration: ProviderSettings + setApiConfigurationField: (field: keyof ProviderSettings, value: ProviderSettings[keyof ProviderSettings]) => void +} + +export const HuggingFace = ({ apiConfiguration, setApiConfigurationField }: HuggingFaceProps) => { + const { t } = useAppTranslation() + const [models, setModels] = useState([]) + const [loading, setLoading] = useState(false) + const [selectedProvider, setSelectedProvider] = useState( + apiConfiguration?.huggingFaceInferenceProvider || "auto", + ) + + const handleInputChange = useCallback( + ( + field: K, + transform: (event: E) => ProviderSettings[K] = inputEventTransform, + ) => + (event: E | Event) => { + setApiConfigurationField(field, transform(event as E)) + }, + [setApiConfigurationField], + ) + + // Fetch models when component mounts + useEffect(() => { + setLoading(true) + vscode.postMessage({ type: "requestHuggingFaceModels" }) + }, []) + + // Handle messages from extension + const onMessage = useCallback((event: MessageEvent) => { + const message: ExtensionMessage = event.data + + switch (message.type) { + case "huggingFaceModels": + setModels(message.huggingFaceModels || []) + setLoading(false) + break + } + }, []) + + useEvent("message", onMessage) + + // Get current model and its providers + const currentModel = models.find((m) => m.id === apiConfiguration?.huggingFaceModelId) + const availableProviders = useMemo( + () => currentModel?.inferenceProviderMapping || [], + [currentModel?.inferenceProviderMapping], + ) + + // Set default provider when model changes + useEffect(() => { + if (currentModel && availableProviders.length > 0) { + const savedProvider = apiConfiguration?.huggingFaceInferenceProvider + if (savedProvider) { + // Use saved provider if it exists + setSelectedProvider(savedProvider) + } else { + const currentProvider = availableProviders.find((p) => p.provider === selectedProvider) + if (!currentProvider) { + // Set to "auto" as default + const defaultProvider = "auto" + setSelectedProvider(defaultProvider) + setApiConfigurationField("huggingFaceInferenceProvider", defaultProvider) + } + } + } + }, [ + currentModel, + availableProviders, + selectedProvider, + apiConfiguration?.huggingFaceInferenceProvider, + setApiConfigurationField, + ]) + + const handleModelSelect = (modelId: string) => { + setApiConfigurationField("huggingFaceModelId", modelId) + // Reset provider selection when model changes + const defaultProvider = "auto" + setSelectedProvider(defaultProvider) + setApiConfigurationField("huggingFaceInferenceProvider", defaultProvider) + } + + const handleProviderSelect = (provider: string) => { + setSelectedProvider(provider) + setApiConfigurationField("huggingFaceInferenceProvider", provider) + } + + // Format provider name for display + const formatProviderName = (provider: string) => { + const nameMap: Record = { + sambanova: "SambaNova", + "fireworks-ai": "Fireworks", + together: "Together AI", + nebius: "Nebius AI Studio", + hyperbolic: "Hyperbolic", + novita: "Novita", + cohere: "Cohere", + "hf-inference": "HF Inference API", + replicate: "Replicate", + } + return nameMap[provider] || provider.charAt(0).toUpperCase() + provider.slice(1) + } + + return ( + <> + + + + +
+ + + ({ + value: model.id, + label: model.id, + }), + )} + placeholder="Select a model..." + searchPlaceholder="Search models..." + emptyMessage="No models found" + disabled={loading} + /> +
+ + {currentModel && availableProviders.length > 0 && ( +
+ + ({ + value: mapping.provider, + label: `${formatProviderName(mapping.provider)} (${mapping.status})`, + }), + ), + ]} + placeholder="Select a provider..." + searchPlaceholder="Search providers..." + emptyMessage="No providers found" + /> +
+ )} + +
+ {t("settings:providers.apiKeyStorageNotice")} +
+ + {!apiConfiguration?.huggingFaceApiKey && ( + + {t("settings:providers.getHuggingFaceApiKey")} + + )} + + ) +} diff --git a/webview-ui/src/components/settings/providers/index.ts b/webview-ui/src/components/settings/providers/index.ts index 54974f7200b..6c6fdddaee0 100644 --- a/webview-ui/src/components/settings/providers/index.ts +++ b/webview-ui/src/components/settings/providers/index.ts @@ -6,6 +6,7 @@ export { DeepSeek } from "./DeepSeek" export { Gemini } from "./Gemini" export { Glama } from "./Glama" export { Groq } from "./Groq" +export { HuggingFace } from "./HuggingFace" export { LMStudio } from "./LMStudio" export { Mistral } from "./Mistral" export { Moonshot } from "./Moonshot" diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index 928ebb42f46..b788bc57699 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -34,6 +34,8 @@ import { litellmDefaultModelId, claudeCodeDefaultModelId, claudeCodeModels, + huggingFaceDefaultModelId, + huggingFaceModels, } from "@roo-code/types" import type { RouterModels } from "@roo/api" @@ -214,11 +216,16 @@ function getSelectedModel({ const info = claudeCodeModels[id as keyof typeof claudeCodeModels] return { id, info: { ...openAiModelInfoSaneDefaults, ...info } } } + case "huggingface": { + const id = apiConfiguration.huggingFaceModelId ?? huggingFaceDefaultModelId + const info = huggingFaceModels[id as keyof typeof huggingFaceModels] + return { id, info: info || openAiModelInfoSaneDefaults } + } // case "anthropic": // case "human-relay": // case "fake-ai": default: { - provider satisfies "anthropic" | "gemini-cli" | "human-relay" | "fake-ai" + provider satisfies "anthropic" | "gemini-cli" | "human-relay" | "fake-ai" | "huggingface" const id = apiConfiguration.apiModelId ?? anthropicDefaultModelId const info = anthropicModels[id as keyof typeof anthropicModels] return { id, info } diff --git a/webview-ui/src/i18n/locales/en/settings.json b/webview-ui/src/i18n/locales/en/settings.json index 4a826bddab4..5c60edbd908 100644 --- a/webview-ui/src/i18n/locales/en/settings.json +++ b/webview-ui/src/i18n/locales/en/settings.json @@ -264,6 +264,9 @@ "apiKey": "API Key", "openAiBaseUrl": "Base URL", "getOpenAiApiKey": "Get OpenAI API Key", + "huggingFaceApiKey": "HuggingFace API Key", + "getHuggingFaceApiKey": "Get HuggingFace API Key", + "huggingFaceModelId": "Model", "mistralApiKey": "Mistral API Key", "getMistralApiKey": "Get Mistral / Codestral API Key", "codestralBaseUrl": "Codestral Base URL (Optional)",