From 82386eae3bd01f29a1516b0d2c046beea202ff70 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Wed, 18 Jun 2025 15:16:49 -0700 Subject: [PATCH 01/18] initial commit --- packages/types/src/provider-settings.ts | 10 ++ packages/types/src/providers/archgw.ts | 13 ++ packages/types/src/providers/index.ts | 1 + src/api/index.ts | 3 + src/api/providers/archgw.ts | 139 ++++++++++++++++ src/api/providers/fetchers/archgw.ts | 56 +++++++ src/api/providers/fetchers/litellm.ts | 1 + src/api/providers/fetchers/modelCache.ts | 7 + src/api/providers/index.ts | 1 + src/core/webview/webviewMessageHandler.ts | 16 ++ src/shared/api.ts | 3 +- .../src/components/settings/ApiOptions.tsx | 10 ++ .../src/components/settings/ModelPicker.tsx | 8 +- .../src/components/settings/constants.ts | 1 + .../components/settings/providers/archgw.tsx | 151 ++++++++++++++++++ .../components/settings/providers/index.ts | 1 + .../components/ui/hooks/useSelectedModel.ts | 9 ++ webview-ui/src/utils/validate.ts | 3 + 18 files changed, 431 insertions(+), 2 deletions(-) create mode 100644 packages/types/src/providers/archgw.ts create mode 100644 src/api/providers/archgw.ts create mode 100644 src/api/providers/fetchers/archgw.ts create mode 100644 webview-ui/src/components/settings/providers/archgw.tsx diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index 65e3f9b5b65..c30f3d53516 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -21,6 +21,7 @@ export const providerNames = [ "openai-native", "mistral", "deepseek", + "archgw", "unbound", "requesty", "human-relay", @@ -168,6 +169,12 @@ const deepSeekSchema = apiModelIdProviderModelSchema.extend({ deepSeekApiKey: z.string().optional(), }) +const archgwSchema = apiModelIdProviderModelSchema.extend({ + archgwBaseUrl: z.string().optional(), + archgwApiKey: z.string().optional(), + archgwModelId: z.string().optional(), +}) + const unboundSchema = baseProviderSettingsSchema.extend({ unboundApiKey: z.string().optional(), unboundModelId: z.string().optional(), @@ -220,6 +227,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv openAiNativeSchema.merge(z.object({ apiProvider: z.literal("openai-native") })), mistralSchema.merge(z.object({ apiProvider: z.literal("mistral") })), deepSeekSchema.merge(z.object({ apiProvider: z.literal("deepseek") })), + archgwSchema.merge(z.object({ apiProvider: z.literal("archgw") })), unboundSchema.merge(z.object({ apiProvider: z.literal("unbound") })), requestySchema.merge(z.object({ apiProvider: z.literal("requesty") })), humanRelaySchema.merge(z.object({ apiProvider: z.literal("human-relay") })), @@ -246,6 +254,7 @@ export const providerSettingsSchema = z.object({ ...openAiNativeSchema.shape, ...mistralSchema.shape, ...deepSeekSchema.shape, + ...archgwSchema.shape, ...unboundSchema.shape, ...requestySchema.shape, ...humanRelaySchema.shape, @@ -271,6 +280,7 @@ export const MODEL_ID_KEYS: Partial[] = [ "unboundModelId", "requestyModelId", "litellmModelId", + "archgwModelId", ] export const getModelId = (settings: ProviderSettings): string | undefined => { diff --git a/packages/types/src/providers/archgw.ts b/packages/types/src/providers/archgw.ts new file mode 100644 index 00000000000..c4eb36fb958 --- /dev/null +++ b/packages/types/src/providers/archgw.ts @@ -0,0 +1,13 @@ +import type { ModelInfo } from "../model.js" + +export const archgwDefaultModelId = "openai/gpt-4.1" + +export const archgwDefaultModelInfo: ModelInfo = { + maxTokens: 32_768, + contextWindow: 1_047_576, + supportsImages: true, + supportsPromptCache: true, + inputPrice: 2, + outputPrice: 8, + cacheReadsPrice: 0.5, +} diff --git a/packages/types/src/providers/index.ts b/packages/types/src/providers/index.ts index 5f1c08041f7..5d69160b95d 100644 --- a/packages/types/src/providers/index.ts +++ b/packages/types/src/providers/index.ts @@ -15,3 +15,4 @@ export * from "./unbound.js" export * from "./vertex.js" export * from "./vscode-llm.js" export * from "./xai.js" +export * from "./archgw.js" diff --git a/src/api/index.ts b/src/api/index.ts index 8b09bf4cf9b..52a7392bc89 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -27,6 +27,7 @@ import { GroqHandler, ChutesHandler, LiteLLMHandler, + ArchGwHandler, } from "./providers" export interface SingleCompletionHandler { @@ -86,6 +87,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler { return new OpenAiNativeHandler(options) case "deepseek": return new DeepSeekHandler(options) + case "archgw": + return new ArchGwHandler(options) case "vscode-lm": return new VsCodeLmHandler(options) case "mistral": diff --git a/src/api/providers/archgw.ts b/src/api/providers/archgw.ts new file mode 100644 index 00000000000..f74102ebefd --- /dev/null +++ b/src/api/providers/archgw.ts @@ -0,0 +1,139 @@ +import OpenAI from "openai" +import { Anthropic } from "@anthropic-ai/sdk" // Keep for type usage only + +import { archgwDefaultModelId, archgwDefaultModelInfo } from "@roo-code/types" + +import { calculateApiCostOpenAI } from "../../shared/cost" + +import { ApiHandlerOptions } from "../../shared/api" + +import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" +import { convertToOpenAiMessages } from "../transform/openai-format" + +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import { RouterProvider } from "./router-provider" + +/** + * LiteLLM provider handler + * + * This handler uses the LiteLLM API to proxy requests to various LLM providers. + * It follows the OpenAI API format for compatibility. + */ +export class ArchGwHandler extends RouterProvider implements SingleCompletionHandler { + constructor(options: ApiHandlerOptions) { + super({ + options, + name: "archgw", + baseURL: `${options.archgwBaseUrl || "http://localhost:12000/v1"}`, + // baseURL: "http://localhost:12000/v1", + apiKey: options.archgwApiKey || "dummy-key", + modelId: options.archgwModelId, + defaultModelId: archgwDefaultModelId, + defaultModelInfo: archgwDefaultModelInfo, + }) + } + + override async *createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + const { id: modelId, info } = await this.fetchModel() + + const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ + { role: "system", content: systemPrompt }, + ...convertToOpenAiMessages(messages), + ] + + // Required by some providers; others default to max tokens allowed + let maxTokens: number | undefined = info.maxTokens ?? undefined + + const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { + model: modelId, + max_tokens: maxTokens, + messages: openAiMessages, + stream: true, + stream_options: { + include_usage: true, + }, + } + + if (this.supportsTemperature(modelId)) { + requestOptions.temperature = this.options.modelTemperature ?? 0 + } + + try { + const { data: completion } = await this.client.chat.completions.create(requestOptions).withResponse() + + let lastUsage + + for await (const chunk of completion) { + const delta = chunk.choices[0]?.delta + const usage = chunk.usage as ArchgwUsage + + if (delta?.content) { + yield { type: "text", text: delta.content } + } + + if (usage) { + lastUsage = usage + } + } + + if (lastUsage) { + const usageData: ApiStreamUsageChunk = { + type: "usage", + inputTokens: lastUsage.prompt_tokens || 0, + outputTokens: lastUsage.completion_tokens || 0, + cacheWriteTokens: lastUsage.cache_creation_input_tokens || 0, + cacheReadTokens: lastUsage.prompt_tokens_details?.cached_tokens || 0, + } + + usageData.totalCost = calculateApiCostOpenAI( + info, + usageData.inputTokens, + usageData.outputTokens, + usageData.cacheWriteTokens, + usageData.cacheReadTokens, + ) + + yield usageData + } + } catch (error) { + if (error instanceof Error) { + throw new Error(`archgw streaming error: ${error.message}`) + } + throw error + } + } + + async completePrompt(prompt: string): Promise { + const { id: modelId, info } = await this.fetchModel() + + try { + const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { + model: modelId, + messages: [{ role: "user", content: prompt }], + } + + if (this.supportsTemperature(modelId)) { + requestOptions.temperature = this.options.modelTemperature ?? 0 + } + + requestOptions.max_tokens = info.maxTokens + + const response = await this.client.chat.completions.create(requestOptions) + return response.choices[0]?.message.content || "" + } catch (error) { + if (error instanceof Error) { + throw new Error(`archgw completion error: ${error.message}`) + } + throw error + } + } +} + +// archgw usage may include an extra field for Anthropic use cases. +interface ArchgwUsage extends OpenAI.CompletionUsage { + cache_creation_input_tokens?: number +} diff --git a/src/api/providers/fetchers/archgw.ts b/src/api/providers/fetchers/archgw.ts new file mode 100644 index 00000000000..5c75ef84366 --- /dev/null +++ b/src/api/providers/fetchers/archgw.ts @@ -0,0 +1,56 @@ +import axios from "axios" + +import type { ModelInfo } from "@roo-code/types" + +import { parseApiPrice } from "../../../shared/cost" + +export async function getArchGwModels(apiKey: string, baseUrl: string): Promise> { + const models: Record = {} + + console.log("Fetching archgw models...") + + try { + const headers: Record = { + "Content-Type": "application/json", + } + + if (apiKey) { + headers["Authorization"] = `Bearer ${apiKey}` + } + + const url = new URL("/v1/models", baseUrl).href + const response = await axios.get(url, { headers, timeout: 5000 }) + const rawModels = response.data + + for (const rawModel of rawModels.data) { + const modelInfo: ModelInfo = { + maxTokens: rawModel.maxTokensOutput, + contextWindow: rawModel.maxTokensInput, + supportsImages: rawModel.capabilities?.includes("input:image"), + supportsComputerUse: rawModel.capabilities?.includes("computer_use"), + supportsPromptCache: rawModel.capabilities?.includes("caching"), + inputPrice: parseApiPrice(rawModel.pricePerToken?.input), + outputPrice: parseApiPrice(rawModel.pricePerToken?.output), + description: undefined, + cacheWritesPrice: parseApiPrice(rawModel.pricePerToken?.cacheWrite), + cacheReadsPrice: parseApiPrice(rawModel.pricePerToken?.cacheRead), + } + + switch (rawModel.id) { + case rawModel.id.startsWith("anthropic/"): + modelInfo.maxTokens = 8192 + break + default: + break + } + + models[rawModel.id] = modelInfo + } + } catch (error) { + console.error(`Error fetching archgw models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`) + } + + console.log("Fetched archgw models:", models) + + return models +} diff --git a/src/api/providers/fetchers/litellm.ts b/src/api/providers/fetchers/litellm.ts index 47617cd3908..3d8b4b05ca1 100644 --- a/src/api/providers/fetchers/litellm.ts +++ b/src/api/providers/fetchers/litellm.ts @@ -13,6 +13,7 @@ import type { ModelRecord } from "../../../shared/api" * @throws Will throw an error if the request fails or the response is not as expected. */ export async function getLiteLLMModels(apiKey: string, baseUrl: string): Promise { + console.log("[getLiteLLMModels] Fetching LiteLLM models...") try { const headers: Record = { "Content-Type": "application/json", diff --git a/src/api/providers/fetchers/modelCache.ts b/src/api/providers/fetchers/modelCache.ts index 12d636bc46c..69bb61bf2bd 100644 --- a/src/api/providers/fetchers/modelCache.ts +++ b/src/api/providers/fetchers/modelCache.ts @@ -14,6 +14,7 @@ import { getGlamaModels } from "./glama" import { getUnboundModels } from "./unbound" import { getLiteLLMModels } from "./litellm" import { GetModelsOptions } from "../../../shared/api" +import { getArchGwModels } from "./archgw" const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 }) async function writeModels(router: RouterName, data: ModelRecord) { @@ -42,9 +43,11 @@ async function readModels(router: RouterName): Promise * @returns The models from the cache or the fetched models. */ export const getModels = async (options: GetModelsOptions): Promise => { + console.log("[getModels] Fetching models for provider:", options.provider) const { provider } = options let models = memoryCache.get(provider) if (models) { + console.log(`[getModels] Models for ${provider} found in memory cache`) return models } @@ -68,6 +71,10 @@ export const getModels = async (options: GetModelsOptions): Promise // Type safety ensures apiKey and baseUrl are always provided for litellm models = await getLiteLLMModels(options.apiKey, options.baseUrl) break + case "archgw": + console.log("[getModels] Fetching ArchGw models...") + models = await getArchGwModels(options.apiKey, options.baseUrl) + break default: { // Ensures router is exhaustively checked if RouterName is a strict union const exhaustiveCheck: never = provider diff --git a/src/api/providers/index.ts b/src/api/providers/index.ts index b305118188c..4b48f193a8b 100644 --- a/src/api/providers/index.ts +++ b/src/api/providers/index.ts @@ -20,3 +20,4 @@ export { UnboundHandler } from "./unbound" export { VertexHandler } from "./vertex" export { VsCodeLmHandler } from "./vscode-lm" export { XAIHandler } from "./xai" +export { ArchGwHandler } from "./archgw" diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index a60c5fea41e..8a2567c3ea3 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -361,6 +361,22 @@ export const webviewMessageHandler = async ( }) } + const archgwApiKey = apiConfiguration.archgwApiKey || message?.values?.archgwApiKey + const archgwBaseUrl = + apiConfiguration.archgwBaseUrl || message?.values?.archgwBaseUrl || "http://localhost:12000/v1" + console.log( + `[webviewMessageHandler] requestRouterModels - archgwApiKey: ${archgwApiKey} archgwBaseUrl: ${archgwBaseUrl}`, + ) + // const archgwBaseUrl = "http://localhost:12000/v1" + if (archgwBaseUrl) { + modelFetchPromises.push({ + key: "archgw", + options: { provider: "archgw", apiKey: archgwApiKey, baseUrl: archgwBaseUrl }, + }) + } + + console.log("[webviewMessageHandler] requestRouterModels - modelFetchPromises:", modelFetchPromises) + const results = await Promise.allSettled( modelFetchPromises.map(async ({ key, options }) => { const models = await safeGetModels(options) diff --git a/src/shared/api.ts b/src/shared/api.ts index 8ad88286589..c3254a3d2e3 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -6,7 +6,7 @@ export type ApiHandlerOptions = Omit // RouterName -const routerNames = ["openrouter", "requesty", "glama", "unbound", "litellm"] as const +const routerNames = ["openrouter", "requesty", "glama", "unbound", "litellm", "archgw"] as const export type RouterName = (typeof routerNames)[number] @@ -82,3 +82,4 @@ export type GetModelsOptions = | { provider: "requesty"; apiKey?: string } | { provider: "unbound"; apiKey?: string } | { provider: "litellm"; apiKey: string; baseUrl: string } + | { provider: "archgw"; apiKey: string; baseUrl: string } diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 905f34a8600..d0db9f8456f 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -27,6 +27,7 @@ import { Bedrock, Chutes, DeepSeek, + ArchGw, Gemini, Glama, Groq, @@ -170,6 +171,7 @@ const ApiOptions = ({ apiConfiguration?.ollamaBaseUrl, apiConfiguration?.lmStudioBaseUrl, apiConfiguration?.litellmBaseUrl, + apiConfiguration?.archgwBaseUrl, apiConfiguration?.litellmApiKey, customHeaders, ], @@ -379,6 +381,14 @@ const ApiOptions = ({ )} + {selectedProvider === "archgw" && ( + + )} + {selectedProvider === "vscode-lm" && ( )} diff --git a/webview-ui/src/components/settings/ModelPicker.tsx b/webview-ui/src/components/settings/ModelPicker.tsx index 906b98e47e9..3cddc6a78d5 100644 --- a/webview-ui/src/components/settings/ModelPicker.tsx +++ b/webview-ui/src/components/settings/ModelPicker.tsx @@ -26,7 +26,13 @@ import { ModelInfoView } from "./ModelInfoView" type ModelIdKey = keyof Pick< ProviderSettings, - "glamaModelId" | "openRouterModelId" | "unboundModelId" | "requestyModelId" | "openAiModelId" | "litellmModelId" + | "glamaModelId" + | "openRouterModelId" + | "unboundModelId" + | "requestyModelId" + | "openAiModelId" + | "litellmModelId" + | "archgwModelId" > interface ModelPickerProps { diff --git a/webview-ui/src/components/settings/constants.ts b/webview-ui/src/components/settings/constants.ts index 5b808643e59..916574f6f74 100644 --- a/webview-ui/src/components/settings/constants.ts +++ b/webview-ui/src/components/settings/constants.ts @@ -47,4 +47,5 @@ export const PROVIDERS = [ { value: "groq", label: "Groq" }, { value: "chutes", label: "Chutes AI" }, { value: "litellm", label: "LiteLLM" }, + { value: "archgw", label: "Arch LLM Gateway" }, ].sort((a, b) => a.label.localeCompare(b.label)) diff --git a/webview-ui/src/components/settings/providers/archgw.tsx b/webview-ui/src/components/settings/providers/archgw.tsx new file mode 100644 index 00000000000..7f4200ad645 --- /dev/null +++ b/webview-ui/src/components/settings/providers/archgw.tsx @@ -0,0 +1,151 @@ +import { useCallback, useState, useEffect, useRef } from "react" +import { VSCodeTextField } from "@vscode/webview-ui-toolkit/react" + +import { archgwDefaultModelId, type OrganizationAllowList, type ProviderSettings } from "@roo-code/types" + +import { useAppTranslation } from "@src/i18n/TranslationContext" +import { VSCodeButtonLink } from "@src/components/common/VSCodeButtonLink" +import { RouterName } from "@roo/api" +import { ExtensionMessage } from "@roo/ExtensionMessage" + +import { inputEventTransform } from "../transforms" +import { useExtensionState } from "@src/context/ExtensionStateContext" +import { Button } from "@src/components/ui" +import { vscode } from "@src/utils/vscode" +import { ModelPicker } from "../ModelPicker" + +type ArchGwProps = { + apiConfiguration: ProviderSettings + setApiConfigurationField: (field: keyof ProviderSettings, value: ProviderSettings[keyof ProviderSettings]) => void + organizationAllowList: OrganizationAllowList +} + +export const ArchGw = ({ apiConfiguration, setApiConfigurationField, organizationAllowList }: ArchGwProps) => { + const { t } = useAppTranslation() + const { routerModels } = useExtensionState() + const [refreshStatus, setRefreshStatus] = useState<"idle" | "loading" | "success" | "error">("idle") + const archGwErrorJustReceived = useRef(false) + const [refreshError, setRefreshError] = useState() + + useEffect(() => { + const handleMessage = (event: MessageEvent) => { + const message = event.data + if (message.type === "singleRouterModelFetchResponse" && !message.success) { + const providerName = message.values?.provider as RouterName + if (providerName === "archgw") { + archGwErrorJustReceived.current = true + setRefreshStatus("error") + setRefreshError(message.error) + } + } else if (message.type === "routerModels") { + // If we were loading and no specific error for litellm was just received, mark as success. + // The ModelPicker will show available models or "no models found". + if (refreshStatus === "loading") { + if (!archGwErrorJustReceived.current) { + setRefreshStatus("success") + } + // If litellmErrorJustReceived.current is true, status is already (or will be) "error". + } + } + } + + window.addEventListener("message", handleMessage) + return () => { + window.removeEventListener("message", handleMessage) + } + }, [refreshStatus, refreshError, setRefreshStatus, setRefreshError]) + + const handleInputChange = useCallback( + ( + field: K, + transform: (event: E) => ProviderSettings[K] = inputEventTransform, + ) => + (event: E | Event) => { + setApiConfigurationField(field, transform(event as E)) + }, + [setApiConfigurationField], + ) + + const handleRefreshModels = useCallback(() => { + archGwErrorJustReceived.current = false // Reset flag on new refresh action + setRefreshStatus("loading") + setRefreshError(undefined) + + const key = apiConfiguration.archgwApiKey + const url = apiConfiguration.archgwBaseUrl + + if (!key || !url) { + setRefreshStatus("error") + setRefreshError(t("settings:providers.refreshModels.missingConfig")) + return + } + + vscode.postMessage({ type: "requestRouterModels", values: { archgwApiKey: key, archgwBaseUrl: url } }) + }, [apiConfiguration, setRefreshStatus, setRefreshError, t]) + + return ( + <> + + + + + + + + + + {refreshStatus === "loading" && ( +
+ {t("settings:providers.refreshModels.loading")} +
+ )} + {refreshStatus === "success" && ( +
{t("settings:providers.refreshModels.success")}
+ )} + {refreshStatus === "error" && ( +
+ {refreshError || t("settings:providers.refreshModels.error")} +
+ )} + +
+ {t("settings:providers.apiKeyStorageNotice")} +
+ + + + ) +} diff --git a/webview-ui/src/components/settings/providers/index.ts b/webview-ui/src/components/settings/providers/index.ts index b244fb515c4..1ac5a7d732b 100644 --- a/webview-ui/src/components/settings/providers/index.ts +++ b/webview-ui/src/components/settings/providers/index.ts @@ -17,3 +17,4 @@ export { Vertex } from "./Vertex" export { VSCodeLM } from "./VSCodeLM" export { XAI } from "./XAI" export { LiteLLM } from "./LiteLLM" +export { ArchGw } from "./archgw" diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index 9f77cbe3707..975b111862d 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -121,6 +121,15 @@ function getSelectedModel({ ? { id, info } : { id: litellmDefaultModelId, info: routerModels.litellm[litellmDefaultModelId] } } + + case "archgw": { + const id = apiConfiguration.archgwModelId ?? "openai/gpt-4.1" + const info = routerModels.archgw[id] + return info + ? { id, info } + : { id: litellmDefaultModelId, info: routerModels.litellm[litellmDefaultModelId] } + } + case "xai": { const id = apiConfiguration.apiModelId ?? xaiDefaultModelId const info = xaiModels[id as keyof typeof xaiModels] diff --git a/webview-ui/src/utils/validate.ts b/webview-ui/src/utils/validate.ts index 5122ca58d41..35bf28c1570 100644 --- a/webview-ui/src/utils/validate.ts +++ b/webview-ui/src/utils/validate.ts @@ -219,6 +219,9 @@ export function validateModelId(apiConfiguration: ProviderSettings, routerModels case "litellm": modelId = apiConfiguration.litellmModelId break + case "archgw": + modelId = apiConfiguration.archgwModelId + break } if (!modelId) { From 369933c64d32fc0ae7aaaa2ccb4ce4ab48a0caf4 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Wed, 18 Jun 2025 16:53:24 -0700 Subject: [PATCH 02/18] add more changes --- packages/types/src/provider-settings.ts | 1 - src/api/providers/archgw.ts | 2 - src/core/webview/webviewMessageHandler.ts | 8 +- .../src/components/settings/ApiOptions.tsx | 11 ++- .../components/settings/providers/archgw.tsx | 85 ++++++++++--------- .../components/settings/providers/index.ts | 2 +- .../components/ui/hooks/useSelectedModel.ts | 5 +- .../src/context/ExtensionStateContext.tsx | 1 + 8 files changed, 61 insertions(+), 54 deletions(-) diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index c30f3d53516..d3bdaee5bc7 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -171,7 +171,6 @@ const deepSeekSchema = apiModelIdProviderModelSchema.extend({ const archgwSchema = apiModelIdProviderModelSchema.extend({ archgwBaseUrl: z.string().optional(), - archgwApiKey: z.string().optional(), archgwModelId: z.string().optional(), }) diff --git a/src/api/providers/archgw.ts b/src/api/providers/archgw.ts index f74102ebefd..40d746f2575 100644 --- a/src/api/providers/archgw.ts +++ b/src/api/providers/archgw.ts @@ -25,8 +25,6 @@ export class ArchGwHandler extends RouterProvider implements SingleCompletionHan options, name: "archgw", baseURL: `${options.archgwBaseUrl || "http://localhost:12000/v1"}`, - // baseURL: "http://localhost:12000/v1", - apiKey: options.archgwApiKey || "dummy-key", modelId: options.archgwModelId, defaultModelId: archgwDefaultModelId, defaultModelInfo: archgwDefaultModelInfo, diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 8a2567c3ea3..067da6b96a0 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -361,17 +361,13 @@ export const webviewMessageHandler = async ( }) } - const archgwApiKey = apiConfiguration.archgwApiKey || message?.values?.archgwApiKey const archgwBaseUrl = apiConfiguration.archgwBaseUrl || message?.values?.archgwBaseUrl || "http://localhost:12000/v1" - console.log( - `[webviewMessageHandler] requestRouterModels - archgwApiKey: ${archgwApiKey} archgwBaseUrl: ${archgwBaseUrl}`, - ) - // const archgwBaseUrl = "http://localhost:12000/v1" + console.log(`[webviewMessageHandler] requestRouterModels - archgwBaseUrl: ${archgwBaseUrl}`) if (archgwBaseUrl) { modelFetchPromises.push({ key: "archgw", - options: { provider: "archgw", apiKey: archgwApiKey, baseUrl: archgwBaseUrl }, + options: { provider: "archgw", baseUrl: archgwBaseUrl }, }) } diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index d0db9f8456f..6eb47bd8436 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -11,6 +11,7 @@ import { glamaDefaultModelId, unboundDefaultModelId, litellmDefaultModelId, + archgwDefaultModelId, } from "@roo-code/types" import { vscode } from "@src/utils/vscode" @@ -160,6 +161,8 @@ const ApiOptions = ({ vscode.postMessage({ type: "requestVsCodeLmModels" }) } else if (selectedProvider === "litellm") { vscode.postMessage({ type: "requestRouterModels" }) + } else if (selectedProvider === "archgw") { + vscode.postMessage({ type: "requestRouterModels" }) } }, 250, @@ -171,8 +174,8 @@ const ApiOptions = ({ apiConfiguration?.ollamaBaseUrl, apiConfiguration?.lmStudioBaseUrl, apiConfiguration?.litellmBaseUrl, - apiConfiguration?.archgwBaseUrl, apiConfiguration?.litellmApiKey, + apiConfiguration?.archgwBaseUrl, customHeaders, ], ) @@ -232,6 +235,11 @@ const ApiOptions = ({ setApiConfigurationField("litellmModelId", litellmDefaultModelId) } break + case "archgw": + if (!apiConfiguration.archgwModelId) { + setApiConfigurationField("archgwModelId", archgwDefaultModelId) + } + break } setApiConfigurationField("apiProvider", value) @@ -243,6 +251,7 @@ const ApiOptions = ({ apiConfiguration.unboundModelId, apiConfiguration.requestyModelId, apiConfiguration.litellmModelId, + apiConfiguration.archgwModelId, ], ) diff --git a/webview-ui/src/components/settings/providers/archgw.tsx b/webview-ui/src/components/settings/providers/archgw.tsx index 7f4200ad645..b082279bd5c 100644 --- a/webview-ui/src/components/settings/providers/archgw.tsx +++ b/webview-ui/src/components/settings/providers/archgw.tsx @@ -1,17 +1,23 @@ import { useCallback, useState, useEffect, useRef } from "react" import { VSCodeTextField } from "@vscode/webview-ui-toolkit/react" +import { Checkbox } from "vscrui" -import { archgwDefaultModelId, type OrganizationAllowList, type ProviderSettings } from "@roo-code/types" +import { + type ProviderSettings, + type OrganizationAllowList, + litellmDefaultModelId, + archgwDefaultModelId, +} from "@roo-code/types" -import { useAppTranslation } from "@src/i18n/TranslationContext" -import { VSCodeButtonLink } from "@src/components/common/VSCodeButtonLink" import { RouterName } from "@roo/api" import { ExtensionMessage } from "@roo/ExtensionMessage" -import { inputEventTransform } from "../transforms" +import { vscode } from "@src/utils/vscode" import { useExtensionState } from "@src/context/ExtensionStateContext" +import { useAppTranslation } from "@src/i18n/TranslationContext" import { Button } from "@src/components/ui" -import { vscode } from "@src/utils/vscode" + +import { inputEventTransform } from "../transforms" import { ModelPicker } from "../ModelPicker" type ArchGwProps = { @@ -24,8 +30,10 @@ export const ArchGw = ({ apiConfiguration, setApiConfigurationField, organizatio const { t } = useAppTranslation() const { routerModels } = useExtensionState() const [refreshStatus, setRefreshStatus] = useState<"idle" | "loading" | "success" | "error">("idle") - const archGwErrorJustReceived = useRef(false) const [refreshError, setRefreshError] = useState() + const archgwErrorJustReceived = useRef(false) + + const [archgwBaseUrlSelected, setArchgwBaseUrlSelected] = useState(!!apiConfiguration?.archgwBaseUrl) useEffect(() => { const handleMessage = (event: MessageEvent) => { @@ -33,18 +41,18 @@ export const ArchGw = ({ apiConfiguration, setApiConfigurationField, organizatio if (message.type === "singleRouterModelFetchResponse" && !message.success) { const providerName = message.values?.provider as RouterName if (providerName === "archgw") { - archGwErrorJustReceived.current = true + archgwErrorJustReceived.current = true setRefreshStatus("error") setRefreshError(message.error) } } else if (message.type === "routerModels") { - // If we were loading and no specific error for litellm was just received, mark as success. + // If we were loading and no specific error for archgw was just received, mark as success. // The ModelPicker will show available models or "no models found". if (refreshStatus === "loading") { - if (!archGwErrorJustReceived.current) { + if (!archgwErrorJustReceived.current) { setRefreshStatus("success") } - // If litellmErrorJustReceived.current is true, status is already (or will be) "error". + // If archgwErrorJustReceived.current is true, status is already (or will be) "error". } } } @@ -67,47 +75,49 @@ export const ArchGw = ({ apiConfiguration, setApiConfigurationField, organizatio ) const handleRefreshModels = useCallback(() => { - archGwErrorJustReceived.current = false // Reset flag on new refresh action + archgwErrorJustReceived.current = false // Reset flag on new refresh action setRefreshStatus("loading") setRefreshError(undefined) - const key = apiConfiguration.archgwApiKey const url = apiConfiguration.archgwBaseUrl - if (!key || !url) { + if (!url) { setRefreshStatus("error") setRefreshError(t("settings:providers.refreshModels.missingConfig")) return } - - vscode.postMessage({ type: "requestRouterModels", values: { archgwApiKey: key, archgwBaseUrl: url } }) + vscode.postMessage({ type: "requestRouterModels", values: { archgwBaseUrl: url } }) }, [apiConfiguration, setRefreshStatus, setRefreshError, t]) return ( <> - - - - - - - + { + setArchgwBaseUrlSelected(checked) + + if (!checked) { + setApiConfigurationField("archgwBaseUrl", "") + } + }}> + {t("settings:providers.useCustomBaseUrl")} + + {archgwBaseUrlSelected && ( + <> + + + )}