Skip to content

Commit c87dc74

Browse files
mongodbenBen Perlmutter
andauthored
(EAI-988): Refactor GenerateResponse for tool call support (#687)
* refactor GenerateRespose * Clean up imports * consolidate generate user prompt to the legacy file * update test config imports * Fix broken tests * (EAI-989): Refactor verified answers to wrap `GenerateResponse` (#688) verified answer generate response Co-authored-by: Ben Perlmutter <mongodben@mongodb.com> * handle streaming * separate generateresponse * typo fix --------- Co-authored-by: Ben Perlmutter <mongodben@mongodb.com>
1 parent f7bac39 commit c87dc74

22 files changed

+717
-1075
lines changed

packages/chatbot-server-mongodb-public/src/config.ts

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@ import {
1616
requireValidIpAddress,
1717
requireRequestOrigin,
1818
AddCustomDataFunc,
19-
makeVerifiedAnswerGenerateUserPrompt,
2019
makeDefaultFindVerifiedAnswer,
2120
defaultCreateConversationCustomData,
2221
defaultAddMessageToConversationCustomData,
22+
makeLegacyGenerateResponse,
23+
makeVerifiedAnswerGenerateResponse,
2324
} from "mongodb-chatbot-server";
2425
import cookieParser from "cookie-parser";
2526
import { makeStepBackRagGenerateUserPrompt } from "./processors/makeStepBackRagGenerateUserPrompt";
@@ -173,8 +174,8 @@ export const preprocessorOpenAiClient = wrapOpenAI(
173174
})
174175
);
175176

