Skip to content

feat: add prompt caching support for LiteLLM (#5791) #6074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ const litellmSchema = baseProviderSettingsSchema.extend({
litellmBaseUrl: z.string().optional(),
litellmApiKey: z.string().optional(),
litellmModelId: z.string().optional(),
litellmUsePromptCache: z.boolean().optional(),
})

const defaultSchema = z.object({
Expand Down
158 changes: 158 additions & 0 deletions src/api/providers/__tests__/lite-llm.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import { describe, it, expect, vi, beforeEach } from "vitest"
import OpenAI from "openai"
import { Anthropic } from "@anthropic-ai/sdk"

import { LiteLLMHandler } from "../lite-llm"
import { ApiHandlerOptions } from "../../../shared/api"
import { litellmDefaultModelId, litellmDefaultModelInfo } from "@roo-code/types"

// Mock vscode first to avoid import errors
vi.mock("vscode", () => ({}))

// Mock OpenAI
vi.mock("openai", () => {
const mockStream = {
[Symbol.asyncIterator]: vi.fn(),
}

const mockCreate = vi.fn().mockReturnValue({
withResponse: vi.fn().mockResolvedValue({ data: mockStream }),
})

return {
default: vi.fn().mockImplementation(() => ({
chat: {
completions: {
create: mockCreate,
},
},
})),
}
})

// Mock model fetching
vi.mock("../fetchers/modelCache", () => ({
getModels: vi.fn().mockImplementation(() => {
return Promise.resolve({
[litellmDefaultModelId]: litellmDefaultModelInfo,
})
}),
}))

describe("LiteLLMHandler", () => {
let handler: LiteLLMHandler
let mockOptions: ApiHandlerOptions
let mockOpenAIClient: any

beforeEach(() => {
vi.clearAllMocks()
mockOptions = {
litellmApiKey: "test-key",
litellmBaseUrl: "http://localhost:4000",
litellmModelId: litellmDefaultModelId,
}
handler = new LiteLLMHandler(mockOptions)
mockOpenAIClient = new OpenAI()
})

describe("prompt caching", () => {
it("should add cache control headers when litellmUsePromptCache is enabled", async () => {
const optionsWithCache: ApiHandlerOptions = {
...mockOptions,
litellmUsePromptCache: true,
}
handler = new LiteLLMHandler(optionsWithCache)

const systemPrompt = "You are a helpful assistant"
const messages: Anthropic.Messages.MessageParam[] = [
{ role: "user", content: "Hello" },
{ role: "assistant", content: "Hi there!" },
{ role: "user", content: "How are you?" },
]

// Mock the stream response
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
choices: [{ delta: { content: "I'm doing well!" } }],
usage: {
prompt_tokens: 100,
completion_tokens: 50,
cache_creation_input_tokens: 20,
cache_read_input_tokens: 30,
},
}
},
}

mockOpenAIClient.chat.completions.create.mockReturnValue({
withResponse: vi.fn().mockResolvedValue({ data: mockStream }),
})

const generator = handler.createMessage(systemPrompt, messages)
const results = []
for await (const chunk of generator) {
results.push(chunk)
}

// Verify that create was called with cache control headers
const createCall = mockOpenAIClient.chat.completions.create.mock.calls[0][0]

// Check system message has cache control in the proper format
expect(createCall.messages[0]).toMatchObject({
role: "system",
content: [
{
type: "text",
text: systemPrompt,
cache_control: { type: "ephemeral" },
},
],
})

// Check that the last two user messages have cache control
const userMessageIndices = createCall.messages
.map((msg: any, idx: number) => (msg.role === "user" ? idx : -1))
.filter((idx: number) => idx !== -1)

const lastUserIdx = userMessageIndices[userMessageIndices.length - 1]
const secondLastUserIdx = userMessageIndices[userMessageIndices.length - 2]

// Check last user message has proper structure with cache control
expect(createCall.messages[lastUserIdx]).toMatchObject({
role: "user",
content: [
{
type: "text",
text: "How are you?",
cache_control: { type: "ephemeral" },
},
],
})
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding an assertion for the second last user message as well, to fully verify that cache control headers are applied to both the last two user messages.


// Check second last user message (first user message in this case)
if (secondLastUserIdx !== -1) {
expect(createCall.messages[secondLastUserIdx]).toMatchObject({
role: "user",
content: [
{
type: "text",
text: "Hello",
cache_control: { type: "ephemeral" },
},
],
})
}

// Verify usage includes cache tokens
const usageChunk = results.find((chunk) => chunk.type === "usage")
expect(usageChunk).toMatchObject({
type: "usage",
inputTokens: 100,
outputTokens: 50,
cacheWriteTokens: 20,
cacheReadTokens: 30,
})
})
})
})
88 changes: 79 additions & 9 deletions src/api/providers/lite-llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,78 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa
): ApiStream {
const { id: modelId, info } = await this.fetchModel()

const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt },
...convertToOpenAiMessages(messages),
]
const openAiMessages = convertToOpenAiMessages(messages)

