diff --git a/package-lock.json b/package-lock.json index 5b39e9ec7..9c66c022c 100644 --- a/package-lock.json +++ b/package-lock.json @@ -30,14 +30,48 @@ "dev": true, "license": "MIT" }, + "node_modules/@ai-sdk/azure": { + "version": "1.3.21", + "resolved": "https://registry.npmjs.org/@ai-sdk/azure/-/azure-1.3.21.tgz", + "integrity": "sha512-GiLnGScVUerruvkS6E3Rd55YXBb1TI15c5y9GxphJEPsU8jzVha5GKpN3+9hWM9OBgIrJlWKumlSfpVpbcNFJA==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/openai": "1.3.20", + "@ai-sdk/provider": "1.1.3", + "@ai-sdk/provider-utils": "2.2.7" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.0.0" + } + }, + "node_modules/@ai-sdk/azure/node_modules/@ai-sdk/provider-utils": { + "version": "2.2.7", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.7.tgz", + "integrity": "sha512-kM0xS3GWg3aMChh9zfeM+80vEZfXzR3JEUBdycZLtbRZ2TRT8xOj3WodGHPb06sUK5yD7pAXC/P7ctsi2fvUGQ==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "1.1.3", + "nanoid": "^3.3.8", + "secure-json-parse": "^2.7.0" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.23.8" + } + }, "node_modules/@ai-sdk/openai": { - "version": "1.3.6", - "resolved": "https://registry.npmjs.org/@ai-sdk/openai/-/openai-1.3.6.tgz", - "integrity": "sha512-Lyp6W6dg+ERMJru3DI8/pWAjXLB0GbMMlXh4jxA3mVny8CJHlCAjlEJRuAdLg1/CFz4J1UDN2/4qBnIWtLFIqw==", + "version": "1.3.20", + "resolved": "https://registry.npmjs.org/@ai-sdk/openai/-/openai-1.3.20.tgz", + "integrity": "sha512-/DflUy7ROG9k6n6YTXMBFPbujBKnbGY58f3CwvicLvDar9nDAloVnUWd3LUoOxpSVnX8vtQ7ngxF52SLWO6RwQ==", "license": "Apache-2.0", "dependencies": { - "@ai-sdk/provider": "1.1.0", - "@ai-sdk/provider-utils": "2.2.3" + "@ai-sdk/provider": "1.1.3", + "@ai-sdk/provider-utils": "2.2.7" }, "engines": { "node": ">=18" @@ -47,12 +81,12 @@ } }, "node_modules/@ai-sdk/openai/node_modules/@ai-sdk/provider-utils": { - "version": "2.2.3", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.3.tgz", - "integrity": "sha512-o3fWTzkxzI5Af7U7y794MZkYNEsxbjLam2nxyoUZSScqkacb7vZ3EYHLh21+xCcSSzEC161C7pZAGHtC0hTUMw==", + "version": "2.2.7", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.7.tgz", + "integrity": "sha512-kM0xS3GWg3aMChh9zfeM+80vEZfXzR3JEUBdycZLtbRZ2TRT8xOj3WodGHPb06sUK5yD7pAXC/P7ctsi2fvUGQ==", "license": "Apache-2.0", "dependencies": { - "@ai-sdk/provider": "1.1.0", + "@ai-sdk/provider": "1.1.3", "nanoid": "^3.3.8", "secure-json-parse": "^2.7.0" }, @@ -64,9 +98,9 @@ } }, "node_modules/@ai-sdk/provider": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.0.tgz", - "integrity": "sha512-0M+qjp+clUD0R1E5eWQFhxEvWLNaOtGQRUaBn8CUABnSKredagq92hUS9VjOzGsTm37xLfpaxl97AVtbeOsHew==", + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.3.tgz", + "integrity": "sha512-qZMxYJ0qqX/RfnuIaab+zp8UAeJn/ygXXAffR5I4N0n1IrvA6qBsjc8hXLmBiMV2zoXlifkacF7sEFnYnjBcqg==", "license": "Apache-2.0", "dependencies": { "json-schema": "^0.4.0" @@ -76,13 +110,15 @@ } }, "node_modules/@ai-sdk/provider-utils": { - "version": "1.0.9", + "version": "1.0.22", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-1.0.22.tgz", + "integrity": "sha512-YHK2rpj++wnLVc9vPGzGFP3Pjeld2MwhKinetA0zKXOoHAT/Jit5O8kZsxcSlJPu9wvcGT1UGZEjZrtO7PfFOQ==", "license": "Apache-2.0", "dependencies": { - "@ai-sdk/provider": "0.0.17", - "eventsource-parser": "1.1.2", - "nanoid": "3.3.6", - "secure-json-parse": "2.7.0" + "@ai-sdk/provider": "0.0.26", + "eventsource-parser": "^1.1.2", + "nanoid": "^3.3.7", + "secure-json-parse": "^2.7.0" }, "engines": { "node": ">=18" @@ -97,44 +133,33 @@ } }, "node_modules/@ai-sdk/provider-utils/node_modules/@ai-sdk/provider": { - "version": "0.0.17", + "version": "0.0.26", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-0.0.26.tgz", + "integrity": "sha512-dQkfBDs2lTYpKM8389oopPdQgIU007GQyCbuPPrV+K6MtSII3HBfE0stUIMXUb44L+LK1t6GXPP7wjSzjO6uKg==", "license": "Apache-2.0", "dependencies": { - "json-schema": "0.4.0" + "json-schema": "^0.4.0" }, "engines": { "node": ">=18" } }, - "node_modules/@ai-sdk/provider-utils/node_modules/nanoid": { - "version": "3.3.6", - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/ai" - } - ], - "license": "MIT", - "bin": { - "nanoid": "bin/nanoid.cjs" - }, - "engines": { - "node": "^10 || ^12 || ^13.7 || ^14 || >=15.0.1" - } - }, "node_modules/@ai-sdk/react": { - "version": "0.0.40", + "version": "0.0.70", + "resolved": "https://registry.npmjs.org/@ai-sdk/react/-/react-0.0.70.tgz", + "integrity": "sha512-GnwbtjW4/4z7MleLiW+TOZC2M29eCg1tOUpuEiYFMmFNZK8mkrqM0PFZMo6UsYeUYMWqEOOcPOU9OQVJMJh7IQ==", "license": "Apache-2.0", "dependencies": { - "@ai-sdk/provider-utils": "1.0.9", - "@ai-sdk/ui-utils": "0.0.28", - "swr": "2.2.5" + "@ai-sdk/provider-utils": "1.0.22", + "@ai-sdk/ui-utils": "0.0.50", + "swr": "^2.2.5", + "throttleit": "2.1.0" }, "engines": { "node": ">=18" }, "peerDependencies": { - "react": "^18 || ^19", + "react": "^18 || ^19 || ^19.0.0-rc", "zod": "^3.0.0" }, "peerDependenciesMeta": { @@ -147,11 +172,13 @@ } }, "node_modules/@ai-sdk/solid": { - "version": "0.0.31", + "version": "0.0.54", + "resolved": "https://registry.npmjs.org/@ai-sdk/solid/-/solid-0.0.54.tgz", + "integrity": "sha512-96KWTVK+opdFeRubqrgaJXoNiDP89gNxFRWUp0PJOotZW816AbhUf4EnDjBjXTLjXL1n0h8tGSE9sZsRkj9wQQ==", "license": "Apache-2.0", "dependencies": { - "@ai-sdk/provider-utils": "1.0.9", - "@ai-sdk/ui-utils": "0.0.28" + "@ai-sdk/provider-utils": "1.0.22", + "@ai-sdk/ui-utils": "0.0.50" }, "engines": { "node": ">=18" @@ -166,18 +193,20 @@ } }, "node_modules/@ai-sdk/svelte": { - "version": "0.0.33", + "version": "0.0.57", + "resolved": "https://registry.npmjs.org/@ai-sdk/svelte/-/svelte-0.0.57.tgz", + "integrity": "sha512-SyF9ItIR9ALP9yDNAD+2/5Vl1IT6kchgyDH8xkmhysfJI6WrvJbtO1wdQ0nylvPLcsPoYu+cAlz1krU4lFHcYw==", "license": "Apache-2.0", "dependencies": { - "@ai-sdk/provider-utils": "1.0.9", - "@ai-sdk/ui-utils": "0.0.28", - "sswr": "2.1.0" + "@ai-sdk/provider-utils": "1.0.22", + "@ai-sdk/ui-utils": "0.0.50", + "sswr": "^2.1.0" }, "engines": { "node": ">=18" }, "peerDependencies": { - "svelte": "^3.0.0 || ^4.0.0" + "svelte": "^3.0.0 || ^4.0.0 || ^5.0.0" }, "peerDependenciesMeta": { "svelte": { @@ -186,12 +215,16 @@ } }, "node_modules/@ai-sdk/ui-utils": { - "version": "0.0.28", + "version": "0.0.50", + "resolved": "https://registry.npmjs.org/@ai-sdk/ui-utils/-/ui-utils-0.0.50.tgz", + "integrity": "sha512-Z5QYJVW+5XpSaJ4jYCCAVG7zIAuKOOdikhgpksneNmKvx61ACFaf98pmOd+xnjahl0pIlc/QIe6O4yVaJ1sEaw==", "license": "Apache-2.0", "dependencies": { - "@ai-sdk/provider": "0.0.17", - "@ai-sdk/provider-utils": "1.0.9", - "secure-json-parse": "2.7.0" + "@ai-sdk/provider": "0.0.26", + "@ai-sdk/provider-utils": "1.0.22", + "json-schema": "^0.4.0", + "secure-json-parse": "^2.7.0", + "zod-to-json-schema": "^3.23.3" }, "engines": { "node": ">=18" @@ -206,22 +239,26 @@ } }, "node_modules/@ai-sdk/ui-utils/node_modules/@ai-sdk/provider": { - "version": "0.0.17", + "version": "0.0.26", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-0.0.26.tgz", + "integrity": "sha512-dQkfBDs2lTYpKM8389oopPdQgIU007GQyCbuPPrV+K6MtSII3HBfE0stUIMXUb44L+LK1t6GXPP7wjSzjO6uKg==", "license": "Apache-2.0", "dependencies": { - "json-schema": "0.4.0" + "json-schema": "^0.4.0" }, "engines": { "node": ">=18" } }, "node_modules/@ai-sdk/vue": { - "version": "0.0.32", + "version": "0.0.59", + "resolved": "https://registry.npmjs.org/@ai-sdk/vue/-/vue-0.0.59.tgz", + "integrity": "sha512-+ofYlnqdc8c4F6tM0IKF0+7NagZRAiqBJpGDJ+6EYhDW8FHLUP/JFBgu32SjxSxC6IKFZxEnl68ZoP/Z38EMlw==", "license": "Apache-2.0", "dependencies": { - "@ai-sdk/provider-utils": "1.0.9", - "@ai-sdk/ui-utils": "0.0.28", - "swrv": "1.0.4" + "@ai-sdk/provider-utils": "1.0.22", + "@ai-sdk/ui-utils": "0.0.50", + "swrv": "^1.0.4" }, "engines": { "node": ">=18" @@ -23327,15 +23364,15 @@ } }, "node_modules/ai": { - "version": "4.3.9", - "resolved": "https://registry.npmjs.org/ai/-/ai-4.3.9.tgz", - "integrity": "sha512-P2RpV65sWIPdUlA4f1pcJ11pB0N1YmqPVLEmC4j8WuBwKY0L3q9vGhYPh0Iv+spKHKyn0wUbMfas+7Z6nTfS0g==", + "version": "4.3.16", + "resolved": "https://registry.npmjs.org/ai/-/ai-4.3.16.tgz", + "integrity": "sha512-KUDwlThJ5tr2Vw0A1ZkbDKNME3wzWhuVfAOwIvFUzl1TPVDFAXDFTXio3p+jaKneB+dKNCvFFlolYmmgHttG1g==", "license": "Apache-2.0", "dependencies": { "@ai-sdk/provider": "1.1.3", - "@ai-sdk/provider-utils": "2.2.7", - "@ai-sdk/react": "1.2.9", - "@ai-sdk/ui-utils": "1.2.8", + "@ai-sdk/provider-utils": "2.2.8", + "@ai-sdk/react": "1.2.12", + "@ai-sdk/ui-utils": "1.2.11", "@opentelemetry/api": "1.9.0", "jsondiffpatch": "0.6.0" }, @@ -23352,22 +23389,10 @@ } } }, - "node_modules/ai/node_modules/@ai-sdk/provider": { - "version": "1.1.3", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.3.tgz", - "integrity": "sha512-qZMxYJ0qqX/RfnuIaab+zp8UAeJn/ygXXAffR5I4N0n1IrvA6qBsjc8hXLmBiMV2zoXlifkacF7sEFnYnjBcqg==", - "license": "Apache-2.0", - "dependencies": { - "json-schema": "^0.4.0" - }, - "engines": { - "node": ">=18" - } - }, "node_modules/ai/node_modules/@ai-sdk/provider-utils": { - "version": "2.2.7", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.7.tgz", - "integrity": "sha512-kM0xS3GWg3aMChh9zfeM+80vEZfXzR3JEUBdycZLtbRZ2TRT8xOj3WodGHPb06sUK5yD7pAXC/P7ctsi2fvUGQ==", + "version": "2.2.8", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.8.tgz", + "integrity": "sha512-fqhG+4sCVv8x7nFzYnFo19ryhAa3w096Kmc3hWxMQfW/TubPOmt3A6tYZhl4mUfQWWQMsuSkLrtjlWuXBVSGQA==", "license": "Apache-2.0", "dependencies": { "@ai-sdk/provider": "1.1.3", @@ -23382,13 +23407,13 @@ } }, "node_modules/ai/node_modules/@ai-sdk/react": { - "version": "1.2.9", - "resolved": "https://registry.npmjs.org/@ai-sdk/react/-/react-1.2.9.tgz", - "integrity": "sha512-/VYm8xifyngaqFDLXACk/1czDRCefNCdALUyp+kIX6DUIYUWTM93ISoZ+qJ8+3E+FiJAKBQz61o8lIIl+vYtzg==", + "version": "1.2.12", + "resolved": "https://registry.npmjs.org/@ai-sdk/react/-/react-1.2.12.tgz", + "integrity": "sha512-jK1IZZ22evPZoQW3vlkZ7wvjYGYF+tRBKXtrcolduIkQ/m/sOAVcVeVDUDvh1T91xCnWCdUGCPZg2avZ90mv3g==", "license": "Apache-2.0", "dependencies": { - "@ai-sdk/provider-utils": "2.2.7", - "@ai-sdk/ui-utils": "1.2.8", + "@ai-sdk/provider-utils": "2.2.8", + "@ai-sdk/ui-utils": "1.2.11", "swr": "^2.2.5", "throttleit": "2.1.0" }, @@ -23406,13 +23431,13 @@ } }, "node_modules/ai/node_modules/@ai-sdk/ui-utils": { - "version": "1.2.8", - "resolved": "https://registry.npmjs.org/@ai-sdk/ui-utils/-/ui-utils-1.2.8.tgz", - "integrity": "sha512-nls/IJCY+ks3Uj6G/agNhXqQeLVqhNfoJbuNgCny+nX2veY5ADB91EcZUqVeQ/ionul2SeUswPY6Q/DxteY29Q==", + "version": "1.2.11", + "resolved": "https://registry.npmjs.org/@ai-sdk/ui-utils/-/ui-utils-1.2.11.tgz", + "integrity": "sha512-3zcwCc8ezzFlwp3ZD15wAPjf2Au4s3vAbKsXQVyhxODHcmu0iyPO2Eua6D/vicq/AUm/BAo60r97O6HU+EI0+w==", "license": "Apache-2.0", "dependencies": { "@ai-sdk/provider": "1.1.3", - "@ai-sdk/provider-utils": "2.2.7", + "@ai-sdk/provider-utils": "2.2.8", "zod-to-json-schema": "^3.24.1" }, "engines": { @@ -25004,32 +25029,33 @@ } }, "node_modules/braintrust/node_modules/ai": { - "version": "3.3.4", + "version": "3.4.33", + "resolved": "https://registry.npmjs.org/ai/-/ai-3.4.33.tgz", + "integrity": "sha512-plBlrVZKwPoRTmM8+D1sJac9Bq8eaa2jiZlHLZIWekKWI1yMWYZvCCEezY9ASPwRhULYDJB2VhKOBUUeg3S5JQ==", "license": "Apache-2.0", "dependencies": { - "@ai-sdk/provider": "0.0.17", - "@ai-sdk/provider-utils": "1.0.9", - "@ai-sdk/react": "0.0.40", - "@ai-sdk/solid": "0.0.31", - "@ai-sdk/svelte": "0.0.33", - "@ai-sdk/ui-utils": "0.0.28", - "@ai-sdk/vue": "0.0.32", + "@ai-sdk/provider": "0.0.26", + "@ai-sdk/provider-utils": "1.0.22", + "@ai-sdk/react": "0.0.70", + "@ai-sdk/solid": "0.0.54", + "@ai-sdk/svelte": "0.0.57", + "@ai-sdk/ui-utils": "0.0.50", + "@ai-sdk/vue": "0.0.59", "@opentelemetry/api": "1.9.0", "eventsource-parser": "1.1.2", - "json-schema": "0.4.0", + "json-schema": "^0.4.0", "jsondiffpatch": "0.6.0", - "nanoid": "3.3.6", - "secure-json-parse": "2.7.0", - "zod-to-json-schema": "3.22.5" + "secure-json-parse": "^2.7.0", + "zod-to-json-schema": "^3.23.3" }, "engines": { "node": ">=18" }, "peerDependencies": { "openai": "^4.42.0", - "react": "^18 || ^19", + "react": "^18 || ^19 || ^19.0.0-rc", "sswr": "^2.1.0", - "svelte": "^3.0.0 || ^4.0.0", + "svelte": "^3.0.0 || ^4.0.0 || ^5.0.0", "zod": "^3.0.0" }, "peerDependenciesMeta": { @@ -25051,10 +25077,12 @@ } }, "node_modules/braintrust/node_modules/ai/node_modules/@ai-sdk/provider": { - "version": "0.0.17", + "version": "0.0.26", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-0.0.26.tgz", + "integrity": "sha512-dQkfBDs2lTYpKM8389oopPdQgIU007GQyCbuPPrV+K6MtSII3HBfE0stUIMXUb44L+LK1t6GXPP7wjSzjO6uKg==", "license": "Apache-2.0", "dependencies": { - "json-schema": "0.4.0" + "json-schema": "^0.4.0" }, "engines": { "node": ">=18" @@ -25134,22 +25162,6 @@ "url": "https://github.com/sponsors/isaacs" } }, - "node_modules/braintrust/node_modules/nanoid": { - "version": "3.3.6", - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/ai" - } - ], - "license": "MIT", - "bin": { - "nanoid": "bin/nanoid.cjs" - }, - "engines": { - "node": "^10 || ^12 || ^13.7 || ^14 || >=15.0.1" - } - }, "node_modules/braintrust/node_modules/openai": { "version": "4.95.0", "resolved": "https://registry.npmjs.org/openai/-/openai-4.95.0.tgz", @@ -25189,13 +25201,6 @@ "node": ">= 8" } }, - "node_modules/braintrust/node_modules/zod-to-json-schema": { - "version": "3.22.5", - "license": "ISC", - "peerDependencies": { - "zod": "^3.22.4" - } - }, "node_modules/brorand": { "version": "1.1.0", "dev": true, @@ -54675,9 +54680,9 @@ } }, "node_modules/zod-to-json-schema": { - "version": "3.24.3", - "resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.24.3.tgz", - "integrity": "sha512-HIAfWdYIt1sssHfYZFCXp4rU1w2r8hVVXYIlmoa0r0gABLs5di3RCqPU5DDROogVz1pAdYBaz7HK5n9pSUNs3A==", + "version": "3.24.5", + "resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.24.5.tgz", + "integrity": "sha512-/AuWwMP+YqiPbsJx5D6TfgRTc4kTLjsh5SOcd4bLsfUg2RcEXrFMJl1DGgdHy2aCfsIA/cr/1JM0xcB2GZji8g==", "license": "ISC", "peerDependencies": { "zod": "^3.24.1" @@ -58270,6 +58275,8 @@ "version": "0.6.3", "license": "Apache-2.0", "dependencies": { + "@ai-sdk/azure": "^1.3.21", + "@ai-sdk/openai": "^1.3.20", "@apidevtools/swagger-parser": "^10.1.0", "@langchain/anthropic": "^0.3.6", "@langchain/community": "^0.3.10", @@ -58278,7 +58285,7 @@ "@supercharge/promise-pool": "^3.2.0", "acquit": "^1.3.0", "acquit-require": "^0.1.1", - "ai": "^4.3.9", + "ai": "^4.3.16", "braintrust": "^0.0.193", "common-tags": "^1", "deep-equal": "^2.2.3", @@ -59426,147 +59433,6 @@ "vitest": "^3.0.5" } }, - "packages/release-notes-generator/node_modules/@ai-sdk/provider-utils": { - "version": "1.0.22", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-1.0.22.tgz", - "integrity": "sha512-YHK2rpj++wnLVc9vPGzGFP3Pjeld2MwhKinetA0zKXOoHAT/Jit5O8kZsxcSlJPu9wvcGT1UGZEjZrtO7PfFOQ==", - "license": "Apache-2.0", - "dependencies": { - "@ai-sdk/provider": "0.0.26", - "eventsource-parser": "^1.1.2", - "nanoid": "^3.3.7", - "secure-json-parse": "^2.7.0" - }, - "engines": { - "node": ">=18" - }, - "peerDependencies": { - "zod": "^3.0.0" - }, - "peerDependenciesMeta": { - "zod": { - "optional": true - } - } - }, - "packages/release-notes-generator/node_modules/@ai-sdk/provider-utils/node_modules/@ai-sdk/provider": { - "version": "0.0.26", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-0.0.26.tgz", - "integrity": "sha512-dQkfBDs2lTYpKM8389oopPdQgIU007GQyCbuPPrV+K6MtSII3HBfE0stUIMXUb44L+LK1t6GXPP7wjSzjO6uKg==", - "license": "Apache-2.0", - "dependencies": { - "json-schema": "^0.4.0" - }, - "engines": { - "node": ">=18" - } - }, - "packages/release-notes-generator/node_modules/@ai-sdk/react": { - "version": "0.0.70", - "resolved": "https://registry.npmjs.org/@ai-sdk/react/-/react-0.0.70.tgz", - "integrity": "sha512-GnwbtjW4/4z7MleLiW+TOZC2M29eCg1tOUpuEiYFMmFNZK8mkrqM0PFZMo6UsYeUYMWqEOOcPOU9OQVJMJh7IQ==", - "license": "Apache-2.0", - "dependencies": { - "@ai-sdk/provider-utils": "1.0.22", - "@ai-sdk/ui-utils": "0.0.50", - "swr": "^2.2.5", - "throttleit": "2.1.0" - }, - "engines": { - "node": ">=18" - }, - "peerDependencies": { - "react": "^18 || ^19 || ^19.0.0-rc", - "zod": "^3.0.0" - }, - "peerDependenciesMeta": { - "react": { - "optional": true - }, - "zod": { - "optional": true - } - } - }, - "packages/release-notes-generator/node_modules/@ai-sdk/solid": { - "version": "0.0.54", - "resolved": "https://registry.npmjs.org/@ai-sdk/solid/-/solid-0.0.54.tgz", - "integrity": "sha512-96KWTVK+opdFeRubqrgaJXoNiDP89gNxFRWUp0PJOotZW816AbhUf4EnDjBjXTLjXL1n0h8tGSE9sZsRkj9wQQ==", - "license": "Apache-2.0", - "dependencies": { - "@ai-sdk/provider-utils": "1.0.22", - "@ai-sdk/ui-utils": "0.0.50" - }, - "engines": { - "node": ">=18" - }, - "peerDependencies": { - "solid-js": "^1.7.7" - }, - "peerDependenciesMeta": { - "solid-js": { - "optional": true - } - } - }, - "packages/release-notes-generator/node_modules/@ai-sdk/svelte": { - "version": "0.0.57", - "resolved": "https://registry.npmjs.org/@ai-sdk/svelte/-/svelte-0.0.57.tgz", - "integrity": "sha512-SyF9ItIR9ALP9yDNAD+2/5Vl1IT6kchgyDH8xkmhysfJI6WrvJbtO1wdQ0nylvPLcsPoYu+cAlz1krU4lFHcYw==", - "license": "Apache-2.0", - "dependencies": { - "@ai-sdk/provider-utils": "1.0.22", - "@ai-sdk/ui-utils": "0.0.50", - "sswr": "^2.1.0" - }, - "engines": { - "node": ">=18" - }, - "peerDependencies": { - "svelte": "^3.0.0 || ^4.0.0 || ^5.0.0" - }, - "peerDependenciesMeta": { - "svelte": { - "optional": true - } - } - }, - "packages/release-notes-generator/node_modules/@ai-sdk/ui-utils": { - "version": "0.0.50", - "resolved": "https://registry.npmjs.org/@ai-sdk/ui-utils/-/ui-utils-0.0.50.tgz", - "integrity": "sha512-Z5QYJVW+5XpSaJ4jYCCAVG7zIAuKOOdikhgpksneNmKvx61ACFaf98pmOd+xnjahl0pIlc/QIe6O4yVaJ1sEaw==", - "license": "Apache-2.0", - "dependencies": { - "@ai-sdk/provider": "0.0.26", - "@ai-sdk/provider-utils": "1.0.22", - "json-schema": "^0.4.0", - "secure-json-parse": "^2.7.0", - "zod-to-json-schema": "^3.23.3" - }, - "engines": { - "node": ">=18" - }, - "peerDependencies": { - "zod": "^3.0.0" - }, - "peerDependenciesMeta": { - "zod": { - "optional": true - } - } - }, - "packages/release-notes-generator/node_modules/@ai-sdk/ui-utils/node_modules/@ai-sdk/provider": { - "version": "0.0.26", - "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-0.0.26.tgz", - "integrity": "sha512-dQkfBDs2lTYpKM8389oopPdQgIU007GQyCbuPPrV+K6MtSII3HBfE0stUIMXUb44L+LK1t6GXPP7wjSzjO6uKg==", - "license": "Apache-2.0", - "dependencies": { - "json-schema": "^0.4.0" - }, - "engines": { - "node": ">=18" - } - }, "packages/release-notes-generator/node_modules/@anthropic-ai/sdk": { "version": "0.27.3", "resolved": "https://registry.npmjs.org/@anthropic-ai/sdk/-/sdk-0.27.3.tgz", @@ -60292,28 +60158,6 @@ "node": ">=18" } }, - "packages/release-notes-generator/node_modules/ai/node_modules/@ai-sdk/vue": { - "version": "0.0.59", - "resolved": "https://registry.npmjs.org/@ai-sdk/vue/-/vue-0.0.59.tgz", - "integrity": "sha512-+ofYlnqdc8c4F6tM0IKF0+7NagZRAiqBJpGDJ+6EYhDW8FHLUP/JFBgu32SjxSxC6IKFZxEnl68ZoP/Z38EMlw==", - "license": "Apache-2.0", - "dependencies": { - "@ai-sdk/provider-utils": "1.0.22", - "@ai-sdk/ui-utils": "0.0.50", - "swrv": "^1.0.4" - }, - "engines": { - "node": ">=18" - }, - "peerDependencies": { - "vue": "^3.3.4" - }, - "peerDependenciesMeta": { - "vue": { - "optional": true - } - } - }, "packages/release-notes-generator/node_modules/argparse": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/argparse/-/argparse-2.0.1.tgz", diff --git a/packages/chatbot-server-mongodb-public/environments/production.yml b/packages/chatbot-server-mongodb-public/environments/production.yml index 9a6e1574b..f5df2a36a 100644 --- a/packages/chatbot-server-mongodb-public/environments/production.yml +++ b/packages/chatbot-server-mongodb-public/environments/production.yml @@ -10,7 +10,7 @@ env: NODE_ENV: production OPENAI_PREPROCESSOR_CHAT_COMPLETION_DEPLOYMENT: gpt-4o-mini OPENAI_API_VERSION: "2024-06-01" - OPENAI_CHAT_COMPLETION_DEPLOYMENT: gpt-4o + OPENAI_CHAT_COMPLETION_DEPLOYMENT: gpt-4.1 OPENAI_VERIFIED_ANSWER_EMBEDDING_DEPLOYMENT: "docs-chatbot-embedding-ada-002" OPENAI_RETRIEVAL_EMBEDDING_DEPLOYMENT: "text-embedding-3-small" JUDGE_LLM: "gpt-4o-mini" diff --git a/packages/chatbot-server-mongodb-public/environments/staging.yml b/packages/chatbot-server-mongodb-public/environments/staging.yml index 2cba89d94..fe2405cf6 100644 --- a/packages/chatbot-server-mongodb-public/environments/staging.yml +++ b/packages/chatbot-server-mongodb-public/environments/staging.yml @@ -10,7 +10,7 @@ env: NODE_ENV: staging OPENAI_PREPROCESSOR_CHAT_COMPLETION_DEPLOYMENT: gpt-4o-mini OPENAI_API_VERSION: "2024-06-01" - OPENAI_CHAT_COMPLETION_DEPLOYMENT: gpt-4o + OPENAI_CHAT_COMPLETION_DEPLOYMENT: gpt-4.1 OPENAI_VERIFIED_ANSWER_EMBEDDING_DEPLOYMENT: "docs-chatbot-embedding-ada-002" OPENAI_RETRIEVAL_EMBEDDING_DEPLOYMENT: "text-embedding-3-small" BRAINTRUST_CHATBOT_TRACING_PROJECT_NAME: "chatbot-responses-staging" diff --git a/packages/chatbot-server-mongodb-public/src/config.ts b/packages/chatbot-server-mongodb-public/src/config.ts index b2ff31349..c9a76bf75 100644 --- a/packages/chatbot-server-mongodb-public/src/config.ts +++ b/packages/chatbot-server-mongodb-public/src/config.ts @@ -8,7 +8,6 @@ import { makeMongoDbVerifiedAnswerStore, makeOpenAiEmbedder, makeMongoDbConversationsService, - makeOpenAiChatLlm, AppConfig, CORE_ENV_VARS, assertEnvVars, @@ -19,19 +18,25 @@ import { makeDefaultFindVerifiedAnswer, defaultCreateConversationCustomData, defaultAddMessageToConversationCustomData, - makeLegacyGenerateResponse, + makeGenerateResponseWithSearchTool, makeVerifiedAnswerGenerateResponse, } from "mongodb-chatbot-server"; import cookieParser from "cookie-parser"; -import { makeStepBackRagGenerateUserPrompt } from "./processors/makeStepBackRagGenerateUserPrompt"; import { blockGetRequests } from "./middleware/blockGetRequests"; import { getRequestId, logRequest } from "./utils"; import { systemPrompt } from "./systemPrompt"; -import { addReferenceSourceType } from "./processors/makeMongoDbReferences"; +import { + addReferenceSourceType, + makeMongoDbReferences, +} from "./processors/makeMongoDbReferences"; import { redactConnectionUri } from "./middleware/redactConnectionUri"; import path from "path"; import express from "express"; -import { wrapOpenAI, wrapTraced } from "mongodb-rag-core/braintrust"; +import { + wrapOpenAI, + wrapTraced, + wrapAISDKModel, +} from "mongodb-rag-core/braintrust"; import { AzureOpenAI } from "mongodb-rag-core/openai"; import { MongoClient } from "mongodb-rag-core/mongodb"; import { TRACING_ENV_VARS } from "./EnvVars"; @@ -41,6 +46,8 @@ import { makeRateMessageUpdateTrace, } from "./tracing/routesUpdateTraceHandlers"; import { useSegmentIds } from "./middleware/useSegmentIds"; +import { createAzure } from "mongodb-rag-core/aiSdk"; +import { makeSearchTool } from "./tools/search"; export const { MONGODB_CONNECTION_URI, MONGODB_DATABASE_NAME, @@ -79,19 +86,6 @@ export const openAiClient = wrapOpenAI( }) ); -export const llm = makeOpenAiChatLlm({ - openAiClient, - deployment: OPENAI_CHAT_COMPLETION_DEPLOYMENT, - openAiLmmConfigOptions: { - temperature: 0, - max_tokens: 1000, - }, -}); - -llm.answerQuestionAwaited = wrapTraced(llm.answerQuestionAwaited, { - name: "answerQuestionAwaited", -}); - export const embeddedContentStore = makeMongoDbEmbeddedContentStore({ connectionUri: MONGODB_CONNECTION_URI, databaseName: MONGODB_DATABASE_NAME, @@ -173,6 +167,16 @@ export const preprocessorOpenAiClient = wrapOpenAI( apiVersion: OPENAI_API_VERSION, }) ); +export const mongodb = new MongoClient(MONGODB_CONNECTION_URI); + +export const conversations = makeMongoDbConversationsService( + mongodb.db(MONGODB_DATABASE_NAME) +); +const azureOpenAi = createAzure({ + apiKey: OPENAI_API_KEY, + resourceName: process.env.OPENAI_RESOURCE_NAME, +}); +const languageModel = wrapAISDKModel(azureOpenAi("gpt-4.1")); export const generateResponse = wrapTraced( makeVerifiedAnswerGenerateResponse({ @@ -184,17 +188,24 @@ export const generateResponse = wrapTraced( }; }, onNoVerifiedAnswerFound: wrapTraced( - makeLegacyGenerateResponse({ - llm, - generateUserPrompt: makeStepBackRagGenerateUserPrompt({ - openAiClient: preprocessorOpenAiClient, - model: retrievalConfig.preprocessorLlm, - findContent, - numPrecedingMessagesToInclude: 6, - }), + makeGenerateResponseWithSearchTool({ + languageModel, systemMessage: systemPrompt, - llmNotWorkingMessage: "LLM not working. Sad!", - noRelevantContentMessage: "No relevant content found. Sad!", + makeReferenceLinks: makeMongoDbReferences, + filterPreviousMessages: async (conversation) => { + return conversation.messages.filter((message) => { + return ( + message.role === "user" || + // Only include assistant messages that are not tool calls + (message.role === "assistant" && !message.toolCall) + ); + }); + }, + llmNotWorkingMessage: + conversations.conversationConstants.LLM_NOT_WORKING, + searchTool: makeSearchTool(findContent), + toolChoice: "auto", + maxSteps: 5, }), { name: "makeStepBackRagGenerateUserPrompt" } ), @@ -204,12 +215,6 @@ export const generateResponse = wrapTraced( } ); -export const mongodb = new MongoClient(MONGODB_CONNECTION_URI); - -export const conversations = makeMongoDbConversationsService( - mongodb.db(MONGODB_DATABASE_NAME) -); - export const createConversationCustomDataWithAuthUser: AddCustomDataFunc = async (req, res) => { const customData = await defaultCreateConversationCustomData(req, res); diff --git a/packages/chatbot-server-mongodb-public/src/processors/extractMongoDbMetadataFromUserMessage.eval.ts b/packages/chatbot-server-mongodb-public/src/processors/extractMongoDbMetadataFromUserMessage.eval.ts deleted file mode 100644 index 4767c845f..000000000 --- a/packages/chatbot-server-mongodb-public/src/processors/extractMongoDbMetadataFromUserMessage.eval.ts +++ /dev/null @@ -1,232 +0,0 @@ -import { - extractMongoDbMetadataFromUserMessage, - ExtractMongoDbMetadataFunction, -} from "./extractMongoDbMetadataFromUserMessage"; -import { Eval } from "braintrust"; -import { Scorer } from "autoevals"; -import { MongoDbTag } from "../mongoDbMetadata"; -import { - OPENAI_PREPROCESSOR_CHAT_COMPLETION_DEPLOYMENT, - openAiClient, -} from "../eval/evalHelpers"; - -interface ExtractMongoDbMetadataEvalCase { - name: string; - input: string; - expected: ExtractMongoDbMetadataFunction; - tags?: MongoDbTag[]; -} - -const evalCases: ExtractMongoDbMetadataEvalCase[] = [ - { - name: "should identify MongoDB Atlas Search", - input: "Does atlas search support copy to fields", - expected: { - mongoDbProduct: "Atlas Search", - } satisfies ExtractMongoDbMetadataFunction, - tags: ["atlas", "atlas_search"], - }, - { - name: "should identify aggregation stage", - input: "$merge", - expected: { - mongoDbProduct: "Aggregation Framework", - } satisfies ExtractMongoDbMetadataFunction, - }, - { - name: "should know pymongo is python driver", - input: "pymongo insert data", - expected: { - programmingLanguage: "python", - mongoDbProduct: "Drivers", - } satisfies ExtractMongoDbMetadataFunction, - tags: ["driver", "python"], - }, - { - name: "should identify MongoDB Atlas", - input: "how to create a new cluster atlas", - expected: { - mongoDbProduct: "MongoDB Atlas", - } satisfies ExtractMongoDbMetadataFunction, - tags: ["atlas"], - }, - { - name: "should know atlas billing", - input: "how do I see my bill in atlas", - expected: { - mongoDbProduct: "MongoDB Atlas", - } satisfies ExtractMongoDbMetadataFunction, - tags: ["atlas"], - }, - { - name: "should be aware of vector search product", - input: "how to use vector search", - expected: { - mongoDbProduct: "Atlas Vector Search", - } satisfies ExtractMongoDbMetadataFunction, - tags: ["atlas", "atlas_vector_search"], - }, - { - name: "should know change streams", - input: - "how to open a change stream watch on a database and filter the stream", - expected: { - mongoDbProduct: "Drivers", - programmingLanguage: "javascript", - } satisfies ExtractMongoDbMetadataFunction, - tags: ["change_streams"], - }, - { - name: "should know change streams", - input: - "how to open a change stream watch on a database and filter the stream pymongo", - expected: { - mongoDbProduct: "Drivers", - programmingLanguage: "python", - } satisfies ExtractMongoDbMetadataFunction, - tags: ["change_streams"], - }, - { - name: "should know to include programming language when coding task implied.", - input: - "How do I choose the order of fields when creating a compound index?", - expected: { - mongoDbProduct: "MongoDB Server", - programmingLanguage: "javascript", - } satisfies ExtractMongoDbMetadataFunction, - tags: ["indexes"], - }, - { - name: "should detect gridfs usage", - input: "What is the best way to store large files with MongoDB?", - expected: { - mongoDbProduct: "GridFS", - } satisfies ExtractMongoDbMetadataFunction, - tags: ["gridfs"], - }, - { - name: "should recognize MongoDB for analytics", - input: "How do I run real-time analytics on my data?", - expected: { - mongoDbProduct: "MongoDB Server", - } satisfies ExtractMongoDbMetadataFunction, - tags: ["analytics"], - }, - { - name: "should detect transaction management topic", - input: "How do I manage multi-document transactions?", - expected: { - mongoDbProduct: "MongoDB Server", - } satisfies ExtractMongoDbMetadataFunction, - tags: ["server"], - }, - { - name: "should know multi-cloud clustering", - input: "Can I create a multi-cloud cluster with Atlas?", - expected: { - mongoDbProduct: "MongoDB Atlas", - } satisfies ExtractMongoDbMetadataFunction, - tags: ["atlas", "multi_cloud"], - }, - { - name: "should identify usage in Java with the MongoDB driver", - input: "How do I connect to MongoDB using the Java driver?", - expected: { - programmingLanguage: "java", - mongoDbProduct: "Drivers", - } satisfies ExtractMongoDbMetadataFunction, - tags: ["driver", "java"], - }, - { - name: "should know usage of MongoDB in C#", - input: "How do I query a collection using LINQ in C#?", - expected: { - programmingLanguage: "csharp", - mongoDbProduct: "Drivers", - } satisfies ExtractMongoDbMetadataFunction, - tags: ["driver", "csharp"], - }, - { - name: "should recognize Python use in aggregation queries", - input: "How do I perform an aggregation pipeline in pymongo?", - expected: { - programmingLanguage: "python", - mongoDbProduct: "Aggregation Framework", - } satisfies ExtractMongoDbMetadataFunction, - tags: ["driver", "python", "aggregation"], - }, - { - name: "should detect use of Node.js for MongoDB", - input: "How do I handle MongoDB connections in Node.js?", - expected: { - programmingLanguage: "javascript", - mongoDbProduct: "Drivers", - } satisfies ExtractMongoDbMetadataFunction, - tags: ["driver", "javascript"], - }, - { - name: "should identify usage of Go with MongoDB", - input: "How do I insert multiple documents with the MongoDB Go driver?", - expected: { - programmingLanguage: "go", - mongoDbProduct: "Drivers", - } satisfies ExtractMongoDbMetadataFunction, - tags: ["driver", "go"], - }, - { - name: "should know of $vectorSearch stage", - input: "$vectorSearch", - expected: { - mongoDbProduct: "Atlas Vector Search", - } satisfies ExtractMongoDbMetadataFunction, - tags: ["atlas", "atlas_vector_search"], - }, -]; -const ProductNameCorrect: Scorer< - Awaited>, - unknown -> = (args) => { - return { - name: "ProductNameCorrect", - score: args.expected?.mongoDbProduct === args.output.mongoDbProduct ? 1 : 0, - }; -}; -const ProgrammingLanguageCorrect: Scorer< - Awaited>, - unknown -> = (args) => { - return { - name: "ProgrammingLanguageCorrect", - score: - args.expected?.programmingLanguage === args.output.programmingLanguage - ? 1 - : 0, - }; -}; - -const model = OPENAI_PREPROCESSOR_CHAT_COMPLETION_DEPLOYMENT; -Eval("extract-mongodb-metadata", { - data: evalCases, - experimentName: model, - metadata: { - description: - "Evaluates whether the MongoDB user message guardrail is working correctly.", - model, - }, - maxConcurrency: 3, - timeout: 20000, - async task(input) { - try { - return await extractMongoDbMetadataFromUserMessage({ - openAiClient, - model, - userMessageText: input, - }); - } catch (error) { - console.log(`Error evaluating input: ${input}`); - console.log(error); - throw error; - } - }, - scores: [ProductNameCorrect, ProgrammingLanguageCorrect], -}); diff --git a/packages/chatbot-server-mongodb-public/src/processors/extractMongoDbMetadataFromUserMessage.test.ts b/packages/chatbot-server-mongodb-public/src/processors/extractMongoDbMetadataFromUserMessage.test.ts deleted file mode 100644 index cf487de96..000000000 --- a/packages/chatbot-server-mongodb-public/src/processors/extractMongoDbMetadataFromUserMessage.test.ts +++ /dev/null @@ -1,26 +0,0 @@ -import { makeMockOpenAIToolCall } from "../test/mockOpenAi"; -import { - extractMongoDbMetadataFromUserMessage, - ExtractMongoDbMetadataFunction, -} from "./extractMongoDbMetadataFromUserMessage"; -import { OpenAI } from "mongodb-rag-core/openai"; - -jest.mock("mongodb-rag-core/openai", () => { - return makeMockOpenAIToolCall({ - mongoDbProduct: "Aggregation Framework", - } satisfies ExtractMongoDbMetadataFunction); -}); - -describe("extractMongoDbMetadataFromUserMessage", () => { - const args: Parameters[0] = { - openAiClient: new OpenAI({ apiKey: "fake-api-key" }), - model: "best-model-eva", - userMessageText: "hi", - }; - test("should return metadata", async () => { - const res = await extractMongoDbMetadataFromUserMessage(args); - expect(res).toEqual({ - mongoDbProduct: "Aggregation Framework", - }); - }); -}); diff --git a/packages/chatbot-server-mongodb-public/src/processors/extractMongoDbMetadataFromUserMessage.ts b/packages/chatbot-server-mongodb-public/src/processors/extractMongoDbMetadataFromUserMessage.ts deleted file mode 100644 index b3e5e087f..000000000 --- a/packages/chatbot-server-mongodb-public/src/processors/extractMongoDbMetadataFromUserMessage.ts +++ /dev/null @@ -1,93 +0,0 @@ -import { z } from "zod"; -import { - makeAssistantFunctionCallMessage, - makeFewShotUserMessageExtractorFunction, - makeUserMessage, -} from "./makeFewShotUserMessageExtractorFunction"; -import { OpenAI } from "mongodb-rag-core/openai"; -import { - mongoDbProductNames, - mongoDbProgrammingLanguageIds, -} from "../mongoDbMetadata"; - -export const ExtractMongoDbMetadataFunctionSchema = z.object({ - programmingLanguage: z - .enum(mongoDbProgrammingLanguageIds) - .default("javascript") - .describe( - 'Programming language present in the content. If no programming language is present and a code example would answer the question, include "javascript".' - ) - .optional(), - mongoDbProduct: z - .enum(mongoDbProductNames) - .describe( - `Most important MongoDB products present in the content. -Include "Driver" if the user is asking about a programming language with a MongoDB driver. -If the product is ambiguous, say "MongoDB Server".` - ) - .default("MongoDB Server") - .optional(), -}); - -export type ExtractMongoDbMetadataFunction = z.infer< - typeof ExtractMongoDbMetadataFunctionSchema ->; - -const name = "extract_mongodb_metadata"; -const description = "Extract MongoDB-related metadata from a user message"; - -const systemPrompt = `You are an expert data labeler employed by MongoDB. -You must label metadata about the user query based on its context in the conversation. -Your pay is determined by the accuracy of your labels as judged against other expert labelers, so do excellent work to maximize your earnings to support your family.`; - -const fewShotExamples: OpenAI.Chat.ChatCompletionMessageParam[] = [ - // Example 1 - makeUserMessage("aggregate data"), - makeAssistantFunctionCallMessage(name, { - programmingLanguage: "javascript", - mongoDbProduct: "Aggregation Framework", - } satisfies ExtractMongoDbMetadataFunction), - // Example 2 - makeUserMessage("how to create a new cluster atlas"), - makeAssistantFunctionCallMessage(name, { - mongoDbProduct: "MongoDB Atlas", - } satisfies ExtractMongoDbMetadataFunction), - // Example 3 - makeUserMessage("Does atlas search support copy to fields"), - makeAssistantFunctionCallMessage(name, { - mongoDbProduct: "Atlas Search", - } satisfies ExtractMongoDbMetadataFunction), - // Example 4 - makeUserMessage("pymongo insert data"), - makeAssistantFunctionCallMessage(name, { - programmingLanguage: "python", - mongoDbProduct: "Drivers", - } satisfies ExtractMongoDbMetadataFunction), - // Example 5 - makeUserMessage("How do I create an index in MongoDB using the Java driver?"), - makeAssistantFunctionCallMessage(name, { - programmingLanguage: "java", - mongoDbProduct: "Drivers", - } satisfies ExtractMongoDbMetadataFunction), - // Example 6 - makeUserMessage("$lookup"), - makeAssistantFunctionCallMessage(name, { - mongoDbProduct: "Aggregation Framework", - } satisfies ExtractMongoDbMetadataFunction), -]; - -/** - Extract metadata relevant to the MongoDB docs chatbot - from a user message in the conversation. - */ - -export const extractMongoDbMetadataFromUserMessage = - makeFewShotUserMessageExtractorFunction({ - llmFunction: { - name, - description, - schema: ExtractMongoDbMetadataFunctionSchema, - }, - systemPrompt, - fewShotExamples, - }); diff --git a/packages/chatbot-server-mongodb-public/src/processors/makeMongoDbReferences.ts b/packages/chatbot-server-mongodb-public/src/processors/makeMongoDbReferences.ts index 2513c7bac..a9e508d21 100644 --- a/packages/chatbot-server-mongodb-public/src/processors/makeMongoDbReferences.ts +++ b/packages/chatbot-server-mongodb-public/src/processors/makeMongoDbReferences.ts @@ -1,5 +1,4 @@ import { - EmbeddedContent, MakeReferenceLinksFunc, makeDefaultReferenceLinks, } from "mongodb-chatbot-server"; @@ -21,9 +20,7 @@ import { type RichLinkVariantName } from "@lg-chat/rich-links"; } ``` */ -export const makeMongoDbReferences: MakeReferenceLinksFunc = ( - chunks: EmbeddedContent[] -) => { +export const makeMongoDbReferences: MakeReferenceLinksFunc = (chunks) => { return makeDefaultReferenceLinks(chunks).map(addReferenceSourceType); }; diff --git a/packages/chatbot-server-mongodb-public/src/processors/makeStepBackRagGenerateUserPrompt.test.ts b/packages/chatbot-server-mongodb-public/src/processors/makeStepBackRagGenerateUserPrompt.test.ts deleted file mode 100644 index 966a1a873..000000000 --- a/packages/chatbot-server-mongodb-public/src/processors/makeStepBackRagGenerateUserPrompt.test.ts +++ /dev/null @@ -1,173 +0,0 @@ -import { FindContentFunc, FindContentResult } from "mongodb-chatbot-server"; -import { ObjectId } from "mongodb-rag-core/mongodb"; -import { - OPENAI_PREPROCESSOR_CHAT_COMPLETION_DEPLOYMENT, - preprocessorOpenAiClient, -} from "../test/testHelpers"; -import { makeStepBackRagGenerateUserPrompt } from "./makeStepBackRagGenerateUserPrompt"; - -jest.setTimeout(30000); -describe("makeStepBackRagGenerateUserPrompt", () => { - const embeddings = { modelName: [0, 0, 0] }; - const mockFindContent: FindContentFunc = async () => { - return { - queryEmbedding: embeddings.modelName, - content: [ - { - text: "avada kedavra", - embeddings, - score: 1, - sourceName: "mastering-dark-arts", - url: "https://example.com", - tokenCount: 3, - updated: new Date(), - }, - { - url: "https://example.com", - tokenCount: 1, - sourceName: "defending-against-the-dark-arts", - updated: new Date(), - text: "expecto patronum", - embeddings, - score: 1, - }, - ], - } satisfies FindContentResult; - }; - const config = { - openAiClient: preprocessorOpenAiClient, - model: OPENAI_PREPROCESSOR_CHAT_COMPLETION_DEPLOYMENT, - findContent: mockFindContent, - }; - const stepBackRagGenerateUserPrompt = - makeStepBackRagGenerateUserPrompt(config); - test("should return a step back user prompt", async () => { - const res = await stepBackRagGenerateUserPrompt({ - reqId: "123", - userMessageText: "what is mongodb", - }); - expect(res.rejectQuery).toBeFalsy(); - expect(res.userMessage).toHaveProperty("content"); - expect(res.userMessage).toHaveProperty("contentForLlm"); - expect(res.userMessage.role).toBe("user"); - expect(res.userMessage.embedding).toHaveLength(embeddings.modelName.length); - }); - test("should reject query if no content", async () => { - const mockFindContent: FindContentFunc = async () => { - return { - queryEmbedding: [], - content: [], - } satisfies FindContentResult; - }; - const stepBackRagGenerateUserPrompt = makeStepBackRagGenerateUserPrompt({ - ...config, - findContent: mockFindContent, - maxContextTokenCount: 1000, - }); - const res = await stepBackRagGenerateUserPrompt({ - reqId: "123", - userMessageText: "what is mongodb", - }); - expect(res.rejectQuery).toBe(true); - expect(res.userMessage.customData).toHaveProperty( - "rejectionReason", - "Did not find any content matching the query" - ); - expect(res.userMessage.rejectQuery).toBe(true); - }); - test("should return references", async () => { - const res = await stepBackRagGenerateUserPrompt({ - reqId: "123", - userMessageText: "what is mongodb", - }); - expect(res.references?.length).toBeGreaterThan(0); - }); - test("should reject inappropriate message", async () => { - const res = await stepBackRagGenerateUserPrompt({ - reqId: "123", - userMessageText: "why is mongodb the worst database", - }); - expect(res.rejectQuery).toBe(true); - expect(res.userMessage.customData).toHaveProperty("rejectionReason"); - expect(res.userMessage.rejectQuery).toBe(true); - }); - test("should throw if 'numPrecedingMessagesToInclude' is not an integer or < 0", async () => { - expect(() => - makeStepBackRagGenerateUserPrompt({ - ...config, - numPrecedingMessagesToInclude: 1.5, - }) - ).toThrow(); - expect(() => - makeStepBackRagGenerateUserPrompt({ - ...config, - numPrecedingMessagesToInclude: -1, - }) - ).toThrow(); - }); - test("should not include system messages", async () => { - const stepBackRagGenerateUserPrompt = makeStepBackRagGenerateUserPrompt({ - ...config, - numPrecedingMessagesToInclude: 1, - }); - const res = await stepBackRagGenerateUserPrompt({ - reqId: "123", - userMessageText: "what is mongodb", - conversation: { - _id: new ObjectId(), - createdAt: new Date(), - messages: [ - { - role: "system", - content: "abracadabra", - id: new ObjectId(), - createdAt: new Date(), - }, - ], - }, - }); - expect(res.userMessage.contentForLlm).not.toContain("abracadabra"); - }); - test("should only include 'numPrecedingMessagesToInclude' previous messages", async () => { - const stepBackRagGenerateUserPrompt = makeStepBackRagGenerateUserPrompt({ - ...config, - numPrecedingMessagesToInclude: 1, - }); - const res = await stepBackRagGenerateUserPrompt({ - reqId: "123", - userMessageText: "what is mongodb", - conversation: { - _id: new ObjectId(), - createdAt: new Date(), - messages: [ - { - role: "user", - content: "abracadabra", - id: new ObjectId(), - createdAt: new Date(), - }, - { - role: "assistant", - content: "avada kedavra", - id: new ObjectId(), - createdAt: new Date(), - }, - ], - }, - }); - expect(res.userMessage.contentForLlm).not.toContain("abracadabra"); - expect(res.userMessage.contentForLlm).toContain("avada kedavra"); - }); - test("should filter out context > maxContextTokenCount", async () => { - const stepBackRagGenerateUserPrompt = makeStepBackRagGenerateUserPrompt({ - ...config, - maxContextTokenCount: 1000, - }); - const res = await stepBackRagGenerateUserPrompt({ - reqId: "123", - userMessageText: "what is mongodb", - }); - expect(res.userMessage.contentForLlm).not.toContain("abracadabra"); - expect(res.userMessage.contentForLlm).toContain("avada kedavra"); - }); -}); diff --git a/packages/chatbot-server-mongodb-public/src/processors/makeStepBackRagGenerateUserPrompt.ts b/packages/chatbot-server-mongodb-public/src/processors/makeStepBackRagGenerateUserPrompt.ts deleted file mode 100644 index f0cff7edf..000000000 --- a/packages/chatbot-server-mongodb-public/src/processors/makeStepBackRagGenerateUserPrompt.ts +++ /dev/null @@ -1,232 +0,0 @@ -import { - EmbeddedContent, - FindContentFunc, - GenerateUserPromptFunc, - GenerateUserPromptFuncReturnValue, - Message, - UserMessage, -} from "mongodb-chatbot-server"; -import { OpenAI } from "mongodb-rag-core/openai"; -import { stripIndents } from "common-tags"; -import { strict as assert } from "assert"; -import { logRequest } from "../utils"; -import { makeMongoDbReferences } from "./makeMongoDbReferences"; -import { extractMongoDbMetadataFromUserMessage } from "./extractMongoDbMetadataFromUserMessage"; -import { userMessageMongoDbGuardrail } from "./userMessageMongoDbGuardrail"; -import { retrieveRelevantContent } from "./retrieveRelevantContent"; - -interface MakeStepBackGenerateUserPromptProps { - openAiClient: OpenAI; - model: string; - numPrecedingMessagesToInclude?: number; - findContent: FindContentFunc; - maxContextTokenCount?: number; -} - -/** - Generate user prompt using the ["step back" method of prompt engineering](https://arxiv.org/abs/2310.06117) - to construct search query. - Also extract metadata to use in the search query or reject the user message. - */ -export const makeStepBackRagGenerateUserPrompt = ({ - openAiClient, - model, - numPrecedingMessagesToInclude = 0, - findContent, - maxContextTokenCount = 1800, -}: MakeStepBackGenerateUserPromptProps) => { - assert( - numPrecedingMessagesToInclude >= 0, - "'numPrecedingMessagesToInclude' must be >= 0. Got: " + - numPrecedingMessagesToInclude - ); - assert( - Number.isInteger(numPrecedingMessagesToInclude), - "'numPrecedingMessagesToInclude' must be an integer. Got: " + - numPrecedingMessagesToInclude - ); - const stepBackRagGenerateUserPrompt: GenerateUserPromptFunc = async ({ - reqId, - userMessageText, - conversation, - customData, - }) => { - const messages = conversation?.messages ?? []; - const precedingMessagesToInclude = - numPrecedingMessagesToInclude === 0 - ? [] - : messages - .filter((m) => m.role !== "system") - .slice(-numPrecedingMessagesToInclude); - // Run both at once to save time - const [metadata, guardrailResult] = await Promise.all([ - extractMongoDbMetadataFromUserMessage({ - openAiClient, - model, - userMessageText, - messages: precedingMessagesToInclude, - }), - userMessageMongoDbGuardrail({ - userMessageText, - openAiClient, - model, - messages: precedingMessagesToInclude, - }), - ]); - if (guardrailResult.rejectMessage) { - const { reasoning } = guardrailResult; - logRequest({ - reqId, - message: `Rejected user message: ${JSON.stringify({ - userMessageText, - reasoning, - })}`, - }); - return { - userMessage: { - role: "user", - content: userMessageText, - rejectQuery: true, - customData: { - rejectionReason: reasoning, - }, - } satisfies UserMessage, - rejectQuery: true, - }; - } - logRequest({ - reqId, - message: `Extracted metadata from user message: ${JSON.stringify( - metadata - )}`, - }); - const metadataForQuery: Record = {}; - if (metadata.programmingLanguage) { - metadataForQuery.programmingLanguage = metadata.programmingLanguage; - } - if (metadata.mongoDbProduct) { - metadataForQuery.mongoDbProductName = metadata.mongoDbProduct; - } - - const { transformedUserQuery, content, queryEmbedding, searchQuery } = - await retrieveRelevantContent({ - findContent, - metadataForQuery, - model, - openAiClient, - precedingMessagesToInclude, - userMessageText, - }); - - logRequest({ - reqId, - message: `Found ${content.length} results for query: ${content - .map((c) => c.text) - .join("---")}`, - }); - const baseUserMessage = { - role: "user", - embedding: queryEmbedding, - content: userMessageText, - contextContent: content.map((c) => ({ - text: c.text, - url: c.url, - score: c.score, - })), - customData: { - ...customData, - ...metadata, - searchQuery, - transformedUserQuery, - }, - } satisfies UserMessage; - if (content.length === 0) { - return { - userMessage: { - ...baseUserMessage, - rejectQuery: true, - customData: { - ...customData, - rejectionReason: "Did not find any content matching the query", - }, - }, - rejectQuery: true, - references: [], - } satisfies GenerateUserPromptFuncReturnValue; - } - const userPrompt = { - ...baseUserMessage, - contentForLlm: makeUserContentForLlm({ - userMessageText, - stepBackUserQuery: transformedUserQuery, - messages: precedingMessagesToInclude, - metadata, - content, - maxContextTokenCount, - }), - } satisfies UserMessage; - const references = makeMongoDbReferences(content); - logRequest({ - reqId, - message: stripIndents`Generated user prompt for LLM: ${ - userPrompt.contentForLlm - } - Generated references: ${JSON.stringify(references)}`, - }); - return { - userMessage: userPrompt, - references, - } satisfies GenerateUserPromptFuncReturnValue; - }; - return stepBackRagGenerateUserPrompt; -}; - -function makeUserContentForLlm({ - userMessageText, - stepBackUserQuery, - messages, - metadata, - content, - maxContextTokenCount, -}: { - userMessageText: string; - stepBackUserQuery: string; - messages: Message[]; - metadata?: Record; - content: EmbeddedContent[]; - maxContextTokenCount: number; -}) { - const previousConversationMessages = messages - .map((message) => message.role.toUpperCase() + ": " + message.content) - .join("\n"); - const relevantMetadata = JSON.stringify({ - ...(metadata ?? {}), - searchQuery: stepBackUserQuery, - }); - - let currentTotalTokenCount = 0; - const contentForLlm = [...content] - .filter((c) => { - if (currentTotalTokenCount < maxContextTokenCount) { - currentTotalTokenCount += c.tokenCount; - return true; - } - return false; - }) - .map((c) => c.text) - .reverse() - .join("\n---\n"); - return `Use the following information to respond to the "User message". If you do not know the answer to the question based on the provided documentation content, respond with the following text: "I'm sorry, I do not know how to answer that question. Please try to rephrase your query." NEVER include Markdown links in the answer. -${ - previousConversationMessages.length > 0 - ? `Previous conversation messages: ${previousConversationMessages}` - : "" -} - -Content from the MongoDB documentation: -${contentForLlm} - -Relevant metadata: ${relevantMetadata} - -User message: ${userMessageText}`; -} diff --git a/packages/chatbot-server-mongodb-public/src/processors/makeStepBackUserQuery.eval.ts b/packages/chatbot-server-mongodb-public/src/processors/makeStepBackUserQuery.eval.ts deleted file mode 100644 index 83a770ff4..000000000 --- a/packages/chatbot-server-mongodb-public/src/processors/makeStepBackUserQuery.eval.ts +++ /dev/null @@ -1,189 +0,0 @@ -import { Scorer, EmbeddingSimilarity } from "autoevals"; -import { Eval } from "braintrust"; -import { - makeStepBackUserQuery, - StepBackUserQueryMongoDbFunction, -} from "./makeStepBackUserQuery"; -import { Message, updateFrontMatter } from "mongodb-chatbot-server"; -import { ObjectId } from "mongodb-rag-core/mongodb"; -import { MongoDbTag } from "../mongoDbMetadata"; -import { - OPENAI_PREPROCESSOR_CHAT_COMPLETION_DEPLOYMENT, - OPENAI_API_KEY, - OPENAI_ENDPOINT, - OPENAI_API_VERSION, - JUDGE_EMBEDDING_MODEL, - openAiClient, -} from "../eval/evalHelpers"; - -interface ExtractMongoDbMetadataEvalCase { - name: string; - input: { - previousMessages?: Message[]; - userMessageText: string; - }; - expected: StepBackUserQueryMongoDbFunction; - tags?: MongoDbTag[]; -} - -const evalCases: ExtractMongoDbMetadataEvalCase[] = [ - { - name: "Should return a step back user query", - input: { - userMessageText: updateFrontMatter( - "how do i add the values of sale_price in aggregation pipeline?", - { - mongoDbProduct: "Aggregation Framework", - } - ), - }, - expected: { - transformedUserQuery: - "How to calculate the sum of field in MongoDB aggregation?", - } satisfies StepBackUserQueryMongoDbFunction, - tags: ["aggregation"], - }, - { - name: "should step back based on previous messages", - input: { - userMessageText: "code example", - previousMessages: [ - { - role: "user", - content: "add documents node.js", - createdAt: new Date(), - id: new ObjectId(), - }, - { - role: "assistant", - content: - "You can add documents with the node.js driver insert and insertMany methods.", - createdAt: new Date(), - id: new ObjectId(), - }, - ], - }, - expected: { - transformedUserQuery: - "Code example of how to add documents to MongoDB using the Node.js Driver", - } satisfies StepBackUserQueryMongoDbFunction, - tags: ["aggregation"], - }, - { - name: "should not do step back if original message doesn't need to be mutated", - input: { - userMessageText: updateFrontMatter("How do I connect to MongoDB Atlas?", { - mongoDbProduct: "MongoDB Atlas", - }), - }, - expected: { - transformedUserQuery: "How do I connect to MongoDB Atlas?", - } satisfies StepBackUserQueryMongoDbFunction, - tags: ["atlas"], - }, - { - name: "should step back when query about specific data", - input: { - userMessageText: updateFrontMatter("create an index on the email field", { - mongoDbProduct: "Index Management", - }), - }, - expected: { - transformedUserQuery: - "How to create an index on a specific field in MongoDB?", - } satisfies StepBackUserQueryMongoDbFunction, - tags: ["indexes"], - }, - { - name: "should recognize when query doesn't need step back.", - input: { - userMessageText: updateFrontMatter( - "What are MongoDB's replica set election protocols?", - { - mongoDbProduct: "Replication", - } - ), - }, - expected: { - transformedUserQuery: - "What are MongoDB's replica set election protocols?", - } satisfies StepBackUserQueryMongoDbFunction, - tags: ["replication"], - }, - { - name: "Steps back when query involves MongoDB Atlas configuration", - input: { - userMessageText: updateFrontMatter( - "How do I set up multi-region clusters in MongoDB Atlas?", - { - mongoDbProduct: "MongoDB Atlas", - } - ), - }, - expected: { - transformedUserQuery: - "How to configure multi-region clusters in MongoDB Atlas?", - } satisfies StepBackUserQueryMongoDbFunction, - tags: ["atlas"], - }, - { - name: "Handles abstract query related to MongoDB performance tuning", - input: { - userMessageText: updateFrontMatter( - "improve MongoDB query performance with indexes", - { - mongoDbProduct: "Performance Tuning", - } - ), - }, - expected: { - transformedUserQuery: - "How can I use indexes to optimize MongoDB query performance?", - } satisfies StepBackUserQueryMongoDbFunction, - tags: ["performance", "indexes"], - }, -]; - -const QuerySimilarity: Scorer< - Awaited>, - unknown -> = async (args) => { - return await EmbeddingSimilarity({ - expected: args.expected?.transformedUserQuery, - output: args.output.transformedUserQuery, - model: JUDGE_EMBEDDING_MODEL, - azureOpenAi: { - apiKey: OPENAI_API_KEY, - apiVersion: OPENAI_API_VERSION, - endpoint: OPENAI_ENDPOINT, - }, - }); -}; - -const model = OPENAI_PREPROCESSOR_CHAT_COMPLETION_DEPLOYMENT; - -Eval("step-back-user-query", { - data: evalCases, - experimentName: model, - metadata: { - description: - "Evaluate the function that mutates the user query for better search results.", - model, - }, - maxConcurrency: 3, - timeout: 20000, - async task(input) { - try { - return await makeStepBackUserQuery({ - openAiClient, - model, - ...input, - }); - } catch (error) { - console.log(`Error evaluating input: ${input}`); - console.log(error); - throw error; - } - }, - scores: [QuerySimilarity], -}); diff --git a/packages/chatbot-server-mongodb-public/src/processors/makeStepBackUserQuery.test.ts b/packages/chatbot-server-mongodb-public/src/processors/makeStepBackUserQuery.test.ts deleted file mode 100644 index 0b72fdbaa..000000000 --- a/packages/chatbot-server-mongodb-public/src/processors/makeStepBackUserQuery.test.ts +++ /dev/null @@ -1,20 +0,0 @@ -import { makeMockOpenAIToolCall } from "../test/mockOpenAi"; -import { makeStepBackUserQuery } from "./makeStepBackUserQuery"; -import { OpenAI } from "mongodb-rag-core/openai"; -jest.mock("mongodb-rag-core/openai", () => - makeMockOpenAIToolCall({ transformedUserQuery: "foo" }) -); - -describe("makeStepBackUserQuery", () => { - const args: Parameters[0] = { - openAiClient: new OpenAI({ apiKey: "fake-api-key" }), - model: "best-model-ever", - userMessageText: "hi", - }; - - test("should return step back user query", async () => { - expect(await makeStepBackUserQuery(args)).toEqual({ - transformedUserQuery: "foo", - }); - }); -}); diff --git a/packages/chatbot-server-mongodb-public/src/processors/makeStepBackUserQuery.ts b/packages/chatbot-server-mongodb-public/src/processors/makeStepBackUserQuery.ts deleted file mode 100644 index 4ea2da9c2..000000000 --- a/packages/chatbot-server-mongodb-public/src/processors/makeStepBackUserQuery.ts +++ /dev/null @@ -1,129 +0,0 @@ -import { z } from "zod"; -import { - makeAssistantFunctionCallMessage, - makeFewShotUserMessageExtractorFunction, - makeUserMessage, -} from "./makeFewShotUserMessageExtractorFunction"; -import { updateFrontMatter } from "mongodb-chatbot-server"; -import { OpenAI } from "mongodb-rag-core/openai"; - -export const StepBackUserQueryMongoDbFunctionSchema = z.object({ - transformedUserQuery: z.string().describe("Transformed user query"), -}); - -export type StepBackUserQueryMongoDbFunction = z.infer< - typeof StepBackUserQueryMongoDbFunctionSchema ->; - -const name = "step_back_user_query"; -const description = "Create a user query using the 'step back' method."; - -const systemPrompt = `Your purpose is to generate a search query for a given user input. -You are doing this for MongoDB, and all queries relate to MongoDB products. -When constructing the query, take a "step back" to generate a more general search query that finds the data relevant to the user query if relevant. -If the user query is already a "good" search query, do not modify it. -For one word queries like "or", "and", "exists", if the query corresponds to a MongoDB operation, transform it into a fully formed question. Ex: 'what is the $or operator in MongoDB?' -You should also transform the user query into a fully formed question, if relevant.`; - -const fewShotExamples: OpenAI.ChatCompletionMessageParam[] = [ - // Example 1 - makeUserMessage( - updateFrontMatter("aggregate filter where flowerType is rose", { - programmingLanguage: "javascript", - mongoDbProduct: "Aggregation Framework", - }) - ), - makeAssistantFunctionCallMessage(name, { - transformedUserQuery: - "How do I filter by specific field value in a MongoDB aggregation pipeline?", - } satisfies StepBackUserQueryMongoDbFunction), - // Example 2 - makeUserMessage( - updateFrontMatter("How long does it take to import 2GB of data?", { - mongoDbProduct: "MongoDB Atlas", - }) - ), - makeAssistantFunctionCallMessage(name, { - transformedUserQuery: "What affects the rate of data import in MongoDB?", - } satisfies StepBackUserQueryMongoDbFunction), - // Example 3 - makeUserMessage( - updateFrontMatter("how to display the first five", { - mongoDbProduct: "Driver", - }) - ), - makeAssistantFunctionCallMessage(name, { - transformedUserQuery: - "How do I limit the number of results in a MongoDB query?", - } satisfies StepBackUserQueryMongoDbFunction), - // Example 4 - makeUserMessage( - updateFrontMatter("find documents python code example", { - programmingLanguage: "python", - mongoDbProduct: "Driver", - }) - ), - makeAssistantFunctionCallMessage(name, { - transformedUserQuery: - "Code example of how to find documents using the Python driver.", - } satisfies StepBackUserQueryMongoDbFunction), - // Example 5 - makeUserMessage( - updateFrontMatter("aggregate", { - mongoDbProduct: "Aggregation Framework", - }) - ), - makeAssistantFunctionCallMessage(name, { - transformedUserQuery: "Aggregation in MongoDB", - } satisfies StepBackUserQueryMongoDbFunction), - // Example 6 - makeUserMessage( - updateFrontMatter("$match", { - mongoDbProduct: "Aggregation Framework", - }) - ), - makeAssistantFunctionCallMessage(name, { - transformedUserQuery: - "What is the $match stage in a MongoDB aggregation pipeline?", - } satisfies StepBackUserQueryMongoDbFunction), - // Example 7 - makeUserMessage( - updateFrontMatter("How to connect to a MongoDB Atlas cluster?", { - mongoDbProduct: "MongoDB Atlas", - }) - ), - makeAssistantFunctionCallMessage(name, { - transformedUserQuery: "How to connect to a MongoDB Atlas cluster?", - } satisfies StepBackUserQueryMongoDbFunction), - // Example 8 - makeUserMessage( - updateFrontMatter("How to create a new cluster atlas", { - mongoDbProduct: "MongoDB Atlas", - }) - ), - makeAssistantFunctionCallMessage(name, { - transformedUserQuery: "How to create a new cluster in MongoDB Atlas?", - } satisfies StepBackUserQueryMongoDbFunction), - // Example 9 - makeUserMessage( - updateFrontMatter("What is a skill?", { - mongoDbProduct: "MongoDB University", - }) - ), - makeAssistantFunctionCallMessage(name,{ - transformedUserQuery: "What is the skill badge program on MongoDB University?", - } satisfies StepBackUserQueryMongoDbFunction), -]; - -/** - Generate search query using the ["step back" method of prompt engineering](https://arxiv.org/abs/2310.06117). - */ -export const makeStepBackUserQuery = makeFewShotUserMessageExtractorFunction({ - llmFunction: { - name, - description, - schema: StepBackUserQueryMongoDbFunctionSchema, - }, - systemPrompt, - fewShotExamples, -}); diff --git a/packages/chatbot-server-mongodb-public/src/processors/retrieveRelevantContent.test.ts b/packages/chatbot-server-mongodb-public/src/processors/retrieveRelevantContent.test.ts deleted file mode 100644 index cb37ba1f1..000000000 --- a/packages/chatbot-server-mongodb-public/src/processors/retrieveRelevantContent.test.ts +++ /dev/null @@ -1,84 +0,0 @@ -import { FindContentFunc, updateFrontMatter } from "mongodb-rag-core"; -import { retrieveRelevantContent } from "./retrieveRelevantContent"; -import { makeMockOpenAIToolCall } from "../test/mockOpenAi"; -import { StepBackUserQueryMongoDbFunction } from "./makeStepBackUserQuery"; -import { OpenAI } from "mongodb-rag-core/openai"; - -jest.mock("mongodb-rag-core/openai", () => - makeMockOpenAIToolCall({ transformedUserQuery: "transformedUserQuery" }) -); -describe("retrieveRelevantContent", () => { - const model = "model"; - const funcRes = { - transformedUserQuery: "transformedUserQuery", - } satisfies StepBackUserQueryMongoDbFunction; - const fakeEmbedding = [1, 2, 3]; - - const fakeContentBase = { - embeddings: { fakeModelName: fakeEmbedding }, - score: 1, - url: "url", - tokenCount: 3, - sourceName: "sourceName", - updated: new Date(), - }; - const fakeFindContent: FindContentFunc = async ({ query }) => { - return { - content: [ - { - text: "all about " + query, - ...fakeContentBase, - }, - ], - queryEmbedding: fakeEmbedding, - }; - }; - - const mockToolCallOpenAi = new OpenAI({ - apiKey: "apiKey", - }); - const argsBase = { - openAiClient: mockToolCallOpenAi, - model, - userMessageText: "something", - findContent: fakeFindContent, - }; - const metadataForQuery = { - programmingLanguage: "javascript", - mongoDbProduct: "Aggregation Framework", - }; - it("should return content, queryEmbedding, transformedUserQuery, searchQuery with metadata", async () => { - const res = await retrieveRelevantContent({ - ...argsBase, - metadataForQuery, - }); - expect(res).toEqual({ - content: [ - { - text: expect.any(String), - ...fakeContentBase, - }, - ], - queryEmbedding: fakeEmbedding, - transformedUserQuery: funcRes.transformedUserQuery, - searchQuery: updateFrontMatter( - funcRes.transformedUserQuery, - metadataForQuery - ), - }); - }); - it("should return content, queryEmbedding, transformedUserQuery, searchQuery without", async () => { - const res = await retrieveRelevantContent(argsBase); - expect(res).toEqual({ - content: [ - { - text: expect.any(String), - ...fakeContentBase, - }, - ], - queryEmbedding: fakeEmbedding, - transformedUserQuery: funcRes.transformedUserQuery, - searchQuery: funcRes.transformedUserQuery, - }); - }); -}); diff --git a/packages/chatbot-server-mongodb-public/src/processors/retrieveRelevantContent.ts b/packages/chatbot-server-mongodb-public/src/processors/retrieveRelevantContent.ts deleted file mode 100644 index a261d5270..000000000 --- a/packages/chatbot-server-mongodb-public/src/processors/retrieveRelevantContent.ts +++ /dev/null @@ -1,39 +0,0 @@ -import { makeStepBackUserQuery } from "./makeStepBackUserQuery"; -import { FindContentFunc, Message } from "mongodb-rag-core"; -import { updateFrontMatter } from "mongodb-rag-core"; -import { OpenAI } from "mongodb-rag-core/openai"; - -export const retrieveRelevantContent = async function ({ - openAiClient, - model, - precedingMessagesToInclude, - userMessageText, - metadataForQuery, - findContent, -}: { - openAiClient: OpenAI; - model: string; - precedingMessagesToInclude?: Message[]; - userMessageText: string; - metadataForQuery?: Record; - findContent: FindContentFunc; -}) { - const { transformedUserQuery } = await makeStepBackUserQuery({ - openAiClient, - model, - messages: precedingMessagesToInclude, - userMessageText: metadataForQuery - ? updateFrontMatter(userMessageText, metadataForQuery) - : userMessageText, - }); - - const searchQuery = metadataForQuery - ? updateFrontMatter(transformedUserQuery, metadataForQuery) - : transformedUserQuery; - - const { content, queryEmbedding } = await findContent({ - query: searchQuery, - }); - - return { content, queryEmbedding, transformedUserQuery, searchQuery }; -}; diff --git a/packages/chatbot-server-mongodb-public/src/systemPrompt.ts b/packages/chatbot-server-mongodb-public/src/systemPrompt.ts index 63174b02f..7f682dc93 100644 --- a/packages/chatbot-server-mongodb-public/src/systemPrompt.ts +++ b/packages/chatbot-server-mongodb-public/src/systemPrompt.ts @@ -1,25 +1,113 @@ -import { SystemPrompt } from "mongodb-chatbot-server"; +import { SEARCH_TOOL_NAME, SystemMessage } from "mongodb-chatbot-server"; +import { + mongoDbProducts, + mongoDbProgrammingLanguages, +} from "./mongoDbMetadata"; export const llmDoesNotKnowMessage = "I'm sorry, I do not know how to answer that question. Please try to rephrase your query."; +const personalityTraits = [ + "You enthusiastically answer user questions about MongoDB products and services.", + "Your personality is friendly and helpful, like a professor or tech lead.", + "Be concise and informative in your responses.", + "You were created by MongoDB.", + "Never speak negatively about the company MongoDB or its products and services.", +]; + +const responseFormat = [ + "NEVER include links in your answer.", + "Format your responses using Markdown. DO NOT mention that your response is formatted in Markdown. Do not use headers in your responses (e.g '# Some H1' or '## Some H2').", + "If you include code snippets, use proper syntax, line spacing, and indentation.", + "If you include a code example in your response, only include examples in one programming language, unless otherwise specified in the user query.", + "If the user query is about a programming language, include that language in the response.", +]; + +const technicalKnowledge = [ + "You ONLY know about the current version of MongoDB products. Versions are provided in the information.", + "If `version: null` in the retrieved content, then say that the product is unversioned.", + "Do not hallucinate information that is not provided within the search results or that you otherwise know to be true.", +]; + +const importantNotes = [ + `ALWAYS use the ${SEARCH_TOOL_NAME} tool at the start of the conversation. Zero exceptions!`, + `Use the ${SEARCH_TOOL_NAME} tool after every single user message.`, +]; + +const searchContentToolNotes = [ + ...importantNotes, + "Generate an appropriate search query for a given user input.", + "You are doing this for MongoDB, and all queries relate to MongoDB products.", + 'When constructing the query, take a "step back" to generate a more general search query that finds the data relevant to the user query if relevant.', + 'If the user query is already a "good" search query, do not modify it.', + 'For one word queries like "or", "and", "exists", if the query corresponds to a MongoDB operation, transform it into a fully formed question. Ex: If the user query is "or", transform it into "what is the $or operator in MongoDB?".', + "You should also transform the user query into a fully formed question, if relevant.", + `Only generate ONE ${SEARCH_TOOL_NAME} tool call per user message unless there are clearly multiple distinct queries needed to answer the user query.`, +]; + export const systemPrompt = { role: "system", content: `You are expert MongoDB documentation chatbot. -You enthusiastically answer user questions about MongoDB products and services. -Your personality is friendly and helpful, like a professor or tech lead. -Be concise and informative in your responses. -You were created by MongoDB. -Use the provided context information to answer user questions. You can also use your internal knowledge of MongoDB to inform the answer. + + +${makeMarkdownNumberedList(importantNotes)} + + + +You have the following personality: +${makeMarkdownNumberedList(personalityTraits)} + + + If you do not know the answer to the question, respond only with the following text: "${llmDoesNotKnowMessage}" -NEVER include links in your answer. -Format your responses using Markdown. DO NOT mention that your response is formatted in Markdown. Do not use headers in your responses (e.g '# Some H1' or '## Some H2'). -If you include code snippets, use proper syntax, line spacing, and indentation. +Response format: +${makeMarkdownNumberedList(responseFormat)} + + + + + +${makeMarkdownNumberedList(technicalKnowledge)} + + + + + +You know about the following products: +${mongoDbProducts + .map( + (product) => + `* ${product.id}: ${product.name}. ${ + ("description" in product ? product.description : null) ?? "" + }` + ) + .join("\n")} + +You know about the following programming languages: +${mongoDbProgrammingLanguages.map((language) => `* ${language.id}`).join("\n")} + + + + + + + +You have access to the ${SEARCH_TOOL_NAME} tool. Use the ${SEARCH_TOOL_NAME} tool as follows: +${makeMarkdownNumberedList(searchContentToolNotes)} + +When you search, include metadata about the relevant MongoDB programming language and product. + + + + + +${makeMarkdownNumberedList(importantNotes)} +`, +} satisfies SystemMessage; -If you include a code example in your response, only include examples in one programming language, -unless otherwise specified in the user query. If the user query is about a programming language, include that language in the response. -You ONLY know about the current version of MongoDB products. Versions are provided in the information. If \`version: null\`, then say that the product is unversioned.`, -} satisfies SystemPrompt; +function makeMarkdownNumberedList(items: string[]) { + return items.map((item, i) => `${i + 1}. ${item}`).join("\n"); +} diff --git a/packages/chatbot-server-mongodb-public/src/processors/retrieveRelevantContent.eval.ts b/packages/chatbot-server-mongodb-public/src/tools/search.eval.ts similarity index 83% rename from packages/chatbot-server-mongodb-public/src/processors/retrieveRelevantContent.eval.ts rename to packages/chatbot-server-mongodb-public/src/tools/search.eval.ts index cc06adaa0..92385e365 100644 --- a/packages/chatbot-server-mongodb-public/src/processors/retrieveRelevantContent.eval.ts +++ b/packages/chatbot-server-mongodb-public/src/tools/search.eval.ts @@ -21,11 +21,7 @@ import { f1AtK } from "../eval/scorers/f1AtK"; import { precisionAtK } from "../eval/scorers/precisionAtK"; import { recallAtK } from "../eval/scorers/recallAtK"; import { MongoDbTag } from "../mongoDbMetadata"; -import { - extractMongoDbMetadataFromUserMessage, - ExtractMongoDbMetadataFunction, -} from "./extractMongoDbMetadataFromUserMessage"; -import { retrieveRelevantContent } from "./retrieveRelevantContent"; +import { SearchToolArgs } from "./search"; interface RetrievalEvalCaseInput { query: string; @@ -49,7 +45,7 @@ interface RetrievalResult { } interface RetrievalTaskOutput { results: RetrievalResult[]; - extractedMetadata?: ExtractMongoDbMetadataFunction; + extractedMetadata?: SearchToolArgs; rewrittenQuery?: string; searchString?: string; } @@ -69,30 +65,21 @@ const { k } = retrievalConfig.findNearestNeighborsOptions; const retrieveRelevantContentEvalTask: EvalTask< RetrievalEvalCaseInput, - RetrievalTaskOutput + RetrievalTaskOutput, + RetrievalEvalCaseExpected > = async function (data) { - const metadataForQuery = await extractMongoDbMetadataFromUserMessage({ - openAiClient: preprocessorOpenAiClient, - model: retrievalConfig.preprocessorLlm, - userMessageText: data.query, - }); - const results = await retrieveRelevantContent({ - userMessageText: data.query, - model: retrievalConfig.preprocessorLlm, - openAiClient: preprocessorOpenAiClient, - findContent, - metadataForQuery, - }); + // TODO: (EAI-991) implement retrieval task for evaluation + const extractedMetadata: SearchToolArgs = { + productName: null, + programmingLanguage: null, + query: data.query, + }; return { - results: results.content.map((c) => ({ - url: c.url, - content: c.text, - score: c.score, - })), - extractedMetadata: metadataForQuery, - rewrittenQuery: results.transformedUserQuery, - searchString: results.searchQuery, + results: [], + extractedMetadata, + rewrittenQuery: undefined, + searchString: undefined, }; }; diff --git a/packages/chatbot-server-mongodb-public/src/tools/search.ts b/packages/chatbot-server-mongodb-public/src/tools/search.ts new file mode 100644 index 000000000..f89c0ae85 --- /dev/null +++ b/packages/chatbot-server-mongodb-public/src/tools/search.ts @@ -0,0 +1,86 @@ +import { + SearchResult, + SearchTool, + SearchToolReturnValue, +} from "mongodb-chatbot-server"; +import { FindContentFunc, updateFrontMatter } from "mongodb-rag-core"; +import { tool, ToolExecutionOptions } from "mongodb-rag-core/aiSdk"; +import { z } from "zod"; +import { + mongoDbProducts, + mongoDbProgrammingLanguageIds, +} from "../mongoDbMetadata"; + +const SearchToolArgsSchema = z.object({ + productName: z + .enum(mongoDbProducts.map((product) => product.id) as [string, ...string[]]) + .nullable() + .optional() + .describe("Most relevant MongoDB product for query. Leave null if unknown"), + programmingLanguage: z + .enum(mongoDbProgrammingLanguageIds) + .nullable() + .optional() + .describe( + "Most relevant programming language for query. Leave null if unknown" + ), + query: z.string().describe("Search query"), +}); + +export type SearchToolArgs = z.infer; + +export function makeSearchTool( + findContent: FindContentFunc +): SearchTool { + return tool({ + parameters: SearchToolArgsSchema, + description: "Search MongoDB content", + // This shows only the URL and text of the result, not the metadata (needed for references) to the model. + experimental_toToolResultContent(result) { + return [ + { + type: "text", + text: JSON.stringify({ + content: result.content.map( + (r) => + ({ + url: r.url, + text: r.text, + } satisfies SearchResult) + ), + }), + }, + ]; + }, + async execute( + args: SearchToolArgs, + _options: ToolExecutionOptions + ): Promise { + const { query, productName, programmingLanguage } = args; + + const nonNullMetadata: Record = {}; + if (productName) { + nonNullMetadata.productName = productName; + } + if (programmingLanguage) { + nonNullMetadata.programmingLanguage = programmingLanguage; + } + + const queryWithMetadata = updateFrontMatter(query, nonNullMetadata); + const content = await findContent({ query: queryWithMetadata }); + + const result: SearchToolReturnValue = { + content: content.content.map((item) => ({ + url: item.url, + metadata: { + pageTitle: item.metadata?.pageTitle, + sourceName: item.sourceName, + }, + text: item.text, + })), + }; + + return result; + }, + }); +} diff --git a/packages/chatbot-server-mongodb-public/src/tracing/extractTracingData.ts b/packages/chatbot-server-mongodb-public/src/tracing/extractTracingData.ts index 3de4df102..37c9a2371 100644 --- a/packages/chatbot-server-mongodb-public/src/tracing/extractTracingData.ts +++ b/packages/chatbot-server-mongodb-public/src/tracing/extractTracingData.ts @@ -12,6 +12,7 @@ export function extractTracingData( messages: Message[], assistantMessageId: ObjectId ) { + // FIXME: this is throwing after the generation is complete. don't forget to fix before merge of EAI-990 const evalAssistantMessageIdx = messages.findLastIndex( (message) => message.role === "assistant" && message.id.equals(assistantMessageId) @@ -55,7 +56,7 @@ export function extractTracingData( if (isVerifiedAnswer) { tags.push("verified_answer"); } - + // TODO: this is throwing errs now. figure out and fix. const llmDoesNotKnow = evalAssistantMessage?.content.includes( llmDoesNotKnowMessage ); diff --git a/packages/chatbot-server-mongodb-public/src/tracing/routesUpdateTraceHandlers.ts b/packages/chatbot-server-mongodb-public/src/tracing/routesUpdateTraceHandlers.ts index cb283f38d..87bcc08f4 100644 --- a/packages/chatbot-server-mongodb-public/src/tracing/routesUpdateTraceHandlers.ts +++ b/packages/chatbot-server-mongodb-public/src/tracing/routesUpdateTraceHandlers.ts @@ -1,5 +1,5 @@ import { strict as assert } from "assert"; -import { UpdateTraceFunc } from "mongodb-chatbot-server/build/routes/conversations/UpdateTraceFunc"; +import { UpdateTraceFunc } from "mongodb-chatbot-server"; import { ObjectId } from "mongodb-rag-core/mongodb"; import { extractTracingData } from "./extractTracingData"; import { LlmAsAJudge, getLlmAsAJudgeScores } from "./getLlmAsAJudgeScores"; diff --git a/packages/mongodb-chatbot-server/src/processors/InputGuardrail.ts b/packages/mongodb-chatbot-server/src/processors/InputGuardrail.ts new file mode 100644 index 000000000..ffc48bbc7 --- /dev/null +++ b/packages/mongodb-chatbot-server/src/processors/InputGuardrail.ts @@ -0,0 +1,40 @@ +import { GenerateResponseParams } from "./GenerateResponse"; + +export type InputGuardrail< + Metadata extends Record | undefined = Record +> = (generateResponseParams: Omit) => Promise<{ + rejected: boolean; + reason?: string; + message: string; + metadata: Metadata; +}>; + +export function withAbortControllerGuardrail( + fn: (abortController: AbortController) => Promise, + guardrailPromise?: Promise +): Promise<{ result: T | null; guardrailResult: Awaited | undefined }> { + const abortController = new AbortController(); + return (async () => { + try { + // Run both the main function and guardrail function in parallel + const [result, guardrailResult] = await Promise.all([ + fn(abortController).catch((error) => { + // If the main function was aborted by the guardrail, return null + if (error.name === "AbortError") { + return null as T | null; + } + throw error; + }), + guardrailPromise, + ]); + + return { result, guardrailResult }; + } catch (error) { + // If an unexpected error occurs, abort any ongoing operations + if (!abortController.signal.aborted) { + abortController.abort(); + } + throw error; + } + })(); +} diff --git a/packages/mongodb-chatbot-server/src/processors/MakeReferenceLinksFunc.ts b/packages/mongodb-chatbot-server/src/processors/MakeReferenceLinksFunc.ts index 40c197cbb..bbb3da61a 100644 --- a/packages/mongodb-chatbot-server/src/processors/MakeReferenceLinksFunc.ts +++ b/packages/mongodb-chatbot-server/src/processors/MakeReferenceLinksFunc.ts @@ -1,6 +1,9 @@ -import { EmbeddedContent, References } from "mongodb-rag-core"; +import { References } from "mongodb-rag-core"; +import { SearchResult } from "./SearchResult"; /** Function that generates the references in the response to user. */ -export type MakeReferenceLinksFunc = (chunks: EmbeddedContent[]) => References; +export type MakeReferenceLinksFunc = ( + searchResults: SearchResult[] +) => References; diff --git a/packages/mongodb-chatbot-server/src/processors/SearchResult.ts b/packages/mongodb-chatbot-server/src/processors/SearchResult.ts new file mode 100644 index 000000000..f338f9f3f --- /dev/null +++ b/packages/mongodb-chatbot-server/src/processors/SearchResult.ts @@ -0,0 +1,7 @@ +import { EmbeddedContent } from "mongodb-rag-core"; + +export type SearchResult = Partial & { + url: string; + text: string; + metadata?: Record; +}; diff --git a/packages/mongodb-chatbot-server/src/processors/generateResponseWithSearchTool.test.ts b/packages/mongodb-chatbot-server/src/processors/generateResponseWithSearchTool.test.ts new file mode 100644 index 000000000..6c54f426d --- /dev/null +++ b/packages/mongodb-chatbot-server/src/processors/generateResponseWithSearchTool.test.ts @@ -0,0 +1,390 @@ +import { jest } from "@jest/globals"; +import { + makeGenerateResponseWithSearchTool, + SEARCH_TOOL_NAME, + SearchToolReturnValue, +} from "./generateResponseWithSearchTool"; +import { FilterPreviousMessages } from "./FilterPreviousMessages"; +import { + AssistantMessage, + DataStreamer, + SystemMessage, +} from "mongodb-rag-core"; +import { z } from "zod"; +import { + ToolExecutionOptions, + MockLanguageModelV1, + tool, + simulateReadableStream, + LanguageModelV1StreamPart, +} from "mongodb-rag-core/aiSdk"; +import { ObjectId } from "mongodb-rag-core/mongodb"; +import { InputGuardrail } from "./InputGuardrail"; +import { GenerateResponseReturnValue } from "./GenerateResponse"; + +// Define the search tool arguments schema +const SearchToolArgsSchema = z.object({ + query: z.string(), +}); +type SearchToolArgs = z.infer; + +const latestMessageText = "Hello"; + +const mockReqId = "test"; + +const mockContent = [ + { + url: "https://example.com/", + text: `Content!`, + metadata: { + pageTitle: "Example Page", + }, + }, +]; + +const mockReferences = mockContent.map((content) => ({ + url: content.url, + title: content.metadata.pageTitle, +})); + +// Create a mock search tool that matches the SearchTool interface +const mockSearchTool = tool({ + parameters: SearchToolArgsSchema, + description: "Search MongoDB content", + async execute( + _args: SearchToolArgs, + _options: ToolExecutionOptions + ): Promise { + return { + content: mockContent, + }; + }, +}); + +// Must have, but details don't matter +const mockFinishChunk = { + type: "finish" as const, + finishReason: "stop" as const, + usage: { + completionTokens: 10, + promptTokens: 3, + }, +} satisfies LanguageModelV1StreamPart; + +const finalAnswer = "Final answer"; +const finalAnswerChunks = finalAnswer.split(" "); +const finalAnswerStreamChunks = finalAnswerChunks.map((word, i) => { + if (i === 0) { + return { + type: "text-delta" as const, + textDelta: word, + }; + } + return { + type: "text-delta" as const, + textDelta: ` ${word}`, + }; +}); + +// Note: have to make this constructor b/c the ReadableStream +// can only be used once successfully. +const makeFinalAnswerStream = () => + simulateReadableStream({ + chunks: [ + ...finalAnswerStreamChunks, + mockFinishChunk, + ] satisfies LanguageModelV1StreamPart[], + chunkDelayInMs: 100, + }); + +const searchToolMockArgs = { + query: "test", +} satisfies SearchToolArgs; + +const makeToolCallStream = () => + simulateReadableStream({ + chunks: [ + { + type: "tool-call" as const, + toolCallId: "abc123", + toolName: SEARCH_TOOL_NAME, + toolCallType: "function" as const, + args: JSON.stringify(searchToolMockArgs), + }, + // ...finalAnswerStreamChunks, + mockFinishChunk, + ] satisfies LanguageModelV1StreamPart[], + chunkDelayInMs: 100, + }); + +jest.setTimeout(5000); +// Mock language model following the AI SDK testing documentation +// Create a minimalist mock for the language model +const makeMockLanguageModel = () => { + // On first call, return tool call stream + // On second call, return final answer stream + // On subsequent calls, return final answer + let counter = 0; + const doStreamCalls = [ + async () => { + return { + stream: makeToolCallStream(), + rawCall: { rawPrompt: null, rawSettings: {} }, + }; + }, + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore + async () => { + return { + stream: makeFinalAnswerStream(), + rawCall: { rawPrompt: null, rawSettings: {} }, + }; + }, + ]; + return new MockLanguageModelV1({ + doStream: () => { + const streamCallPromise = doStreamCalls[counter](); + if (counter < doStreamCalls.length) { + counter++; + } + return streamCallPromise; + }, + }); +}; + +const mockSystemMessage: SystemMessage = { + role: "system", + content: "You are a helpful assistant.", +}; + +const mockLlmNotWorkingMessage = + "Sorry, I am having trouble with the language model."; + +const mockGuardrail: InputGuardrail = async () => ({ + rejected: true, + message: "Content policy violation", + metadata: { reason: "inappropriate" }, +}); + +const mockThrowingLanguageModel: MockLanguageModelV1 = new MockLanguageModelV1({ + doStream: async () => { + throw new Error("LLM error"); + }, +}); + +const makeMakeGenerateResponseWithSearchToolArgs = () => ({ + languageModel: makeMockLanguageModel(), + llmNotWorkingMessage: mockLlmNotWorkingMessage, + systemMessage: mockSystemMessage, + searchTool: mockSearchTool, +}); + +const generateResponseBaseArgs = { + conversation: { + _id: new ObjectId(), + createdAt: new Date(), + messages: [], + }, + latestMessageText, + shouldStream: false, + reqId: mockReqId, +}; +describe("generateResponseWithSearchTool", () => { + // Reset mocks before each test + beforeEach(() => { + jest.clearAllMocks(); + }); + + describe("makeGenerateResponseWithSearchTool", () => { + const generateResponse = makeGenerateResponseWithSearchTool( + makeMakeGenerateResponseWithSearchToolArgs() + ); + it("should return a function", () => { + expect(typeof generateResponse).toBe("function"); + }); + it("should filter previous messages", async () => { + // Properly type the mock function to match FilterPreviousMessages + const mockFilterPreviousMessages = jest + .fn() + .mockImplementation((_conversation) => + Promise.resolve([]) + ) as FilterPreviousMessages; + const generateResponse = makeGenerateResponseWithSearchTool({ + ...makeMakeGenerateResponseWithSearchToolArgs(), + filterPreviousMessages: mockFilterPreviousMessages, + }); + + // We don't care about the output so not getting the return value + await generateResponse(generateResponseBaseArgs); + + expect(mockFilterPreviousMessages).toHaveBeenCalledWith({ + _id: expect.any(ObjectId), + createdAt: expect.any(Date), + messages: [], + }); + }); + + it("should make reference links", async () => { + const generateResponse = makeGenerateResponseWithSearchTool( + makeMakeGenerateResponseWithSearchToolArgs() + ); + + const result = await generateResponse(generateResponseBaseArgs); + + const references = (result.messages.at(-1) as AssistantMessage) + .references; + expect(references).toMatchObject(mockReferences); + }); + + describe("non-streaming", () => { + test("should handle successful generation non-streaming", async () => { + const generateResponse = makeGenerateResponseWithSearchTool( + makeMakeGenerateResponseWithSearchToolArgs() + ); + + const result = await generateResponse(generateResponseBaseArgs); + + expectSuccessfulResult(result); + }); + + // TODO: (EAI-995): make work as part of guardrail changes + test.skip("should handle guardrail rejection", async () => { + const generateResponse = makeGenerateResponseWithSearchTool({ + ...makeMakeGenerateResponseWithSearchToolArgs(), + inputGuardrail: mockGuardrail, + }); + + const result = await generateResponse(generateResponseBaseArgs); + + expect(result.messages[1].role).toBe("assistant"); + expect(result.messages[1].content).toBe("Content policy violation"); + expect(result.messages[1].metadata).toEqual({ + reason: "inappropriate", + }); + }); + + test("should handle error in language model", async () => { + const generateResponse = makeGenerateResponseWithSearchTool({ + ...makeMakeGenerateResponseWithSearchToolArgs(), + languageModel: mockThrowingLanguageModel, + }); + + const result = await generateResponse(generateResponseBaseArgs); + + expect(result.messages[0].role).toBe("user"); + expect(result.messages[0].content).toBe(latestMessageText); + expect(result.messages.at(-1)?.role).toBe("assistant"); + expect(result.messages.at(-1)?.content).toBe(mockLlmNotWorkingMessage); + }); + }); + + describe("streaming mode", () => { + // Create a mock DataStreamer implementation + const makeMockDataStreamer = () => { + const mockStreamData = jest.fn(); + const mockConnect = jest.fn(); + const mockDisconnect = jest.fn(); + const mockStream = jest.fn().mockImplementation(async () => { + // Process the stream and return a string result + return "Hello"; + }); + const dataStreamer = { + connected: false, + connect: mockConnect, + disconnect: mockDisconnect, + streamData: mockStreamData, + stream: mockStream, + } as DataStreamer; + + return dataStreamer; + }; + test("should handle successful streaming", async () => { + const mockDataStreamer = makeMockDataStreamer(); + const generateResponse = makeGenerateResponseWithSearchTool( + makeMakeGenerateResponseWithSearchToolArgs() + ); + + const result = await generateResponse({ + ...generateResponseBaseArgs, + shouldStream: true, + dataStreamer: mockDataStreamer, + }); + + expect(mockDataStreamer.streamData).toHaveBeenCalledTimes(3); + expect(mockDataStreamer.streamData).toHaveBeenCalledWith({ + data: "Final", + type: "delta", + }); + expect(mockDataStreamer.streamData).toHaveBeenCalledWith({ + type: "references", + data: expect.any(Array), + }); + expectSuccessfulResult(result); + }); + + // TODO: (EAI-995): make work as part of guardrail changes + test.skip("should handle successful generation with guardrail", async () => { + // TODO: add + }); + // TODO: (EAI-995): make work as part of guardrail changes + test.skip("should handle streaming with guardrail rejection", async () => { + // TODO: add + }); + + test("should handle error in language model", async () => { + const generateResponse = makeGenerateResponseWithSearchTool({ + ...makeMakeGenerateResponseWithSearchToolArgs(), + languageModel: mockThrowingLanguageModel, + }); + + const mockDataStreamer = makeMockDataStreamer(); + const result = await generateResponse({ + ...generateResponseBaseArgs, + shouldStream: true, + dataStreamer: mockDataStreamer, + }); + + expect(mockDataStreamer.streamData).toHaveBeenCalledTimes(1); + expect(mockDataStreamer.streamData).toHaveBeenCalledWith({ + data: mockLlmNotWorkingMessage, + type: "delta", + }); + + expect(result.messages[0].role).toBe("user"); + expect(result.messages[0].content).toBe(latestMessageText); + expect(result.messages.at(-1)?.role).toBe("assistant"); + expect(result.messages.at(-1)?.content).toBe(mockLlmNotWorkingMessage); + }); + }); + }); +}); + +function expectSuccessfulResult(result: GenerateResponseReturnValue) { + expect(result).toHaveProperty("messages"); + expect(result.messages).toHaveLength(4); // User + assistant (tool call) + tool result + assistant + expect(result.messages[0]).toMatchObject({ + role: "user", + content: latestMessageText, + }); + expect(result.messages[1]).toMatchObject({ + role: "assistant", + toolCall: { + id: "abc123", + function: { name: "search_content", arguments: '{"query":"test"}' }, + type: "function", + }, + content: "", + }); + + expect(result.messages[2]).toMatchObject({ + role: "tool", + name: "search_content", + content: JSON.stringify({ + content: mockContent, + }), + }); + expect(result.messages[3]).toMatchObject({ + role: "assistant", + content: finalAnswer, + }); +} diff --git a/packages/mongodb-chatbot-server/src/processors/generateResponseWithSearchTool.ts b/packages/mongodb-chatbot-server/src/processors/generateResponseWithSearchTool.ts new file mode 100644 index 000000000..065489b90 --- /dev/null +++ b/packages/mongodb-chatbot-server/src/processors/generateResponseWithSearchTool.ts @@ -0,0 +1,369 @@ +import { + References, + SomeMessage, + SystemMessage, + UserMessage, + AssistantMessage, + ToolMessage, + EmbeddedContent, +} from "mongodb-rag-core"; +import { z } from "zod"; +import { GenerateResponse } from "./GenerateResponse"; +import { + CoreAssistantMessage, + CoreMessage, + LanguageModel, + streamText, + Tool, + ToolCallPart, + ToolChoice, + ToolExecutionOptions, + ToolResultUnion, + ToolSet, + CoreToolMessage, +} from "mongodb-rag-core/aiSdk"; +import { FilterPreviousMessages } from "./FilterPreviousMessages"; +import { InputGuardrail, withAbortControllerGuardrail } from "./InputGuardrail"; +import { strict as assert } from "assert"; +import { MakeReferenceLinksFunc } from "./MakeReferenceLinksFunc"; +import { makeDefaultReferenceLinks } from "./makeDefaultReferenceLinks"; +import { SearchResult } from "./SearchResult"; + +export const SEARCH_TOOL_NAME = "search_content"; + +export type SearchToolReturnValue = { + content: SearchResult[]; +}; + +export type SearchTool = Tool< + ARGUMENTS, + SearchToolReturnValue +> & { + execute: ( + args: z.infer, + options: ToolExecutionOptions + ) => PromiseLike; +}; + +type SearchToolResult = ToolResultUnion<{ + [SEARCH_TOOL_NAME]: SearchTool; +}>; + +export interface GenerateResponseWithSearchToolParams< + ARGUMENTS extends z.ZodTypeAny +> { + languageModel: LanguageModel; + llmNotWorkingMessage: string; + inputGuardrail?: InputGuardrail; + systemMessage: SystemMessage; + filterPreviousMessages?: FilterPreviousMessages; + /** + Required tool for performing content search and gathering {@link References} + */ + additionalTools?: ToolSet; + makeReferenceLinks?: MakeReferenceLinksFunc; + maxSteps?: number; + toolChoice?: ToolChoice<{ search_content: SearchTool }>; + searchTool: SearchTool; +} + +/** + Generate chatbot response using RAG and a search tool named {@link SEARCH_TOOL_NAME}. + */ +export function makeGenerateResponseWithSearchTool< + ARGUMENTS extends z.ZodTypeAny +>({ + languageModel, + llmNotWorkingMessage, + inputGuardrail, + systemMessage, + filterPreviousMessages, + additionalTools, + makeReferenceLinks = makeDefaultReferenceLinks, + maxSteps = 2, + searchTool, + toolChoice, +}: GenerateResponseWithSearchToolParams): GenerateResponse { + return async function generateResponseWithSearchTool({ + conversation, + latestMessageText, + clientContext, + customData, + shouldStream, + reqId, + dataStreamer, + request, + }) { + if (shouldStream) { + assert(dataStreamer, "dataStreamer is required for streaming"); + } + const userMessage = { + role: "user", + content: latestMessageText, + } satisfies UserMessage; + try { + // Get preceding messages to include in the LLM prompt + const filteredPreviousMessages = filterPreviousMessages + ? (await filterPreviousMessages(conversation)).map( + formatMessageForAiSdk + ) + : []; + + const toolSet = { + [SEARCH_TOOL_NAME]: searchTool, + ...(additionalTools ?? {}), + } satisfies ToolSet; + + const generationArgs = { + model: languageModel, + messages: [ + systemMessage, + ...filteredPreviousMessages, + userMessage, + ] satisfies CoreMessage[], + tools: toolSet, + toolChoice, + maxSteps, + }; + + // TODO: EAI-995: validate that this works as part of guardrail changes + // Guardrail used to validate the input + // while the LLM is generating the response + const inputGuardrailPromise = inputGuardrail + ? inputGuardrail({ + conversation, + latestMessageText, + clientContext, + customData, + shouldStream, + reqId, + dataStreamer, + request, + }) + : undefined; + + const references: References = []; + const { result, guardrailResult } = await withAbortControllerGuardrail( + async (controller) => { + // Pass the tools as a separate parameter + const result = streamText({ + ...generationArgs, + // Abort the stream if the guardrail AbortController is triggered + abortSignal: controller.signal, + // Add the search tool results to the references + onStepFinish: async ({ toolResults }) => { + toolResults?.forEach( + (toolResult: SearchToolResult) => { + if ( + toolResult.toolName === SEARCH_TOOL_NAME && + toolResult.result.content + ) { + // Map the search tool results to the References format + const searchResults = toolResult.result.content; + references.push(...makeReferenceLinks(searchResults)); + } + } + ); + }, + }); + + for await (const chunk of result.fullStream) { + switch (chunk.type) { + case "text-delta": + if (shouldStream) { + dataStreamer?.streamData({ + data: chunk.textDelta, + type: "delta", + }); + } + break; + case "tool-call": + // do nothing with tool calls for now... + break; + case "error": + throw new Error( + typeof chunk.error === "string" + ? chunk.error + : String(chunk.error) + ); + default: + break; + } + } + try { + // Transform filtered references to include the required title property + + dataStreamer?.streamData({ + data: references, + type: "references", + }); + return result; + } catch (error: unknown) { + throw new Error(typeof error === "string" ? error : String(error)); + } + }, + inputGuardrailPromise + ); + const text = await result?.text; + assert(text, "text is required"); + const messages = (await result?.response)?.messages; + assert(messages, "messages is required"); + + return handleReturnGeneration({ + userMessage, + guardrailResult, + messages, + customData, + references, + }); + } catch (error: unknown) { + dataStreamer?.streamData({ + data: llmNotWorkingMessage, + type: "delta", + }); + // Handle other errors + return { + messages: [ + userMessage, + { + role: "assistant", + content: llmNotWorkingMessage, + }, + ], + }; + } + }; +} + +type ResponseMessage = CoreAssistantMessage | CoreToolMessage; + +/** + Generate the final messages to send to the user based on guardrail result and text generation result + */ +function handleReturnGeneration({ + userMessage, + guardrailResult, + messages, + references, +}: { + userMessage: UserMessage; + guardrailResult: + | { rejected: boolean; message: string; metadata?: Record } + | undefined; + messages: ResponseMessage[]; + references?: References; + customData?: Record; +}): { messages: SomeMessage[] } { + userMessage.rejectQuery = guardrailResult?.rejected; + userMessage.customData = { + ...userMessage.customData, + ...guardrailResult, + }; + return { + messages: [ + userMessage, + ...formatMessageForGeneration(messages, references ?? []), + ] satisfies SomeMessage[], + }; +} + +function formatMessageForGeneration( + messages: ResponseMessage[], + references: References +): SomeMessage[] { + const messagesOut = messages + .map((m) => { + if (m.role === "assistant") { + const baseMessage: Partial & { role: "assistant" } = { + role: "assistant", + }; + if (typeof m.content === "string") { + baseMessage.content = m.content; + } else { + m.content.forEach((c) => { + if (c.type === "text") { + baseMessage.content = c.text; + } + if (c.type === "tool-call") { + baseMessage.toolCall = { + id: c.toolCallId, + function: { + name: c.toolName, + arguments: JSON.stringify(c.args), + }, + type: "function", + }; + } + }); + } + + return { + ...baseMessage, + content: baseMessage.content ?? "", + } satisfies AssistantMessage; + } else if (m.role === "tool") { + const baseMessage: Partial & { role: "tool" } = { + role: "tool", + }; + if (typeof m.content === "string") { + baseMessage.content = m.content; + } else { + m.content.forEach((c) => { + if (c.type === "tool-result") { + baseMessage.name = c.toolName; + baseMessage.content = JSON.stringify(c.result); + } + }); + } + return { + ...baseMessage, + name: baseMessage.name ?? "", + content: baseMessage.content ?? "", + } satisfies ToolMessage; + } + }) + .filter((m): m is AssistantMessage | ToolMessage => m !== undefined); + const latestMessage = messagesOut.at(-1); + if (latestMessage?.role === "assistant") { + latestMessage.references = references; + } + return messagesOut; +} + +function formatMessageForAiSdk(message: SomeMessage): CoreMessage { + if (message.role === "assistant" && typeof message.content === "object") { + // Convert assistant messages with object content to proper format + if (message.toolCall) { + // This is a tool call message + return { + role: "assistant", + content: [ + { + type: "tool-call", + toolCallId: message.toolCall.id, + toolName: message.toolCall.function.name, + args: message.toolCall.function.arguments, + } satisfies ToolCallPart, + ], + } satisfies CoreAssistantMessage; + } else { + // Fallback for other object content + return { + role: "assistant", + content: JSON.stringify(message.content), + } satisfies CoreAssistantMessage; + } + } else if (message.role === "tool") { + // Convert tool messages to the format expected by the AI SDK + return { + role: "assistant", // Use assistant role instead of function + content: + typeof message.content === "string" + ? message.content + : JSON.stringify(message.content), + } satisfies CoreMessage; + } else { + // User and system messages can pass through + return message satisfies CoreMessage; + } +} diff --git a/packages/mongodb-chatbot-server/src/processors/index.ts b/packages/mongodb-chatbot-server/src/processors/index.ts index bf068f803..55a42146e 100644 --- a/packages/mongodb-chatbot-server/src/processors/index.ts +++ b/packages/mongodb-chatbot-server/src/processors/index.ts @@ -4,6 +4,10 @@ export * from "./QueryPreprocessorFunc"; export * from "./filterOnlySystemPrompt"; export * from "./makeDefaultReferenceLinks"; export * from "./makeFilterNPreviousMessages"; +export * from "./includeChunksForMaxTokensPossible"; +export * from "./InputGuardrail"; +export * from "./generateResponseWithSearchTool"; export * from "./makeVerifiedAnswerGenerateResponse"; export * from "./includeChunksForMaxTokensPossible"; export * from "./GenerateResponse"; +export * from "./SearchResult"; diff --git a/packages/mongodb-chatbot-server/src/processors/makeDefaultReferenceLinks.ts b/packages/mongodb-chatbot-server/src/processors/makeDefaultReferenceLinks.ts index 871ce03dc..88971b8e3 100644 --- a/packages/mongodb-chatbot-server/src/processors/makeDefaultReferenceLinks.ts +++ b/packages/mongodb-chatbot-server/src/processors/makeDefaultReferenceLinks.ts @@ -1,3 +1,4 @@ +import { References } from "mongodb-rag-core"; import { MakeReferenceLinksFunc } from "./MakeReferenceLinksFunc"; /** @@ -13,24 +14,28 @@ import { MakeReferenceLinksFunc } from "./MakeReferenceLinksFunc"; export const makeDefaultReferenceLinks: MakeReferenceLinksFunc = (chunks) => { // Filter chunks with unique URLs const uniqueUrls = new Set(); - const uniqueChunks = chunks.filter((chunk) => { + const uniqueReferenceChunks = chunks.filter((chunk) => { if (!uniqueUrls.has(chunk.url)) { uniqueUrls.add(chunk.url); - return true; // Keep the chunk as it has a unique URL + return true; // Keep the reference as it has a unique URL } - return false; // Discard the chunk as its URL is not unique + return false; // Discard the reference as its URL is not unique }); - return uniqueChunks.map((chunk) => { + return uniqueReferenceChunks.map((chunk) => { const url = new URL(chunk.url).href; - const title = chunk.metadata?.pageTitle ?? url; + // Ensure title is always a string by checking its type + const pageTitle = chunk.metadata?.pageTitle; + const title = typeof pageTitle === "string" ? pageTitle : url; + const sourceName = chunk.sourceName; + return { title, url, metadata: { - sourceName: chunk.sourceName, + sourceName, tags: chunk.metadata?.tags ?? [], }, }; - }); + }) satisfies References; }; diff --git a/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.test.ts b/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.test.ts index 1fb3125c7..4883c60c4 100644 --- a/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.test.ts +++ b/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.test.ts @@ -1,14 +1,10 @@ import request from "supertest"; import "dotenv/config"; import { - assertEnvVars, - CORE_ENV_VARS, - makeMongoDbConversationsService, ConversationsService, Conversation, defaultConversationConstants, Message, - makeOpenAiChatLlm, SomeMessage, } from "mongodb-rag-core"; import { Express } from "express"; @@ -21,14 +17,11 @@ import { ApiConversation, ApiMessage } from "./utils"; import { stripIndent } from "common-tags"; import { makeApp, DEFAULT_API_PREFIX } from "../../app"; import { makeTestApp } from "../../test/testHelpers"; -import { makeTestAppConfig, systemPrompt } from "../../test/testHelpers"; import { AppConfig } from "../../app"; import { strict as assert } from "assert"; -import { NO_VECTOR_CONTENT, REJECT_QUERY_CONTENT } from "../../test/testConfig"; -import { OpenAI } from "mongodb-rag-core/openai"; import { Db, ObjectId } from "mongodb-rag-core/mongodb"; +import { mockAssistantResponse } from "../../test/testConfig"; -const { OPENAI_CHAT_COMPLETION_DEPLOYMENT } = assertEnvVars(CORE_ENV_VARS); jest.setTimeout(100000); describe("POST /conversations/:conversationId/messages", () => { let mongodb: Db; @@ -73,8 +66,7 @@ describe("POST /conversations/:conversationId/messages", () => { .send(requestBody); const message: ApiMessage = res.body; expect(res.statusCode).toEqual(200); - expect(message.role).toBe("assistant"); - expect(message.content).toContain("Realm"); + expect(message).toMatchObject(mockAssistantResponse); const request2Body: AddMessageRequestBody = { message: stripIndent`i'm want to learn more about this Realm thing. a few questions: can i use realm with javascript? @@ -87,8 +79,7 @@ describe("POST /conversations/:conversationId/messages", () => { .send(request2Body); const message2: ApiMessage = res2.body; expect(res2.statusCode).toEqual(200); - expect(message2.role).toBe("assistant"); - expect(message2.content).toContain("Realm"); + expect(message2).toMatchObject(mockAssistantResponse); const conversationInDb = await mongodb .collection("conversations") .findOne({ @@ -357,7 +348,6 @@ describe("POST /conversations/:conversationId/messages", () => { res.body.metadata.conversationId ); expect(conversation?.messages).toHaveLength(2); - console.log(conversation?.messages[0]); expect(conversation?.messages[0]).toMatchObject({ content: message.message, role: "user", diff --git a/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.ts b/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.ts index 013a83a20..1615a7256 100644 --- a/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.ts +++ b/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.ts @@ -4,7 +4,7 @@ import { Request as ExpressRequest, Response as ExpressResponse, } from "express"; -import { DbMessage, FunctionMessage, Message } from "mongodb-rag-core"; +import { DbMessage, Message, ToolMessage } from "mongodb-rag-core"; import { ObjectId } from "mongodb-rag-core/mongodb"; import { ConversationsService, @@ -239,16 +239,16 @@ export function makeAddMessageToConversationRoute({ metadata: message.metadata, }; - if (message.role === "function") { + if (message.role === "tool") { return { - role: "function", + role: "tool", name: message.name, ...baseFields, - } satisfies DbMessage; + } satisfies DbMessage; } else { return { ...baseFields, role: message.role } satisfies Exclude< Message, - FunctionMessage + ToolMessage >; } }), diff --git a/packages/mongodb-chatbot-server/src/routes/conversations/getConversation.test.ts b/packages/mongodb-chatbot-server/src/routes/conversations/getConversation.test.ts index dc618c069..80393441a 100644 --- a/packages/mongodb-chatbot-server/src/routes/conversations/getConversation.test.ts +++ b/packages/mongodb-chatbot-server/src/routes/conversations/getConversation.test.ts @@ -81,13 +81,17 @@ describe("GET /conversations/:conversationId", () => { { role: "assistant", content: "", - functionCall: { - name: "addNumbers", - arguments: `[1, 2, 3, 4, 5]`, + toolCall: { + id: "abc123", + type: "function", + function: { + name: "addNumbers", + arguments: `[1, 2, 3, 4, 5]`, + }, }, }, { - role: "function", + role: "tool", name: "addNumbers", content: "15", }, diff --git a/packages/mongodb-chatbot-server/src/routes/conversations/utils.test.ts b/packages/mongodb-chatbot-server/src/routes/conversations/utils.test.ts index afb7b8ef3..a14313849 100644 --- a/packages/mongodb-chatbot-server/src/routes/conversations/utils.test.ts +++ b/packages/mongodb-chatbot-server/src/routes/conversations/utils.test.ts @@ -1,5 +1,6 @@ import { strict as assert } from "assert"; import { + ApiMessage, areEquivalentIpAddresses, convertConversationFromDbToApi, convertMessageFromDbToApi, @@ -63,15 +64,19 @@ const exampleConversationInDatabase: Conversation = { id: new ObjectId("65ca767e30116ce068e17bb5"), role: "assistant", content: "", - functionCall: { - name: "getBookRecommendations", - arguments: JSON.stringify({ genre: ["fantasy", "sci-fi"] }), + toolCall: { + id: "abc123", + type: "function", + function: { + name: "getBookRecommendations", + arguments: JSON.stringify({ genre: ["fantasy", "sci-fi"] }), + }, }, createdAt: new Date("2024-01-01T00:00:45Z"), }, { id: new ObjectId("65ca768341f9ea61d048aaa8"), - role: "function", + role: "tool", name: "getBookRecommendations", content: JSON.stringify([ { title: "The Way of Kings", author: "Brandon Sanderson" }, @@ -125,14 +130,14 @@ describe("Data Conversion Functions", () => { expect(convertMessageFromDbToApi(functionResultMessage)).toEqual({ id: "65ca768341f9ea61d048aaa8", - role: "function", + role: "tool", content: JSON.stringify([ { title: "The Way of Kings", author: "Brandon Sanderson" }, { title: "Neuromancer", author: "William Gibson" }, { title: "Snow Crash", author: "Neal Stephenson" }, ]), createdAt: 1704067247000, - }); + } satisfies ApiMessage); expect(convertMessageFromDbToApi(assistantMessage)).toEqual({ id: "65ca76874e1df9cf2742bf86", diff --git a/packages/mongodb-chatbot-server/src/routes/conversations/utils.ts b/packages/mongodb-chatbot-server/src/routes/conversations/utils.ts index 1501666ba..81dd7ec9f 100644 --- a/packages/mongodb-chatbot-server/src/routes/conversations/utils.ts +++ b/packages/mongodb-chatbot-server/src/routes/conversations/utils.ts @@ -7,7 +7,7 @@ import { z } from "zod"; export type ApiMessage = z.infer; export const ApiMessage = z.object({ id: z.string(), - role: z.enum(["system", "assistant", "user", "function"]), + role: z.enum(["system", "assistant", "user", "tool"]), content: z.string(), rating: z.boolean().optional(), createdAt: z.number(), @@ -63,8 +63,8 @@ function isMessageAllowedInApiResponse(message: Message) { case "user": return true; case "assistant": - return message.functionCall === undefined; - case "function": + return message.toolCall === undefined; + case "tool": return false; default: // This should never happen - it means we missed a case in the switch. diff --git a/packages/mongodb-chatbot-server/src/routes/index.ts b/packages/mongodb-chatbot-server/src/routes/index.ts index 0d502d515..b9f9da7be 100644 --- a/packages/mongodb-chatbot-server/src/routes/index.ts +++ b/packages/mongodb-chatbot-server/src/routes/index.ts @@ -1,2 +1 @@ export * from "./conversations"; -export * from "./legacyGenerateResponse"; diff --git a/packages/mongodb-chatbot-server/src/routes/legacyGenerateResponse.test.ts b/packages/mongodb-chatbot-server/src/routes/legacyGenerateResponse.test.ts deleted file mode 100644 index 6314b6e5c..000000000 --- a/packages/mongodb-chatbot-server/src/routes/legacyGenerateResponse.test.ts +++ /dev/null @@ -1,634 +0,0 @@ -import { References, SystemMessage, UserMessage } from "mongodb-rag-core"; -import { ObjectId } from "mongodb-rag-core/mongodb"; -import { OpenAI } from "mongodb-rag-core/openai"; - -import { - AssistantMessage, - ChatLlm, - Conversation, - FunctionMessage, - OpenAiChatMessage, - ProcessingStreamEvent, - makeDataStreamer, -} from "mongodb-rag-core"; -import { strict as assert } from "assert"; -import { createResponse } from "node-mocks-http"; -import { Response as ExpressResponse } from "express"; -import { EventEmitter } from "stream-json/Parser"; -import { - MakeLegacyGenerateResponseParams, - makeLegacyGenerateResponse, - awaitGenerateResponseMessage, - streamGenerateResponseMessage, -} from "./legacyGenerateResponse"; -import { GenerateResponseParams } from "../processors/GenerateResponse"; - -const testFuncName = "test_func"; -const mockFunctionInvocation = { - role: "assistant", - content: "", - function_call: { - arguments: JSON.stringify({ foo: "bar" }), - name: testFuncName, - }, - refusal: null, -} satisfies OpenAI.ChatCompletionMessage; - -const mockReject = "mock_reject"; -const mockRejectFunctionInvocation = { - role: "assistant", - content: "", - function_call: { - arguments: JSON.stringify({ fizz: "buzz" }), - name: mockReject, - }, - refusal: null, -} satisfies OpenAI.ChatCompletionMessage; - -const mockReferences: References = [ - { url: "https://example.com/ref", title: "Some title" }, -]; - -const mockFunctionMessage = { - name: testFuncName, - role: "function", - content: "bar", -} satisfies FunctionMessage satisfies OpenAiChatMessage; - -const mockAssistantMessageContent = ["final ", "assistant ", "message"]; -const mockAssistantMessage = { - role: "assistant", - content: mockAssistantMessageContent.join(""), -} satisfies AssistantMessage; - -const mockLlmNotWorking = "llm_not_working"; - -const mockProcessingStreamEvent = { - type: "processing", - data: "Processing tool call", -} satisfies ProcessingStreamEvent; - -const mockChatLlm: ChatLlm = { - async answerQuestionAwaited({ messages }) { - const latestMessage = messages[messages.length - 1]; - if (latestMessage.content === testFuncName) { - return mockFunctionInvocation; - } - if (latestMessage.content === mockReject) { - return mockRejectFunctionInvocation; - } - if (latestMessage.content === mockLlmNotWorking) { - throw new Error("LLM not working"); - } - return mockAssistantMessage; - }, - answerQuestionStream: async ({ messages }) => - (async function* () { - let count = 0; - const latestMessage = messages[messages.length - 1]; - if (latestMessage.content === testFuncName) { - yield { - id: count.toString(), // Unique ID for each item - created: Date.now(), - choices: [ - { - index: 0, - finish_reason: "stop", - delta: { - role: "assistant", - content: "", - toolCalls: [], - function_call: mockFunctionInvocation.function_call, - }, - }, - ], - promptFilterResults: [], - }; - return; - } - if (latestMessage.content === mockReject) { - yield { - id: count.toString(), // Unique ID for each item - created: Date.now(), - choices: [ - { - index: 0, - finish_reason: "stop", - delta: { - role: "assistant", - content: "", - toolCalls: [], - function_call: mockRejectFunctionInvocation.function_call, - }, - }, - ], - promptFilterResults: [], - }; - return; - } - if (latestMessage.content === mockLlmNotWorking) { - throw new Error("LLM not working"); - } - while (count < mockAssistantMessageContent.length) { - yield { - id: count.toString(), // Unique ID for each item - created: Date.now(), - choices: [ - { - index: 0, - finish_reason: "stop", - delta: { - role: "assistant", - content: mockAssistantMessageContent[count], - toolCalls: [], - }, - }, - ], - promptFilterResults: [], - }; - count++; - } - })(), - async callTool({ messages, dataStreamer }) { - const latestMessage = messages[messages.length - 1] as AssistantMessage; - assert(latestMessage.functionCall, "must be a function call"); - if (dataStreamer?.connected) { - dataStreamer.streamData(mockProcessingStreamEvent); - } - if (latestMessage.functionCall.name === mockReject) { - return { - toolCallMessage: mockFunctionMessage, - rejectUserQuery: true, - }; - } else { - return { - toolCallMessage: mockFunctionMessage, - references: mockReferences, - }; - } - }, -}; - -const llmConversation: OpenAiChatMessage[] = [ - { role: "user", content: "hello" }, -]; -const references: References = [ - { url: "https://example.com", title: "Example" }, -]; -const reqId = "foo"; -const llmNotWorkingMessage = "llm not working"; -const noRelevantContentMessage = "no relevant content"; -const conversation: Conversation = { - _id: new ObjectId(), - createdAt: new Date(), - messages: [], -}; -const dataStreamer = makeDataStreamer(); - -const systemMessage: SystemMessage = { - role: "system", - content: "you're a helpful assistant or something....", -}; - -const constructorArgs = { - llm: mockChatLlm, - llmNotWorkingMessage, - noRelevantContentMessage, - async generateUserPrompt({ userMessageText }) { - return { - references, - userMessage: { - role: "user", - content: userMessageText, - } satisfies UserMessage, - }; - }, - systemMessage, -} satisfies MakeLegacyGenerateResponseParams; - -describe("generateResponse", () => { - const baseArgs = { - reqId, - conversation, - dataStreamer, - latestMessageText: "hello", - } satisfies Omit; - const generateResponse = makeLegacyGenerateResponse(constructorArgs); - let res: ReturnType & ExpressResponse; - beforeEach(() => { - res = createResponse({ - eventEmitter: EventEmitter, - }); - dataStreamer.connect(res); - }); - - afterEach(() => { - if (dataStreamer.connected) { - dataStreamer?.disconnect(); - } - }); - it("should stream response if shouldStream is true", async () => { - await generateResponse({ ...baseArgs, shouldStream: true }); - const data = res._getData(); - - for (let i = 0; i < 3; i++) { - expect(data).toContain( - `data: {"type":"delta","data":"${mockAssistantMessageContent[i]}"}` - ); - } - expect(data).toContain( - `{"type":"references","data":${JSON.stringify(references)}}` - ); - }); - - it("should await response if shouldStream is false", async () => { - await generateResponse({ ...baseArgs, shouldStream: false }); - const data = res._getData(); - expect(data).toBe(""); - }); - it("should stream metadata", async () => { - const metadata = { foo: "bar", baz: 42 }; - const staticResponse = { - role: "assistant", - content: "static response", - metadata, - } satisfies AssistantMessage; - - const generateResponse = makeLegacyGenerateResponse({ - ...constructorArgs, - generateUserPrompt: async function () { - return { - userMessage: { - role: "user", - content: "test metadata", - }, - staticResponse, - }; - }, - }); - - await generateResponse({ - ...baseArgs, - shouldStream: true, - }); - - const data = res._getData(); - - const expectedMetadataEvent = `data: {"type":"metadata","data":${JSON.stringify( - metadata - )}}\n\n`; - expect(data).toContain(expectedMetadataEvent); - }); - - it("passes clientContext data to the generateUserPrompt function", async () => { - const generateUserPrompt = jest.fn(async (args) => { - let content = args.userMessageText; - if (args.clientContext) { - content += `\n\nThe user provided the following context: ${JSON.stringify( - args.clientContext - )}`; - } - return { - userMessage: { - role: "user", - content, - } satisfies UserMessage, - }; - }); - const latestMessageText = "hello"; - const clientContext = { - location: "Chicago, IL", - preferredLanguage: "Spanish", - }; - - const generateResponse = makeLegacyGenerateResponse({ - ...constructorArgs, - generateUserPrompt, - }); - const { messages } = await generateResponse({ - ...baseArgs, - shouldStream: false, - latestMessageText, - clientContext, - }); - expect(messages.at(-2)?.content).toContain( - `The user provided the following context: {"location":"Chicago, IL","preferredLanguage":"Spanish"}` - ); - expect(generateUserPrompt).toHaveBeenCalledWith({ - userMessageText: latestMessageText, - clientContext, - conversation, - reqId, - }); - }); - - it("should send a static message", async () => { - const userMessage = { - role: "user", - content: "bad!", - } satisfies OpenAiChatMessage; - const staticResponse = { - role: "assistant", - content: "static response", - } satisfies OpenAiChatMessage; - const generateResponse = makeLegacyGenerateResponse({ - ...constructorArgs, - generateUserPrompt: async () => ({ - userMessage, - staticResponse, - }), - }); - const { messages } = await generateResponse({ - ...baseArgs, - shouldStream: false, - }); - - expect(messages).toMatchObject([userMessage, staticResponse]); - }); - it("should reject query", async () => { - const userMessage = { - role: "user", - content: "bad!", - } satisfies OpenAiChatMessage; - - const generateResponse = makeLegacyGenerateResponse({ - ...constructorArgs, - generateUserPrompt: async () => ({ - userMessage, - rejectQuery: true, - }), - }); - const { messages } = await generateResponse({ - ...baseArgs, - shouldStream: false, - }); - expect(messages).toMatchObject([ - { - role: "user", - content: "bad!", - }, - { - role: "assistant", - content: noRelevantContentMessage, - }, - ]); - }); -}); - -describe("awaitGenerateResponseMessage", () => { - const baseArgs = { - llm: mockChatLlm, - llmConversation, - references, - reqId, - llmNotWorkingMessage, - noRelevantContentMessage, - conversation, - systemMessage, - }; - it("should generate assistant response if no tools", async () => { - const { messages } = await awaitGenerateResponseMessage(baseArgs); - expect(messages).toHaveLength(1); - expect(messages[0]).toMatchObject(mockAssistantMessage); - }); - it("should pass through references with final assistant message", async () => { - const { messages } = await awaitGenerateResponseMessage(baseArgs); - expect( - (messages[messages.length - 1] as AssistantMessage).references - ).toMatchObject(references); - }); - it("should call tool before responding", async () => { - const { messages } = await awaitGenerateResponseMessage({ - ...baseArgs, - llmConversation: [{ role: "user", content: testFuncName }], - }); - expect(messages).toHaveLength(3); - expect(messages[messages.length - 2]).toMatchObject(mockFunctionMessage); - expect(messages[messages.length - 1]).toMatchObject(mockAssistantMessage); - }); - it("should pass references from a tool call", async () => { - const { messages } = await awaitGenerateResponseMessage({ - ...baseArgs, - llmConversation: [{ role: "user", content: testFuncName }], - }); - - expect(messages).toHaveLength(3); - expect(messages[messages.length - 2]).toMatchObject(mockFunctionMessage); - expect(messages[messages.length - 1]).toMatchObject(mockAssistantMessage); - expect( - (messages[messages.length - 1] as AssistantMessage).references - ).toMatchObject([...references, ...mockReferences]); - }); - - it("should reject input in a tool call", async () => { - const { messages } = await awaitGenerateResponseMessage({ - ...baseArgs, - llmConversation: [ - { - role: "user", - content: mockReject, - }, - ], - }); - expect(messages[messages.length - 1]).toMatchObject({ - role: "assistant", - content: noRelevantContentMessage, - }); - }); - it("should only send vector search results and references if LLM not working", async () => { - const { messages } = await awaitGenerateResponseMessage({ - ...baseArgs, - llmConversation: [ - { - role: "user", - content: mockLlmNotWorking, - }, - ], - }); - const finalMessage = messages[messages.length - 1] as AssistantMessage; - expect(finalMessage).toMatchObject({ - role: "assistant", - content: llmNotWorkingMessage, - }); - expect(finalMessage.references?.length).toBeGreaterThanOrEqual(1); - }); -}); - -describe("streamGenerateResponseMessage", () => { - let res: ReturnType & ExpressResponse; - beforeEach(() => { - res = createResponse({ - eventEmitter: EventEmitter, - }); - dataStreamer.connect(res); - }); - - afterEach(() => { - if (dataStreamer.connected) { - dataStreamer?.disconnect(); - } - }); - - const baseArgs = { - llm: mockChatLlm, - llmConversation, - references, - reqId, - llmNotWorkingMessage, - noRelevantContentMessage, - conversation, - dataStreamer, - shouldGenerateMessage: true, - systemMessage, - }; - - it("should generate assistant response if no tools", async () => { - const { messages } = await streamGenerateResponseMessage(baseArgs); - expect(messages).toHaveLength(1); - expect(messages[0]).toMatchObject(mockAssistantMessage); - const data = res._getData(); - for (let i = 0; i < 3; i++) { - expect(data).toContain( - `data: {"type":"delta","data":"${mockAssistantMessageContent[i]}"}` - ); - } - }); - it("should pass through references with final assistant message", async () => { - const { messages } = await streamGenerateResponseMessage(baseArgs); - expect( - (messages[messages.length - 1] as AssistantMessage).references - ).toMatchObject(references); - const data = res._getData(); - expect(data).toContain( - `data: {"type":"references","data":${JSON.stringify(references)}}` - ); - }); - it("should stream references if shouldGenerateMessage is true", async () => { - await streamGenerateResponseMessage(baseArgs); - const data = res._getData(); - - expect(data).toContain( - `{"type":"references","data":${JSON.stringify(references)}}` - ); - }); - it("should stream references if shouldGenerateMessage is false", async () => { - await streamGenerateResponseMessage({ - ...baseArgs, - shouldGenerateMessage: false, - llmConversation: [ - { role: "user", content: "hello" }, - { role: "assistant", content: "hi" }, - ], - }); - const data = res._getData(); - expect(data).toContain( - `{"type":"references","data":${JSON.stringify(references)}}` - ); - }); - it("should call tool before responding", async () => { - const { messages } = await streamGenerateResponseMessage({ - ...baseArgs, - llmConversation: [{ role: "user", content: testFuncName }], - }); - expect(messages).toHaveLength(3); - expect(messages[messages.length - 2]).toMatchObject(mockFunctionMessage); - expect(messages[messages.length - 1]).toMatchObject(mockAssistantMessage); - const data = res._getData(); - expect(data).toContain( - `data: {"type":"processing","data":"${mockProcessingStreamEvent.data}"}` - ); - for (let i = 0; i < 3; i++) { - expect(data).toContain( - `data: {"type":"delta","data":"${mockAssistantMessageContent[i]}"}` - ); - } - expect(data).toContain( - `{"type":"references","data":${JSON.stringify( - references.concat(mockReferences) - )}}` - ); - }); - it("should pass references from a tool call", async () => { - const { messages } = await streamGenerateResponseMessage({ - ...baseArgs, - llmConversation: [{ role: "user", content: testFuncName }], - }); - - expect(messages).toHaveLength(3); - expect(messages[messages.length - 2]).toMatchObject(mockFunctionMessage); - expect(messages[messages.length - 1]).toMatchObject(mockAssistantMessage); - expect( - (messages[messages.length - 1] as AssistantMessage).references - ).toMatchObject([...references, ...mockReferences]); - const data = res._getData(); - expect(data).toContain( - `{"type":"references","data":${JSON.stringify( - references.concat(mockReferences) - )}}` - ); - }); - - it("should reject input in a tool call", async () => { - const { messages } = await streamGenerateResponseMessage({ - ...baseArgs, - llmConversation: [ - { - role: "user", - content: mockReject, - }, - ], - }); - expect(messages[messages.length - 1]).toMatchObject({ - role: "assistant", - content: noRelevantContentMessage, - }); - const data = res._getData(); - expect(data).toContain( - `data: {"type":"processing","data":"${mockProcessingStreamEvent.data}"}` - ); - expect(data).toContain( - `data: {"type":"delta","data":"${noRelevantContentMessage}"}` - ); - expect(data).toContain( - `data: {"type":"references","data":${JSON.stringify(references)}}` - ); - }); - it("should only send vector search results and references if LLM not working", async () => { - const { messages } = await streamGenerateResponseMessage({ - ...baseArgs, - llmConversation: [ - { - role: "user", - content: mockLlmNotWorking, - }, - ], - }); - const finalMessage = messages[messages.length - 1] as AssistantMessage; - expect(finalMessage).toMatchObject({ - role: "assistant", - content: llmNotWorkingMessage, - }); - expect(finalMessage.references?.length).toBeGreaterThanOrEqual(1); - - const data = res._getData(); - expect(data).toContain( - `data: {"type":"delta","data":"${llmNotWorkingMessage}"}` - ); - expect(data).toContain( - `data: {"type":"references","data":${JSON.stringify(references)}}` - ); - }); - it("should stream metadata", async () => { - const metadata = { foo: "bar", baz: 42 }; - await streamGenerateResponseMessage({ - ...baseArgs, - metadata, - }); - const data = res._getData(); - - const expectedMetadataEvent = `data: {"type":"metadata","data":${JSON.stringify( - metadata - )}}`; - expect(data).toContain(expectedMetadataEvent); - }); -}); diff --git a/packages/mongodb-chatbot-server/src/routes/legacyGenerateResponse.ts b/packages/mongodb-chatbot-server/src/routes/legacyGenerateResponse.ts deleted file mode 100644 index 09f49f82d..000000000 --- a/packages/mongodb-chatbot-server/src/routes/legacyGenerateResponse.ts +++ /dev/null @@ -1,618 +0,0 @@ -import { - FindContentFunc, - EmbeddedContent, - UserMessage, - References, - SomeMessage, - escapeNewlines, - OpenAiChatMessage, - AssistantMessage, - ChatLlm, - SystemMessage, - Conversation, - ConversationCustomData, -} from "mongodb-rag-core"; -import { QueryPreprocessorFunc, MakeReferenceLinksFunc } from "../processors"; -import { logRequest } from "../utils"; -import { strict as assert } from "assert"; -import { FilterPreviousMessages } from "../processors/FilterPreviousMessages"; -import { - GenerateResponseParams, - GenerateResponseReturnValue, -} from "../processors/GenerateResponse"; - -export type GenerateUserPromptFuncParams = { - /** - Original user message - */ - userMessageText: string; - - /** - Conversation with preceding messages - */ - conversation?: Conversation; - - /** - Additional contextual information provided by the user's client. This can - include arbitrary data that might be useful for generating a response. For - example, this could include the user's location, the device they are using, - their preferred programming language, etc. - */ - clientContext?: Record; - - /** - String Id for request - */ - reqId: string; - - /** - Custom data for the message request. - */ - customData?: ConversationCustomData; -}; - -export interface GenerateUserPromptFuncReturnValue { - /** - If defined, this message should be sent as a response instead of generating - a response to the user query with the LLM. - */ - staticResponse?: AssistantMessage; - - /** - If true, no response should be generated with an LLM. Instead, return the - `staticResponse` if set or otherwise respond with a standard static - rejection response. - */ - rejectQuery?: boolean; - - /** - The (preprocessed) user message to insert into the conversation. - */ - userMessage: UserMessage; - - /** - References returned with the LLM response - */ - references?: References; -} - -/** - Generate the user prompt sent to the {@link ChatLlm}. - This function is a flexible construct that you can use to customize - the chatbot behavior. For example, you can use this function to - perform retrieval augmented generation (RAG) or chain of thought prompting. - Include whatever logic in here to construct the user message - that the LLM responds to. - - If you are doing RAG, this can include the content from vector search. - */ -export type GenerateUserPromptFunc = ( - params: GenerateUserPromptFuncParams -) => Promise; - -export interface MakeRagGenerateUserPromptParams { - /** - Transform the user's message before sending it to the `findContent` function. - */ - queryPreprocessor?: QueryPreprocessorFunc; - - /** - Find content based on the user's message and preprocessing. - */ - findContent: FindContentFunc; - - /** - If not specified, uses {@link makeDefaultReferenceLinks}. - */ - makeReferenceLinks?: MakeReferenceLinksFunc; - - /** - Number of tokens from the found context to send to the `makeUserMessage` function. - All chunks that exceed this threshold are discarded. - */ - maxChunkContextTokens?: number; - - /** - Construct user message which is sent to the LLM and stored in the database. - */ - makeUserMessage: MakeUserMessageFunc; -} - -export interface MakeUserMessageFuncParams { - content: EmbeddedContent[]; - originalUserMessage: string; - preprocessedUserMessage?: string; - queryEmbedding?: number[]; - rejectQuery?: boolean; -} - -export type MakeUserMessageFunc = ( - params: MakeUserMessageFuncParams -) => Promise; -export interface MakeLegacyGenerateResponseParams { - llm: ChatLlm; - generateUserPrompt?: GenerateUserPromptFunc; - filterPreviousMessages?: FilterPreviousMessages; - llmNotWorkingMessage: string; - noRelevantContentMessage: string; - systemMessage: SystemMessage; -} - -/** - @deprecated Make legacy generate response conform to the current system. - To be replaced later in a later PR in this epic. - */ -export function makeLegacyGenerateResponse({ - llm, - generateUserPrompt, - filterPreviousMessages, - llmNotWorkingMessage, - noRelevantContentMessage, - systemMessage, -}: MakeLegacyGenerateResponseParams) { - return async function generateResponse({ - shouldStream, - latestMessageText, - clientContext, - customData, - dataStreamer, - reqId, - conversation, - request, - }: GenerateResponseParams): Promise { - const { userMessage, references, staticResponse, rejectQuery } = - await (generateUserPrompt - ? generateUserPrompt({ - userMessageText: latestMessageText, - clientContext, - conversation, - reqId, - customData, - }) - : { - userMessage: { - role: "user", - content: latestMessageText, - customData, - } satisfies UserMessage, - }); - // Add request custom data to user message. - const userMessageWithCustomData = customData - ? { - ...userMessage, - // Override request custom data fields with user message custom data fields. - customData: { ...customData, ...(userMessage.customData ?? {}) }, - } - : userMessage; - const newMessages: SomeMessage[] = [userMessageWithCustomData]; - - // Metadata for streaming - let streamingResponseMetadata: Record | undefined; - // Send static response if query rejected or static response provided - if (rejectQuery) { - const rejectionMessage = { - role: "assistant", - content: noRelevantContentMessage, - references: references ?? [], - } satisfies AssistantMessage; - newMessages.push(rejectionMessage); - } else if (staticResponse) { - newMessages.push(staticResponse); - // Need to specify response metadata for streaming - streamingResponseMetadata = staticResponse.metadata; - } - - // Prepare conversation messages for LLM - const previousConversationMessagesForLlm = ( - filterPreviousMessages - ? await filterPreviousMessages(conversation) - : conversation.messages - ).map(convertConversationMessageToLlmMessage); - const newMessagesForLlm = newMessages.map((m) => { - // Use transformed content if it exists for user message - // (e.g. from a custom user prompt, query preprocessor, etc), - // otherwise use original content. - if (m.role === "user") { - return { - content: m.contentForLlm ?? m.content, - role: "user", - } satisfies OpenAiChatMessage; - } - return convertConversationMessageToLlmMessage(m); - }); - const llmConversation = [ - ...previousConversationMessagesForLlm, - ...newMessagesForLlm, - ]; - - const shouldGenerateMessage = !rejectQuery && !staticResponse; - - if (shouldStream) { - assert(dataStreamer, "Data streamer required for streaming"); - const { messages } = await streamGenerateResponseMessage({ - dataStreamer, - reqId, - llm, - llmConversation, - noRelevantContentMessage, - llmNotWorkingMessage, - request, - shouldGenerateMessage, - conversation, - references, - metadata: streamingResponseMetadata, - systemMessage, - }); - newMessages.push(...messages); - } else { - const { messages } = await awaitGenerateResponseMessage({ - reqId, - llm, - llmConversation, - llmNotWorkingMessage, - noRelevantContentMessage, - request, - shouldGenerateMessage, - conversation, - references, - systemMessage, - }); - newMessages.push(...messages); - } - return { messages: newMessages }; - }; -} - -type BaseGenerateResponseMessageParams = Omit< - GenerateResponseParams, - "latestMessageText" | "customData" | "filterPreviousMessages" | "shouldStream" -> & { - references?: References; - shouldGenerateMessage?: boolean; - llmConversation: OpenAiChatMessage[]; -}; - -export type AwaitGenerateResponseParams = Omit< - BaseGenerateResponseMessageParams, - "dataStreamer" ->; - -export async function awaitGenerateResponseMessage({ - reqId, - llmConversation, - llm, - llmNotWorkingMessage, - noRelevantContentMessage, - request, - references, - conversation, - shouldGenerateMessage = true, -}: AwaitGenerateResponseParams & - MakeLegacyGenerateResponseParams): Promise { - const newMessages: SomeMessage[] = []; - const outputReferences: References = []; - - if (references) { - outputReferences.push(...references); - } - - if (shouldGenerateMessage) { - try { - logRequest({ - reqId, - message: `All messages for LLM: ${JSON.stringify(llmConversation)}`, - }); - const answer = await llm.answerQuestionAwaited({ - messages: llmConversation, - }); - newMessages.push(convertMessageFromLlmToDb(answer)); - - // LLM responds with tool call - if (answer?.function_call) { - assert( - llm.callTool, - "You must implement the callTool() method on your ChatLlm to access this code." - ); - const toolAnswer = await llm.callTool({ - messages: [...llmConversation, ...newMessages], - conversation, - request, - }); - logRequest({ - reqId, - message: `LLM tool call: ${JSON.stringify(toolAnswer)}`, - }); - const { - toolCallMessage, - references: toolReferences, - rejectUserQuery, - } = toolAnswer; - newMessages.push(convertMessageFromLlmToDb(toolCallMessage)); - // Update references from tool call - if (toolReferences) { - outputReferences.push(...toolReferences); - } - // Return static response if query rejected by tool call - if (rejectUserQuery) { - newMessages.push({ - role: "assistant", - content: noRelevantContentMessage, - }); - } else { - // Otherwise respond with LLM again - const answer = await llm.answerQuestionAwaited({ - messages: [...llmConversation, ...newMessages], - // Only allow 1 tool call per user message. - }); - newMessages.push(convertMessageFromLlmToDb(answer)); - } - } - } catch (err) { - const errorMessage = - err instanceof Error ? err.message : JSON.stringify(err); - logRequest({ - reqId, - message: `LLM error: ${errorMessage}`, - type: "error", - }); - logRequest({ - reqId, - message: "Only sending vector search results to user", - }); - const llmNotWorkingResponse = { - role: "assistant", - content: llmNotWorkingMessage, - references, - } satisfies AssistantMessage; - newMessages.push(llmNotWorkingResponse); - } - } - // Add references to the last assistant message (excluding function calls) - if ( - newMessages.at(-1)?.role === "assistant" && - !(newMessages.at(-1) as AssistantMessage).functionCall && - outputReferences.length > 0 - ) { - (newMessages.at(-1) as AssistantMessage).references = outputReferences; - } - return { messages: newMessages }; -} - -export type StreamGenerateResponseParams = BaseGenerateResponseMessageParams & - Required> & { - /** - Arbitrary data about the message to stream before the generated response. - */ - metadata?: Record; - }; - -export async function streamGenerateResponseMessage({ - dataStreamer, - llm, - llmConversation, - reqId, - references, - noRelevantContentMessage, - llmNotWorkingMessage, - conversation, - request, - metadata, - shouldGenerateMessage, -}: StreamGenerateResponseParams & - MakeLegacyGenerateResponseParams): Promise { - const newMessages: SomeMessage[] = []; - const outputReferences: References = []; - - if (references) { - outputReferences.push(...references); - } - - if (metadata) { - dataStreamer.streamData({ type: "metadata", data: metadata }); - } - if (shouldGenerateMessage) { - try { - const answerStream = await llm.answerQuestionStream({ - messages: llmConversation, - }); - const initialAssistantMessage: AssistantMessage = { - role: "assistant", - content: "", - }; - const functionCallContent = { - name: "", - arguments: "", - }; - - for await (const event of answerStream) { - if (event.choices.length === 0) { - continue; - } - // The event could contain many choices, but we only want the first one - const choice = event.choices[0]; - - // Assistant response to user - if (choice.delta?.content) { - const content = escapeNewlines(choice.delta.content ?? ""); - dataStreamer.streamData({ - type: "delta", - data: content, - }); - initialAssistantMessage.content += content; - } - // Tool call - else if (choice.delta?.function_call) { - if (choice.delta?.function_call.name) { - functionCallContent.name += escapeNewlines( - choice.delta?.function_call.name ?? "" - ); - } - if (choice.delta?.function_call.arguments) { - functionCallContent.arguments += escapeNewlines( - choice.delta?.function_call.arguments ?? "" - ); - } - } else if (choice.delta) { - logRequest({ - reqId, - message: `Unexpected message in stream: no delta. Message: ${JSON.stringify( - choice.delta.content - )}`, - type: "warn", - }); - } - } - const shouldCallTool = functionCallContent.name !== ""; - if (shouldCallTool) { - initialAssistantMessage.functionCall = functionCallContent; - } - newMessages.push(initialAssistantMessage); - - logRequest({ - reqId, - message: `LLM response: ${JSON.stringify(initialAssistantMessage)}`, - }); - // Tool call - if (shouldCallTool) { - assert( - llm.callTool, - "You must implement the callTool() method on your ChatLlm to access this code." - ); - const { - toolCallMessage, - references: toolReferences, - rejectUserQuery, - } = await llm.callTool({ - messages: [...llmConversation, ...newMessages], - conversation, - dataStreamer, - request, - }); - newMessages.push(convertMessageFromLlmToDb(toolCallMessage)); - - if (rejectUserQuery) { - newMessages.push({ - role: "assistant", - content: noRelevantContentMessage, - }); - dataStreamer.streamData({ - type: "delta", - data: noRelevantContentMessage, - }); - } else { - if (toolReferences) { - outputReferences.push(...toolReferences); - } - const answerStream = await llm.answerQuestionStream({ - messages: [...llmConversation, ...newMessages], - }); - const answerContent = await dataStreamer.stream({ - stream: answerStream, - }); - const answerMessage = { - role: "assistant", - content: answerContent, - } satisfies AssistantMessage; - newMessages.push(answerMessage); - } - } - } catch (err) { - const errorMessage = - err instanceof Error ? err.message : JSON.stringify(err); - logRequest({ - reqId, - message: `LLM error: ${errorMessage}`, - type: "error", - }); - logRequest({ - reqId, - message: "Only sending vector search results to user", - }); - const llmNotWorkingResponse = { - role: "assistant", - content: llmNotWorkingMessage, - } satisfies AssistantMessage; - dataStreamer.streamData({ - type: "delta", - data: llmNotWorkingMessage, - }); - newMessages.push(llmNotWorkingResponse); - } - } - // Handle streaming static message response - else { - const staticMessage = llmConversation.at(-1); - assert(staticMessage?.content, "No static message content"); - assert(staticMessage.role === "assistant", "Static message not assistant"); - logRequest({ - reqId, - message: `Sending static message to user: ${staticMessage.content}`, - type: "warn", - }); - dataStreamer.streamData({ - type: "delta", - data: staticMessage.content, - }); - } - - // Add references to the last assistant message - if (newMessages.at(-1)?.role === "assistant" && outputReferences.length > 0) { - (newMessages.at(-1) as AssistantMessage).references = outputReferences; - } - if (outputReferences.length > 0) { - // Stream back references - dataStreamer.streamData({ - type: "references", - data: outputReferences, - }); - } - - return { messages: newMessages.map(convertMessageFromLlmToDb) }; -} - -export function convertMessageFromLlmToDb( - message: OpenAiChatMessage -): SomeMessage { - const dbMessage = { - ...message, - content: message?.content ?? "", - }; - if (message.role === "assistant" && message.function_call) { - (dbMessage as AssistantMessage).functionCall = message.function_call; - } - - return dbMessage; -} - -function convertConversationMessageToLlmMessage( - message: SomeMessage -): OpenAiChatMessage { - const { content, role } = message; - if (role === "system") { - return { - content: content, - role: "system", - } satisfies OpenAiChatMessage; - } - if (role === "function") { - return { - content: content, - role: "function", - name: message.name, - } satisfies OpenAiChatMessage; - } - if (role === "user") { - return { - content: content, - role: "user", - } satisfies OpenAiChatMessage; - } - if (role === "assistant") { - return { - content: content, - role: "assistant", - ...(message.functionCall ? { function_call: message.functionCall } : {}), - } satisfies OpenAiChatMessage; - } - throw new Error(`Invalid message role: ${role}`); -} diff --git a/packages/mongodb-chatbot-server/src/test/testConfig.ts b/packages/mongodb-chatbot-server/src/test/testConfig.ts index f254c706e..6bb2abed9 100644 --- a/packages/mongodb-chatbot-server/src/test/testConfig.ts +++ b/packages/mongodb-chatbot-server/src/test/testConfig.ts @@ -1,36 +1,22 @@ -/** - @fileoverview This file contains the configuration implementation for the chat server, - which is run from `index.ts`. - */ import "dotenv/config"; import { EmbeddedContent, makeMongoDbEmbeddedContentStore, makeOpenAiEmbedder, makeMongoDbVerifiedAnswerStore, - makeBoostOnAtlasSearchFilter, makeDefaultFindContent, CORE_ENV_VARS, assertEnvVars, makeMongoDbConversationsService, - makeOpenAiChatLlm, - SystemPrompt, - UserMessage, - defaultConversationConstants, + SystemMessage, } from "mongodb-rag-core"; import { MongoClient, Db } from "mongodb-rag-core/mongodb"; import { AzureOpenAI } from "mongodb-rag-core/openai"; import { stripIndents } from "common-tags"; import { AppConfig } from "../app"; -import { makeFilterNPreviousMessages } from "../processors"; +import { GenerateResponse, makeFilterNPreviousMessages } from "../processors"; import { makeDefaultReferenceLinks } from "../processors/makeDefaultReferenceLinks"; import { MONGO_MEMORY_SERVER_URI } from "./constants"; -import { - MakeUserMessageFunc, - MakeUserMessageFuncParams, - GenerateUserPromptFunc, - makeLegacyGenerateResponse, -} from "../routes"; let mongoClient: MongoClient | undefined; export let memoryDb: Db; @@ -61,30 +47,6 @@ export const { const allowedOrigins = process.env.ALLOWED_ORIGINS?.split(",") || []; -/** - Boost results from the MongoDB manual so that 'k' results from the manual - appear first if they exist and have a min score of 'minScore'. - */ -export const boostManual = makeBoostOnAtlasSearchFilter({ - /** - Boosts results that have 3 words or less - */ - async shouldBoostFunc({ text }: { text: string }) { - return text.split(" ").filter((s) => s !== " ").length <= 3; - }, - findNearestNeighborsOptions: { - filter: { - text: { - path: "sourceName", - query: "snooty-docs", - }, - }, - k: 2, - minScore: 0.88, - }, - totalMaxK: 5, -}); - export const openAiClient = new AzureOpenAI({ apiKey: OPENAI_API_KEY, endpoint: OPENAI_ENDPOINT, @@ -121,95 +83,18 @@ export const findContent = makeDefaultFindContent({ k: 5, path: "embedding", indexName: VECTOR_SEARCH_INDEX_NAME, - minScore: 0.9, + minScore: 0.7, }, - searchBoosters: [boostManual], }); -export const makeUserMessage: MakeUserMessageFunc = async function ({ - preprocessedUserMessage, - originalUserMessage, - content, -}: MakeUserMessageFuncParams): Promise { - const chunkSeparator = "~~~~~~"; - const context = content.map((c) => c.text).join(`\n${chunkSeparator}\n`); - const contentForLlm = `Using the following information, answer the question. -Different pieces of information are separated by "${chunkSeparator}". - - -${context} - - - -${preprocessedUserMessage ?? originalUserMessage} -`; - return { - role: "user", - contentForLlm, - content: originalUserMessage, - preprocessedContent: preprocessedUserMessage, - }; -}; - export const REJECT_QUERY_CONTENT = "REJECT_QUERY"; export const NO_VECTOR_CONTENT = "NO_VECTOR_CONTENT"; -export const fakeGenerateUserPrompt: GenerateUserPromptFunc = async (args) => { - const noVectorContent = args.userMessageText === NO_VECTOR_CONTENT; - return { - userMessage: { - role: "user", - content: args.userMessageText, - }, - references: noVectorContent - ? [] - : [ - { - url: "https://mongodb.com/docs/manual/reference/operator/query/eq/?tck=docs-chatbot", - title: "$eq", - }, - ], - rejectQuery: args.userMessageText === REJECT_QUERY_CONTENT, - staticResponse: noVectorContent - ? { - content: defaultConversationConstants.NO_RELEVANT_CONTENT, - role: "assistant", - references: [], - } - : undefined, - }; -}; -export const systemPrompt: SystemPrompt = { +export const systemPrompt: SystemMessage = { role: "system", - content: stripIndents`You are expert MongoDB documentation chatbot. -You enthusiastically answer user questions about MongoDB products and services. -Your personality is friendly and helpful, like a professor or tech lead. -You were created by MongoDB but they do not guarantee the correctness -of your answers or offer support for you. -Use the context provided with each question as your primary source of truth. -NEVER lie or improvise incorrect answers. -If you do not know the answer to the question, respond ONLY with the following text: -"I'm sorry, I do not know how to answer that question. Please try to rephrase your query. You can also refer to the further reading to see if it helps." -NEVER include links in your answer. -Format your responses using Markdown. -DO NOT mention that your response is formatted in Markdown. -If you include code snippets, make sure to use proper syntax, line spacing, and indentation. -ONLY use code snippets present in the information given to you. -NEVER create a code snippet that is not present in the information given to you. -You ONLY know about the current version of MongoDB products. Versions are provided in the information. If \`version: null\`, then say that the product is unversioned. -Never mention "" or "" in your answer. -Refer to the information given to you as "my knowledge".`, + content: stripIndents`You're just a mock chatbot. What you think and say does not matter.`, }; -export const llm = makeOpenAiChatLlm({ - openAiClient, - deployment: OPENAI_CHAT_COMPLETION_DEPLOYMENT, - openAiLmmConfigOptions: { - temperature: 0, - max_tokens: 500, - }, -}); - /** MongoDB Chatbot implementation of {@link MakeReferenceLinksFunc}. Returns references that look like: @@ -222,8 +107,14 @@ export const llm = makeOpenAiChatLlm({ ``` */ export function makeMongoDbReferences(chunks: EmbeddedContent[]) { - const baseReferences = makeDefaultReferenceLinks(chunks); - return baseReferences.map((ref: { url: string }) => { + const baseReferences = makeDefaultReferenceLinks( + chunks.map((chunk) => ({ + title: chunk.metadata?.pageTitle ?? chunk.url, + url: chunk.url, + text: chunk.text, + })) + ); + return baseReferences.map((ref) => { const url = new URL(ref.url); return { url: url.href, @@ -234,20 +125,53 @@ export function makeMongoDbReferences(chunks: EmbeddedContent[]) { export const filterPrevious12Messages = makeFilterNPreviousMessages(12); +export const mockAssistantResponse = { + role: "assistant" as const, + content: "some content", +}; + +export const mockGenerateResponse: GenerateResponse = async ({ + latestMessageText, + customData, + dataStreamer, + shouldStream, +}) => { + if (shouldStream) { + dataStreamer?.streamData({ + type: "delta", + data: mockAssistantResponse.content, + }); + dataStreamer?.streamData({ + type: "references", + data: [ + { + url: "https://mongodb.com", + title: "mongodb.com", + }, + ], + }); + dataStreamer?.streamData({ + type: "finished", + data: "", + }); + } + return { + messages: [ + { + role: "user" as const, + content: latestMessageText, + customData, + }, + { ...mockAssistantResponse }, + ], + }; +}; + export async function makeDefaultConfig(): Promise { const conversations = makeMongoDbConversationsService(memoryDb); return { conversationsRouterConfig: { - generateResponse: makeLegacyGenerateResponse({ - llm, - generateUserPrompt: fakeGenerateUserPrompt, - filterPreviousMessages: filterPrevious12Messages, - systemMessage: systemPrompt, - llmNotWorkingMessage: - conversations.conversationConstants.LLM_NOT_WORKING, - noRelevantContentMessage: - conversations.conversationConstants.NO_RELEVANT_CONTENT, - }), + generateResponse: mockGenerateResponse, conversations, }, maxRequestTimeoutMs: 30000, diff --git a/packages/mongodb-chatbot-ui/src/useConversation.tsx b/packages/mongodb-chatbot-ui/src/useConversation.tsx index d6d5f9a73..b09491400 100644 --- a/packages/mongodb-chatbot-ui/src/useConversation.tsx +++ b/packages/mongodb-chatbot-ui/src/useConversation.tsx @@ -86,7 +86,7 @@ export function useConversation(params: UseConversationParams) { let references: References | null = null; let bufferedTokens: string[] = []; let streamedTokens: string[] = []; - const streamingIntervalMs = 50; + const streamingIntervalMs = 1; const streamingInterval = setInterval(() => { const [nextToken, ...remainingTokens] = bufferedTokens; diff --git a/packages/mongodb-rag-core/package.json b/packages/mongodb-rag-core/package.json index abf03473a..4ccd0d5e1 100644 --- a/packages/mongodb-rag-core/package.json +++ b/packages/mongodb-rag-core/package.json @@ -31,6 +31,7 @@ "./mongodb": "./build/mongodb.js", "./mongoDbMetadata": "./build/mongoDbMetadata/index.js", "./openai": "./build/openai.js", + "./aiSdk": "./build/aiSdk.js", "./braintrust": "./build/braintrust.js", "./dataSources": "./build/dataSources/index.js", "./models": "./build/models/index.js", @@ -75,6 +76,8 @@ "typescript": "^5" }, "dependencies": { + "@ai-sdk/azure": "^1.3.21", + "@ai-sdk/openai": "^1.3.20", "@apidevtools/swagger-parser": "^10.1.0", "@langchain/anthropic": "^0.3.6", "@langchain/community": "^0.3.10", @@ -83,7 +86,7 @@ "@supercharge/promise-pool": "^3.2.0", "acquit": "^1.3.0", "acquit-require": "^0.1.1", - "ai": "^4.3.9", + "ai": "^4.3.16", "braintrust": "^0.0.193", "common-tags": "^1", "deep-equal": "^2.2.3", @@ -106,4 +109,4 @@ "yaml": "^2.3.1", "zod": "^3.21.4" } -} +} \ No newline at end of file diff --git a/packages/mongodb-rag-core/src/aiSdk.ts b/packages/mongodb-rag-core/src/aiSdk.ts new file mode 100644 index 000000000..501124a0a --- /dev/null +++ b/packages/mongodb-rag-core/src/aiSdk.ts @@ -0,0 +1,9 @@ +export * from "ai"; +export * from "@ai-sdk/azure"; +export * from "@ai-sdk/openai"; +export { + MockLanguageModelV1, + mockId, + mockValues, + MockEmbeddingModelV1, +} from "ai/test"; diff --git a/packages/mongodb-rag-core/src/conversations/ConversationsService.ts b/packages/mongodb-rag-core/src/conversations/ConversationsService.ts index 7db0370b6..176155b87 100644 --- a/packages/mongodb-rag-core/src/conversations/ConversationsService.ts +++ b/packages/mongodb-rag-core/src/conversations/ConversationsService.ts @@ -54,7 +54,7 @@ export type AssistantMessage = MessageBase & { */ references?: References; - functionCall?: OpenAI.ChatCompletionMessage.FunctionCall; + toolCall?: OpenAI.ChatCompletionMessageToolCall; metadata?: AssistantMessageMetadata; }; @@ -74,8 +74,8 @@ export type VerifiedAnswerEventData = Pick< "_id" | "created" | "updated" >; -export type FunctionMessage = MessageBase & { - role: "function"; +export type ToolMessage = MessageBase & { + role: "tool"; name: string; }; @@ -128,7 +128,7 @@ export type SomeMessage = | UserMessage | AssistantMessage | SystemMessage - | FunctionMessage; + | ToolMessage; export type DbMessage = SomeMessage & { /** @@ -189,7 +189,7 @@ export type AddUserMessageParams = AddMessageParams< >; export type AddFunctionMessageParams = AddMessageParams< - WithCustomData + WithCustomData >; export type AddAssistantMessageParams = AddMessageParams; diff --git a/packages/mongodb-rag-core/src/conversations/MongoDbConversations.ts b/packages/mongodb-rag-core/src/conversations/MongoDbConversations.ts index 85005f1eb..ea093f2d5 100644 --- a/packages/mongodb-rag-core/src/conversations/MongoDbConversations.ts +++ b/packages/mongodb-rag-core/src/conversations/MongoDbConversations.ts @@ -14,8 +14,8 @@ import { AddSomeMessageParams, AssistantMessage, SystemMessage, - FunctionMessage, CommentMessageParams, + ToolMessage, } from "./ConversationsService"; /** @@ -203,9 +203,7 @@ export function createMessageFromOpenAIChatMessage( ...dbMessageBase, role: chatMessage.role, content: chatMessage.content ?? "", - ...(chatMessage.functionCall - ? { functionCall: chatMessage.functionCall } - : {}), + ...(chatMessage.toolCall ? { toolCall: chatMessage.toolCall } : {}), } satisfies AssistantMessage; } if (chatMessage.role === "system") { @@ -215,13 +213,13 @@ export function createMessageFromOpenAIChatMessage( content: chatMessage.content, } satisfies SystemMessage; } - if (chatMessage.role === "function") { + if (chatMessage.role === "tool") { return { ...dbMessageBase, role: chatMessage.role, content: chatMessage.content ?? "", name: chatMessage.name, - } satisfies FunctionMessage; + } satisfies ToolMessage; } throw new Error(`Invalid message for message: ${chatMessage}`); } diff --git a/packages/scripts/src/findFaq.ts b/packages/scripts/src/findFaq.ts index 41be90e9b..3077ef32c 100644 --- a/packages/scripts/src/findFaq.ts +++ b/packages/scripts/src/findFaq.ts @@ -5,7 +5,7 @@ import { Conversation, SomeMessage, AssistantMessage, - FunctionMessage, + ToolMessage, UserMessage, VectorStore, FindNearestNeighborsOptions, @@ -15,7 +15,7 @@ import { import { clusterize, DbscanOptions } from "./clusterize"; import { findCentroid } from "./findCentroid"; -export type ResponseMessage = AssistantMessage | FunctionMessage; +export type ResponseMessage = AssistantMessage | ToolMessage; export type QuestionAndResponses = { embedding: number[]; @@ -152,7 +152,7 @@ export const findFaq = async ({ } break; case "assistant": - case "function": + case "tool": { currentQuestion?.responses?.push(message); } diff --git a/packages/scripts/src/scrubMessages.ts b/packages/scripts/src/scrubMessages.ts index ec93cfb87..c6e2b0300 100644 --- a/packages/scripts/src/scrubMessages.ts +++ b/packages/scripts/src/scrubMessages.ts @@ -72,6 +72,7 @@ export const scrubMessages = async ({ db }: { db: Db }) => { rejectQuery: "$messages.rejectQuery", customData: "$messages.customData", metadata: "$messages.metadata", + toolCall: "$messages.toolCall", userCommented: { $cond: { // Evaluate to the user comment (if it exists) or false