From d460f43c7d870ba54bdeda360c64321f1d09f511 Mon Sep 17 00:00:00 2001 From: MuriloFP Date: Tue, 22 Jul 2025 15:45:14 -0300 Subject: [PATCH 1/3] feat: add prompt caching support for LiteLLM (#5791) - Add litellmUsePromptCache configuration option to provider settings - Implement cache control headers in LiteLLM handler when enabled - Add UI checkbox for enabling prompt caching (only shown for supported models) - Track cache read/write tokens in usage data - Add comprehensive test for prompt caching functionality - Reuse existing translation keys for consistency across languages This allows LiteLLM users to benefit from prompt caching with supported models like Claude 3.7, reducing costs and improving response times. --- packages/types/src/provider-settings.ts | 1 + src/api/providers/__tests__/lite-llm.spec.ts | 130 ++++++++++++++++++ src/api/providers/lite-llm.ts | 51 ++++++- .../components/settings/providers/LiteLLM.tsx | 25 +++- 4 files changed, 201 insertions(+), 6 deletions(-) create mode 100644 src/api/providers/__tests__/lite-llm.spec.ts diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index be74ae6bb4c..4fa759ef4a9 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -218,6 +218,7 @@ const litellmSchema = baseProviderSettingsSchema.extend({ litellmBaseUrl: z.string().optional(), litellmApiKey: z.string().optional(), litellmModelId: z.string().optional(), + litellmUsePromptCache: z.boolean().optional(), }) const defaultSchema = z.object({ diff --git a/src/api/providers/__tests__/lite-llm.spec.ts b/src/api/providers/__tests__/lite-llm.spec.ts new file mode 100644 index 00000000000..661b17cc3e7 --- /dev/null +++ b/src/api/providers/__tests__/lite-llm.spec.ts @@ -0,0 +1,130 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import OpenAI from "openai" +import { Anthropic } from "@anthropic-ai/sdk" + +import { LiteLLMHandler } from "../lite-llm" +import { ApiHandlerOptions } from "../../../shared/api" +import { litellmDefaultModelId, litellmDefaultModelInfo } from "@roo-code/types" + +// Mock vscode first to avoid import errors +vi.mock("vscode", () => ({})) + +// Mock OpenAI +vi.mock("openai", () => { + const mockStream = { + [Symbol.asyncIterator]: vi.fn(), + } + + const mockCreate = vi.fn().mockReturnValue({ + withResponse: vi.fn().mockResolvedValue({ data: mockStream }), + }) + + return { + default: vi.fn().mockImplementation(() => ({ + chat: { + completions: { + create: mockCreate, + }, + }, + })), + } +}) + +// Mock model fetching +vi.mock("../fetchers/modelCache", () => ({ + getModels: vi.fn().mockImplementation(() => { + return Promise.resolve({ + [litellmDefaultModelId]: litellmDefaultModelInfo, + }) + }), +})) + +describe("LiteLLMHandler", () => { + let handler: LiteLLMHandler + let mockOptions: ApiHandlerOptions + let mockOpenAIClient: any + + beforeEach(() => { + vi.clearAllMocks() + mockOptions = { + litellmApiKey: "test-key", + litellmBaseUrl: "http://localhost:4000", + litellmModelId: litellmDefaultModelId, + } + handler = new LiteLLMHandler(mockOptions) + mockOpenAIClient = new OpenAI() + }) + + describe("prompt caching", () => { + it("should add cache control headers when litellmUsePromptCache is enabled", async () => { + const optionsWithCache: ApiHandlerOptions = { + ...mockOptions, + litellmUsePromptCache: true, + } + handler = new LiteLLMHandler(optionsWithCache) + + const systemPrompt = "You are a helpful assistant" + const messages: Anthropic.Messages.MessageParam[] = [ + { role: "user", content: "Hello" }, + { role: "assistant", content: "Hi there!" }, + { role: "user", content: "How are you?" }, + ] + + // Mock the stream response + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { + choices: [{ delta: { content: "I'm doing well!" } }], + usage: { + prompt_tokens: 100, + completion_tokens: 50, + cache_creation_input_tokens: 20, + cache_read_input_tokens: 30, + }, + } + }, + } + + mockOpenAIClient.chat.completions.create.mockReturnValue({ + withResponse: vi.fn().mockResolvedValue({ data: mockStream }), + }) + + const generator = handler.createMessage(systemPrompt, messages) + const results = [] + for await (const chunk of generator) { + results.push(chunk) + } + + // Verify that create was called with cache control headers + const createCall = mockOpenAIClient.chat.completions.create.mock.calls[0][0] + + // Check system message has cache control + expect(createCall.messages[0]).toMatchObject({ + role: "system", + content: systemPrompt, + cache_control: { type: "ephemeral" }, + }) + + // Check that the last two user messages have cache control + const userMessageIndices = createCall.messages + .map((msg: any, idx: number) => (msg.role === "user" ? idx : -1)) + .filter((idx: number) => idx !== -1) + + const lastUserIdx = userMessageIndices[userMessageIndices.length - 1] + + expect(createCall.messages[lastUserIdx]).toMatchObject({ + cache_control: { type: "ephemeral" }, + }) + + // Verify usage includes cache tokens + const usageChunk = results.find((chunk) => chunk.type === "usage") + expect(usageChunk).toMatchObject({ + type: "usage", + inputTokens: 100, + outputTokens: 50, + cacheWriteTokens: 20, + cacheReadTokens: 30, + }) + }) + }) +}) diff --git a/src/api/providers/lite-llm.ts b/src/api/providers/lite-llm.ts index e8cd58b12c7..97bf4fdbeb9 100644 --- a/src/api/providers/lite-llm.ts +++ b/src/api/providers/lite-llm.ts @@ -44,13 +44,44 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa ...convertToOpenAiMessages(messages), ] + // Apply cache control if prompt caching is enabled and supported + let enhancedMessages = openAiMessages + if (this.options.litellmUsePromptCache && info.supportsPromptCache) { + const cacheControl = { cache_control: { type: "ephemeral" } } + + // Add cache control to system message + enhancedMessages[0] = { + ...enhancedMessages[0], + ...cacheControl, + } + + // Find the last two user messages to apply caching + const userMsgIndices = enhancedMessages.reduce( + (acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc), + [] as number[], + ) + const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1 + const secondLastUserMsgIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1 + + // Apply cache_control to the last two user messages + enhancedMessages = enhancedMessages.map((message, index) => { + if (index === lastUserMsgIndex || index === secondLastUserMsgIndex) { + return { + ...message, + ...cacheControl, + } + } + return message + }) + } + // Required by some providers; others default to max tokens allowed let maxTokens: number | undefined = info.maxTokens ?? undefined const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { model: modelId, max_tokens: maxTokens, - messages: openAiMessages, + messages: enhancedMessages, stream: true, stream_options: { include_usage: true, @@ -80,20 +111,30 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa } if (lastUsage) { + // Extract cache-related information if available + // LiteLLM may use different field names for cache tokens + const cacheWriteTokens = + lastUsage.cache_creation_input_tokens || (lastUsage as any).prompt_cache_miss_tokens || 0 + const cacheReadTokens = + lastUsage.prompt_tokens_details?.cached_tokens || + (lastUsage as any).cache_read_input_tokens || + (lastUsage as any).prompt_cache_hit_tokens || + 0 + const usageData: ApiStreamUsageChunk = { type: "usage", inputTokens: lastUsage.prompt_tokens || 0, outputTokens: lastUsage.completion_tokens || 0, - cacheWriteTokens: lastUsage.cache_creation_input_tokens || 0, - cacheReadTokens: lastUsage.prompt_tokens_details?.cached_tokens || 0, + cacheWriteTokens: cacheWriteTokens > 0 ? cacheWriteTokens : undefined, + cacheReadTokens: cacheReadTokens > 0 ? cacheReadTokens : undefined, } usageData.totalCost = calculateApiCostOpenAI( info, usageData.inputTokens, usageData.outputTokens, - usageData.cacheWriteTokens, - usageData.cacheReadTokens, + usageData.cacheWriteTokens || 0, + usageData.cacheReadTokens || 0, ) yield usageData diff --git a/webview-ui/src/components/settings/providers/LiteLLM.tsx b/webview-ui/src/components/settings/providers/LiteLLM.tsx index a2467b3c0b1..caf7a173feb 100644 --- a/webview-ui/src/components/settings/providers/LiteLLM.tsx +++ b/webview-ui/src/components/settings/providers/LiteLLM.tsx @@ -1,5 +1,5 @@ import { useCallback, useState, useEffect, useRef } from "react" -import { VSCodeTextField } from "@vscode/webview-ui-toolkit/react" +import { VSCodeTextField, VSCodeCheckbox } from "@vscode/webview-ui-toolkit/react" import { type ProviderSettings, type OrganizationAllowList, litellmDefaultModelId } from "@roo-code/types" @@ -151,6 +151,29 @@ export const LiteLLM = ({ organizationAllowList={organizationAllowList} errorMessage={modelValidationError} /> + + {/* Show prompt caching option if the selected model supports it */} + {(() => { + const selectedModelId = apiConfiguration.litellmModelId || litellmDefaultModelId + const selectedModel = routerModels?.litellm?.[selectedModelId] + if (selectedModel?.supportsPromptCache) { + return ( +
+ { + setApiConfigurationField("litellmUsePromptCache", e.target.checked) + }}> + {t("settings:providers.enablePromptCaching")} + +
+ {t("settings:providers.enablePromptCachingTitle")} +
+
+ ) + } + return null + })()} ) } From 934504105b34b56ac5f4dda270b19d4a42a4e9f7 Mon Sep 17 00:00:00 2001 From: MuriloFP Date: Wed, 23 Jul 2025 15:51:14 -0300 Subject: [PATCH 2/3] fix: improve LiteLLM prompt caching to work for multi-turn conversations - Convert system message to structured format with cache_control - Handle both string and array content types for user messages - Apply cache_control to content items, not just message level - Update tests to match new message structure This ensures prompt caching works correctly for all messages in a conversation, not just the initial system prompt and first user message. --- src/api/providers/__tests__/lite-llm.spec.ts | 36 ++++++++-- src/api/providers/lite-llm.ts | 69 ++++++++++++++------ 2 files changed, 81 insertions(+), 24 deletions(-) diff --git a/src/api/providers/__tests__/lite-llm.spec.ts b/src/api/providers/__tests__/lite-llm.spec.ts index 661b17cc3e7..26ebbc35258 100644 --- a/src/api/providers/__tests__/lite-llm.spec.ts +++ b/src/api/providers/__tests__/lite-llm.spec.ts @@ -98,11 +98,16 @@ describe("LiteLLMHandler", () => { // Verify that create was called with cache control headers const createCall = mockOpenAIClient.chat.completions.create.mock.calls[0][0] - // Check system message has cache control + // Check system message has cache control in the proper format expect(createCall.messages[0]).toMatchObject({ role: "system", - content: systemPrompt, - cache_control: { type: "ephemeral" }, + content: [ + { + type: "text", + text: systemPrompt, + cache_control: { type: "ephemeral" }, + }, + ], }) // Check that the last two user messages have cache control @@ -111,11 +116,34 @@ describe("LiteLLMHandler", () => { .filter((idx: number) => idx !== -1) const lastUserIdx = userMessageIndices[userMessageIndices.length - 1] + const secondLastUserIdx = userMessageIndices[userMessageIndices.length - 2] + // Check last user message has proper structure with cache control expect(createCall.messages[lastUserIdx]).toMatchObject({ - cache_control: { type: "ephemeral" }, + role: "user", + content: [ + { + type: "text", + text: "How are you?", + cache_control: { type: "ephemeral" }, + }, + ], }) + // Check second last user message (first user message in this case) + if (secondLastUserIdx !== -1) { + expect(createCall.messages[secondLastUserIdx]).toMatchObject({ + role: "user", + content: [ + { + type: "text", + text: "Hello", + cache_control: { type: "ephemeral" }, + }, + ], + }) + } + // Verify usage includes cache tokens const usageChunk = results.find((chunk) => chunk.type === "usage") expect(usageChunk).toMatchObject({ diff --git a/src/api/providers/lite-llm.ts b/src/api/providers/lite-llm.ts index 97bf4fdbeb9..f93f20c5920 100644 --- a/src/api/providers/lite-llm.ts +++ b/src/api/providers/lite-llm.ts @@ -39,24 +39,27 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa ): ApiStream { const { id: modelId, info } = await this.fetchModel() - const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [ - { role: "system", content: systemPrompt }, - ...convertToOpenAiMessages(messages), - ] + const openAiMessages = convertToOpenAiMessages(messages) - // Apply cache control if prompt caching is enabled and supported - let enhancedMessages = openAiMessages - if (this.options.litellmUsePromptCache && info.supportsPromptCache) { - const cacheControl = { cache_control: { type: "ephemeral" } } + // Prepare messages with cache control if enabled and supported + let systemMessage: OpenAI.Chat.ChatCompletionMessageParam + let enhancedMessages: OpenAI.Chat.ChatCompletionMessageParam[] - // Add cache control to system message - enhancedMessages[0] = { - ...enhancedMessages[0], - ...cacheControl, - } + if (this.options.litellmUsePromptCache && info.supportsPromptCache) { + // Create system message with cache control in the proper format + systemMessage = { + role: "system", + content: [ + { + type: "text", + text: systemPrompt, + cache_control: { type: "ephemeral" }, + }, + ], + } as OpenAI.Chat.ChatCompletionSystemMessageParam // Find the last two user messages to apply caching - const userMsgIndices = enhancedMessages.reduce( + const userMsgIndices = openAiMessages.reduce( (acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc), [] as number[], ) @@ -64,15 +67,41 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa const secondLastUserMsgIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1 // Apply cache_control to the last two user messages - enhancedMessages = enhancedMessages.map((message, index) => { - if (index === lastUserMsgIndex || index === secondLastUserMsgIndex) { - return { - ...message, - ...cacheControl, + enhancedMessages = openAiMessages.map((message, index) => { + if ((index === lastUserMsgIndex || index === secondLastUserMsgIndex) && message.role === "user") { + // Handle both string and array content types + if (typeof message.content === "string") { + return { + ...message, + content: [ + { + type: "text", + text: message.content, + cache_control: { type: "ephemeral" }, + }, + ], + } + } else if (Array.isArray(message.content)) { + // Apply cache control to the last content item in the array + return { + ...message, + content: message.content.map((content, contentIndex) => + contentIndex === message.content.length - 1 + ? { + ...content, + cache_control: { type: "ephemeral" }, + } + : content, + ), + } } } return message }) + } else { + // No cache control - use simple format + systemMessage = { role: "system", content: systemPrompt } + enhancedMessages = openAiMessages } // Required by some providers; others default to max tokens allowed @@ -81,7 +110,7 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { model: modelId, max_tokens: maxTokens, - messages: enhancedMessages, + messages: [systemMessage, ...enhancedMessages], stream: true, stream_options: { include_usage: true, From f5f91e6ac601a5fdb10b244ee89191afff8f13f8 Mon Sep 17 00:00:00 2001 From: MuriloFP Date: Wed, 23 Jul 2025 15:53:01 -0300 Subject: [PATCH 3/3] fix: resolve TypeScript linter error for cache_control property Use type assertion to handle cache_control property that's not in OpenAI types --- src/api/providers/lite-llm.ts | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/api/providers/lite-llm.ts b/src/api/providers/lite-llm.ts index f93f20c5920..7cea7411feb 100644 --- a/src/api/providers/lite-llm.ts +++ b/src/api/providers/lite-llm.ts @@ -54,9 +54,9 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa type: "text", text: systemPrompt, cache_control: { type: "ephemeral" }, - }, + } as any, ], - } as OpenAI.Chat.ChatCompletionSystemMessageParam + } // Find the last two user messages to apply caching const userMsgIndices = openAiMessages.reduce( @@ -78,7 +78,7 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa type: "text", text: message.content, cache_control: { type: "ephemeral" }, - }, + } as any, ], } } else if (Array.isArray(message.content)) { @@ -87,10 +87,10 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa ...message, content: message.content.map((content, contentIndex) => contentIndex === message.content.length - 1 - ? { + ? ({ ...content, cache_control: { type: "ephemeral" }, - } + } as any) : content, ), }