diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index be74ae6bb4c..4fa759ef4a9 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -218,6 +218,7 @@ const litellmSchema = baseProviderSettingsSchema.extend({ litellmBaseUrl: z.string().optional(), litellmApiKey: z.string().optional(), litellmModelId: z.string().optional(), + litellmUsePromptCache: z.boolean().optional(), }) const defaultSchema = z.object({ diff --git a/src/api/providers/__tests__/lite-llm.spec.ts b/src/api/providers/__tests__/lite-llm.spec.ts new file mode 100644 index 00000000000..26ebbc35258 --- /dev/null +++ b/src/api/providers/__tests__/lite-llm.spec.ts @@ -0,0 +1,158 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import OpenAI from "openai" +import { Anthropic } from "@anthropic-ai/sdk" + +import { LiteLLMHandler } from "../lite-llm" +import { ApiHandlerOptions } from "../../../shared/api" +import { litellmDefaultModelId, litellmDefaultModelInfo } from "@roo-code/types" + +// Mock vscode first to avoid import errors +vi.mock("vscode", () => ({})) + +// Mock OpenAI +vi.mock("openai", () => { + const mockStream = { + [Symbol.asyncIterator]: vi.fn(), + } + + const mockCreate = vi.fn().mockReturnValue({ + withResponse: vi.fn().mockResolvedValue({ data: mockStream }), + }) + + return { + default: vi.fn().mockImplementation(() => ({ + chat: { + completions: { + create: mockCreate, + }, + }, + })), + } +}) + +// Mock model fetching +vi.mock("../fetchers/modelCache", () => ({ + getModels: vi.fn().mockImplementation(() => { + return Promise.resolve({ + [litellmDefaultModelId]: litellmDefaultModelInfo, + }) + }), +})) + +describe("LiteLLMHandler", () => { + let handler: LiteLLMHandler + let mockOptions: ApiHandlerOptions + let mockOpenAIClient: any + + beforeEach(() => { + vi.clearAllMocks() + mockOptions = { + litellmApiKey: "test-key", + litellmBaseUrl: "http://localhost:4000", + litellmModelId: litellmDefaultModelId, + } + handler = new LiteLLMHandler(mockOptions) + mockOpenAIClient = new OpenAI() + }) + + describe("prompt caching", () => { + it("should add cache control headers when litellmUsePromptCache is enabled", async () => { + const optionsWithCache: ApiHandlerOptions = { + ...mockOptions, + litellmUsePromptCache: true, + } + handler = new LiteLLMHandler(optionsWithCache) + + const systemPrompt = "You are a helpful assistant" + const messages: Anthropic.Messages.MessageParam[] = [ + { role: "user", content: "Hello" }, + { role: "assistant", content: "Hi there!" }, + { role: "user", content: "How are you?" }, + ] + + // Mock the stream response + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { + choices: [{ delta: { content: "I'm doing well!" } }], + usage: { + prompt_tokens: 100, + completion_tokens: 50, + cache_creation_input_tokens: 20, + cache_read_input_tokens: 30, + }, + } + }, + } + + mockOpenAIClient.chat.completions.create.mockReturnValue({ + withResponse: vi.fn().mockResolvedValue({ data: mockStream }), + }) + + const generator = handler.createMessage(systemPrompt, messages) + const results = [] + for await (const chunk of generator) { + results.push(chunk) + } + + // Verify that create was called with cache control headers + const createCall = mockOpenAIClient.chat.completions.create.mock.calls[0][0] + + // Check system message has cache control in the proper format + expect(createCall.messages[0]).toMatchObject({ + role: "system", + content: [ + { + type: "text", + text: systemPrompt, + cache_control: { type: "ephemeral" }, + }, + ], + }) + + // Check that the last two user messages have cache control + const userMessageIndices = createCall.messages + .map((msg: any, idx: number) => (msg.role === "user" ? idx : -1)) + .filter((idx: number) => idx !== -1) + + const lastUserIdx = userMessageIndices[userMessageIndices.length - 1] + const secondLastUserIdx = userMessageIndices[userMessageIndices.length - 2] + + // Check last user message has proper structure with cache control + expect(createCall.messages[lastUserIdx]).toMatchObject({ + role: "user", + content: [ + { + type: "text", + text: "How are you?", + cache_control: { type: "ephemeral" }, + }, + ], + }) + + // Check second last user message (first user message in this case) + if (secondLastUserIdx !== -1) { + expect(createCall.messages[secondLastUserIdx]).toMatchObject({ + role: "user", + content: [ + { + type: "text", + text: "Hello", + cache_control: { type: "ephemeral" }, + }, + ], + }) + } + + // Verify usage includes cache tokens + const usageChunk = results.find((chunk) => chunk.type === "usage") + expect(usageChunk).toMatchObject({ + type: "usage", + inputTokens: 100, + outputTokens: 50, + cacheWriteTokens: 20, + cacheReadTokens: 30, + }) + }) + }) +}) diff --git a/src/api/providers/lite-llm.ts b/src/api/providers/lite-llm.ts index e8cd58b12c7..7cea7411feb 100644 --- a/src/api/providers/lite-llm.ts +++ b/src/api/providers/lite-llm.ts @@ -39,10 +39,70 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa ): ApiStream { const { id: modelId, info } = await this.fetchModel() - const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ - { role: "system", content: systemPrompt }, - ...convertToOpenAiMessages(messages), - ] + const openAiMessages = convertToOpenAiMessages(messages) + + // Prepare messages with cache control if enabled and supported + let systemMessage: OpenAI.Chat.ChatCompletionMessageParam + let enhancedMessages: OpenAI.Chat.ChatCompletionMessageParam[] + + if (this.options.litellmUsePromptCache && info.supportsPromptCache) { + // Create system message with cache control in the proper format + systemMessage = { + role: "system", + content: [ + { + type: "text", + text: systemPrompt, + cache_control: { type: "ephemeral" }, + } as any, + ], + } + + // Find the last two user messages to apply caching + const userMsgIndices = openAiMessages.reduce( + (acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc), + [] as number[], + ) + const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1 + const secondLastUserMsgIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1 + + // Apply cache_control to the last two user messages + enhancedMessages = openAiMessages.map((message, index) => { + if ((index === lastUserMsgIndex || index === secondLastUserMsgIndex) && message.role === "user") { + // Handle both string and array content types + if (typeof message.content === "string") { + return { + ...message, + content: [ + { + type: "text", + text: message.content, + cache_control: { type: "ephemeral" }, + } as any, + ], + } + } else if (Array.isArray(message.content)) { + // Apply cache control to the last content item in the array + return { + ...message, + content: message.content.map((content, contentIndex) => + contentIndex === message.content.length - 1 + ? ({ + ...content, + cache_control: { type: "ephemeral" }, + } as any) + : content, + ), + } + } + } + return message + }) + } else { + // No cache control - use simple format + systemMessage = { role: "system", content: systemPrompt } + enhancedMessages = openAiMessages + } // Required by some providers; others default to max tokens allowed let maxTokens: number | undefined = info.maxTokens ?? undefined @@ -50,7 +110,7 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { model: modelId, max_tokens: maxTokens, - messages: openAiMessages, + messages: [systemMessage, ...enhancedMessages], stream: true, stream_options: { include_usage: true, @@ -80,20 +140,30 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa } if (lastUsage) { + // Extract cache-related information if available + // LiteLLM may use different field names for cache tokens + const cacheWriteTokens = + lastUsage.cache_creation_input_tokens || (lastUsage as any).prompt_cache_miss_tokens || 0 + const cacheReadTokens = + lastUsage.prompt_tokens_details?.cached_tokens || + (lastUsage as any).cache_read_input_tokens || + (lastUsage as any).prompt_cache_hit_tokens || + 0 + const usageData: ApiStreamUsageChunk = { type: "usage", inputTokens: lastUsage.prompt_tokens || 0, outputTokens: lastUsage.completion_tokens || 0, - cacheWriteTokens: lastUsage.cache_creation_input_tokens || 0, - cacheReadTokens: lastUsage.prompt_tokens_details?.cached_tokens || 0, + cacheWriteTokens: cacheWriteTokens > 0 ? cacheWriteTokens : undefined, + cacheReadTokens: cacheReadTokens > 0 ? cacheReadTokens : undefined, } usageData.totalCost = calculateApiCostOpenAI( info, usageData.inputTokens, usageData.outputTokens, - usageData.cacheWriteTokens, - usageData.cacheReadTokens, + usageData.cacheWriteTokens || 0, + usageData.cacheReadTokens || 0, ) yield usageData diff --git a/webview-ui/src/components/settings/providers/LiteLLM.tsx b/webview-ui/src/components/settings/providers/LiteLLM.tsx index a2467b3c0b1..caf7a173feb 100644 --- a/webview-ui/src/components/settings/providers/LiteLLM.tsx +++ b/webview-ui/src/components/settings/providers/LiteLLM.tsx @@ -1,5 +1,5 @@ import { useCallback, useState, useEffect, useRef } from "react" -import { VSCodeTextField } from "@vscode/webview-ui-toolkit/react" +import { VSCodeTextField, VSCodeCheckbox } from "@vscode/webview-ui-toolkit/react" import { type ProviderSettings, type OrganizationAllowList, litellmDefaultModelId } from "@roo-code/types" @@ -151,6 +151,29 @@ export const LiteLLM = ({ organizationAllowList={organizationAllowList} errorMessage={modelValidationError} /> + + {/* Show prompt caching option if the selected model supports it */} + {(() => { + const selectedModelId = apiConfiguration.litellmModelId || litellmDefaultModelId + const selectedModel = routerModels?.litellm?.[selectedModelId] + if (selectedModel?.supportsPromptCache) { + return ( +