Skip to content

Commit 4fac3da

Browse files
authored
Merge pull request RooCodeInc#1217 from RooVetGit/cte/control-max-tokens
2 parents 28cdf0e + cf69b0f commit 4fac3da

File tree

10 files changed

+203
-66
lines changed

10 files changed

+203
-66
lines changed

.changeset/wild-emus-dream.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"roo-cline": patch
3+
---
4+
5+
Allow control over maxTokens for thinking models

src/api/providers/anthropic.ts

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
3131
let stream: AnthropicStream<Anthropic.Messages.RawMessageStreamEvent>
3232
const cacheControl: CacheControlEphemeral = { type: "ephemeral" }
3333
let { id: modelId, info: modelInfo } = this.getModel()
34-
const maxTokens = modelInfo.maxTokens || 8192
34+
const maxTokens = this.options.modelMaxTokens || modelInfo.maxTokens || 8192
3535
let temperature = this.options.modelTemperature ?? ANTHROPIC_DEFAULT_TEMPERATURE
3636
let thinking: BetaThinkingConfigParam | undefined = undefined
3737

@@ -41,7 +41,15 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
4141
// `claude-3-7-sonnet-20250219` model with a thinking budget.
4242
// We can handle this more elegantly in the future.
4343
modelId = "claude-3-7-sonnet-20250219"
44-
const budgetTokens = this.options.anthropicThinking ?? Math.max(maxTokens * 0.8, 1024)
44+
45+
// Clamp the thinking budget to be at most 80% of max tokens and at
46+
// least 1024 tokens.
47+
const maxBudgetTokens = Math.floor(maxTokens * 0.8)
48+
const budgetTokens = Math.max(
49+
Math.min(this.options.anthropicThinking ?? maxBudgetTokens, maxBudgetTokens),
50+
1024,
51+
)
52+
4553
thinking = { type: "enabled", budget_tokens: budgetTokens }
4654
temperature = 1.0
4755
}

src/api/providers/openrouter.ts

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,19 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
108108
topP = 0.95
109109
}
110110

111+
const maxTokens = this.options.modelMaxTokens || modelInfo.maxTokens
111112
let temperature = this.options.modelTemperature ?? defaultTemperature
112113
let thinking: BetaThinkingConfigParam | undefined = undefined
113114

