Skip to content

feat: add prompt caching support for LiteLLM (#5791) #6074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ const litellmSchema = baseProviderSettingsSchema.extend({
litellmBaseUrl: z.string().optional(),
litellmApiKey: z.string().optional(),
litellmModelId: z.string().optional(),
litellmUsePromptCache: z.boolean().optional(),
})

const defaultSchema = z.object({
Expand Down
130 changes: 130 additions & 0 deletions src/api/providers/__tests__/lite-llm.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import { describe, it, expect, vi, beforeEach } from "vitest"
import OpenAI from "openai"
import { Anthropic } from "@anthropic-ai/sdk"

import { LiteLLMHandler } from "../lite-llm"
import { ApiHandlerOptions } from "../../../shared/api"
import { litellmDefaultModelId, litellmDefaultModelInfo } from "@roo-code/types"

// Mock vscode first to avoid import errors
vi.mock("vscode", () => ({}))

// Mock OpenAI
vi.mock("openai", () => {
const mockStream = {
[Symbol.asyncIterator]: vi.fn(),
}

const mockCreate = vi.fn().mockReturnValue({
withResponse: vi.fn().mockResolvedValue({ data: mockStream }),
})

return {
default: vi.fn().mockImplementation(() => ({
chat: {
completions: {
create: mockCreate,
},
},
})),
}
})

// Mock model fetching
vi.mock("../fetchers/modelCache", () => ({
getModels: vi.fn().mockImplementation(() => {
return Promise.resolve({
[litellmDefaultModelId]: litellmDefaultModelInfo,
})
}),
}))

describe("LiteLLMHandler", () => {
let handler: LiteLLMHandler
let mockOptions: ApiHandlerOptions
let mockOpenAIClient: any

beforeEach(() => {
vi.clearAllMocks()
mockOptions = {
litellmApiKey: "test-key",
litellmBaseUrl: "http://localhost:4000",
litellmModelId: litellmDefaultModelId,
}
handler = new LiteLLMHandler(mockOptions)
mockOpenAIClient = new OpenAI()
})

describe("prompt caching", () => {
it("should add cache control headers when litellmUsePromptCache is enabled", async () => {
const optionsWithCache: ApiHandlerOptions = {
...mockOptions,
litellmUsePromptCache: true,
}
handler = new LiteLLMHandler(optionsWithCache)

const systemPrompt = "You are a helpful assistant"
const messages: Anthropic.Messages.MessageParam[] = [
{ role: "user", content: "Hello" },
{ role: "assistant", content: "Hi there!" },
{ role: "user", content: "How are you?" },
]

// Mock the stream response
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
choices: [{ delta: { content: "I'm doing well!" } }],
usage: {
prompt_tokens: 100,
completion_tokens: 50,
cache_creation_input_tokens: 20,
cache_read_input_tokens: 30,
},
}
},
}

mockOpenAIClient.chat.completions.create.mockReturnValue({
withResponse: vi.fn().mockResolvedValue({ data: mockStream }),
})

const generator = handler.createMessage(systemPrompt, messages)
const results = []
for await (const chunk of generator) {
results.push(chunk)
}

// Verify that create was called with cache control headers
const createCall = mockOpenAIClient.chat.completions.create.mock.calls[0][0]

// Check system message has cache control
expect(createCall.messages[0]).toMatchObject({
role: "system",
content: systemPrompt,
cache_control: { type: "ephemeral" },
})

// Check that the last two user messages have cache control
const userMessageIndices = createCall.messages
.map((msg: any, idx: number) => (msg.role === "user" ? idx : -1))
.filter((idx: number) => idx !== -1)

const lastUserIdx = userMessageIndices[userMessageIndices.length - 1]

expect(createCall.messages[lastUserIdx]).toMatchObject({
cache_control: { type: "ephemeral" },
})
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding an assertion for the second last user message as well, to fully verify that cache control headers are applied to both the last two user messages.


// Verify usage includes cache tokens
const usageChunk = results.find((chunk) => chunk.type === "usage")
expect(usageChunk).toMatchObject({
type: "usage",
inputTokens: 100,
outputTokens: 50,
cacheWriteTokens: 20,
cacheReadTokens: 30,
})
})
})
})
51 changes: 46 additions & 5 deletions src/api/providers/lite-llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,44 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa
...convertToOpenAiMessages(messages),
]

