diff --git a/packages/types/src/__tests__/provider-settings.test.ts b/packages/types/src/__tests__/provider-settings.test.ts index 8277320289b..339907a82eb 100644 --- a/packages/types/src/__tests__/provider-settings.test.ts +++ b/packages/types/src/__tests__/provider-settings.test.ts @@ -46,6 +46,12 @@ describe("getApiProtocol", () => { expect(getApiProtocol("litellm", "claude-instant")).toBe("openai") expect(getApiProtocol("ollama", "claude-model")).toBe("openai") }) + + it("should return 'openai' for vscode-lm provider", () => { + expect(getApiProtocol("vscode-lm")).toBe("openai") + expect(getApiProtocol("vscode-lm", "copilot-gpt-4")).toBe("openai") + expect(getApiProtocol("vscode-lm", "copilot-gpt-3.5")).toBe("openai") + }) }) describe("Edge cases", () => { diff --git a/src/api/providers/__tests__/vscode-lm.spec.ts b/src/api/providers/__tests__/vscode-lm.spec.ts index afb349e5e09..318c0ee542c 100644 --- a/src/api/providers/__tests__/vscode-lm.spec.ts +++ b/src/api/providers/__tests__/vscode-lm.spec.ts @@ -168,14 +168,19 @@ describe("VsCodeLmHandler", () => { chunks.push(chunk) } - expect(chunks).toHaveLength(2) // Text chunk + usage chunk - expect(chunks[0]).toEqual({ + expect(chunks).toHaveLength(3) // Initial usage + text chunk + final usage chunk + expect(chunks[0]).toMatchObject({ + type: "usage", + inputTokens: expect.any(Number), + outputTokens: 0, + }) + expect(chunks[1]).toEqual({ type: "text", text: responseText, }) - expect(chunks[1]).toMatchObject({ + expect(chunks[2]).toMatchObject({ type: "usage", - inputTokens: expect.any(Number), + inputTokens: 0, outputTokens: expect.any(Number), }) }) @@ -216,8 +221,13 @@ describe("VsCodeLmHandler", () => { chunks.push(chunk) } - expect(chunks).toHaveLength(2) // Tool call chunk + usage chunk - expect(chunks[0]).toEqual({ + expect(chunks).toHaveLength(3) // Initial usage + tool call chunk + final usage chunk + expect(chunks[0]).toMatchObject({ + type: "usage", + inputTokens: expect.any(Number), + outputTokens: 0, + }) + expect(chunks[1]).toEqual({ type: "text", text: JSON.stringify({ type: "tool_call", ...toolCallData }), }) @@ -234,7 +244,17 @@ describe("VsCodeLmHandler", () => { mockLanguageModelChat.sendRequest.mockRejectedValueOnce(new Error("API Error")) - await expect(handler.createMessage(systemPrompt, messages).next()).rejects.toThrow("API Error") + const stream = handler.createMessage(systemPrompt, messages) + // First chunk should be the initial usage + const firstChunk = await stream.next() + expect(firstChunk.value).toMatchObject({ + type: "usage", + inputTokens: expect.any(Number), + outputTokens: 0, + }) + + // The error should occur when trying to get the next chunk + await expect(stream.next()).rejects.toThrow("API Error") }) }) @@ -262,6 +282,19 @@ describe("VsCodeLmHandler", () => { }) describe("completePrompt", () => { + beforeEach(() => { + // Ensure we have a fresh mock for CancellationTokenSource + const mockCancellationTokenSource = { + token: { + isCancellationRequested: false, + onCancellationRequested: vi.fn(), + }, + cancel: vi.fn(), + dispose: vi.fn(), + } + ;(vscode.CancellationTokenSource as Mock).mockReturnValue(mockCancellationTokenSource) + }) + it("should complete single prompt", async () => { const mockModel = { ...mockLanguageModelChat } ;(vscode.lm.selectChatModels as Mock).mockResolvedValueOnce([mockModel]) diff --git a/src/api/providers/vscode-lm.ts b/src/api/providers/vscode-lm.ts index 6474371beeb..df6b5bf2afd 100644 --- a/src/api/providers/vscode-lm.ts +++ b/src/api/providers/vscode-lm.ts @@ -361,8 +361,20 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan // Calculate input tokens before starting the stream const totalInputTokens: number = await this.calculateTotalInputTokens(systemPrompt, vsCodeLmMessages) + // Yield initial usage with input tokens (similar to Anthropic's message_start) + yield { + type: "usage", + inputTokens: totalInputTokens, + outputTokens: 0, + // VS Code LM doesn't provide cache token information, so we set them to 0 + cacheWriteTokens: 0, + cacheReadTokens: 0, + } + // Accumulate the text and count at the end of the stream to reduce token counting overhead. let accumulatedText: string = "" + let lastTokenCountUpdate: number = 0 + const TOKEN_UPDATE_INTERVAL = 500 // Update token count every 500 characters try { // Create the response stream with minimal required options @@ -393,6 +405,19 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan type: "text", text: chunk.value, } + + // Periodically yield token updates during streaming + if (accumulatedText.length - lastTokenCountUpdate > TOKEN_UPDATE_INTERVAL) { + const currentOutputTokens = await this.internalCountTokens(accumulatedText) + yield { + type: "usage", + inputTokens: 0, + outputTokens: currentOutputTokens, + cacheWriteTokens: 0, + cacheReadTokens: 0, + } + lastTokenCountUpdate = accumulatedText.length + } } else if (chunk instanceof vscode.LanguageModelToolCallPart) { try { // Validate tool call parameters @@ -448,10 +473,14 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan const totalOutputTokens: number = await this.internalCountTokens(accumulatedText) // Report final usage after stream completion + // Note: We report the total tokens here, not incremental, as the UI expects the final total yield { type: "usage", - inputTokens: totalInputTokens, - outputTokens: totalOutputTokens, + inputTokens: 0, // Already reported at the start + outputTokens: totalOutputTokens, // Report the final total + // VS Code LM doesn't provide cache token information, so we set them to 0 + cacheWriteTokens: 0, + cacheReadTokens: 0, } } catch (error: unknown) { this.ensureCleanState()