From bc14d9916c65c4e82bf617c8656d53142a9d0482 Mon Sep 17 00:00:00 2001 From: Nathan Sarrazin Date: Mon, 25 Sep 2023 10:35:09 +0200 Subject: [PATCH 01/12] wip agents --- src/lib/components/chat/ChatMessage.svelte | 27 ++ src/lib/server/database.ts | 4 +- ...ultEndpoint.ts => generateFromEndpoint.ts} | 5 +- src/lib/server/summarize.ts | 4 +- src/lib/server/tools/uploadFile.ts | 18 ++ src/lib/server/websearch/generateQuery.ts | 4 +- src/lib/types/FileMetaData.ts | 7 + src/lib/types/Message.ts | 7 + src/lib/types/MessageUpdate.ts | 8 +- src/routes/conversation/[id]/+page.svelte | 4 +- src/routes/conversation/[id]/+server.ts | 283 +++++++++++++----- .../[id]/output/[sha256]/+server.ts | 60 ++++ .../conversation/[id]/upload/+server.ts | 39 +++ 13 files changed, 385 insertions(+), 85 deletions(-) rename src/lib/server/{generateFromDefaultEndpoint.ts => generateFromEndpoint.ts} (94%) create mode 100644 src/lib/server/tools/uploadFile.ts create mode 100644 src/lib/types/FileMetaData.ts create mode 100644 src/routes/conversation/[id]/output/[sha256]/+server.ts create mode 100644 src/routes/conversation/[id]/upload/+server.ts diff --git a/src/lib/components/chat/ChatMessage.svelte b/src/lib/components/chat/ChatMessage.svelte index 799d0cf9d83..bfa109ddea9 100644 --- a/src/lib/components/chat/ChatMessage.svelte +++ b/src/lib/components/chat/ChatMessage.svelte @@ -152,6 +152,33 @@ class="prose max-w-none dark:prose-invert max-sm:prose-sm prose-headings:font-semibold prose-h1:text-lg prose-h2:text-base prose-h3:text-base prose-pre:bg-gray-800 dark:prose-pre:bg-gray-900" bind:this={contentEl} > + {#if message.files && message.files.length > 0} +
+ {#each message.files as file} +
+ {#if file.mime === "image/jpeg"} + tool output + {:else if file.mime === "audio/wav"} + + {/if} + {#if file.model} + Content generated using {file.model} + {/if} +
+ {/each} +
+
+ {/if} {#each tokens as token} {#if token.type === "code"} diff --git a/src/lib/server/database.ts b/src/lib/server/database.ts index 0925a8a6a3d..a8b3febd0e3 100644 --- a/src/lib/server/database.ts +++ b/src/lib/server/database.ts @@ -1,5 +1,5 @@ import { MONGODB_URL, MONGODB_DB_NAME, MONGODB_DIRECT_CONNECTION } from "$env/static/private"; -import { MongoClient } from "mongodb"; +import { GridFSBucket, MongoClient } from "mongodb"; import type { Conversation } from "$lib/types/Conversation"; import type { SharedConversation } from "$lib/types/SharedConversation"; import type { WebSearch } from "$lib/types/WebSearch"; @@ -29,6 +29,7 @@ const settings = db.collection("settings"); const users = db.collection("users"); const webSearches = db.collection("webSearches"); const messageEvents = db.collection("messageEvents"); +const bucket = new GridFSBucket(db, { bucketName: "toolOutputs" }); export { client, db }; export const collections = { @@ -39,6 +40,7 @@ export const collections = { users, webSearches, messageEvents, + bucket, }; client.on("open", () => { diff --git a/src/lib/server/generateFromDefaultEndpoint.ts b/src/lib/server/generateFromEndpoint.ts similarity index 94% rename from src/lib/server/generateFromDefaultEndpoint.ts rename to src/lib/server/generateFromEndpoint.ts index b65e8d98100..dde420b3bea 100644 --- a/src/lib/server/generateFromDefaultEndpoint.ts +++ b/src/lib/server/generateFromEndpoint.ts @@ -11,8 +11,9 @@ interface Parameters { max_new_tokens: number; stop: string[]; } -export async function generateFromDefaultEndpoint( +export async function generateFromEndpoint( prompt: string, + model?: typeof defaultModel, parameters?: Partial ): Promise { const newParameters = { @@ -21,7 +22,7 @@ export async function generateFromDefaultEndpoint( return_full_text: false, }; - const randomEndpoint = modelEndpoint(defaultModel); + const randomEndpoint = modelEndpoint(model ?? defaultModel); const abortController = new AbortController(); diff --git a/src/lib/server/summarize.ts b/src/lib/server/summarize.ts index 3398cebd633..76161c230cc 100644 --- a/src/lib/server/summarize.ts +++ b/src/lib/server/summarize.ts @@ -1,5 +1,5 @@ import { buildPrompt } from "$lib/buildPrompt"; -import { generateFromDefaultEndpoint } from "$lib/server/generateFromDefaultEndpoint"; +import { generateFromEndpoint } from "$lib/server/generateFromEndpoint"; import { defaultModel } from "$lib/server/models"; export async function summarize(prompt: string) { @@ -12,7 +12,7 @@ export async function summarize(prompt: string) { model: defaultModel, }); - const generated_text = await generateFromDefaultEndpoint(summaryPrompt).catch((e) => { + const generated_text = await generateFromEndpoint(summaryPrompt).catch((e) => { console.error(e); return null; }); diff --git a/src/lib/server/tools/uploadFile.ts b/src/lib/server/tools/uploadFile.ts new file mode 100644 index 00000000000..73b1ea5f00b --- /dev/null +++ b/src/lib/server/tools/uploadFile.ts @@ -0,0 +1,18 @@ +import type { Conversation } from "$lib/types/Conversation"; +import { sha256 } from "$lib/utils/sha256"; +import type { Tool } from "@huggingface/agents/src/types"; +import { collections } from "../database"; + +export async function uploadFile(file: Blob, conv: Conversation, tool?: Tool) { + const sha = await sha256(await file.text()); + const filename = `${conv._id}-${sha}`; + + const upload = collections.bucket.openUploadStream(filename, { + metadata: { conversation: conv._id.toString(), model: tool?.model, mime: tool?.mime }, + }); + + upload.write((await file.arrayBuffer()) as unknown as Buffer); + upload.end(); + + return filename; +} diff --git a/src/lib/server/websearch/generateQuery.ts b/src/lib/server/websearch/generateQuery.ts index d812bff4d24..2d26a4ee9dd 100644 --- a/src/lib/server/websearch/generateQuery.ts +++ b/src/lib/server/websearch/generateQuery.ts @@ -1,6 +1,6 @@ import type { Message } from "$lib/types/Message"; import { format } from "date-fns"; -import { generateFromDefaultEndpoint } from "../generateFromDefaultEndpoint"; +import { generateFromEndpoint } from "../generateFromEndpoint"; import { defaultModel } from "../models"; export async function generateQuery(messages: Message[]) { @@ -13,7 +13,7 @@ export async function generateQuery(messages: Message[]) { previousMessages: previousUserMessages.map(({ content }) => content).join(" "), currentDate, }); - const searchQuery = await generateFromDefaultEndpoint(promptSearchQuery).then((query) => { + const searchQuery = await generateFromEndpoint(promptSearchQuery).then((query) => { // example of generating google query: // case 1 // user: tell me what happened yesterday diff --git a/src/lib/types/FileMetaData.ts b/src/lib/types/FileMetaData.ts new file mode 100644 index 00000000000..6d681ee9101 --- /dev/null +++ b/src/lib/types/FileMetaData.ts @@ -0,0 +1,7 @@ +export interface FileMetaData { + convId: string; + sha256: string; + createdAt: Date; + model: string; + tool: string; +} diff --git a/src/lib/types/Message.ts b/src/lib/types/Message.ts index 2d092c10f0b..9f1cce8d129 100644 --- a/src/lib/types/Message.ts +++ b/src/lib/types/Message.ts @@ -2,6 +2,12 @@ import type { MessageUpdate } from "./MessageUpdate"; import type { Timestamps } from "./Timestamps"; import type { WebSearch } from "./WebSearch"; +export interface File { + sha256: string; + model?: string; + mime?: string; +} + export type Message = Partial & { from: "user" | "assistant"; id: ReturnType; @@ -10,4 +16,5 @@ export type Message = Partial & { webSearchId?: WebSearch["_id"]; // legacy version webSearch?: WebSearch; score?: -1 | 0 | 1; + files?: File[]; // filenames }; diff --git a/src/lib/types/MessageUpdate.ts b/src/lib/types/MessageUpdate.ts index 613b92e05b8..d9253250a30 100644 --- a/src/lib/types/MessageUpdate.ts +++ b/src/lib/types/MessageUpdate.ts @@ -1,4 +1,5 @@ import type { WebSearchSource } from "./WebSearch"; +import type { Update } from "@huggingface/agents/src/types"; export type FinalAnswer = { type: "finalAnswer"; @@ -10,12 +11,9 @@ export type TextStreamUpdate = { token: string; }; -export type AgentUpdate = { +export interface AgentUpdate extends Update { type: "agent"; - agent: string; - content: string; - binary?: Blob; -}; +} export type WebSearchUpdate = { type: "webSearch"; diff --git a/src/routes/conversation/[id]/+page.svelte b/src/routes/conversation/[id]/+page.svelte index 9de3d10aaf5..0819df6eb80 100644 --- a/src/routes/conversation/[id]/+page.svelte +++ b/src/routes/conversation/[id]/+page.svelte @@ -68,7 +68,7 @@ id: messageId, response_id: responseId, is_retry: isRetry, - web_search: $webSearchParameters.useSearch, + tools: $webSearchParameters.useSearch ? ["textToImage", "webSearch", "textToSpeech"] : [], }), }); @@ -128,6 +128,8 @@ } } else if (update.type === "webSearch") { webSearchMessages = [...webSearchMessages, update]; + } else if (update.type === "agent") { + console.log(update); } } catch (parseError) { // in case of parsing error we wait for the next message diff --git a/src/routes/conversation/[id]/+server.ts b/src/routes/conversation/[id]/+server.ts index c3d1b8d0486..45952e02a91 100644 --- a/src/routes/conversation/[id]/+server.ts +++ b/src/routes/conversation/[id]/+server.ts @@ -6,7 +6,7 @@ import { collections } from "$lib/server/database"; import { modelEndpoint } from "$lib/server/modelEndpoint"; import { models } from "$lib/server/models"; import { ERROR_MESSAGES } from "$lib/stores/errors"; -import type { Message } from "$lib/types/Message"; +import type { File, Message } from "$lib/types/Message"; import { trimPrefix } from "$lib/utils/trimPrefix"; import { trimSuffix } from "$lib/utils/trimSuffix"; import { textGenerationStream } from "@huggingface/inference"; @@ -14,11 +14,15 @@ import { error } from "@sveltejs/kit"; import { ObjectId } from "mongodb"; import { z } from "zod"; import { AwsClient } from "aws4fetch"; -import type { MessageUpdate } from "$lib/types/MessageUpdate"; +import type { AgentUpdate, MessageUpdate } from "$lib/types/MessageUpdate"; import { runWebSearch } from "$lib/server/websearch/runWebSearch"; import type { WebSearch } from "$lib/types/WebSearch"; import { abortedGenerations } from "$lib/server/abortedGenerations"; import { summarize } from "$lib/server/summarize"; +import type { TextGenerationStreamOutput } from "@huggingface/inference"; +import { defaultTools, HfChatAgent } from "@huggingface/agents"; +import { uploadFile } from "$lib/server/tools/uploadFile.js"; +import type { Tool } from "@huggingface/agents/src/types.js"; export async function POST({ request, fetch, locals, params, getClientAddress }) { const id = z.string().parse(params.id); @@ -84,14 +88,14 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) response_id: responseId, id: messageId, is_retry, - web_search: webSearch, + tools, } = z .object({ inputs: z.string().trim().min(1), id: z.optional(z.string().uuid()), response_id: z.optional(z.string().uuid()), is_retry: z.optional(z.boolean()), - web_search: z.optional(z.boolean()), + tools: z.array(z.string()), }) .parse(json); @@ -122,6 +126,36 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) ]; })() satisfies Message[]; + // save user prompt + await collections.conversations.updateOne( + { + _id: convId, + }, + { + $set: { + messages, + title: (await summarize(newPrompt)) ?? conv.title, + updatedAt: new Date(), + }, + } + ); + + // fetch the endpoint + const randomEndpoint = modelEndpoint(model); + + let usedFetch = fetch; + + if (randomEndpoint.host === "sagemaker") { + const aws = new AwsClient({ + accessKeyId: randomEndpoint.accessKey, + secretAccessKey: randomEndpoint.secretKey, + sessionToken: randomEndpoint.sessionToken, + service: "sagemaker", + }); + + usedFetch = aws.fetch.bind(aws) as typeof fetch; + } + // we now build the stream const stream = new ReadableStream({ async start(controller) { @@ -131,40 +165,33 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) if (newUpdate.type !== "stream") { updates.push(newUpdate); } - controller.enqueue(JSON.stringify(newUpdate) + "\n"); - } - - update({ type: "status", status: "started" }); - - let webSearchResults: WebSearch | undefined; - - if (webSearch) { - webSearchResults = await runWebSearch(conv, newPrompt, update); + try { + controller.enqueue(JSON.stringify(newUpdate) + "\n"); + } catch (e) { + console.error(e); + } } - // we can now build the prompt using the messages - const prompt = await buildPrompt({ - messages, - model, - webSearch: webSearchResults, - preprompt: settings?.customPrompts?.[model.id] ?? model.preprompt, - locals: locals, - }); - - // fetch the endpoint - const randomEndpoint = modelEndpoint(model); - - let usedFetch = fetch; - - if (randomEndpoint.host === "sagemaker") { - const aws = new AwsClient({ - accessKeyId: randomEndpoint.accessKey, - secretAccessKey: randomEndpoint.secretKey, - sessionToken: randomEndpoint.sessionToken, - service: "sagemaker", - }); + function getStream(inputs: string) { + if (!conv) { + throw new Error("Conversation not found"); + } - usedFetch = aws.fetch.bind(aws) as typeof fetch; + return textGenerationStream( + { + inputs, + parameters: { + ...models.find((m) => m.id === conv.model)?.parameters, + return_full_text: false, + }, + model: randomEndpoint.url, + accessToken: randomEndpoint.host === "sagemaker" ? undefined : HF_ACCESS_TOKEN, + }, + { + use_cache: false, + fetch: usedFetch, + } + ); } async function saveLast(generated_text: string) { @@ -175,16 +202,6 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) const lastMessage = messages[messages.length - 1]; if (lastMessage) { - // We could also check if PUBLIC_ASSISTANT_MESSAGE_TOKEN is present and use it to slice the text - if (generated_text.startsWith(prompt)) { - generated_text = generated_text.slice(prompt.length); - } - - generated_text = trimSuffix( - trimPrefix(generated_text, "<|startoftext|>"), - PUBLIC_SEP_TOKEN - ).trimEnd(); - // remove the stop tokens for (const stop of [...(model?.parameters?.stop ?? []), "<|endoftext|>"]) { if (generated_text.endsWith(stop)) { @@ -213,33 +230,11 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) } } - const tokenStream = textGenerationStream( - { - parameters: { - ...models.find((m) => m.id === conv.model)?.parameters, - return_full_text: false, - }, - model: randomEndpoint.url, - inputs: prompt, - accessToken: randomEndpoint.host === "sagemaker" ? undefined : HF_ACCESS_TOKEN, - }, - { - use_cache: false, - fetch: usedFetch, - } - ); - - for await (const output of tokenStream) { - // if not generated_text is here it means the generation is not done + const streamCallback = async (output: TextGenerationStreamOutput) => { if (!output.generated_text) { // else we get the next token if (!output.token.special) { const lastMessage = messages[messages.length - 1]; - update({ - type: "stream", - token: output.token.text, - }); - // if the last message is not from assistant, it means this is the first token if (lastMessage?.from !== "assistant") { // so we create a new message @@ -250,7 +245,7 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) { from: "assistant", content: output.token.text.trimStart(), - webSearch: webSearchResults, + webSearch: undefined, updates: updates, id: (responseId as Message["id"]) || crypto.randomUUID(), createdAt: new Date(), @@ -259,21 +254,165 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) ]; } else { const date = abortedGenerations.get(convId.toString()); + if (date && date > promptedAt) { saveLast(lastMessage.content); } + if (!output) { - break; + return; } // otherwise we just concatenate tokens lastMessage.content += output.token.text; } + + update({ + type: "stream", + token: output.token.text, + }); } - } else { - saveLast(output.generated_text); } + }; + + const files: File[] = []; + + const webSearchTool: Tool = { + name: "webSearch", + description: + "This tool can be used to search the web for extra information. It will return the most relevant paragraphs from the web", + examples: [ + { + prompt: "What are the best restaurants in Paris?", + code: '{"tool" : "imageToText", "input" : "What are the best restaurants in Paris?"}', + tools: ["webSearch"], + }, + { + prompt: "Who is the president of the United States?", + code: '{"tool" : "imageToText", "input" : "Who is the president of the United States?"}', + tools: ["webSearch"], + }, + ], + call: async (input, _) => { + const data = await input; + if (typeof data !== "string") throw "Input must be a string."; + + const results = await runWebSearch(conv, data, update); + return results.context; + }, + }; + + const SDXLTool: Tool = { + name: "textToImage", + description: + "This tool can be used to generate an image from text. It will return the image.", + mime: "image/jpeg", + model: "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0", + examples: [ + { + prompt: "Generate an image of a cat wearing a top hat", + code: '{"tool" : "textToImage", "input" : "a cat wearing a top hat"}', + tools: ["textToImage"], + }, + { + prompt: "Draw a brown dog on a beach", + code: '{"tool" : "textToImage", "input" : "drawing of a brown dog on a beach"}', + tools: ["textToImage"], + }, + ], + call: async (input, inference) => { + const data = await input; + if (typeof data !== "string") throw "Input must be a string."; + + const imageBase = await inference.textToImage( + { + inputs: data, + model: "stabilityai/stable-diffusion-xl-base-1.0", + }, + { wait_for_model: true } + ); + + const imageRefined = await inference.imageToImage( + { + inputs: imageBase, + model: "stabilityai/stable-diffusion-xl-refiner-1.0", + parameters: { + prompt: data, + }, + }, + { + wait_for_model: true, + } + ); + return imageRefined; + }, + }; + + // const listTools = [ + // ...defaultTools.filter((t) => t.name !== "textToImage"), + // webSearchTool, + // SDXLTool, + // ]; + + const listTools = [...defaultTools, webSearchTool]; + + const agent = new HfChatAgent({ + accessToken: HF_ACCESS_TOKEN, + llm: getStream, + chatFormat: (inputs: { messages: Message[] }) => + model.chatPromptRender({ + messages: inputs.messages, + preprompt: settings?.customPrompts?.[model.id] ?? model.preprompt, + }), + callbacks: { + onFile: async (file, tool) => { + const filename = await uploadFile(file, conv, tool); + files.push({ + sha256: filename.split("-")[1], + model: tool?.model, + mime: tool?.mime, + }); + }, + onUpdate: async (agentUpdate) => { + update({ ...agentUpdate, type: "agent" } satisfies AgentUpdate); + }, + onStream: streamCallback, + onFinalAnswer: async (answer) => { + saveLast(answer); + }, + }, + chatHistory: messages, + tools: listTools, + }); + + update({ type: "status", status: "started" }); + + try { + await agent.chat(newPrompt); + + const lastMessage = messages[messages.length - 1]; + if (lastMessage && lastMessage.from === "assistant") { + lastMessage.files = files; + } + } catch (e) { + console.error(e); + return new Error((e as Error).message); } + + // let webSearchResults: WebSearch | undefined; + + // if (tools.includes("websearch")) { + // webSearchResults = await runWebSearch(conv, newPrompt, update); + // } + + // // we can now build the prompt using the messages + // const prompt = await buildPrompt({ + // messages, + // model, + // webSearch: webSearchResults, + // preprompt: settings?.customPrompts?.[model.id] ?? model.preprompt, + // locals: locals, + // }); }, async cancel() { await collections.conversations.updateOne( diff --git a/src/routes/conversation/[id]/output/[sha256]/+server.ts b/src/routes/conversation/[id]/output/[sha256]/+server.ts new file mode 100644 index 00000000000..1ef6f21ede5 --- /dev/null +++ b/src/routes/conversation/[id]/output/[sha256]/+server.ts @@ -0,0 +1,60 @@ +import { authCondition } from "$lib/server/auth"; +import { collections } from "$lib/server/database"; +import { error } from "@sveltejs/kit"; +import { ObjectId } from "mongodb"; +import { z } from "zod"; +import type { RequestHandler } from "./$types"; + +export const GET: RequestHandler = async ({ locals, params }) => { + const convId = new ObjectId(z.string().parse(params.id)); + const sha256 = z.string().parse(params.sha256); + + const userId = locals.user?._id ?? locals.sessionId; + + // check user + if (!userId) { + throw error(401, "Unauthorized"); + } + + // check if the user has access to the conversation + const conv = await collections.conversations.findOne({ + _id: convId, + ...authCondition(locals), + }); + + if (!conv) { + throw error(404, "Conversation not found"); + } + + const fileId = collections.bucket.find({ filename: `${convId}-${sha256}` }); + let mime; + + const content = await fileId.next().then(async (file) => { + if (!file) { + throw error(404, "File not found"); + } + + if (file.metadata?.conversation !== convId.toString()) { + throw error(403, "You don't have access to this file."); + } + + mime = file.metadata?.mime; + + const fileStream = collections.bucket.openDownloadStream(file._id); + + const fileBuffer = await new Promise((resolve, reject) => { + const chunks: Uint8Array[] = []; + fileStream.on("data", (chunk) => chunks.push(chunk)); + fileStream.on("error", reject); + fileStream.on("end", () => resolve(Buffer.concat(chunks))); + }); + + return fileBuffer; + }); + + return new Response(content, { + headers: { + "Content-Type": mime ?? "application/octet-stream", + }, + }); +}; diff --git a/src/routes/conversation/[id]/upload/+server.ts b/src/routes/conversation/[id]/upload/+server.ts new file mode 100644 index 00000000000..e57db0a952d --- /dev/null +++ b/src/routes/conversation/[id]/upload/+server.ts @@ -0,0 +1,39 @@ +import { authCondition } from "$lib/server/auth"; +import { collections } from "$lib/server/database"; +import { uploadFile } from "$lib/server/tools/uploadFile"; +import { error } from "@sveltejs/kit"; +import { ObjectId } from "mongodb"; +import { z } from "zod"; +import type { RequestHandler } from "../$types"; + +export const POST: RequestHandler = async ({ locals, params, request }) => { + const convId = new ObjectId(z.string().parse(params.id)); + const data = await request.formData(); + + const userId = locals.user?._id ?? locals.sessionId; + + // check user + if (!userId) { + throw error(401, "Unauthorized"); + } + + // check if the user has access to the conversation + const conv = await collections.conversations.findOne({ + _id: convId, + ...authCondition(locals), + }); + + if (!conv) { + throw error(404, "Conversation not found"); + } + + const file = data.get("file") as File; + + if (!file) { + throw error(400, "No file provided"); + } + + const filename = await uploadFile(file, conv); + + return new Response(filename); +}; From 7cb9885e07e82cb8c9c0e3b55efbf627aa7086ad Mon Sep 17 00:00:00 2001 From: Nathan Sarrazin Date: Mon, 25 Sep 2023 10:42:32 +0200 Subject: [PATCH 02/12] filter tools --- src/routes/conversation/[id]/+server.ts | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/routes/conversation/[id]/+server.ts b/src/routes/conversation/[id]/+server.ts index 45952e02a91..4ea812cfa09 100644 --- a/src/routes/conversation/[id]/+server.ts +++ b/src/routes/conversation/[id]/+server.ts @@ -1,14 +1,10 @@ import { HF_ACCESS_TOKEN, MESSAGES_BEFORE_LOGIN, RATE_LIMIT } from "$env/static/private"; -import { buildPrompt } from "$lib/buildPrompt"; -import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken"; import { authCondition, requiresUser } from "$lib/server/auth"; import { collections } from "$lib/server/database"; import { modelEndpoint } from "$lib/server/modelEndpoint"; import { models } from "$lib/server/models"; import { ERROR_MESSAGES } from "$lib/stores/errors"; import type { File, Message } from "$lib/types/Message"; -import { trimPrefix } from "$lib/utils/trimPrefix"; -import { trimSuffix } from "$lib/utils/trimSuffix"; import { textGenerationStream } from "@huggingface/inference"; import { error } from "@sveltejs/kit"; import { ObjectId } from "mongodb"; @@ -16,7 +12,6 @@ import { z } from "zod"; import { AwsClient } from "aws4fetch"; import type { AgentUpdate, MessageUpdate } from "$lib/types/MessageUpdate"; import { runWebSearch } from "$lib/server/websearch/runWebSearch"; -import type { WebSearch } from "$lib/types/WebSearch"; import { abortedGenerations } from "$lib/server/abortedGenerations"; import { summarize } from "$lib/server/summarize"; import type { TextGenerationStreamOutput } from "@huggingface/inference"; @@ -302,6 +297,7 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) }, }; + // eslint-disable-next-line @typescript-eslint/no-unused-vars const SDXLTool: Tool = { name: "textToImage", description: @@ -382,7 +378,7 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) }, }, chatHistory: messages, - tools: listTools, + tools: listTools.filter((t) => tools.includes(t.name)), }); update({ type: "status", status: "started" }); From 512d7f3140831705b95dcd60cd39f11cc9eb98b8 Mon Sep 17 00:00:00 2001 From: Nathan Sarrazin Date: Tue, 26 Sep 2023 09:33:11 +0200 Subject: [PATCH 03/12] hook updates to collapse --- .../components/OpenWebSearchResults.svelte | 67 +++++++------------ src/lib/components/WebSearchToggle.svelte | 6 +- src/lib/components/chat/ChatMessage.svelte | 29 ++++---- src/lib/components/chat/ChatMessages.svelte | 12 ++-- src/lib/components/chat/ChatWindow.svelte | 8 +-- src/lib/server/websearch/runWebSearch.ts | 2 +- src/routes/conversation/[id]/+page.svelte | 13 ++-- src/routes/conversation/[id]/+server.ts | 1 + 8 files changed, 63 insertions(+), 75 deletions(-) diff --git a/src/lib/components/OpenWebSearchResults.svelte b/src/lib/components/OpenWebSearchResults.svelte index aac5fa54141..2dd1f545885 100644 --- a/src/lib/components/OpenWebSearchResults.svelte +++ b/src/lib/components/OpenWebSearchResults.svelte @@ -1,5 +1,5 @@
{/if} Web search + >Agents
@@ -39,48 +42,30 @@
- {#if webSearchMessages.length === 0} + {#if messagesToDisplay.length === 0}
{:else}
    - {#each webSearchMessages as message} - {#if message.messageType === "update"} -
  1. -
    -
    -

    - {message.message} -

    -
    - {#if message.args} -

    - {message.args} -

    - {/if} -
  2. - {:else if message.messageType === "error"} -
  3. -
    - -

    - {message.message} -

    -
    - {#if message.args} -

    - {message.args} -

    - {/if} -
  4. - {/if} + {#each messagesToDisplay as message} +
  5. +
    +
    +

    + {message.message} +

    +
    + {#if message.type === "webSearch" && message.args} +

    + {message.args} +

    + {/if} +
  6. {/each}
{/if} diff --git a/src/lib/components/WebSearchToggle.svelte b/src/lib/components/WebSearchToggle.svelte index 66295e7637c..6e93d891d43 100644 --- a/src/lib/components/WebSearchToggle.svelte +++ b/src/lib/components/WebSearchToggle.svelte @@ -12,15 +12,15 @@ on:keypress={toggle} > -
Search web
+
Agents ✨

- When enabled, the model will try to complement its answer with information queried from the - web. + When enabled, the model might try to use tools to search the web for an answer or to produce + audio and images.

diff --git a/src/lib/components/chat/ChatMessage.svelte b/src/lib/components/chat/ChatMessage.svelte index bfa109ddea9..ed0e45300df 100644 --- a/src/lib/components/chat/ChatMessage.svelte +++ b/src/lib/components/chat/ChatMessage.svelte @@ -16,7 +16,7 @@ import type { Model } from "$lib/types/Model"; import OpenWebSearchResults from "../OpenWebSearchResults.svelte"; - import type { WebSearchUpdate } from "$lib/types/MessageUpdate"; + import type { MessageUpdate, WebSearchUpdate } from "$lib/types/MessageUpdate"; function sanitizeMd(md: string) { let ret = md @@ -48,8 +48,9 @@ export let readOnly = false; export let isTapped = false; - export let webSearchMessages: WebSearchUpdate[]; + export let updateMessages: MessageUpdate[]; + console.log(updateMessages); const dispatch = createEventDispatcher<{ retry: { content: string; id: Message["id"] }; vote: { score: Message["score"]; id: Message["id"] }; @@ -104,23 +105,21 @@ } }); - let searchUpdates: WebSearchUpdate[] = []; - - $: searchUpdates = ((webSearchMessages.length > 0 - ? webSearchMessages - : message.updates?.filter(({ type }) => type === "webSearch")) ?? []) as WebSearchUpdate[]; - $: downloadLink = message.from === "user" ? `${$page.url.pathname}/message/${message.id}/prompt` : undefined; let webSearchIsDone = true; $: webSearchIsDone = - searchUpdates.length > 0 && searchUpdates[searchUpdates.length - 1].messageType === "sources"; + updateMessages.length > 0 && updateMessages[updateMessages.length - 1].type === "finalAnswer"; $: webSearchSources = - searchUpdates && - searchUpdates?.filter(({ messageType }) => messageType === "sources")?.[0]?.sources; + updateMessages && + (updateMessages?.filter(({ type }) => type === "webSearch") as WebSearchUpdate[]).filter( + ({ messageType }) => messageType === "sources" + )?.[0]?.sources; + + console.log(updateMessages); {#if message.from === "assistant"} @@ -137,14 +136,14 @@
- {#if searchUpdates && searchUpdates.length > 0} + {#if updateMessages && updateMessages.length > 0} {/if} - {#if !message.content && (webSearchIsDone || (webSearchMessages && webSearchMessages.length === 0))} + {#if !message.content && (webSearchIsDone || (updateMessages && updateMessages.length === 0))} {/if} diff --git a/src/lib/components/chat/ChatMessages.svelte b/src/lib/components/chat/ChatMessages.svelte index e46a41f74bf..227f8794f95 100644 --- a/src/lib/components/chat/ChatMessages.svelte +++ b/src/lib/components/chat/ChatMessages.svelte @@ -8,7 +8,7 @@ import type { LayoutData } from "../../../routes/$types"; import ChatIntroduction from "./ChatIntroduction.svelte"; import ChatMessage from "./ChatMessage.svelte"; - import type { WebSearchUpdate } from "$lib/types/MessageUpdate"; + import type { MessageUpdate } from "$lib/types/MessageUpdate"; import { browser } from "$app/environment"; export let messages: Message[]; @@ -22,7 +22,7 @@ let chatContainer: HTMLElement; - export let webSearchMessages: WebSearchUpdate[] = []; + export let updateMessages: MessageUpdate[] = []; async function scrollToBottom() { await tick(); @@ -37,7 +37,7 @@
@@ -48,7 +48,9 @@ {isAuthor} {readOnly} model={currentModel} - webSearchMessages={i === messages.length - 1 ? webSearchMessages : []} + updateMessages={!message.updates && i === messages.length - 1 + ? updateMessages + : message.updates ?? []} on:retry on:vote /> @@ -59,7 +61,7 @@ {/if}
diff --git a/src/lib/components/chat/ChatWindow.svelte b/src/lib/components/chat/ChatWindow.svelte index aefe926bc6c..70532ac985f 100644 --- a/src/lib/components/chat/ChatWindow.svelte +++ b/src/lib/components/chat/ChatWindow.svelte @@ -14,7 +14,7 @@ import type { LayoutData } from "../../../routes/$types"; import WebSearchToggle from "../WebSearchToggle.svelte"; import LoginModal from "../LoginModal.svelte"; - import type { WebSearchUpdate } from "$lib/types/MessageUpdate"; + import type { MessageUpdate } from "$lib/types/MessageUpdate"; export let messages: Message[] = []; export let loading = false; @@ -23,7 +23,7 @@ export let currentModel: Model; export let models: Model[]; export let settings: LayoutData["settings"]; - export let webSearchMessages: WebSearchUpdate[] = []; + export let updateMessages: MessageUpdate[] = []; export let loginRequired = false; $: isReadOnly = !models.some((model) => model.id === currentModel.id); @@ -58,7 +58,7 @@ {messages} readOnly={isReadOnly} isAuthor={!shared} - {webSearchMessages} + {updateMessages} on:message on:vote on:retry={(ev) => { @@ -135,7 +135,7 @@ type="button" on:click={() => dispatch("share")} > - +
Share this conversation
{/if} diff --git a/src/lib/server/websearch/runWebSearch.ts b/src/lib/server/websearch/runWebSearch.ts index e0c62264615..d57f949e876 100644 --- a/src/lib/server/websearch/runWebSearch.ts +++ b/src/lib/server/websearch/runWebSearch.ts @@ -63,7 +63,7 @@ export async function runWebSearch( text = await parseWeb(link); appendUpdate("Browsing webpage", [link]); } catch (e) { - console.error(`Error parsing webpage "${link}"`, e); + // console.error(`Error parsing webpage "${link}"`, e); } const MAX_N_CHUNKS = 100; const texts = chunk(text, CHUNK_CAR_LEN).slice(0, MAX_N_CHUNKS); diff --git a/src/routes/conversation/[id]/+page.svelte b/src/routes/conversation/[id]/+page.svelte index 0819df6eb80..6f438dcaa6d 100644 --- a/src/routes/conversation/[id]/+page.svelte +++ b/src/routes/conversation/[id]/+page.svelte @@ -14,7 +14,7 @@ import { webSearchParameters } from "$lib/stores/webSearchParameters"; import type { Message } from "$lib/types/Message"; import { PUBLIC_APP_DISCLAIMER } from "$env/static/public"; - import type { MessageUpdate, WebSearchUpdate } from "$lib/types/MessageUpdate"; + import type { MessageUpdate } from "$lib/types/MessageUpdate"; export let data; @@ -22,7 +22,7 @@ let lastLoadedMessages = data.messages; let isAborted = false; - let webSearchMessages: WebSearchUpdate[] = []; + let updateMessages: MessageUpdate[] = []; // Since we modify the messages array locally, we don't want to reset it if an old version is passed $: if (data.messages !== lastLoadedMessages) { @@ -110,6 +110,7 @@ try { let update = JSON.parse(el) as MessageUpdate; if (update.type === "finalAnswer") { + updateMessages = [...updateMessages, update]; finalAnswer = update.text; invalidate(UrlDependency.Conversation); } else if (update.type === "stream") { @@ -127,9 +128,9 @@ messages = [...messages]; } } else if (update.type === "webSearch") { - webSearchMessages = [...webSearchMessages, update]; + updateMessages = [...updateMessages, update]; } else if (update.type === "agent") { - console.log(update); + updateMessages = [...updateMessages, update]; } } catch (parseError) { // in case of parsing error we wait for the next message @@ -140,7 +141,7 @@ } // reset the websearchmessages - webSearchMessages = []; + updateMessages = []; await invalidate(UrlDependency.ConversationList); } catch (err) { @@ -220,7 +221,7 @@ {loading} {pending} {messages} - bind:webSearchMessages + bind:updateMessages on:message={(event) => writeMessage(event.detail)} on:retry={(event) => writeMessage(event.detail.content, event.detail.id)} on:vote={(event) => voteMessage(event.detail.score, event.detail.id)} diff --git a/src/routes/conversation/[id]/+server.ts b/src/routes/conversation/[id]/+server.ts index 4ea812cfa09..2f1a59d4139 100644 --- a/src/routes/conversation/[id]/+server.ts +++ b/src/routes/conversation/[id]/+server.ts @@ -374,6 +374,7 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) }, onStream: streamCallback, onFinalAnswer: async (answer) => { + update({ type: "finalAnswer", text: answer }); saveLast(answer); }, }, From bdeacc03019d9a67903d2ad10506202897a584f6 Mon Sep 17 00:00:00 2001 From: Nathan Sarrazin Date: Tue, 26 Sep 2023 10:07:39 +0200 Subject: [PATCH 04/12] remove logs --- src/lib/components/chat/ChatMessage.svelte | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/lib/components/chat/ChatMessage.svelte b/src/lib/components/chat/ChatMessage.svelte index ed0e45300df..1db5696dc39 100644 --- a/src/lib/components/chat/ChatMessage.svelte +++ b/src/lib/components/chat/ChatMessage.svelte @@ -50,7 +50,6 @@ export let updateMessages: MessageUpdate[]; - console.log(updateMessages); const dispatch = createEventDispatcher<{ retry: { content: string; id: Message["id"] }; vote: { score: Message["score"]; id: Message["id"] }; @@ -118,8 +117,6 @@ (updateMessages?.filter(({ type }) => type === "webSearch") as WebSearchUpdate[]).filter( ({ messageType }) => messageType === "sources" )?.[0]?.sources; - - console.log(updateMessages); {#if message.from === "assistant"} From b55e5f7256199d078a0dfb226b28c93c74f52085 Mon Sep 17 00:00:00 2001 From: Nathan Sarrazin Date: Tue, 26 Sep 2023 10:22:59 +0200 Subject: [PATCH 05/12] message creation occurs outside of streaming --- src/routes/conversation/[id]/+server.ts | 52 +++++++++++-------------- 1 file changed, 22 insertions(+), 30 deletions(-) diff --git a/src/routes/conversation/[id]/+server.ts b/src/routes/conversation/[id]/+server.ts index 2f1a59d4139..9dbe10e65c1 100644 --- a/src/routes/conversation/[id]/+server.ts +++ b/src/routes/conversation/[id]/+server.ts @@ -96,7 +96,7 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) // get the list of messages // while checking for retries - let messages = (() => { + const messages = (() => { if (is_retry && messageId) { // if the message is a retry, replace the message and remove the messages after it let retryMessageIdx = conv.messages.findIndex((message) => message.id === messageId); @@ -156,6 +156,16 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) async start(controller) { const updates: MessageUpdate[] = []; + messages.push({ + from: "assistant", + content: "", + webSearch: undefined, + updates: updates, + id: (responseId as Message["id"]) || crypto.randomUUID(), + createdAt: new Date(), + updatedAt: new Date(), + }); + function update(newUpdate: MessageUpdate) { if (newUpdate.type !== "stream") { updates.push(newUpdate); @@ -231,37 +241,19 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) if (!output.token.special) { const lastMessage = messages[messages.length - 1]; // if the last message is not from assistant, it means this is the first token - if (lastMessage?.from !== "assistant") { - // so we create a new message - messages = [ - ...messages, - // id doesn't match the backend id but it's not important for assistant messages - // First token has a space at the beginning, trim it - { - from: "assistant", - content: output.token.text.trimStart(), - webSearch: undefined, - updates: updates, - id: (responseId as Message["id"]) || crypto.randomUUID(), - createdAt: new Date(), - updatedAt: new Date(), - }, - ]; - } else { - const date = abortedGenerations.get(convId.toString()); - - if (date && date > promptedAt) { - saveLast(lastMessage.content); - } - - if (!output) { - return; - } - - // otherwise we just concatenate tokens - lastMessage.content += output.token.text; + const date = abortedGenerations.get(convId.toString()); + + if (date && date > promptedAt) { + saveLast(lastMessage.content); + } + + if (!output) { + return; } + // otherwise we just concatenate tokens + lastMessage.content += output.token.text; + update({ type: "stream", token: output.token.text, From 711c63fc07e4d9a289a31b6d1182068bd85eeadd Mon Sep 17 00:00:00 2001 From: Nathan Sarrazin Date: Tue, 26 Sep 2023 16:57:37 +0200 Subject: [PATCH 06/12] simplified scope a lot --- .../components/OpenWebSearchResults.svelte | 2 +- src/lib/components/WebSearchToggle.svelte | 52 ++++++++++++------ src/lib/components/chat/ChatMessage.svelte | 25 ++------- src/lib/components/chat/ChatMessages.svelte | 9 ---- src/lib/components/chat/ChatWindow.svelte | 2 - src/lib/server/tools/uploadFile.ts | 9 +++- src/lib/stores/webSearchParameters.ts | 2 + src/lib/types/MessageUpdate.ts | 9 +++- src/routes/conversation/[id]/+page.svelte | 42 ++++++++------- src/routes/conversation/[id]/+server.ts | 54 +++++++------------ 10 files changed, 98 insertions(+), 108 deletions(-) diff --git a/src/lib/components/OpenWebSearchResults.svelte b/src/lib/components/OpenWebSearchResults.svelte index 2dd1f545885..9fa74b07a40 100644 --- a/src/lib/components/OpenWebSearchResults.svelte +++ b/src/lib/components/OpenWebSearchResults.svelte @@ -34,7 +34,7 @@ {/if} Agents + >Tools
diff --git a/src/lib/components/WebSearchToggle.svelte b/src/lib/components/WebSearchToggle.svelte index 6e93d891d43..8f8a4a0bdb4 100644 --- a/src/lib/components/WebSearchToggle.svelte +++ b/src/lib/components/WebSearchToggle.svelte @@ -3,25 +3,47 @@ import CarbonInformation from "~icons/carbon/information"; import Switch from "./Switch.svelte"; - const toggle = () => ($webSearchParameters.useSearch = !$webSearchParameters.useSearch); + const toggleWebSearch = () => ($webSearchParameters.useSearch = !$webSearchParameters.useSearch); + const toggleSDXL = () => ($webSearchParameters.useSDXL = !$webSearchParameters.useSDXL);
- -
Agents ✨
-
- -
-

- When enabled, the model might try to use tools to search the web for an answer or to produce - audio and images. -

+
+ +
Web Search
+
+ +
+

+ When enabled, the request will be completed with relevant context fetched from the web. +

+
+
+
+
+ +
SDXL Images
+
+ +
+

+ When enabled, the model will try to generate images to go along with the answers. +

+
diff --git a/src/lib/components/chat/ChatMessage.svelte b/src/lib/components/chat/ChatMessage.svelte index 1db5696dc39..f200b9daac4 100644 --- a/src/lib/components/chat/ChatMessage.svelte +++ b/src/lib/components/chat/ChatMessage.svelte @@ -56,8 +56,6 @@ }>(); let contentEl: HTMLElement; - let loadingEl: IconLoading; - let pendingTimeout: ReturnType; const renderer = new marked.Renderer(); // For code blocks with simple backticks @@ -87,23 +85,6 @@ $: tokens = marked.lexer(sanitizeMd(message.content)); - afterUpdate(() => { - loadingEl?.$destroy(); - clearTimeout(pendingTimeout); - - // Add loading animation to the last message if update takes more than 600ms - if (loading) { - pendingTimeout = setTimeout(() => { - if (contentEl) { - loadingEl = new IconLoading({ - target: deepestChild(contentEl), - props: { classNames: "loading inline ml-2" }, - }); - } - }, 600); - } - }); - $: downloadLink = message.from === "user" ? `${$page.url.pathname}/message/${message.id}/prompt` : undefined; @@ -133,7 +114,7 @@
- {#if updateMessages && updateMessages.length > 0} + {#if updateMessages && updateMessages.filter(({ type }) => type === "agent").length > 0} {#each message.files as file}
- {#if file.mime === "image/jpeg"} + {#if file.mime?.startsWith("image")} tool output - {:else if file.mime === "audio/wav"} + {:else if file.mime?.startsWith("audio")} diff --git a/src/lib/components/chat/ChatMessages.svelte b/src/lib/components/chat/ChatMessages.svelte index 227f8794f95..3cf6f0b1b77 100644 --- a/src/lib/components/chat/ChatMessages.svelte +++ b/src/lib/components/chat/ChatMessages.svelte @@ -3,7 +3,6 @@ import { snapScrollToBottom } from "$lib/actions/snapScrollToBottom"; import ScrollToBottomBtn from "$lib/components/ScrollToBottomBtn.svelte"; import { tick } from "svelte"; - import { randomUUID } from "$lib/utils/randomUuid"; import type { Model } from "$lib/types/Model"; import type { LayoutData } from "../../../routes/$types"; import ChatIntroduction from "./ChatIntroduction.svelte"; @@ -13,7 +12,6 @@ export let messages: Message[]; export let loading: boolean; - export let pending: boolean; export let isAuthor: boolean; export let currentModel: Model; export let settings: LayoutData["settings"]; @@ -57,13 +55,6 @@ {:else} {/each} - {#if pending} - - {/if}
{ const sha = await sha256(await file.text()); const filename = `${conv._id}-${sha}`; @@ -14,5 +14,10 @@ export async function uploadFile(file: Blob, conv: Conversation, tool?: Tool) { upload.write((await file.arrayBuffer()) as unknown as Buffer); upload.end(); - return filename; + // only return the filename when upload throws a finish event or a 10s time out occurs + return new Promise((resolve, reject) => { + upload.once("finish", () => resolve(filename)); + upload.once("error", reject); + setTimeout(() => reject(new Error("Upload timed out")), 10000); + }); } diff --git a/src/lib/stores/webSearchParameters.ts b/src/lib/stores/webSearchParameters.ts index fd088a60621..868eeb92d31 100644 --- a/src/lib/stores/webSearchParameters.ts +++ b/src/lib/stores/webSearchParameters.ts @@ -1,9 +1,11 @@ import { writable } from "svelte/store"; export interface WebSearchParameters { useSearch: boolean; + useSDXL: boolean; nItems: number; } export const webSearchParameters = writable({ useSearch: false, + useSDXL: false, nItems: 5, }); diff --git a/src/lib/types/MessageUpdate.ts b/src/lib/types/MessageUpdate.ts index d9253250a30..a92617575c5 100644 --- a/src/lib/types/MessageUpdate.ts +++ b/src/lib/types/MessageUpdate.ts @@ -1,3 +1,4 @@ +import type { File } from "./Message"; import type { WebSearchSource } from "./WebSearch"; import type { Update } from "@huggingface/agents/src/types"; @@ -29,9 +30,15 @@ export type StatusUpdate = { message?: string; }; +export type FileUpdate = { + type: "file"; + file: File; +}; + export type MessageUpdate = | FinalAnswer | TextStreamUpdate | AgentUpdate | WebSearchUpdate - | StatusUpdate; + | StatusUpdate + | FileUpdate; diff --git a/src/routes/conversation/[id]/+page.svelte b/src/routes/conversation/[id]/+page.svelte index 6f438dcaa6d..9dde7b604d5 100644 --- a/src/routes/conversation/[id]/+page.svelte +++ b/src/routes/conversation/[id]/+page.svelte @@ -31,7 +31,6 @@ } let loading = false; - let pending = false; let loginRequired = false; // this function is used to send new message to the backends @@ -41,7 +40,6 @@ try { isAborted = false; loading = true; - pending = true; // first we check if the messageId already exists, indicating a retry @@ -58,8 +56,20 @@ { from: "user", content: message, id: messageId }, ]; + messages = [...messages, { from: "assistant", id: randomUUID(), content: "", files: [] }]; + const responseId = randomUUID(); + const toolsToBeUsed = []; + + if ($webSearchParameters.useSearch) { + toolsToBeUsed.push("webSearch"); + } + + if ($webSearchParameters.useSDXL) { + toolsToBeUsed.push("textToImage"); + } + const response = await fetch(`${base}/conversation/${$page.params.id}`, { method: "POST", headers: { "Content-Type": "application/json" }, @@ -68,7 +78,7 @@ id: messageId, response_id: responseId, is_retry: isRetry, - tools: $webSearchParameters.useSearch ? ["textToImage", "webSearch", "textToSpeech"] : [], + tools: toolsToBeUsed, }), }); @@ -83,7 +93,7 @@ // this is a bit ugly // we read the stream until we get the final answer - while (finalAnswer === "") { + while (finalAnswer === "" && !isAborted) { // await new Promise((r) => setTimeout(r, 25)); // check for abort @@ -112,25 +122,17 @@ if (update.type === "finalAnswer") { updateMessages = [...updateMessages, update]; finalAnswer = update.text; - invalidate(UrlDependency.Conversation); } else if (update.type === "stream") { - pending = false; - let lastMessage = messages[messages.length - 1]; - - if (lastMessage.from !== "assistant") { - messages = [ - ...messages, - { from: "assistant", id: randomUUID(), content: update.token }, - ]; - } else { - lastMessage.content += update.token; - messages = [...messages]; - } + lastMessage.content += update.token; + messages = [...messages]; } else if (update.type === "webSearch") { updateMessages = [...updateMessages, update]; } else if (update.type === "agent") { updateMessages = [...updateMessages, update]; + } else if (update.type === "file") { + messages[messages.length - 1].files?.push(update.file); + messages = [...messages]; } } catch (parseError) { // in case of parsing error we wait for the next message @@ -142,7 +144,6 @@ // reset the websearchmessages updateMessages = []; - await invalidate(UrlDependency.ConversationList); } catch (err) { if (err instanceof Error && err.message.includes("overloaded")) { @@ -157,7 +158,9 @@ console.error(err); } finally { loading = false; - pending = false; + // wait 500ms + await new Promise((r) => setTimeout(r, 500)); + invalidate(UrlDependency.Conversation); } } @@ -219,7 +222,6 @@ writeMessage(event.detail)} diff --git a/src/routes/conversation/[id]/+server.ts b/src/routes/conversation/[id]/+server.ts index 9dbe10e65c1..355a5fb6c4f 100644 --- a/src/routes/conversation/[id]/+server.ts +++ b/src/routes/conversation/[id]/+server.ts @@ -4,7 +4,7 @@ import { collections } from "$lib/server/database"; import { modelEndpoint } from "$lib/server/modelEndpoint"; import { models } from "$lib/server/models"; import { ERROR_MESSAGES } from "$lib/stores/errors"; -import type { File, Message } from "$lib/types/Message"; +import type { Message } from "$lib/types/Message"; import { textGenerationStream } from "@huggingface/inference"; import { error } from "@sveltejs/kit"; import { ObjectId } from "mongodb"; @@ -161,11 +161,14 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) content: "", webSearch: undefined, updates: updates, + files: [], id: (responseId as Message["id"]) || crypto.randomUUID(), createdAt: new Date(), updatedAt: new Date(), }); + const lastMessage = messages[messages.length - 1]; + function update(newUpdate: MessageUpdate) { if (newUpdate.type !== "stream") { updates.push(newUpdate); @@ -173,7 +176,12 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) try { controller.enqueue(JSON.stringify(newUpdate) + "\n"); } catch (e) { - console.error(e); + try { + stream.cancel(); + } catch (f) { + console.error(f); + // ignore + } } } @@ -188,6 +196,7 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) parameters: { ...models.find((m) => m.id === conv.model)?.parameters, return_full_text: false, + max_new_tokens: 4000, }, model: randomEndpoint.url, accessToken: randomEndpoint.host === "sagemaker" ? undefined : HF_ACCESS_TOKEN, @@ -204,8 +213,6 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) throw new Error("Conversation not found"); } - const lastMessage = messages[messages.length - 1]; - if (lastMessage) { // remove the stop tokens for (const stop of [...(model?.parameters?.stop ?? []), "<|endoftext|>"]) { @@ -228,10 +235,7 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) } ); - update({ - type: "finalAnswer", - text: generated_text, - }); + update({ type: "finalAnswer", text: generated_text }); } } @@ -239,7 +243,6 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) if (!output.generated_text) { // else we get the next token if (!output.token.special) { - const lastMessage = messages[messages.length - 1]; // if the last message is not from assistant, it means this is the first token const date = abortedGenerations.get(convId.toString()); @@ -262,8 +265,6 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) } }; - const files: File[] = []; - const webSearchTool: Tool = { name: "webSearch", description: @@ -355,11 +356,14 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) callbacks: { onFile: async (file, tool) => { const filename = await uploadFile(file, conv, tool); - files.push({ + + const fileObject = { sha256: filename.split("-")[1], model: tool?.model, mime: tool?.mime, - }); + }; + lastMessage.files?.push(fileObject); + update({ type: "file", file: fileObject }); }, onUpdate: async (agentUpdate) => { update({ ...agentUpdate, type: "agent" } satisfies AgentUpdate); @@ -370,38 +374,16 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) saveLast(answer); }, }, - chatHistory: messages, + chatHistory: [...messages], tools: listTools.filter((t) => tools.includes(t.name)), }); - update({ type: "status", status: "started" }); - try { await agent.chat(newPrompt); - - const lastMessage = messages[messages.length - 1]; - if (lastMessage && lastMessage.from === "assistant") { - lastMessage.files = files; - } } catch (e) { console.error(e); return new Error((e as Error).message); } - - // let webSearchResults: WebSearch | undefined; - - // if (tools.includes("websearch")) { - // webSearchResults = await runWebSearch(conv, newPrompt, update); - // } - - // // we can now build the prompt using the messages - // const prompt = await buildPrompt({ - // messages, - // model, - // webSearch: webSearchResults, - // preprompt: settings?.customPrompts?.[model.id] ?? model.preprompt, - // locals: locals, - // }); }, async cancel() { await collections.conversations.updateOne( From beebf044954bec54248ec00af5ff050dfedffbb6 Mon Sep 17 00:00:00 2001 From: Nathan Sarrazin Date: Tue, 26 Sep 2023 16:59:36 +0200 Subject: [PATCH 07/12] lint --- src/lib/components/chat/ChatWindow.svelte | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/components/chat/ChatWindow.svelte b/src/lib/components/chat/ChatWindow.svelte index a3270bebc52..7314690a1a4 100644 --- a/src/lib/components/chat/ChatWindow.svelte +++ b/src/lib/components/chat/ChatWindow.svelte @@ -133,7 +133,7 @@ type="button" on:click={() => dispatch("share")} > - +
Share this conversation
{/if} From 773fdbfae82ecf94ee3dd74991a8d1999eecf62c Mon Sep 17 00:00:00 2001 From: Nathan Sarrazin Date: Wed, 27 Sep 2023 13:05:51 +0200 Subject: [PATCH 08/12] moved tool definition to .env & made tools modular --- .env | 1 + src/lib/components/WebSearchToggle.svelte | 73 +++++---- src/lib/components/chat/ChatMessage.svelte | 3 +- src/lib/components/chat/ChatWindow.svelte | 7 +- src/lib/server/tools.ts | 58 ++++++++ src/routes/+layout.server.ts | 8 +- src/routes/conversation/[id]/+server.ts | 164 ++++++++++----------- 7 files changed, 189 insertions(+), 125 deletions(-) create mode 100644 src/lib/server/tools.ts diff --git a/.env b/.env index 2a6736ae0cd..8732bcbb06f 100644 --- a/.env +++ b/.env @@ -87,6 +87,7 @@ PUBLIC_APP_COLOR=blue # can be any of tailwind colors: https://tailwindcss.com/d PUBLIC_APP_DATA_SHARING=#set to 1 to enable options & text regarding data sharing PUBLIC_APP_DISCLAIMER=#set to 1 to show a disclaimer on login page +TOOLS = [] # PUBLIC_APP_NAME=HuggingChat # PUBLIC_APP_ASSETS=huggingchat # PUBLIC_APP_COLOR=yellow diff --git a/src/lib/components/WebSearchToggle.svelte b/src/lib/components/WebSearchToggle.svelte index 8f8a4a0bdb4..6475e747b61 100644 --- a/src/lib/components/WebSearchToggle.svelte +++ b/src/lib/components/WebSearchToggle.svelte @@ -3,6 +3,11 @@ import CarbonInformation from "~icons/carbon/information"; import Switch from "./Switch.svelte"; + export let tools: { + webSearch: boolean; + textToImage: boolean; + }; + const toggleWebSearch = () => ($webSearchParameters.useSearch = !$webSearchParameters.useSearch); const toggleSDXL = () => ($webSearchParameters.useSDXL = !$webSearchParameters.useSDXL); @@ -10,40 +15,44 @@
-
- -
Web Search
-
- -
-

- When enabled, the request will be completed with relevant context fetched from the web. -

+ {#if tools.webSearch} +
+ +
Web Search
+
+ +
+

+ When enabled, the request will be completed with relevant context fetched from the web. +

+
-
-
- -
SDXL Images
-
- -
-

- When enabled, the model will try to generate images to go along with the answers. -

+ {/if} + {#if tools.textToImage} +
+ +
SDXL Images
+
+ +
+

+ When enabled, the model will try to generate images to go along with the answers. +

+
-
+ {/if}
diff --git a/src/lib/components/chat/ChatMessage.svelte b/src/lib/components/chat/ChatMessage.svelte index f200b9daac4..84ba96886db 100644 --- a/src/lib/components/chat/ChatMessage.svelte +++ b/src/lib/components/chat/ChatMessage.svelte @@ -2,8 +2,7 @@ import { marked } from "marked"; import markedKatex from "marked-katex-extension"; import type { Message } from "$lib/types/Message"; - import { afterUpdate, createEventDispatcher } from "svelte"; - import { deepestChild } from "$lib/utils/deepestChild"; + import { createEventDispatcher } from "svelte"; import { page } from "$app/stores"; import CodeBlock from "../CodeBlock.svelte"; diff --git a/src/lib/components/chat/ChatWindow.svelte b/src/lib/components/chat/ChatWindow.svelte index 7314690a1a4..be77227d46c 100644 --- a/src/lib/components/chat/ChatWindow.svelte +++ b/src/lib/components/chat/ChatWindow.svelte @@ -42,6 +42,7 @@ dispatch("message", message); message = ""; }; + const showTools = settings?.tools.webSearch || settings?.tools.textToImage;
@@ -67,12 +68,12 @@ class="dark:via-gray-80 pointer-events-none absolute inset-x-0 bottom-0 z-0 mx-auto flex w-full max-w-3xl flex-col items-center justify-center bg-gradient-to-t from-white via-white/80 to-white/0 px-3.5 py-4 dark:border-gray-800 dark:from-gray-900 dark:to-gray-900/0 max-md:border-t max-md:bg-white max-md:dark:bg-gray-900 sm:px-5 md:py-8 xl:max-w-4xl [&>*]:pointer-events-auto" >
- {#if settings?.searchEnabled} - + {#if showTools} + {/if} {#if loading} dispatch("stop")} /> {/if} diff --git a/src/lib/server/tools.ts b/src/lib/server/tools.ts new file mode 100644 index 00000000000..f333a2bcc91 --- /dev/null +++ b/src/lib/server/tools.ts @@ -0,0 +1,58 @@ +import { SERPAPI_KEY, SERPER_API_KEY, TOOLS } from "$env/static/private"; +import { z } from "zod"; + +const webSearchTool = z.object({ + name: z.literal("webSearch"), + key: z.union([ + z.object({ + type: z.literal("serpapi"), + apiKey: z.string().min(1).default(SERPAPI_KEY), + }), + z.object({ + type: z.literal("serper"), + apiKey: z.string().min(1).default(SERPER_API_KEY), + }), + ]), +}); + +const textToImageTool = z.object({ + name: z.literal("textToImage"), + model: z.string().min(1).default("stabilityai/stable-diffusion-xl-base-1.0"), + parameters: z.optional( + z.object({ + negative_prompt: z.string().optional(), + height: z.number().optional(), + width: z.number().optional(), + num_inference_steps: z.number().optional(), + guidance_scale: z.number().optional(), + }) + ), +}); + +const toolsDefinition = z.array(z.discriminatedUnion("name", [webSearchTool, textToImageTool])); + +export const tools = toolsDefinition.parse(JSON.parse(TOOLS)); + +// check if SERPAPI_KEY or SERPER_API_KEY are defined, and if so append them to the tools + +if (SERPAPI_KEY) { + tools.push({ + name: "webSearch", + key: { + type: "serpapi", + apiKey: SERPAPI_KEY, + }, + }); +} else if (SERPER_API_KEY) { + tools.push({ + name: "webSearch", + key: { + type: "serper", + apiKey: SERPER_API_KEY, + }, + }); +} + +export type Tool = z.infer[number]; +export type WebSearchTool = z.infer; +export type TextToImageTool = z.infer; diff --git a/src/routes/+layout.server.ts b/src/routes/+layout.server.ts index ba71c157875..51dacd7ee56 100644 --- a/src/routes/+layout.server.ts +++ b/src/routes/+layout.server.ts @@ -6,7 +6,8 @@ import { UrlDependency } from "$lib/types/UrlDependency"; import { defaultModel, models, oldModels, validateModel } from "$lib/server/models"; import { authCondition, requiresUser } from "$lib/server/auth"; import { DEFAULT_SETTINGS } from "$lib/types/Settings"; -import { SERPAPI_KEY, SERPER_API_KEY, MESSAGES_BEFORE_LOGIN } from "$env/static/private"; +import { MESSAGES_BEFORE_LOGIN } from "$env/static/private"; +import { tools } from "$lib/server/tools"; export const load: LayoutServerLoad = async ({ locals, depends, url }) => { const { conversations } = collections; @@ -61,7 +62,10 @@ export const load: LayoutServerLoad = async ({ locals, depends, url }) => { DEFAULT_SETTINGS.shareConversationsWithModelAuthors, ethicsModalAcceptedAt: settings?.ethicsModalAcceptedAt ?? null, activeModel: settings?.activeModel ?? DEFAULT_SETTINGS.activeModel, - searchEnabled: !!(SERPAPI_KEY || SERPER_API_KEY), + tools: { + webSearch: tools.some((tool) => tool.name === "webSearch"), + textToImage: tools.some((tool) => tool.name === "textToImage"), + }, customPrompts: settings?.customPrompts ?? {}, }, models: models.map((model) => ({ diff --git a/src/routes/conversation/[id]/+server.ts b/src/routes/conversation/[id]/+server.ts index 355a5fb6c4f..b4ab3a9bc16 100644 --- a/src/routes/conversation/[id]/+server.ts +++ b/src/routes/conversation/[id]/+server.ts @@ -15,9 +15,10 @@ import { runWebSearch } from "$lib/server/websearch/runWebSearch"; import { abortedGenerations } from "$lib/server/abortedGenerations"; import { summarize } from "$lib/server/summarize"; import type { TextGenerationStreamOutput } from "@huggingface/inference"; -import { defaultTools, HfChatAgent } from "@huggingface/agents"; +import { HfChatAgent } from "@huggingface/agents"; import { uploadFile } from "$lib/server/tools/uploadFile.js"; import type { Tool } from "@huggingface/agents/src/types.js"; +import { tools as toolSettings, type TextToImageTool } from "$lib/server/tools.js"; export async function POST({ request, fetch, locals, params, getClientAddress }) { const id = z.string().parse(params.id); @@ -155,20 +156,6 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) const stream = new ReadableStream({ async start(controller) { const updates: MessageUpdate[] = []; - - messages.push({ - from: "assistant", - content: "", - webSearch: undefined, - updates: updates, - files: [], - id: (responseId as Message["id"]) || crypto.randomUUID(), - createdAt: new Date(), - updatedAt: new Date(), - }); - - const lastMessage = messages[messages.length - 1]; - function update(newUpdate: MessageUpdate) { if (newUpdate.type !== "stream") { updates.push(newUpdate); @@ -208,6 +195,18 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) ); } + messages.push({ + from: "assistant", + content: "", + updates: updates, + files: [], + id: (responseId as Message["id"]) || crypto.randomUUID(), + createdAt: new Date(), + updatedAt: new Date(), + }); + + const lastMessage = messages[messages.length - 1]; + async function saveLast(generated_text: string) { if (!conv) { throw new Error("Conversation not found"); @@ -265,85 +264,78 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) } }; - const webSearchTool: Tool = { - name: "webSearch", - description: - "This tool can be used to search the web for extra information. It will return the most relevant paragraphs from the web", - examples: [ - { - prompt: "What are the best restaurants in Paris?", - code: '{"tool" : "imageToText", "input" : "What are the best restaurants in Paris?"}', - tools: ["webSearch"], - }, - { - prompt: "Who is the president of the United States?", - code: '{"tool" : "imageToText", "input" : "Who is the president of the United States?"}', - tools: ["webSearch"], - }, - ], - call: async (input, _) => { - const data = await input; - if (typeof data !== "string") throw "Input must be a string."; + const listTools: Tool[] = []; - const results = await runWebSearch(conv, data, update); - return results.context; - }, - }; + if (toolSettings.some((t) => t.name === "webSearch")) { + const webSearchTool: Tool = { + name: "webSearch", + description: + "This tool can be used to search the web for extra information. It will return the most relevant paragraphs from the web", + examples: [ + { + prompt: "What are the best restaurants in Paris?", + code: '{"tool" : "imageToText", "input" : "What are the best restaurants in Paris?"}', + tools: ["webSearch"], + }, + { + prompt: "Who is the president of the United States?", + code: '{"tool" : "imageToText", "input" : "Who is the president of the United States?"}', + tools: ["webSearch"], + }, + ], + call: async (input, _) => { + const data = await input; + if (typeof data !== "string") throw "Input must be a string."; - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const SDXLTool: Tool = { - name: "textToImage", - description: - "This tool can be used to generate an image from text. It will return the image.", - mime: "image/jpeg", - model: "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0", - examples: [ - { - prompt: "Generate an image of a cat wearing a top hat", - code: '{"tool" : "textToImage", "input" : "a cat wearing a top hat"}', - tools: ["textToImage"], + const results = await runWebSearch(conv, data, update); + return results.context; }, - { - prompt: "Draw a brown dog on a beach", - code: '{"tool" : "textToImage", "input" : "drawing of a brown dog on a beach"}', - tools: ["textToImage"], - }, - ], - call: async (input, inference) => { - const data = await input; - if (typeof data !== "string") throw "Input must be a string."; + }; - const imageBase = await inference.textToImage( - { - inputs: data, - model: "stabilityai/stable-diffusion-xl-base-1.0", - }, - { wait_for_model: true } - ); + listTools.push(webSearchTool); + } - const imageRefined = await inference.imageToImage( + if (toolSettings.some((t) => t.name === "textToImage")) { + const toolParameters = toolSettings.find( + (t) => t.name === "textToImage" + ) as TextToImageTool; + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const SDXLTool: Tool = { + name: "textToImage", + description: + "This tool can be used to generate an image from text. It will return the image.", + mime: "image/jpeg", + model: "https://huggingface.co/" + toolParameters.model, + examples: [ { - inputs: imageBase, - model: "stabilityai/stable-diffusion-xl-refiner-1.0", - parameters: { - prompt: data, - }, + prompt: "Generate an image of a cat wearing a top hat", + code: '{"tool" : "textToImage", "input" : "a cat wearing a top hat"}', + tools: ["textToImage"], }, { - wait_for_model: true, - } - ); - return imageRefined; - }, - }; - - // const listTools = [ - // ...defaultTools.filter((t) => t.name !== "textToImage"), - // webSearchTool, - // SDXLTool, - // ]; + prompt: "Draw a brown dog on a beach", + code: '{"tool" : "textToImage", "input" : "drawing of a brown dog on a beach"}', + tools: ["textToImage"], + }, + ], + call: async (input, inference) => { + const data = await input; + if (typeof data !== "string") throw "Input must be a string."; + + const imageBase = await inference.textToImage( + { + inputs: data, + model: toolParameters.model, + parameters: toolParameters.parameters, + }, + { wait_for_model: true } + ); + return imageBase; + }, + }; - const listTools = [...defaultTools, webSearchTool]; + listTools.push(SDXLTool); + } const agent = new HfChatAgent({ accessToken: HF_ACCESS_TOKEN, @@ -374,7 +366,7 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) saveLast(answer); }, }, - chatHistory: [...messages], + chatHistory: messages, tools: listTools.filter((t) => tools.includes(t.name)), }); From 3ad8abe410219dc20a9f4a5248427d453c5904ee Mon Sep 17 00:00:00 2001 From: Nathan Sarrazin Date: Wed, 27 Sep 2023 13:32:07 +0200 Subject: [PATCH 09/12] update readme to reflect change in .env --- README.md | 35 +++++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ca47b163366..961baf0d087 100644 --- a/README.md +++ b/README.md @@ -162,11 +162,11 @@ MODELS=`[ You can change things like the parameters, or customize the preprompt to better suit your needs. You can also add more models by adding more objects to the array, with different preprompts for example. -#### Custom prompt templates: +#### Custom prompt templates By default the prompt is constructed using `userMessageToken`, `assistantMessageToken`, `userMessageEndToken`, `assistantMessageEndToken`, `preprompt` parameters and a series of default templates. -However, these templates can be modified by setting the `chatPromptTemplate` and `webSearchQueryPromptTemplate` parameters. Note that if WebSearch is not enabled, only `chatPromptTemplate` needs to be set. The template language is https://handlebarsjs.com. The templates have access to the model's prompt parameters (`preprompt`, etc.). However, if the templates are specified it is recommended to inline the prompt parameters, as using the references (`{{preprompt}}`) is deprecated. +However, these templates can be modified by setting the `chatPromptTemplate` and `webSearchQueryPromptTemplate` parameters. Note that if WebSearch is not enabled, only `chatPromptTemplate` needs to be set. The template language is . The templates have access to the model's prompt parameters (`preprompt`, etc.). However, if the templates are specified it is recommended to inline the prompt parameters, as using the references (`{{preprompt}}`) is deprecated. For example: @@ -300,6 +300,37 @@ If the model being hosted will be available on multiple servers/instances add th ``` +### Tools + +chat-ui supports two tools currently: + +- `webSearch` +- `textToImage` + +You can enable them by adding the following JSON to your `.env.local`: + +``` +TOOLS = `[ + { + "name" : "textToImage", + "model" : "[model name form the hub here]" + }, + { + "name" : "webSearch", + "key" : { + "type" : "serper", + "apiKey" : "[your key here]" + } + } +]` +``` + +Or a subset of these if you only want to enable some of the tools. + +The web search key `type` can be either `serper` or `serpapi`. + +The `textToImage` model can be [any model from the hub](https://huggingface.co/tasks/text-to-image) that matches the right task as long as the inference endpoint for it is enabled. + ## Deploying to a HF Space Create a `DOTENV_LOCAL` secret to your HF space with the content of your .env.local, and they will be picked up automatically when you run. From 97986db651291a3820c69158bcbfae289dffd9bf Mon Sep 17 00:00:00 2001 From: Nathan Sarrazin Date: Wed, 27 Sep 2023 16:02:53 +0200 Subject: [PATCH 10/12] prevent double websearch tools --- src/lib/server/tools.ts | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/src/lib/server/tools.ts b/src/lib/server/tools.ts index f333a2bcc91..d7fc9ac4260 100644 --- a/src/lib/server/tools.ts +++ b/src/lib/server/tools.ts @@ -35,22 +35,24 @@ export const tools = toolsDefinition.parse(JSON.parse(TOOLS)); // check if SERPAPI_KEY or SERPER_API_KEY are defined, and if so append them to the tools -if (SERPAPI_KEY) { - tools.push({ - name: "webSearch", - key: { - type: "serpapi", - apiKey: SERPAPI_KEY, - }, - }); -} else if (SERPER_API_KEY) { - tools.push({ - name: "webSearch", - key: { - type: "serper", - apiKey: SERPER_API_KEY, - }, - }); +if (!tools.some((tool) => tool.name === "webSearch")) { + if (SERPAPI_KEY) { + tools.push({ + name: "webSearch", + key: { + type: "serpapi", + apiKey: SERPAPI_KEY, + }, + }); + } else if (SERPER_API_KEY) { + tools.push({ + name: "webSearch", + key: { + type: "serper", + apiKey: SERPER_API_KEY, + }, + }); + } } export type Tool = z.infer[number]; From 7b8d90846d59ac6120f2ad4f898e7e8cd9218e3d Mon Sep 17 00:00:00 2001 From: Nathan Sarrazin Date: Wed, 27 Sep 2023 16:09:27 +0200 Subject: [PATCH 11/12] websearch fix --- src/lib/server/websearch/searchWeb.ts | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/lib/server/websearch/searchWeb.ts b/src/lib/server/websearch/searchWeb.ts index eab3c3d5f7e..60af1b865a8 100644 --- a/src/lib/server/websearch/searchWeb.ts +++ b/src/lib/server/websearch/searchWeb.ts @@ -1,14 +1,14 @@ -import { SERPAPI_KEY, SERPER_API_KEY } from "$env/static/private"; - import { getJson } from "serpapi"; import type { GoogleParameters } from "serpapi"; +import { tools, type WebSearchTool } from "../tools"; +const webSearchTool = tools.find((tool) => tool.name === "webSearch") as WebSearchTool; // Show result as JSON export async function searchWeb(query: string) { - if (SERPER_API_KEY) { + if (webSearchTool.key.type === "serper") { return await searchWebSerper(query); } - if (SERPAPI_KEY) { + if (webSearchTool.key.type === "serpapi") { return await searchWebSerpApi(query); } throw new Error("No Serper.dev or SerpAPI key found"); @@ -25,7 +25,7 @@ export async function searchWebSerper(query: string) { method: "POST", body: JSON.stringify(params), headers: { - "x-api-key": SERPER_API_KEY, + "x-api-key": webSearchTool.key.apiKey, "Content-type": "application/json; charset=UTF-8", }, }); @@ -51,7 +51,7 @@ export async function searchWebSerpApi(query: string) { hl: "en", gl: "us", google_domain: "google.com", - api_key: SERPAPI_KEY, + api_key: webSearchTool.key.apiKey, } satisfies GoogleParameters; // Show result as JSON From 20b43d33752b8b064a1286212b8bdbab19e71d7d Mon Sep 17 00:00:00 2001 From: Nathan Sarrazin Date: Wed, 27 Sep 2023 17:45:03 +0200 Subject: [PATCH 12/12] display errors on tool fail --- .../components/OpenWebSearchResults.svelte | 36 ++++++++++++++----- src/lib/types/MessageUpdate.ts | 8 ++++- src/routes/conversation/[id]/+server.ts | 3 ++ 3 files changed, 37 insertions(+), 10 deletions(-) diff --git a/src/lib/components/OpenWebSearchResults.svelte b/src/lib/components/OpenWebSearchResults.svelte index 9fa74b07a40..8205bafaa27 100644 --- a/src/lib/components/OpenWebSearchResults.svelte +++ b/src/lib/components/OpenWebSearchResults.svelte @@ -1,5 +1,10 @@
-
-

+ {#if message.type === "error"} + + {:else} +
+ {/if} +

{message.message}

diff --git a/src/lib/types/MessageUpdate.ts b/src/lib/types/MessageUpdate.ts index a92617575c5..707039cc911 100644 --- a/src/lib/types/MessageUpdate.ts +++ b/src/lib/types/MessageUpdate.ts @@ -35,10 +35,16 @@ export type FileUpdate = { file: File; }; +export type ErrorUpdate = { + type: "error"; + message: string; +}; + export type MessageUpdate = | FinalAnswer | TextStreamUpdate | AgentUpdate | WebSearchUpdate | StatusUpdate - | FileUpdate; + | FileUpdate + | ErrorUpdate; diff --git a/src/routes/conversation/[id]/+server.ts b/src/routes/conversation/[id]/+server.ts index b4ab3a9bc16..aa9f4c6804a 100644 --- a/src/routes/conversation/[id]/+server.ts +++ b/src/routes/conversation/[id]/+server.ts @@ -365,6 +365,9 @@ export async function POST({ request, fetch, locals, params, getClientAddress }) update({ type: "finalAnswer", text: answer }); saveLast(answer); }, + onError: async (errorUpdate) => { + update({ type: "error", message: errorUpdate.message }); + }, }, chatHistory: messages, tools: listTools.filter((t) => tools.includes(t.name)),