// Apply cache control if prompt caching is enabled and supported
let enhancedMessages = openAiMessages
if (this.options.litellmUsePromptCache && info.supportsPromptCache) {
const cacheControl = { cache_control: { type: "ephemeral" } }

// Add cache control to system message
enhancedMessages[0] = {
...enhancedMessages[0],
...cacheControl,
}

// Find the last two user messages to apply caching
const userMsgIndices = enhancedMessages.reduce(
(acc, msg, index) => (msg.role === "user" ? [...acc, index] : acc),
[] as number[],
)
const lastUserMsgIndex = userMsgIndices[userMsgIndices.length - 1] ?? -1
const secondLastUserMsgIndex = userMsgIndices[userMsgIndices.length - 2] ?? -1

// Apply cache_control to the last two user messages
enhancedMessages = enhancedMessages.map((message, index) => {
if (index === lastUserMsgIndex || index === secondLastUserMsgIndex) {
return {
...message,
...cacheControl,
}
}
return message
})
}

// Required by some providers; others default to max tokens allowed
let maxTokens: number | undefined = info.maxTokens ?? undefined

const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model: modelId,
max_tokens: maxTokens,
messages: openAiMessages,
messages: enhancedMessages,
stream: true,
stream_options: {
include_usage: true,
Expand Down Expand Up @@ -80,20 +111,30 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa
}

if (lastUsage) {
// Extract cache-related information if available
// LiteLLM may use different field names for cache tokens
const cacheWriteTokens =
lastUsage.cache_creation_input_tokens || (lastUsage as any).prompt_cache_miss_tokens || 0
const cacheReadTokens =
lastUsage.prompt_tokens_details?.cached_tokens ||
(lastUsage as any).cache_read_input_tokens ||
(lastUsage as any).prompt_cache_hit_tokens ||
0

const usageData: ApiStreamUsageChunk = {
type: "usage",
inputTokens: lastUsage.prompt_tokens || 0,
outputTokens: lastUsage.completion_tokens || 0,
cacheWriteTokens: lastUsage.cache_creation_input_tokens || 0,
cacheReadTokens: lastUsage.prompt_tokens_details?.cached_tokens || 0,
cacheWriteTokens: cacheWriteTokens > 0 ? cacheWriteTokens : undefined,
cacheReadTokens: cacheReadTokens > 0 ? cacheReadTokens : undefined,
}

usageData.totalCost = calculateApiCostOpenAI(
info,
usageData.inputTokens,
usageData.outputTokens,
usageData.cacheWriteTokens,
usageData.cacheReadTokens,
usageData.cacheWriteTokens || 0,
usageData.cacheReadTokens || 0,
)

yield usageData
Expand Down
25 changes: 24 additions & 1 deletion webview-ui/src/components/settings/providers/LiteLLM.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { useCallback, useState, useEffect, useRef } from "react"
import { VSCodeTextField } from "@vscode/webview-ui-toolkit/react"
import { VSCodeTextField, VSCodeCheckbox } from "@vscode/webview-ui-toolkit/react"

import { type ProviderSettings, type OrganizationAllowList, litellmDefaultModelId } from "@roo-code/types"

Expand Down Expand Up @@ -151,6 +151,29 @@ export const LiteLLM = ({
organizationAllowList={organizationAllowList}
errorMessage={modelValidationError}
/>

{/* Show prompt caching option if the selected model supports it */}
{(() => {
const selectedModelId = apiConfiguration.litellmModelId || litellmDefaultModelId
const selectedModel = routerModels?.litellm?.[selectedModelId]
if (selectedModel?.supportsPromptCache) {
return (
<div className="mt-4">
<VSCodeCheckbox
checked={apiConfiguration.litellmUsePromptCache || false}
onChange={(e: any) => {
setApiConfigurationField("litellmUsePromptCache", e.target.checked)
}}>
<span className="font-medium">{t("settings:providers.enablePromptCaching")}</span>
</VSCodeCheckbox>
<div className="text-sm text-vscode-descriptionForeground ml-6 mt-1">
{t("settings:providers.enablePromptCachingTitle")}
</div>
</div>
)
}
return null
})()}
</>
)
}