176-
export const generateUserPrompt = wrapTraced(
177-
makeVerifiedAnswerGenerateUserPrompt({
177+
export const generateResponse = wrapTraced(
178+
makeVerifiedAnswerGenerateResponse({
178179
findVerifiedAnswer,
179180
onVerifiedAnswerFound: (verifiedAnswer) => {
180181
return {
@@ -183,11 +184,17 @@ export const generateUserPrompt = wrapTraced(
183184
};
184185
},
185186
onNoVerifiedAnswerFound: wrapTraced(
186-
makeStepBackRagGenerateUserPrompt({
187-
openAiClient: preprocessorOpenAiClient,
188-
model: retrievalConfig.preprocessorLlm,
189-
findContent,
190-
numPrecedingMessagesToInclude: 6,
187+
makeLegacyGenerateResponse({
188+
llm,
189+
generateUserPrompt: makeStepBackRagGenerateUserPrompt({
190+
openAiClient: preprocessorOpenAiClient,
191+
model: retrievalConfig.preprocessorLlm,
192+
findContent,
193+
numPrecedingMessagesToInclude: 6,
194+
}),
195+
systemMessage: systemPrompt,
196+
llmNotWorkingMessage: "LLM not working. Sad!",
197+
noRelevantContentMessage: "No relevant content found. Sad!",
191198
}),
192199
{ name: "makeStepBackRagGenerateUserPrompt" }
193200
),
@@ -237,14 +244,13 @@ const segmentConfig = SEGMENT_WRITE_KEY
237244

238245
export const config: AppConfig = {
239246
conversationsRouterConfig: {
240-
llm,
241247
middleware: [
242248
blockGetRequests,
243249
requireValidIpAddress(),
244250
requireRequestOrigin(),
245251
useSegmentIds(),
246-
cookieParser(),
247252
redactConnectionUri(),
253+
cookieParser(),
248254
],
249255
createConversationCustomData: !isProduction
250256
? createConversationCustomDataWithAuthUser
@@ -294,8 +300,7 @@ export const config: AppConfig = {
294300
: undefined,
295301
segment: segmentConfig,
296302
}),
297-
generateUserPrompt,
298-
systemPrompt,
303+
generateResponse,
299304
maxUserMessagesInConversation: 50,
300305
maxUserCommentLength: 500,
301306
conversations,
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import {
2+
ConversationCustomData,
3+
DataStreamer,
4+
Conversation,
5+
SomeMessage,
6+
} from "mongodb-rag-core";
7+
import { Request as ExpressRequest } from "express";
8+
9+
export type ClientContext = Record<string, unknown>;
10+
11+
export interface GenerateResponseParams {
12+
shouldStream: boolean;
13+
latestMessageText: string;
14+
clientContext?: ClientContext;
15+
customData?: ConversationCustomData;
16+
dataStreamer?: DataStreamer;
17+
reqId: string;
18+
conversation: Conversation;
19+
request?: ExpressRequest;
20+
}
21+
22+
export interface GenerateResponseReturnValue {
23+
messages: SomeMessage[];
24+
}
25+
26+
export type GenerateResponse = (
27+
params: GenerateResponseParams
28+
) => Promise<GenerateResponseReturnValue>;

packages/mongodb-chatbot-server/src/processors/GenerateUserPromptFunc.ts

Lines changed: 0 additions & 76 deletions
This file was deleted.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import { EmbeddedContent } from "mongodb-rag-core";
2+
import { includeChunksForMaxTokensPossible } from "./includeChunksForMaxTokensPossible";
3+
4+
const embeddings = {
5+
modelName: [0.1, 0.2, 0.3],
6+
};
7+
8+
describe("includeChunksForMaxTokensPossible()", () => {
9+
const content: EmbeddedContent[] = [
10+
{
11+
url: "https://mongodb.com/docs/realm/sdk/node/",
12+
text: "foo foo foo",
13+
tokenCount: 100,
14+
embeddings,
15+
sourceName: "realm",
16+
updated: new Date(),
17+
},
18+
{
19+
url: "https://mongodb.com/docs/realm/sdk/node/",
20+
text: "bar bar bar",
21+
tokenCount: 100,
22+
embeddings,
23+
sourceName: "realm",
24+
updated: new Date(),
25+
},
26+
{
27+
url: "https://mongodb.com/docs/realm/sdk/node/",
28+
text: "baz baz baz",
29+
tokenCount: 100,
30+
embeddings,
31+
sourceName: "realm",
32+
updated: new Date(),
33+
},
34+
];
35+
test("Should include all chunks if less that max tokens", () => {
36+
const maxTokens = 1000;
37+
const includedChunks = includeChunksForMaxTokensPossible({
38+
content,
39+
maxTokens,
40+
});
41+
expect(includedChunks).toStrictEqual(content);
42+
});
43+
test("should only include subset of chunks that fit within max tokens, inclusive", () => {
44+
const maxTokens = 200;
45+
const includedChunks = includeChunksForMaxTokensPossible({
46+
content,
47+
maxTokens,
48+
});
49+
expect(includedChunks).toStrictEqual(content.slice(0, 2));
50+
const maxTokens2 = maxTokens + 1;
51+
const includedChunks2 = includeChunksForMaxTokensPossible({
52+
content,
53+
maxTokens: maxTokens2,
54+
});
55+
expect(includedChunks2).toStrictEqual(content.slice(0, 2));
56+
});
57+
});
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import { EmbeddedContent } from "mongodb-rag-core";
2+
3+
/**
4+
This function returns the chunks that can fit in the maxTokens.
5+
It limits the number of tokens that are sent to the LLM.
6+
*/
7+
export function includeChunksForMaxTokensPossible({
8+
maxTokens,
9+
content,
10+
}: {
11+
maxTokens: number;
12+
content: EmbeddedContent[];
13+
}): EmbeddedContent[] {
14+
let total = 0;
15+
const fitRangeEndIndex = content.findIndex(
16+
({ tokenCount }) => (total += tokenCount) > maxTokens
17+
);
18+
return fitRangeEndIndex === -1 ? content : content.slice(0, fitRangeEndIndex);
19+
}
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
export * from "./FilterPreviousMessages";
2-
export * from "./GenerateUserPromptFunc";
32
export * from "./MakeReferenceLinksFunc";
43
export * from "./QueryPreprocessorFunc";
54
export * from "./filterOnlySystemPrompt";
65
export * from "./makeDefaultReferenceLinks";
76
export * from "./makeFilterNPreviousMessages";
8-
export * from "./makeRagGenerateUserPrompt";
9-
export * from "./makeVerifiedAnswerGenerateUserPrompt";
7+
export * from "./makeVerifiedAnswerGenerateResponse";
8+
export * from "./includeChunksForMaxTokensPossible";
9+
export * from "./GenerateResponse";

packages/mongodb-chatbot-server/src/processors/makeFilterNPreviousMessages.test.ts

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,6 @@ const mockConversationBase: Conversation = {
77
messages: [],
88
createdAt: new Date(),
99
};
10-
const systemMessage = {
11-
role: "system",
12-
content: "Hello",
13-
id: new ObjectId(),
14-
createdAt: new Date(),
15-
} satisfies Message;
1610
const userMessage = {
1711
role: "user",
1812
content: "Hi",
@@ -21,25 +15,7 @@ const userMessage = {
2115
} satisfies Message;
2216

2317
describe("makeFilterNPreviousMessages", () => {
24-
it("should throw an error when there are no messages", async () => {
25-
const filterNPreviousMessages = makeFilterNPreviousMessages(2);
26-
await expect(filterNPreviousMessages(mockConversationBase)).rejects.toThrow(
27-
"First message must be system prompt"
28-
);
29-
});
30-
31-
it("should throw an error when the first message is not a system message", async () => {
32-
const filterNPreviousMessages = makeFilterNPreviousMessages(2);
33-
const conversation = {
34-
...mockConversationBase,
35-
messages: [userMessage],
36-
};
37-
await expect(filterNPreviousMessages(conversation)).rejects.toThrow(
38-
"First message must be system prompt"
39-
);
40-
});
41-
42-
it("should return the system message and the n latest messages when there are more than n messages", async () => {
18+
it("should return the n latest messages when there are more than n messages", async () => {
4319
const filterNPreviousMessages = makeFilterNPreviousMessages(2);
4420
const userMessage2 = {
4521
role: "user",
@@ -56,9 +32,9 @@ describe("makeFilterNPreviousMessages", () => {
5632

5733
const conversation = {
5834
...mockConversationBase,
59-
messages: [systemMessage, userMessage, userMessage2, userMessage3],
35+
messages: [userMessage, userMessage2, userMessage3],
6036
};
6137
const result = await filterNPreviousMessages(conversation);
62-
expect(result).toEqual([systemMessage, userMessage2, userMessage3]);
38+
expect(result).toEqual([userMessage2, userMessage3]);
6339
});
6440
});
Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,13 @@
11
import { FilterPreviousMessages } from "./FilterPreviousMessages";
2-
import { strict as assert } from "assert";
32
/**
43
Creates a filter that only includes the previous n messages in the conversations.
5-
The first message in the conversation **must** be the system prompt.
64
@param n - Number of previous messages to include.
75
*/
86
export const makeFilterNPreviousMessages = (
97
n: number
108
): FilterPreviousMessages => {
119
return async (conversation) => {
12-
assert(
13-
conversation.messages[0]?.role === "system",
14-
"First message must be system prompt"
15-
);
16-
// Always include the system prompt.
17-
const systemPrompt = conversation.messages[0];
18-
1910
// Get the n latest messages.
20-
const nLatestMessages = conversation.messages.slice(1).slice(-n);
21-
22-
return [systemPrompt, ...nLatestMessages];
11+
return conversation.messages.slice(-n);
2312
};
2413
};

0 commit comments

Comments
 (0)