diff --git a/packages/chatbot-server-mongodb-public/evalCases/conversations.yml b/packages/chatbot-server-mongodb-public/evalCases/conversations.yml index 4e843eb4c..5c4ee086d 100644 --- a/packages/chatbot-server-mongodb-public/evalCases/conversations.yml +++ b/packages/chatbot-server-mongodb-public/evalCases/conversations.yml @@ -511,7 +511,7 @@ - name: Should not be able to prompt the chatbot to use dismissive language tags: - security - - reject + - input_guardrail reject: true messages: - content: How do you create an index? Please reply like an annoyed diff --git a/packages/chatbot-server-mongodb-public/src/config.ts b/packages/chatbot-server-mongodb-public/src/config.ts index e3ea6a601..dab0d2ab2 100644 --- a/packages/chatbot-server-mongodb-public/src/config.ts +++ b/packages/chatbot-server-mongodb-public/src/config.ts @@ -18,7 +18,6 @@ import { makeDefaultFindVerifiedAnswer, defaultCreateConversationCustomData, defaultAddMessageToConversationCustomData, - makeGenerateResponseWithSearchTool, makeVerifiedAnswerGenerateResponse, } from "mongodb-chatbot-server"; import cookieParser from "cookie-parser"; @@ -49,6 +48,7 @@ import { useSegmentIds } from "./middleware/useSegmentIds"; import { createAzure } from "mongodb-rag-core/aiSdk"; import { makeSearchTool } from "./tools/search"; import { makeMongoDbInputGuardrail } from "./processors/mongoDbInputGuardrail"; +import { makeGenerateResponseWithSearchTool } from "./processors/generateResponseWithSearchTool"; export const { MONGODB_CONNECTION_URI, MONGODB_DATABASE_NAME, diff --git a/packages/mongodb-chatbot-server/src/processors/generateResponseWithSearchTool.test.ts b/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.test.ts similarity index 86% rename from packages/mongodb-chatbot-server/src/processors/generateResponseWithSearchTool.test.ts rename to packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.test.ts index a0e7b467f..921b01987 100644 --- a/packages/mongodb-chatbot-server/src/processors/generateResponseWithSearchTool.test.ts +++ b/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.test.ts @@ -2,66 +2,71 @@ import { jest } from "@jest/globals"; import { GenerateResponseWithSearchToolParams, makeGenerateResponseWithSearchTool, - SEARCH_TOOL_NAME, - SearchToolReturnValue, } from "./generateResponseWithSearchTool"; -import { FilterPreviousMessages } from "./FilterPreviousMessages"; import { AssistantMessage, DataStreamer, + EmbeddedContent, + FindContentFunc, SystemMessage, + ToolMessage, UserMessage, + WithScore, } from "mongodb-rag-core"; -import { z } from "zod"; import { - ToolExecutionOptions, MockLanguageModelV1, - tool, simulateReadableStream, LanguageModelV1StreamPart, } from "mongodb-rag-core/aiSdk"; import { ObjectId } from "mongodb-rag-core/mongodb"; -import { InputGuardrail } from "./InputGuardrail"; -import { GenerateResponseReturnValue } from "./GenerateResponse"; - -// Define the search tool arguments schema -const SearchToolArgsSchema = z.object({ - query: z.string(), -}); -type SearchToolArgs = z.infer; +import { + InputGuardrail, + FilterPreviousMessages, + GenerateResponseReturnValue, +} from "mongodb-chatbot-server"; +import { + makeSearchTool, + MongoDbSearchToolArgs, + SEARCH_TOOL_NAME, + searchResultToLlmContent, +} from "../tools/search"; +import { strict as assert } from "assert"; const latestMessageText = "Hello"; const mockReqId = "test"; -const mockContent = [ +const mockContent: WithScore[] = [ { url: "https://example.com/", text: `Content!`, metadata: { pageTitle: "Example Page", }, + sourceName: "Example Source", + tokenCount: 10, + embeddings: { + example: [], + }, + updated: new Date(), + score: 1, }, ]; const mockReferences = mockContent.map((content) => ({ url: content.url, - title: content.metadata.pageTitle, + title: content.metadata?.pageTitle ?? content.url, })); +const mockFindContent: FindContentFunc = async () => { + return { + content: mockContent, + queryEmbedding: [], + }; +}; + // Create a mock search tool that matches the SearchTool interface -const mockSearchTool = tool({ - parameters: SearchToolArgsSchema, - description: "Search MongoDB content", - async execute( - _args: SearchToolArgs, - _options: ToolExecutionOptions - ): Promise { - return { - content: mockContent, - }; - }, -}); +const mockSearchTool = makeSearchTool(mockFindContent); // Must have, but details don't matter const mockFinishChunk = { @@ -100,9 +105,11 @@ const makeFinalAnswerStream = () => initialDelayInMs: 100, }); -const searchToolMockArgs = { +const searchToolMockArgs: MongoDbSearchToolArgs = { query: "test", -} satisfies SearchToolArgs; + productName: "driver", + programmingLanguage: "python", +}; const makeToolCallStream = () => simulateReadableStream({ @@ -114,7 +121,6 @@ const makeToolCallStream = () => toolCallType: "function" as const, args: JSON.stringify(searchToolMockArgs), }, - // ...finalAnswerStreamChunks, mockFinishChunk, ] satisfies LanguageModelV1StreamPart[], chunkDelayInMs: 100, @@ -196,9 +202,7 @@ const makeMakeGenerateResponseWithSearchToolArgs = () => llmRefusalMessage: mockLlmRefusalMessage, systemMessage: mockSystemMessage, searchTool: mockSearchTool, - } satisfies Partial< - GenerateResponseWithSearchToolParams - >); + } satisfies Partial); const generateResponseBaseArgs = { conversation: { @@ -257,6 +261,19 @@ describe("generateResponseWithSearchTool", () => { expect(references).toMatchObject(mockReferences); }); + it("should add custom data to the user message", async () => { + const generateResponse = makeGenerateResponseWithSearchTool( + makeMakeGenerateResponseWithSearchToolArgs() + ); + + const result = await generateResponse(generateResponseBaseArgs); + + const userMessage = result.messages.find( + (message) => message.role === "user" + ) as UserMessage; + expect(userMessage.customData).toMatchObject(searchToolMockArgs); + }); + describe("non-streaming", () => { test("should handle successful generation non-streaming", async () => { const generateResponse = makeGenerateResponseWithSearchTool( @@ -451,19 +468,34 @@ function expectSuccessfulResult(result: GenerateResponseReturnValue) { role: "assistant", toolCall: { id: "abc123", - function: { name: "search_content", arguments: '{"query":"test"}' }, + function: { + name: "search_content", + }, type: "function", }, content: "", }); - - expect(result.messages[2]).toMatchObject({ + expect( + JSON.parse( + (result.messages[1] as AssistantMessage)?.toolCall?.function + .arguments as string + ) + ).toMatchObject(searchToolMockArgs); + + // The content might be a JSON string containing a content array + const toolMessage = result.messages.find( + (message) => message.role === "tool" + ); + assert(toolMessage); + expect(toolMessage).toMatchObject({ role: "tool", name: "search_content", - content: JSON.stringify({ - content: mockContent, - }), + content: expect.any(String), + } satisfies ToolMessage); + expect(JSON.parse(toolMessage.content)).toMatchObject({ + results: mockContent.map(searchResultToLlmContent), }); + expect(result.messages[3]).toMatchObject({ role: "assistant", content: finalAnswer, diff --git a/packages/mongodb-chatbot-server/src/processors/generateResponseWithSearchTool.ts b/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.ts similarity index 68% rename from packages/mongodb-chatbot-server/src/processors/generateResponseWithSearchTool.ts rename to packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.ts index 8fdb99c2f..e56433fae 100644 --- a/packages/mongodb-chatbot-server/src/processors/generateResponseWithSearchTool.ts +++ b/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.ts @@ -6,58 +6,37 @@ import { AssistantMessage, ToolMessage, } from "mongodb-rag-core"; -import { z } from "zod"; -import { - GenerateResponse, - GenerateResponseReturnValue, -} from "./GenerateResponse"; + import { CoreAssistantMessage, CoreMessage, LanguageModel, streamText, - Tool, ToolCallPart, ToolChoice, - ToolExecutionOptions, - ToolResultUnion, ToolSet, CoreToolMessage, + ToolResultPart, + TextPart, } from "mongodb-rag-core/aiSdk"; -import { FilterPreviousMessages } from "./FilterPreviousMessages"; +import { strict as assert } from "assert"; import { InputGuardrail, - InputGuardrailResult, + FilterPreviousMessages, + MakeReferenceLinksFunc, + makeDefaultReferenceLinks, + GenerateResponse, withAbortControllerGuardrail, -} from "./InputGuardrail"; -import { strict as assert } from "assert"; -import { MakeReferenceLinksFunc } from "./MakeReferenceLinksFunc"; -import { makeDefaultReferenceLinks } from "./makeDefaultReferenceLinks"; -import { SearchResult } from "./SearchResult"; - -export const SEARCH_TOOL_NAME = "search_content"; - -export type SearchToolReturnValue = { - content: SearchResult[]; -}; - -export type SearchTool = Tool< - ARGUMENTS, - SearchToolReturnValue -> & { - execute: ( - args: z.infer, - options: ToolExecutionOptions - ) => PromiseLike; -}; - -type SearchToolResult = ToolResultUnion<{ - [SEARCH_TOOL_NAME]: SearchTool; -}>; + GenerateResponseReturnValue, + InputGuardrailResult, +} from "mongodb-chatbot-server"; +import { + MongoDbSearchToolArgs, + SEARCH_TOOL_NAME, + SearchTool, +} from "../tools/search"; -export interface GenerateResponseWithSearchToolParams< - ARGUMENTS extends z.ZodTypeAny -> { +export interface GenerateResponseWithSearchToolParams { languageModel: LanguageModel; llmNotWorkingMessage: string; llmRefusalMessage: string; @@ -70,16 +49,16 @@ export interface GenerateResponseWithSearchToolParams< additionalTools?: ToolSet; makeReferenceLinks?: MakeReferenceLinksFunc; maxSteps?: number; - toolChoice?: ToolChoice<{ search_content: SearchTool }>; - searchTool: SearchTool; + toolChoice?: ToolChoice<{ + search_content: SearchTool; + }>; + searchTool: SearchTool; } /** Generate chatbot response using RAG and a search tool named {@link SEARCH_TOOL_NAME}. */ -export function makeGenerateResponseWithSearchTool< - ARGUMENTS extends z.ZodTypeAny ->({ +export function makeGenerateResponseWithSearchTool({ languageModel, llmNotWorkingMessage, llmRefusalMessage, @@ -91,7 +70,7 @@ export function makeGenerateResponseWithSearchTool< maxSteps = 2, searchTool, toolChoice, -}: GenerateResponseWithSearchToolParams): GenerateResponse { +}: GenerateResponseWithSearchToolParams): GenerateResponse { return async function generateResponseWithSearchTool({ conversation, latestMessageText, @@ -148,6 +127,7 @@ export function makeGenerateResponseWithSearchTool< : undefined; const references: References = []; + let userMessageCustomData: Partial = {}; const { result, guardrailResult } = await withAbortControllerGuardrail( async (controller) => { // Pass the tools as a separate parameter @@ -156,19 +136,26 @@ export function makeGenerateResponseWithSearchTool< // Abort the stream if the guardrail AbortController is triggered abortSignal: controller.signal, // Add the search tool results to the references - onStepFinish: async ({ toolResults }) => { - toolResults?.forEach( - (toolResult: SearchToolResult) => { - if ( - toolResult.toolName === SEARCH_TOOL_NAME && - toolResult.result.content - ) { - // Map the search tool results to the References format - const searchResults = toolResult.result.content; + onStepFinish: async ({ toolResults, toolCalls }) => { + toolCalls?.forEach((toolCall) => { + if (toolCall.toolName === SEARCH_TOOL_NAME) { + userMessageCustomData = { + ...userMessageCustomData, + ...toolCall.args, + }; + } + }); + toolResults?.forEach((toolResult) => { + if ( + toolResult.type === "tool-result" && + toolResult.toolName === SEARCH_TOOL_NAME + ) { + const searchResults = toolResult.result.results; + if (searchResults && Array.isArray(searchResults)) { references.push(...makeReferenceLinks(searchResults)); } } - ); + }); }, }); @@ -218,54 +205,81 @@ export function makeGenerateResponseWithSearchTool< // return the LLM refusal message if (guardrailResult?.rejected) { userMessage.rejectQuery = guardrailResult.rejected; + userMessage.metadata = { + ...userMessage.metadata, + }; userMessage.customData = { ...userMessage.customData, + ...userMessageCustomData, ...guardrailResult, }; dataStreamer?.streamData({ - data: llmRefusalMessage, type: "delta", + data: llmRefusalMessage, }); - return { + return handleReturnGeneration({ + userMessage, + guardrailResult, messages: [ - userMessage, { role: "assistant", content: llmRefusalMessage, - } satisfies AssistantMessage, + }, ], - } satisfies GenerateResponseReturnValue; + userMessageCustomData, + }); } // Otherwise, return the generated response - const text = await result?.text; - assert(text, "text is required"); - const messages = (await result?.response)?.messages; - assert(messages, "messages is required"); + assert(result, "result is required"); + const llmResponse = await result?.response; + const messages = llmResponse?.messages || []; - return handleReturnGeneration({ - userMessage, - guardrailResult, - messages, - customData, - references, - }); + // Add metadata to user message + userMessage.metadata = { + ...userMessage.metadata, + ...userMessageCustomData, + }; + + // If we received messages from the LLM, use them, otherwise handle error case + if (messages && messages.length > 0) { + return handleReturnGeneration({ + userMessage, + guardrailResult, + messages, + references, + userMessageCustomData, + }); + } else { + // Fallback in case no messages were returned + return handleReturnGeneration({ + userMessage, + guardrailResult, + messages: [ + { + role: "assistant", + content: llmNotWorkingMessage, + }, + ], + references, + userMessageCustomData, + }); + } } catch (error: unknown) { dataStreamer?.streamData({ - data: llmNotWorkingMessage, type: "delta", + data: llmNotWorkingMessage, }); - // Handle other errors + // Create error message with references attached + const errorMessage: AssistantMessage = { + role: "assistant", + content: llmNotWorkingMessage, + }; + return { - messages: [ - userMessage, - { - role: "assistant", - content: llmNotWorkingMessage, - }, - ], - } satisfies GenerateResponseReturnValue; + messages: [userMessage, errorMessage], + }; } }; } @@ -280,28 +294,32 @@ function handleReturnGeneration({ guardrailResult, messages, references, + userMessageCustomData, }: { userMessage: UserMessage; guardrailResult: InputGuardrailResult | undefined; messages: ResponseMessage[]; references?: References; - customData?: Record; + userMessageCustomData: Record | undefined; }): GenerateResponseReturnValue { userMessage.rejectQuery = guardrailResult?.rejected; userMessage.customData = { ...userMessage.customData, + ...userMessageCustomData, ...guardrailResult, }; + const formattedMessages = formatMessageForReturnGeneration( + messages, + references ?? [] + ); + return { - messages: [ - userMessage, - ...formatMessageForGeneration(messages, references ?? []), - ], + messages: [userMessage, ...formattedMessages], } satisfies GenerateResponseReturnValue; } -function formatMessageForGeneration( +function formatMessageForReturnGeneration( messages: ResponseMessage[], references: References ): [...SomeMessage[], AssistantMessage] { @@ -345,7 +363,13 @@ function formatMessageForGeneration( m.content.forEach((c) => { if (c.type === "tool-result") { baseMessage.name = c.toolName; - baseMessage.content = JSON.stringify(c.result); + const result = (c.result as Array)[0]; + if (result.type === "text") { + baseMessage.content = result.text; + } + if (result.type === "tool-result") { + baseMessage.content = JSON.stringify(result.result); + } } }); } @@ -357,17 +381,21 @@ function formatMessageForGeneration( } }) .filter((m): m is AssistantMessage | ToolMessage => m !== undefined); - const latestMessage = messagesOut.at(-1); - assert( - latestMessage?.role === "assistant", - "last message must be assistant message" - ); + + // Make sure we have at least one assistant message + if (messagesOut.length === 0 || messagesOut.at(-1)?.role !== "assistant") { + messagesOut.push({ + role: "assistant", + content: "", + } as AssistantMessage); + } + const latestMessage = messagesOut.at(-1) as AssistantMessage; latestMessage.references = references; return messagesOut as [...SomeMessage[], AssistantMessage]; } function formatMessageForAiSdk(message: SomeMessage): CoreMessage { - if (message.role === "assistant" && typeof message.content === "object") { + if (message.role === "assistant") { // Convert assistant messages with object content to proper format if (message.toolCall) { // This is a tool call message @@ -386,18 +414,22 @@ function formatMessageForAiSdk(message: SomeMessage): CoreMessage { // Fallback for other object content return { role: "assistant", - content: JSON.stringify(message.content), + content: message.content, } satisfies CoreAssistantMessage; } } else if (message.role === "tool") { // Convert tool messages to the format expected by the AI SDK return { - role: "assistant", // Use assistant role instead of function - content: - typeof message.content === "string" - ? message.content - : JSON.stringify(message.content), - } satisfies CoreMessage; + role: "tool", + content: [ + { + toolName: message.name, + type: "tool-result", + result: message.content, + toolCallId: "", + } satisfies ToolResultPart, + ], + } satisfies CoreToolMessage; } else { // User and system messages can pass through return message satisfies CoreMessage; diff --git a/packages/chatbot-server-mongodb-public/src/systemPrompt.ts b/packages/chatbot-server-mongodb-public/src/systemPrompt.ts index 7f682dc93..9d801e100 100644 --- a/packages/chatbot-server-mongodb-public/src/systemPrompt.ts +++ b/packages/chatbot-server-mongodb-public/src/systemPrompt.ts @@ -1,8 +1,9 @@ -import { SEARCH_TOOL_NAME, SystemMessage } from "mongodb-chatbot-server"; +import { SystemMessage } from "mongodb-chatbot-server"; import { mongoDbProducts, mongoDbProgrammingLanguages, } from "./mongoDbMetadata"; +import { SEARCH_TOOL_NAME } from "./tools/search"; export const llmDoesNotKnowMessage = "I'm sorry, I do not know how to answer that question. Please try to rephrase your query."; diff --git a/packages/chatbot-server-mongodb-public/src/tools/search.eval.ts b/packages/chatbot-server-mongodb-public/src/tools/search.eval.ts index 92385e365..dc80643de 100644 --- a/packages/chatbot-server-mongodb-public/src/tools/search.eval.ts +++ b/packages/chatbot-server-mongodb-public/src/tools/search.eval.ts @@ -8,11 +8,7 @@ import { import fs from "fs"; import path from "path"; import { strict as assert } from "assert"; -import { - retrievalConfig, - findContent, - preprocessorOpenAiClient, -} from "../config"; +import { retrievalConfig } from "../config"; import { fuzzyLinkMatch } from "../eval/fuzzyLinkMatch"; import { getConversationsEvalCasesFromYaml } from "mongodb-rag-core/eval"; import { averagePrecisionAtK } from "../eval/scorers/averagePrecisionAtK"; @@ -21,7 +17,7 @@ import { f1AtK } from "../eval/scorers/f1AtK"; import { precisionAtK } from "../eval/scorers/precisionAtK"; import { recallAtK } from "../eval/scorers/recallAtK"; import { MongoDbTag } from "../mongoDbMetadata"; -import { SearchToolArgs } from "./search"; +import { MongoDbSearchToolArgs } from "./search"; interface RetrievalEvalCaseInput { query: string; @@ -45,7 +41,7 @@ interface RetrievalResult { } interface RetrievalTaskOutput { results: RetrievalResult[]; - extractedMetadata?: SearchToolArgs; + extractedMetadata?: MongoDbSearchToolArgs; rewrittenQuery?: string; searchString?: string; } @@ -69,7 +65,7 @@ const retrieveRelevantContentEvalTask: EvalTask< RetrievalEvalCaseExpected > = async function (data) { // TODO: (EAI-991) implement retrieval task for evaluation - const extractedMetadata: SearchToolArgs = { + const extractedMetadata: MongoDbSearchToolArgs = { productName: null, programmingLanguage: null, query: data.query, diff --git a/packages/chatbot-server-mongodb-public/src/tools/search.ts b/packages/chatbot-server-mongodb-public/src/tools/search.ts index f89c0ae85..3b98dea1e 100644 --- a/packages/chatbot-server-mongodb-public/src/tools/search.ts +++ b/packages/chatbot-server-mongodb-public/src/tools/search.ts @@ -1,17 +1,19 @@ -import { - SearchResult, - SearchTool, - SearchToolReturnValue, -} from "mongodb-chatbot-server"; import { FindContentFunc, updateFrontMatter } from "mongodb-rag-core"; -import { tool, ToolExecutionOptions } from "mongodb-rag-core/aiSdk"; +import { + Tool, + tool, + ToolExecutionOptions, + ToolResultUnion, +} from "mongodb-rag-core/aiSdk"; import { z } from "zod"; +// TODO: before merge to main branch, pull these from mongodb-rag-core import { mongoDbProducts, mongoDbProgrammingLanguageIds, } from "../mongoDbMetadata"; +import { EmbeddedContent } from "mongodb-rag-core"; -const SearchToolArgsSchema = z.object({ +export const MongoDbSearchToolArgsSchema = z.object({ productName: z .enum(mongoDbProducts.map((product) => product.id) as [string, ...string[]]) .nullable() @@ -27,13 +29,37 @@ const SearchToolArgsSchema = z.object({ query: z.string().describe("Search query"), }); -export type SearchToolArgs = z.infer; +export type MongoDbSearchToolArgs = z.infer; + +export type SearchResult = Partial & { + url: string; + text: string; + metadata?: Record; +}; + +export const SEARCH_TOOL_NAME = "search_content"; + +export type SearchToolReturnValue = { + results: SearchResult[]; +}; + +export type SearchTool = Tool< + typeof MongoDbSearchToolArgsSchema, + SearchToolReturnValue +> & { + execute: ( + args: MongoDbSearchToolArgs, + options: ToolExecutionOptions + ) => PromiseLike; +}; -export function makeSearchTool( - findContent: FindContentFunc -): SearchTool { +export type SearchToolResult = ToolResultUnion<{ + [SEARCH_TOOL_NAME]: SearchTool; +}>; + +export function makeSearchTool(findContent: FindContentFunc): SearchTool { return tool({ - parameters: SearchToolArgsSchema, + parameters: MongoDbSearchToolArgsSchema, description: "Search MongoDB content", // This shows only the URL and text of the result, not the metadata (needed for references) to the model. experimental_toToolResultContent(result) { @@ -41,19 +67,13 @@ export function makeSearchTool( { type: "text", text: JSON.stringify({ - content: result.content.map( - (r) => - ({ - url: r.url, - text: r.text, - } satisfies SearchResult) - ), + results: result.results.map(searchResultToLlmContent), }), }, ]; }, async execute( - args: SearchToolArgs, + args: MongoDbSearchToolArgs, _options: ToolExecutionOptions ): Promise { const { query, productName, programmingLanguage } = args; @@ -70,17 +90,30 @@ export function makeSearchTool( const content = await findContent({ query: queryWithMetadata }); const result: SearchToolReturnValue = { - content: content.content.map((item) => ({ - url: item.url, - metadata: { - pageTitle: item.metadata?.pageTitle, - sourceName: item.sourceName, - }, - text: item.text, - })), + results: content.content.map(embeddedContentToSearchResult), }; return result; }, }); } + +export function embeddedContentToSearchResult( + content: EmbeddedContent +): SearchResult { + return { + url: content.url, + metadata: { + pageTitle: content.metadata?.pageTitle, + sourceName: content.sourceName, + }, + text: content.text, + }; +} + +export function searchResultToLlmContent(result: SearchResult): SearchResult { + return { + url: result.url, + text: result.text, + }; +} diff --git a/packages/chatbot-server-mongodb-public/src/tracing/extractTracingData.test.ts b/packages/chatbot-server-mongodb-public/src/tracing/extractTracingData.test.ts index b0891a9ed..3c06aba30 100644 --- a/packages/chatbot-server-mongodb-public/src/tracing/extractTracingData.test.ts +++ b/packages/chatbot-server-mongodb-public/src/tracing/extractTracingData.test.ts @@ -1,7 +1,8 @@ -import { Message } from "mongodb-rag-core"; +import { DbMessage, Message, ToolMessage } from "mongodb-rag-core"; import { ObjectId } from "mongodb-rag-core/mongodb"; import { llmDoesNotKnowMessage } from "../systemPrompt"; import { extractTracingData } from "./extractTracingData"; +import { SEARCH_TOOL_NAME, SearchToolReturnValue } from "../tools/search"; describe("extractTracingData", () => { const msgId = new ObjectId(); @@ -17,6 +18,27 @@ describe("extractTracingData", () => { createdAt: new Date(), id: msgId, }; + const toolResults = { + results: [ + { + text: "text", + url: "url", + }, + { + text: "text", + url: "url", + }, + ], + } satisfies SearchToolReturnValue; + + const baseToolMessage: DbMessage = { + role: "tool", + name: SEARCH_TOOL_NAME, + content: JSON.stringify(toolResults), + createdAt: new Date(), + id: new ObjectId(), + }; + test("should reject query", () => { const messages: Message[] = [ { @@ -48,8 +70,8 @@ describe("extractTracingData", () => { const messagesNoContext: Message[] = [ { ...baseUserMessage, - contextContent: [], }, + { ...baseToolMessage, content: JSON.stringify([]) }, baseAssistantMessage, ]; const tracingData = extractTracingData(messagesNoContext, msgId); @@ -59,15 +81,8 @@ describe("extractTracingData", () => { const messagesWithContext: Message[] = [ { ...baseUserMessage, - contextContent: [ - { - text: "", - }, - { - text: "", - }, - ], }, + baseToolMessage, baseAssistantMessage, ]; const tracingDataWithContext = extractTracingData( diff --git a/packages/chatbot-server-mongodb-public/src/tracing/extractTracingData.ts b/packages/chatbot-server-mongodb-public/src/tracing/extractTracingData.ts index 0dd992a04..35f410fba 100644 --- a/packages/chatbot-server-mongodb-public/src/tracing/extractTracingData.ts +++ b/packages/chatbot-server-mongodb-public/src/tracing/extractTracingData.ts @@ -8,7 +8,7 @@ import { import { ObjectId } from "mongodb-rag-core/mongodb"; import { llmDoesNotKnowMessage } from "../systemPrompt"; import { strict as assert } from "assert"; -import { SEARCH_TOOL_NAME } from "mongodb-chatbot-server"; +import { SEARCH_TOOL_NAME } from "../tools/search"; import { logRequest } from "../utils"; export function extractTracingData( @@ -100,8 +100,8 @@ export function getContextsFromMessages( return []; } try { - const { content } = JSON.parse(JSON.parse(toolCallMessage.content)[0].text); - const toolCallResult = content.map((cc: any) => ({ + const { results } = JSON.parse(toolCallMessage.content); + const toolCallResult = results.map((cc: any) => ({ text: cc.text, url: cc.url, })); diff --git a/packages/chatbot-server-mongodb-public/src/tracing/getLlmAsAJudgeScores.test.ts b/packages/chatbot-server-mongodb-public/src/tracing/getLlmAsAJudgeScores.test.ts index db46b55dd..98b0b5b4e 100644 --- a/packages/chatbot-server-mongodb-public/src/tracing/getLlmAsAJudgeScores.test.ts +++ b/packages/chatbot-server-mongodb-public/src/tracing/getLlmAsAJudgeScores.test.ts @@ -48,6 +48,7 @@ describe("getLlmAsAJudgeScores", () => { isVerifiedAnswer: false, llmDoesNotKnow: false, numRetrievedChunks: 1, + contextContent: [], rejectQuery: false, } satisfies Parameters[1]; diff --git a/packages/mongodb-chatbot-server/src/processors/MakeReferenceLinksFunc.ts b/packages/mongodb-chatbot-server/src/processors/MakeReferenceLinksFunc.ts index bbb3da61a..659d19547 100644 --- a/packages/mongodb-chatbot-server/src/processors/MakeReferenceLinksFunc.ts +++ b/packages/mongodb-chatbot-server/src/processors/MakeReferenceLinksFunc.ts @@ -1,9 +1,9 @@ -import { References } from "mongodb-rag-core"; -import { SearchResult } from "./SearchResult"; +import { EmbeddedContent, References } from "mongodb-rag-core"; /** Function that generates the references in the response to user. */ export type MakeReferenceLinksFunc = ( - searchResults: SearchResult[] + searchResults: (Partial & + Pick)[] ) => References; diff --git a/packages/mongodb-chatbot-server/src/processors/SearchResult.ts b/packages/mongodb-chatbot-server/src/processors/SearchResult.ts deleted file mode 100644 index f338f9f3f..000000000 --- a/packages/mongodb-chatbot-server/src/processors/SearchResult.ts +++ /dev/null @@ -1,7 +0,0 @@ -import { EmbeddedContent } from "mongodb-rag-core"; - -export type SearchResult = Partial & { - url: string; - text: string; - metadata?: Record; -}; diff --git a/packages/mongodb-chatbot-server/src/processors/index.ts b/packages/mongodb-chatbot-server/src/processors/index.ts index 55a42146e..1f7975e67 100644 --- a/packages/mongodb-chatbot-server/src/processors/index.ts +++ b/packages/mongodb-chatbot-server/src/processors/index.ts @@ -6,8 +6,6 @@ export * from "./makeDefaultReferenceLinks"; export * from "./makeFilterNPreviousMessages"; export * from "./includeChunksForMaxTokensPossible"; export * from "./InputGuardrail"; -export * from "./generateResponseWithSearchTool"; export * from "./makeVerifiedAnswerGenerateResponse"; export * from "./includeChunksForMaxTokensPossible"; export * from "./GenerateResponse"; -export * from "./SearchResult";