// Prepare messages with cache control if enabled and supported
let systemMessage: OpenAI.Chat.ChatCompletionMessageParam
let enhancedMessages: OpenAI.Chat.ChatCompletionMessageParam[]

if (this.options.litellmUsePromptCache && info.supportsPromptCache) {
// Create system message with cache control in the proper format
systemMessage = {
role: "system",
content: [
{
type: "text",
text: systemPrompt,
cache_control: { type: "ephemeral" },
} as any,
],
}

// Find the last two user messages to apply caching
const userMsgIndices = openAiMessages.reduce(
(acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc),
[] as number[],
)
const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1
const secondLastUserMsgIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1

// Apply cache_control to the last two user messages
enhancedMessages = openAiMessages.map((message, index) => {
if ((index === lastUserMsgIndex || index === secondLastUserMsgIndex) && message.role === "user") {
// Handle both string and array content types
if (typeof message.content === "string") {
return {
...message,
content: [
{
type: "text",
text: message.content,
cache_control: { type: "ephemeral" },
} as any,
],
}
} else if (Array.isArray(message.content)) {
// Apply cache control to the last content item in the array
return {
...message,
content: message.content.map((content, contentIndex) =>
contentIndex === message.content.length - 1
? ({
...content,
cache_control: { type: "ephemeral" },
} as any)
: content,
),
}
}
}
return message
})
} else {
// No cache control - use simple format
systemMessage = { role: "system", content: systemPrompt }
enhancedMessages = openAiMessages
}

// Required by some providers; others default to max tokens allowed
let maxTokens: number | undefined = info.maxTokens ?? undefined

const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model: modelId,
max_tokens: maxTokens,
messages: openAiMessages,
messages: [systemMessage, ...enhancedMessages],
stream: true,
stream_options: {
include_usage: true,
Expand Down Expand Up @@ -80,20 +140,30 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa
}

if (lastUsage) {
// Extract cache-related information if available
// LiteLLM may use different field names for cache tokens
const cacheWriteTokens =
lastUsage.cache_creation_input_tokens || (lastUsage as any).prompt_cache_miss_tokens || 0
const cacheReadTokens =
lastUsage.prompt_tokens_details?.cached_tokens ||
(lastUsage as any).cache_read_input_tokens ||
(lastUsage as any).prompt_cache_hit_tokens ||
0

const usageData: ApiStreamUsageChunk = {
type: "usage",
inputTokens: lastUsage.prompt_tokens || 0,
outputTokens: lastUsage.completion_tokens || 0,
cacheWriteTokens: lastUsage.cache_creation_input_tokens || 0,
cacheReadTokens: lastUsage.prompt_tokens_details?.cached_tokens || 0,
cacheWriteTokens: cacheWriteTokens > 0 ? cacheWriteTokens : undefined,
cacheReadTokens: cacheReadTokens > 0 ? cacheReadTokens : undefined,
}

usageData.totalCost = calculateApiCostOpenAI(
info,
usageData.inputTokens,
usageData.outputTokens,
usageData.cacheWriteTokens,
usageData.cacheReadTokens,
usageData.cacheWriteTokens || 0,
usageData.cacheReadTokens || 0,
)

yield usageData
Expand Down
25 changes: 24 additions & 1 deletion webview-ui/src/components/settings/providers/LiteLLM.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { useCallback, useState, useEffect, useRef } from "react"
import { VSCodeTextField } from "@vscode/webview-ui-toolkit/react"
import { VSCodeTextField, VSCodeCheckbox } from "@vscode/webview-ui-toolkit/react"

import { type ProviderSettings, type OrganizationAllowList, litellmDefaultModelId } from "@roo-code/types"

Expand Down Expand Up @@ -151,6 +151,29 @@ export const LiteLLM = ({
organizationAllowList={organizationAllowList}
errorMessage={modelValidationError}
/>

{/* Show prompt caching option if the selected model supports it */}
{(() => {
const selectedModelId = apiConfiguration.litellmModelId || litellmDefaultModelId
const selectedModel = routerModels?.litellm?.[selectedModelId]
if (selectedModel?.supportsPromptCache) {
return (
<div className="mt-4">
<VSCodeCheckbox
checked={apiConfiguration.litellmUsePromptCache || false}
onChange={(e: any) => {
setApiConfigurationField("litellmUsePromptCache", e.target.checked)
}}>
<span className="font-medium">{t("settings:providers.enablePromptCaching")}</span>
</VSCodeCheckbox>
<div className="text-sm text-vscode-descriptionForeground ml-6 mt-1">
{t("settings:providers.enablePromptCachingTitle")}
</div>
</div>
)
}
return null
})()}
</>
)
}