diff --git a/src/api/providers/cerebras.ts b/src/api/providers/cerebras.ts index a0421844e815..d2ae84b65e57 100644 --- a/src/api/providers/cerebras.ts +++ b/src/api/providers/cerebras.ts @@ -6,7 +6,7 @@ import type { ApiHandlerOptions } from "../../shared/api" import { calculateApiCostOpenAI } from "../../shared/cost" import { ApiStream } from "../transform/stream" import { convertToOpenAiMessages } from "../transform/openai-format" -import { XmlMatcher } from "../../utils/xml-matcher" +import { ReasoningXmlMatcher } from "../../utils/reasoning-xml-matcher" import type { ApiHandlerCreateMessageMetadata, SingleCompletionHandler } from "../index" import { BaseProvider } from "./base-provider" @@ -187,9 +187,8 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan throw new Error(t("common:errors.cerebras.noResponseBody")) } - // Initialize XmlMatcher to parse ... tags - const matcher = new XmlMatcher( - "think", + // Initialize ReasoningXmlMatcher to parse reasoning tags + const matcher = new ReasoningXmlMatcher( (chunk) => ({ type: chunk.matched ? "reasoning" : "text", @@ -228,7 +227,7 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan if (parsed.choices?.[0]?.delta?.content) { const content = parsed.choices[0].delta.content - // Use XmlMatcher to parse ... tags + // Use ReasoningXmlMatcher to parse reasoning tags for (const chunk of matcher.update(content)) { yield chunk } diff --git a/src/api/providers/chutes.ts b/src/api/providers/chutes.ts index 62121bd19dc0..e755dd1c0594 100644 --- a/src/api/providers/chutes.ts +++ b/src/api/providers/chutes.ts @@ -3,7 +3,7 @@ import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" import type { ApiHandlerOptions } from "../../shared/api" -import { XmlMatcher } from "../../utils/xml-matcher" +import { ReasoningXmlMatcher } from "../../utils/reasoning-xml-matcher" import { convertToR1Format } from "../transform/r1-format" import { convertToOpenAiMessages } from "../transform/openai-format" import { ApiStream } from "../transform/stream" @@ -53,8 +53,7 @@ export class ChutesHandler extends BaseOpenAiCompatibleProvider { messages: convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]), }) - const matcher = new XmlMatcher( - "think", + const matcher = new ReasoningXmlMatcher( (chunk) => ({ type: chunk.matched ? "reasoning" : "text", diff --git a/src/api/providers/featherless.ts b/src/api/providers/featherless.ts index 56d7177de7c9..3bca76e6bf8a 100644 --- a/src/api/providers/featherless.ts +++ b/src/api/providers/featherless.ts @@ -1,9 +1,14 @@ -import { DEEP_SEEK_DEFAULT_TEMPERATURE, type FeatherlessModelId, featherlessDefaultModelId, featherlessModels } from "@roo-code/types" +import { + DEEP_SEEK_DEFAULT_TEMPERATURE, + type FeatherlessModelId, + featherlessDefaultModelId, + featherlessModels, +} from "@roo-code/types" import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" import type { ApiHandlerOptions } from "../../shared/api" -import { XmlMatcher } from "../../utils/xml-matcher" +import { ReasoningXmlMatcher } from "../../utils/reasoning-xml-matcher" import { convertToR1Format } from "../transform/r1-format" import { convertToOpenAiMessages } from "../transform/openai-format" import { ApiStream } from "../transform/stream" @@ -53,8 +58,7 @@ export class FeatherlessHandler extends BaseOpenAiCompatibleProvider ({ type: chunk.matched ? "reasoning" : "text", diff --git a/src/api/providers/lm-studio.ts b/src/api/providers/lm-studio.ts index 6c58a96ae1fa..0953808abcc9 100644 --- a/src/api/providers/lm-studio.ts +++ b/src/api/providers/lm-studio.ts @@ -6,7 +6,7 @@ import { type ModelInfo, openAiModelInfoSaneDefaults, LMSTUDIO_DEFAULT_TEMPERATU import type { ApiHandlerOptions } from "../../shared/api" -import { XmlMatcher } from "../../utils/xml-matcher" +import { ReasoningXmlMatcher } from "../../utils/reasoning-xml-matcher" import { convertToOpenAiMessages } from "../transform/openai-format" import { ApiStream } from "../transform/stream" @@ -100,8 +100,7 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan throw handleOpenAIError(error, this.providerName) } - const matcher = new XmlMatcher( - "think", + const matcher = new ReasoningXmlMatcher( (chunk) => ({ type: chunk.matched ? "reasoning" : "text", diff --git a/src/api/providers/native-ollama.ts b/src/api/providers/native-ollama.ts index 83a5c7b36ea8..38cad6d80508 100644 --- a/src/api/providers/native-ollama.ts +++ b/src/api/providers/native-ollama.ts @@ -5,7 +5,7 @@ import { ApiStream } from "../transform/stream" import { BaseProvider } from "./base-provider" import type { ApiHandlerOptions } from "../../shared/api" import { getOllamaModels } from "./fetchers/ollama" -import { XmlMatcher } from "../../utils/xml-matcher" +import { ReasoningXmlMatcher } from "../../utils/reasoning-xml-matcher" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" interface OllamaChatOptions { @@ -179,8 +179,7 @@ export class NativeOllamaHandler extends BaseProvider implements SingleCompletio ...convertToOllamaMessages(messages), ] - const matcher = new XmlMatcher( - "think", + const matcher = new ReasoningXmlMatcher( (chunk) => ({ type: chunk.matched ? "reasoning" : "text", diff --git a/src/api/providers/ollama.ts b/src/api/providers/ollama.ts index ab9df116aa84..58b493c5e381 100644 --- a/src/api/providers/ollama.ts +++ b/src/api/providers/ollama.ts @@ -5,7 +5,7 @@ import { type ModelInfo, openAiModelInfoSaneDefaults, DEEP_SEEK_DEFAULT_TEMPERAT import type { ApiHandlerOptions } from "../../shared/api" -import { XmlMatcher } from "../../utils/xml-matcher" +import { ReasoningXmlMatcher } from "../../utils/reasoning-xml-matcher" import { convertToOpenAiMessages } from "../transform/openai-format" import { convertToR1Format } from "../transform/r1-format" @@ -68,8 +68,7 @@ export class OllamaHandler extends BaseProvider implements SingleCompletionHandl } catch (error) { throw handleOpenAIError(error, this.providerName) } - const matcher = new XmlMatcher( - "think", + const matcher = new ReasoningXmlMatcher( (chunk) => ({ type: chunk.matched ? "reasoning" : "text", diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index aebe671712a7..33f0ae64fee5 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -12,7 +12,7 @@ import { import type { ApiHandlerOptions } from "../../shared/api" -import { XmlMatcher } from "../../utils/xml-matcher" +import { ReasoningXmlMatcher } from "../../utils/reasoning-xml-matcher" import { convertToOpenAiMessages } from "../transform/openai-format" import { convertToR1Format } from "../transform/r1-format" @@ -179,8 +179,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl throw handleOpenAIError(error, this.providerName) } - const matcher = new XmlMatcher( - "think", + const matcher = new ReasoningXmlMatcher( (chunk) => ({ type: chunk.matched ? "reasoning" : "text", diff --git a/src/core/assistant-message/presentAssistantMessage.ts b/src/core/assistant-message/presentAssistantMessage.ts index 689675999fd1..f0b0806ce250 100644 --- a/src/core/assistant-message/presentAssistantMessage.ts +++ b/src/core/assistant-message/presentAssistantMessage.ts @@ -93,21 +93,24 @@ export async function presentAssistantMessage(cline: Task) { if (content) { // Have to do this for partial and complete since sending - // content in thinking tags to markdown renderer will + // content in reasoning tags to markdown renderer will // automatically be removed. - // Remove end substrings of (with optional line break - // after) and (with optional line break before). + // Remove all instances of reasoning tags: , , , + // (with optional line break after opening tags) and their closing tags + // (with optional line break before closing tags). // - Needs to be separate since we dont want to remove the line // break before the first tag. // - Needs to happen before the xml parsing below. - content = content.replace(/\s?/g, "") - content = content.replace(/\s?<\/thinking>/g, "") + const reasoningTags = ["think", "thinking", "reasoning", "thought"] + reasoningTags.forEach((tag) => { + // Remove opening tags with optional line break after + const openingRegex = new RegExp(`<${tag}>\\s?`, "g") + content = content.replace(openingRegex, "") + // Remove closing tags with optional line break before + const closingRegex = new RegExp(`\\s?<\\/${tag}>`, "g") + content = content.replace(closingRegex, "") + }) // Remove partial XML tag at the very end of the content (for // tool use and thinking tags), Prevents scrollview from @@ -136,14 +139,20 @@ export async function presentAssistantMessage(cline: Task) { // (letters and underscores only). const isLikelyTagName = /^[a-zA-Z_]+$/.test(tagContent) + // Check if it's a partial reasoning tag + const reasoningTags = ["think", "thinking", "reasoning", "thought"] + const isPartialReasoningTag = reasoningTags.some( + (tag) => tag.startsWith(tagContent) || tagContent.startsWith(tag), + ) + // Preemptively remove < or { + it("should match tags", () => { + const matcher = new ReasoningXmlMatcher() + const input = "Some text This is reasoning content more text" + const results = matcher.final(input) + + expect(results).toHaveLength(3) + expect(results[0]).toEqual({ matched: false, data: "Some text " }) + expect(results[1]).toEqual({ matched: true, data: "This is reasoning content" }) + expect(results[2]).toEqual({ matched: false, data: " more text" }) + }) + + it("should match tags", () => { + const matcher = new ReasoningXmlMatcher() + const input = "Some text This is reasoning content more text" + const results = matcher.final(input) + + expect(results).toHaveLength(3) + expect(results[0]).toEqual({ matched: false, data: "Some text " }) + expect(results[1]).toEqual({ matched: true, data: "This is reasoning content" }) + expect(results[2]).toEqual({ matched: false, data: " more text" }) + }) + + it("should match tags", () => { + const matcher = new ReasoningXmlMatcher() + const input = "Some text This is reasoning content more text" + const results = matcher.final(input) + + expect(results).toHaveLength(3) + expect(results[0]).toEqual({ matched: false, data: "Some text " }) + expect(results[1]).toEqual({ matched: true, data: "This is reasoning content" }) + expect(results[2]).toEqual({ matched: false, data: " more text" }) + }) + + it("should match tags", () => { + const matcher = new ReasoningXmlMatcher() + const input = "Some text This is reasoning content more text" + const results = matcher.final(input) + + expect(results).toHaveLength(3) + expect(results[0]).toEqual({ matched: false, data: "Some text " }) + expect(results[1]).toEqual({ matched: true, data: "This is reasoning content" }) + expect(results[2]).toEqual({ matched: false, data: " more text" }) + }) + + it("should handle streaming updates for all tag variants", () => { + const testCases = [ + { tag: "think", content: "Thinking about the problem" }, + { tag: "thinking", content: "Processing the request" }, + { tag: "reasoning", content: "Analyzing the situation" }, + { tag: "thought", content: "Considering options" }, + ] + + testCases.forEach(({ tag, content }) => { + const matcher = new ReasoningXmlMatcher() + + // Simulate streaming + const chunks = [ + "Initial text ", + `<${tag}>`, + content.slice(0, 10), + content.slice(10), + ``, + " final text", + ] + + let allResults: any[] = [] + chunks.forEach((chunk) => { + const results = matcher.update(chunk) + allResults.push(...results) + }) + + // Get final results + const finalResults = matcher.final() + allResults.push(...finalResults) + + // Verify we got the expected matched content + const matchedResults = allResults.filter((r) => r.matched) + const unmatchedResults = allResults.filter((r) => !r.matched) + + expect(matchedResults.length).toBeGreaterThan(0) + const fullMatchedContent = matchedResults.map((r) => r.data).join("") + expect(fullMatchedContent).toContain(content) + + const fullUnmatchedContent = unmatchedResults.map((r) => r.data).join("") + expect(fullUnmatchedContent).toContain("Initial text") + expect(fullUnmatchedContent).toContain("final text") + }) + }) + + it("should handle nested tags correctly", () => { + const matcher = new ReasoningXmlMatcher() + const input = "Outer Inner content" + const results = matcher.final(input) + + // Should match the entire nested structure + expect(results).toHaveLength(1) + expect(results[0]).toEqual({ + matched: true, + data: "Outer Inner content", + }) + }) + + it("should handle multiple different reasoning tags in sequence", () => { + const matcher = new ReasoningXmlMatcher() + const input = "Text Think content middle Thinking content end" + const results = matcher.final(input) + + // Should match only the first tag type encountered + expect(results.filter((r) => r.matched).length).toBeGreaterThan(0) + expect(results.some((r) => r.data.includes("Think content"))).toBe(true) + }) + + it("should apply custom transform function", () => { + const transform = (chunk: { matched: boolean; data: string }) => ({ + type: chunk.matched ? "reasoning" : "text", + text: chunk.data, + }) + + const matcher = new ReasoningXmlMatcher(transform) + const input = "Normal text Reasoning here more text" + const results = matcher.final(input) + + expect(results[0]).toEqual({ type: "text", text: "Normal text " }) + expect(results[1]).toEqual({ type: "reasoning", text: "Reasoning here" }) + expect(results[2]).toEqual({ type: "text", text: " more text" }) + }) +}) diff --git a/src/utils/reasoning-xml-matcher.ts b/src/utils/reasoning-xml-matcher.ts new file mode 100644 index 000000000000..b672873bf15a --- /dev/null +++ b/src/utils/reasoning-xml-matcher.ts @@ -0,0 +1,102 @@ +import { XmlMatcher, XmlMatcherResult } from "./xml-matcher" + +/** + * A wrapper around XmlMatcher that can match multiple tag names for reasoning blocks. + * This handles , , , and tags uniformly. + * + * It works by using a single XmlMatcher configured to match the shortest tag name + * and then validates if the full tag is one of the reasoning variants. + */ +export class ReasoningXmlMatcher { + private reasoningTags = ["think", "thinking", "reasoning", "thought"] + private results: Result[] = [] + private buffer = "" + private isProcessing = false + + constructor( + private readonly transform?: (chunks: XmlMatcherResult) => Result, + private readonly position = 0, + ) {} + + private processWithTag(input: string, tagName: string): XmlMatcherResult[] { + const matcher = new XmlMatcher(tagName, undefined, this.position) + return matcher.final(input) + } + + private extractMatchedResults(input: string): Result[] { + // Try each tag type to find matches + for (const tag of this.reasoningTags) { + // Check if the input contains this tag + if (input.includes(`<${tag}>`) || input.includes(``)) { + const results = this.processWithTag(input, tag) + if (results.length > 0) { + // Transform results if needed + if (this.transform) { + return results.map(this.transform) + } + return results as Result[] + } + } + } + + // No reasoning tags found, return the input as unmatched + const unmatchedResult: XmlMatcherResult = { + matched: false, + data: input, + } + + if (this.transform) { + return [this.transform(unmatchedResult)] + } + return [unmatchedResult as Result] + } + + update(chunk: string): Result[] { + this.buffer += chunk + this.results = [] + + // Don't process until we have a complete tag or enough content + // This prevents partial processing issues + if (!this.buffer.includes(">")) { + return this.results + } + + // Check if we have any complete reasoning blocks + let hasCompleteBlock = false + for (const tag of this.reasoningTags) { + const openTag = `<${tag}>` + const closeTag = `` + if (this.buffer.includes(openTag) && this.buffer.includes(closeTag)) { + const openIndex = this.buffer.indexOf(openTag) + const closeIndex = this.buffer.indexOf(closeTag, openIndex) + if (closeIndex > openIndex) { + hasCompleteBlock = true + break + } + } + } + + // If we have a complete block, process it + if (hasCompleteBlock) { + const results = this.extractMatchedResults(this.buffer) + this.buffer = "" + this.results = results + } + + return this.results + } + + final(chunk?: string): Result[] { + if (chunk) { + this.buffer += chunk + } + + if (this.buffer.length === 0) { + return [] + } + + const results = this.extractMatchedResults(this.buffer) + this.buffer = "" + return results + } +}