114115
if (modelInfo.thinking) {
115-
const maxTokens = modelInfo.maxTokens || 8192
116-
const budgetTokens = this.options.anthropicThinking ?? Math.max(maxTokens * 0.8, 1024)
116+
// Clamp the thinking budget to be at most 80% of max tokens and at
117+
// least 1024 tokens.
118+
const maxBudgetTokens = Math.floor((maxTokens || 8192) * 0.8)
119+
const budgetTokens = Math.max(
120+
Math.min(this.options.anthropicThinking ?? maxBudgetTokens, maxBudgetTokens),
121+
1024,
122+
)
123+
117124
thinking = { type: "enabled", budget_tokens: budgetTokens }
118125
temperature = 1.0
119126
}
@@ -271,7 +278,7 @@ export async function getOpenRouterModels() {
271278
modelInfo.supportsPromptCache = true
272279
modelInfo.cacheWritesPrice = 3.75
273280
modelInfo.cacheReadsPrice = 0.3
274-
modelInfo.maxTokens = 16384
281+
modelInfo.maxTokens = 64_000
275282
break
276283
case rawModel.id.startsWith("anthropic/claude-3.5-sonnet-20240620"):
277284
modelInfo.supportsPromptCache = true

src/core/Cline.ts

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ export type ClineOptions = {
8787

8888
export class Cline {
8989
readonly taskId: string
90+
readonly apiConfiguration: ApiConfiguration
9091
api: ApiHandler
9192
private terminalManager: TerminalManager
9293
private urlContentFetcher: UrlContentFetcher
@@ -148,6 +149,7 @@ export class Cline {
148149
}
149150

150151
this.taskId = crypto.randomUUID()
152+
this.apiConfiguration = apiConfiguration
151153
this.api = buildApiHandler(apiConfiguration)
152154
this.terminalManager = new TerminalManager()
153155
this.urlContentFetcher = new UrlContentFetcher(provider.context)
@@ -961,13 +963,21 @@ export class Cline {
961963
cacheWrites = 0,
962964
cacheReads = 0,
963965
}: ClineApiReqInfo = JSON.parse(previousRequest)
966+
964967
const totalTokens = tokensIn + tokensOut + cacheWrites + cacheReads
965968

966-
const trimmedMessages = truncateConversationIfNeeded(
967-
this.apiConversationHistory,
969+
const modelInfo = this.api.getModel().info
970+
const maxTokens = modelInfo.thinking
971+
? this.apiConfiguration.modelMaxTokens || modelInfo.maxTokens
972+
: modelInfo.maxTokens
973+
const contextWindow = modelInfo.contextWindow
974+
975+
const trimmedMessages = truncateConversationIfNeeded({
976+
messages: this.apiConversationHistory,
968977
totalTokens,
969-
this.api.getModel().info,
970-
)
978+
maxTokens,
979+
contextWindow,
980+
})
971981

972982
if (trimmedMessages !== this.apiConversationHistory) {
973983
await this.overwriteApiConversationHistory(trimmedMessages)

src/core/sliding-window/__tests__/sliding-window.test.ts

Lines changed: 88 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,21 @@ describe("getMaxTokens", () => {
119119
// Max tokens = 100000 - 50000 = 50000
120120

121121
// Below max tokens - no truncation
122-
const result1 = truncateConversationIfNeeded(messages, 49999, modelInfo)
122+
const result1 = truncateConversationIfNeeded({
123+
messages,
124+
totalTokens: 49999,
125+
contextWindow: modelInfo.contextWindow,
126+
maxTokens: modelInfo.maxTokens,
127+
})
123128
expect(result1).toEqual(messages)
124129

125130
// Above max tokens - truncate
126-
const result2 = truncateConversationIfNeeded(messages, 50001, modelInfo)
131+
const result2 = truncateConversationIfNeeded({
132+
messages,
133+
totalTokens: 50001,
134+
contextWindow: modelInfo.contextWindow,
135+
maxTokens: modelInfo.maxTokens,
136+
})
127137
expect(result2).not.toEqual(messages)
128138
expect(result2.length).toBe(3) // Truncated with 0.5 fraction
129139
})
@@ -133,11 +143,21 @@ describe("getMaxTokens", () => {
133143
// Max tokens = 100000 - (100000 * 0.2) = 80000
134144

135145
// Below max tokens - no truncation
136-
const result1 = truncateConversationIfNeeded(messages, 79999, modelInfo)
146+
const result1 = truncateConversationIfNeeded({
147+
messages,
148+
totalTokens: 79999,
149+
contextWindow: modelInfo.contextWindow,
150+
maxTokens: modelInfo.maxTokens,
151+
})
137152
expect(result1).toEqual(messages)
138153

139154
// Above max tokens - truncate
140-
const result2 = truncateConversationIfNeeded(messages, 80001, modelInfo)
155+
const result2 = truncateConversationIfNeeded({
156+
messages,
157+
totalTokens: 80001,
158+
contextWindow: modelInfo.contextWindow,
159+
maxTokens: modelInfo.maxTokens,
160+
})
141161
expect(result2).not.toEqual(messages)
142162
expect(result2.length).toBe(3) // Truncated with 0.5 fraction
143163
})
@@ -147,11 +167,21 @@ describe("getMaxTokens", () => {
147167
// Max tokens = 50000 - 10000 = 40000
148168

149169
// Below max tokens - no truncation
150-
const result1 = truncateConversationIfNeeded(messages, 39999, modelInfo)
170+
const result1 = truncateConversationIfNeeded({
171+
messages,
172+
totalTokens: 39999,
173+
contextWindow: modelInfo.contextWindow,
174+
maxTokens: modelInfo.maxTokens,
175+
})
151176
expect(result1).toEqual(messages)
152177

153178
// Above max tokens - truncate
154-
const result2 = truncateConversationIfNeeded(messages, 40001, modelInfo)
179+
const result2 = truncateConversationIfNeeded({
180+
messages,
181+
totalTokens: 40001,
182+
contextWindow: modelInfo.contextWindow,
183+
maxTokens: modelInfo.maxTokens,
184+
})
155185
expect(result2).not.toEqual(messages)
156186
expect(result2.length).toBe(3) // Truncated with 0.5 fraction
157187
})
@@ -161,11 +191,21 @@ describe("getMaxTokens", () => {
161191
// Max tokens = 200000 - 30000 = 170000
162192

163193
// Below max tokens - no truncation
164-
const result1 = truncateConversationIfNeeded(messages, 169999, modelInfo)
194+
const result1 = truncateConversationIfNeeded({
195+
messages,
196+
totalTokens: 169999,
197+
contextWindow: modelInfo.contextWindow,
198+
maxTokens: modelInfo.maxTokens,
199+
})
165200
expect(result1).toEqual(messages)
166201

167202
// Above max tokens - truncate
168-
const result2 = truncateConversationIfNeeded(messages, 170001, modelInfo)
203+
const result2 = truncateConversationIfNeeded({
204+
messages,
205+
totalTokens: 170001,
206+
contextWindow: modelInfo.contextWindow,
207+
maxTokens: modelInfo.maxTokens,
208+
})
169209
expect(result2).not.toEqual(messages)
170210
expect(result2.length).toBe(3) // Truncated with 0.5 fraction
171211
})
@@ -194,7 +234,12 @@ describe("truncateConversationIfNeeded", () => {
194234
const maxTokens = 100000 - 30000 // 70000
195235
const totalTokens = 69999 // Below threshold
196236

197-
const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
237+
const result = truncateConversationIfNeeded({
238+
messages,
239+
totalTokens,
240+
contextWindow: modelInfo.contextWindow,
241+
maxTokens: modelInfo.maxTokens,
242+
})
198243
expect(result).toEqual(messages) // No truncation occurs
199244
})
200245

@@ -207,7 +252,12 @@ describe("truncateConversationIfNeeded", () => {
207252
// With 4 messages after the first, 0.5 fraction means remove 2 messages
208253
const expectedResult = [messages[0], messages[3], messages[4]]
209254

210-
const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo)
255+
const result = truncateConversationIfNeeded({
256+
messages,
257+
totalTokens,
258+
contextWindow: modelInfo.contextWindow,
259+
maxTokens: modelInfo.maxTokens,
260+
})
211261
expect(result).toEqual(expectedResult)
212262
})
213263

@@ -218,14 +268,38 @@ describe("truncateConversationIfNeeded", () => {
218268

219269
// Test below threshold
220270
const belowThreshold = 69999
221-
expect(truncateConversationIfNeeded(messages, belowThreshold, modelInfo1)).toEqual(
222-
truncateConversationIfNeeded(messages, belowThreshold, modelInfo2),
271+
expect(
272+
truncateConversationIfNeeded({
273+
messages,
274+
totalTokens: belowThreshold,
275+
contextWindow: modelInfo1.contextWindow,
276+
maxTokens: modelInfo1.maxTokens,
277+
}),
278+
).toEqual(
279+
truncateConversationIfNeeded({
280+
messages,
281+
totalTokens: belowThreshold,
282+
contextWindow: modelInfo2.contextWindow,
283+
maxTokens: modelInfo2.maxTokens,
284+
}),
223285
)
224286

225287
// Test above threshold
226288
const aboveThreshold = 70001
227-
expect(truncateConversationIfNeeded(messages, aboveThreshold, modelInfo1)).toEqual(
228-
truncateConversationIfNeeded(messages, aboveThreshold, modelInfo2),
289+
expect(
290+
truncateConversationIfNeeded({
291+
messages,
292+
totalTokens: aboveThreshold,
293+
contextWindow: modelInfo1.contextWindow,
294+
maxTokens: modelInfo1.maxTokens,
295+
}),
296+
).toEqual(
297+
truncateConversationIfNeeded({
298+
messages,
299+
totalTokens: aboveThreshold,
300+
contextWindow: modelInfo2.contextWindow,
301+
maxTokens: modelInfo2.maxTokens,
302+
}),
229303
)
230304
})
231305
})

src/core/sliding-window/index.ts

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import { Anthropic } from "@anthropic-ai/sdk"
22

3-
import { ModelInfo } from "../../shared/api"
4-
53
/**
64
* Truncates a conversation by removing a fraction of the messages.
75
*
@@ -26,28 +24,29 @@ export function truncateConversation(
2624
}
2725

2826
/**
29-
* Conditionally truncates the conversation messages if the total token count exceeds the model's limit.
27+
* Conditionally truncates the conversation messages if the total token count
28+
* exceeds the model's limit.
3029
*
3130
* @param {Anthropic.Messages.MessageParam[]} messages - The conversation messages.
3231
* @param {number} totalTokens - The total number of tokens in the conversation.
33-
* @param {ModelInfo} modelInfo - Model metadata including context window size.
32+
* @param {number} contextWindow - The context window size.
33+
* @param {number} maxTokens - The maximum number of tokens allowed.
3434
* @returns {Anthropic.Messages.MessageParam[]} The original or truncated conversation messages.
3535
*/
36-
export function truncateConversationIfNeeded(
37-
messages: Anthropic.Messages.MessageParam[],
38-
totalTokens: number,
39-
modelInfo: ModelInfo,
40-
): Anthropic.Messages.MessageParam[] {
41-
return totalTokens < getMaxTokens(modelInfo) ? messages : truncateConversation(messages, 0.5)
36+
37+
type TruncateOptions = {
38+
messages: Anthropic.Messages.MessageParam[]
39+
totalTokens: number
40+
contextWindow: number
41+
maxTokens?: number
4242
}
4343

44-
/**
45-
* Calculates the maximum allowed tokens
46-
*
47-
* @param {ModelInfo} modelInfo - The model information containing the context window size.
48-
* @returns {number} The maximum number of tokens allowed
49-
*/
50-
function getMaxTokens(modelInfo: ModelInfo): number {
51-
// The buffer needs to be at least as large as `modelInfo.maxTokens`, or 20% of the context window if for some reason it's not set.
52-
return modelInfo.contextWindow - (modelInfo.maxTokens || modelInfo.contextWindow * 0.2)
44+
export function truncateConversationIfNeeded({
45+
messages,
46+
totalTokens,
47+
contextWindow,
48+
maxTokens,
49+
}: TruncateOptions): Anthropic.Messages.MessageParam[] {
50+
const allowedTokens = contextWindow - (maxTokens || contextWindow * 0.2)
51+
return totalTokens < allowedTokens ? messages : truncateConversation(messages, 0.5)
5352
}

src/core/webview/ClineProvider.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,6 +1671,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
16711671
requestyModelId,
16721672
requestyModelInfo,
16731673
modelTemperature,
1674+
modelMaxTokens,
16741675
} = apiConfiguration
16751676
await Promise.all([
16761677
this.updateGlobalState("apiProvider", apiProvider),
@@ -1719,6 +1720,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
17191720
this.updateGlobalState("requestyModelId", requestyModelId),
17201721
this.updateGlobalState("requestyModelInfo", requestyModelInfo),
17211722
this.updateGlobalState("modelTemperature", modelTemperature),
1723+
this.updateGlobalState("modelMaxTokens", modelMaxTokens),
17221724
])
17231725
if (this.cline) {
17241726
this.cline.api = buildApiHandler(apiConfiguration)
@@ -2210,6 +2212,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
22102212
requestyModelId,
22112213
requestyModelInfo,
22122214
modelTemperature,
2215+
modelMaxTokens,
22132216
maxOpenTabsContext,
22142217
] = await Promise.all([
22152218
this.getGlobalState("apiProvider") as Promise<ApiProvider | undefined>,
@@ -2293,6 +2296,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
22932296
this.getGlobalState("requestyModelId") as Promise<string | undefined>,
22942297
this.getGlobalState("requestyModelInfo") as Promise<ModelInfo | undefined>,
22952298
this.getGlobalState("modelTemperature") as Promise<number | undefined>,
2299+
this.getGlobalState("modelMaxTokens") as Promise<number | undefined>,
22962300
this.getGlobalState("maxOpenTabsContext") as Promise<number | undefined>,
22972301
])
22982302

@@ -2358,6 +2362,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
23582362
requestyModelId,
23592363
requestyModelInfo,
23602364
modelTemperature,
2365+
modelMaxTokens,
23612366
},
23622367
lastShownAnnouncementId,
23632368
customInstructions,

0 commit comments

Comments
 (0)