Skip to content

Commit 77399ca

Browse files
nsarrazinMishig
andauthored
Continue generation feature (#707)
* Initial work on continue feature * Move continue button * Fix websearch with continue * Make it work with every model * Update src/routes/conversation/[id]/+server.ts Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu> * fixes * async all the things * add reduce comment * remove log * Only show loading indicator if not continuing --------- Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu>
1 parent 6e0b0ea commit 77399ca

File tree

11 files changed

+211
-83
lines changed

11 files changed

+211
-83
lines changed

.env.template

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ MODELS=`[
5757
"repetition_penalty": 1.2,
5858
"top_k": 50,
5959
"truncate": 3072,
60-
"max_new_tokens": 1024
60+
"max_new_tokens": 1024,
61+
"stop" : ["</s>", " </s><s>[INST] "]
6162
}
6263
},
6364
{
@@ -116,7 +117,8 @@ MODELS=`[
116117
"repetition_penalty": 1.2,
117118
"top_k": 50,
118119
"truncate": 4096,
119-
"max_new_tokens": 4096
120+
"max_new_tokens": 4096,
121+
"stop": [" </s><s>[INST] "]
120122
}
121123
},
122124
{

src/lib/buildPrompt.ts

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ interface buildPromptOptions {
1313
webSearch?: WebSearch;
1414
preprompt?: string;
1515
files?: File[];
16+
continue?: boolean;
1617
}
1718

1819
export async function buildPrompt({
@@ -22,37 +23,38 @@ export async function buildPrompt({
2223
preprompt,
2324
id,
2425
}: buildPromptOptions): Promise<string> {
26+
let modifiedMessages = [...messages];
27+
2528
if (webSearch && webSearch.context) {
26-
const lastMsg = messages.slice(-1)[0];
27-
const messagesWithoutLastUsrMsg = messages.slice(0, -1);
28-
const previousUserMessages = messages.filter((el) => el.from === "user").slice(0, -1);
29+
// find index of the last user message
30+
const lastUsrMsgIndex = modifiedMessages.map((el) => el.from).lastIndexOf("user");
2931

32+
// combine all the other previous questions into one string
33+
const previousUserMessages = modifiedMessages.filter((el) => el.from === "user").slice(0, -1);
3034
const previousQuestions =
3135
previousUserMessages.length > 0
3236
? `Previous questions: \n${previousUserMessages
3337
.map(({ content }) => `- ${content}`)
3438
.join("\n")}`
3539
: "";
40+
3641
const currentDate = format(new Date(), "MMMM d, yyyy");
37-
messages = [
38-
...messagesWithoutLastUsrMsg,
39-
{
40-
from: "user",
41-
content: `I searched the web using the query: ${webSearch.searchQuery}. Today is ${currentDate} and here are the results:
42+
43+
// update the last user message directly (that way if the last message is an assistant partial answer, we keep the beginning of that answer)
44+
modifiedMessages[lastUsrMsgIndex] = {
45+
from: "user",
46+
content: `I searched the web using the query: ${webSearch.searchQuery}. Today is ${currentDate} and here are the results:
4247
=====================
4348
${webSearch.context}
4449
=====================
4550
${previousQuestions}
46-
Answer the question: ${lastMsg.content}
47-
`,
48-
},
49-
];
51+
Answer the question: ${messages[lastUsrMsgIndex].content} `,
52+
};
5053
}
51-
5254
// section to handle potential files input
5355
if (model.multimodal) {
54-
messages = await Promise.all(
55-
messages.map(async (el) => {
56+
modifiedMessages = await Promise.all(
57+
modifiedMessages.map(async (el) => {
5658
let content = el.content;
5759

5860
if (el.from === "user") {
@@ -83,7 +85,7 @@ export async function buildPrompt({
8385

8486
return (
8587
model
86-
.chatPromptRender({ messages, preprompt })
88+
.chatPromptRender({ messages: modifiedMessages, preprompt })
8789
// Not super precise, but it's truncated in the model's backend anyway
8890
.split(" ")
8991
.slice(-(model.parameters?.truncate ?? 0))

src/lib/components/ContinueBtn.svelte

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
<script lang="ts">
2+
import CarbonContinue from "~icons/carbon/continue";
3+
4+
export let classNames = "";
5+
</script>
6+
7+
<button
8+
type="button"
9+
on:click
10+
class="btn flex h-8 rounded-lg border bg-white px-3 py-1 text-gray-500 shadow-sm transition-all hover:bg-gray-100 dark:border-gray-600 dark:bg-gray-700 dark:text-gray-300 dark:hover:bg-gray-600 {classNames}"
11+
>
12+
<CarbonContinue class="mr-2 text-xs " /> Continue
13+
</button>

src/lib/components/chat/ChatMessage.svelte

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import CarbonDownload from "~icons/carbon/download";
1414
import CarbonThumbsUp from "~icons/carbon/thumbs-up";
1515
import CarbonThumbsDown from "~icons/carbon/thumbs-down";
16+
1617
import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken";
1718
import type { Model } from "$lib/types/Model";
1819

src/lib/components/chat/ChatMessages.svelte

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,12 @@
5454
webSearchMessages={i === messages.length - 1 ? webSearchMessages : []}
5555
on:retry
5656
on:vote
57+
on:continue
5758
/>
5859
{:else}
5960
<ChatIntroduction {models} {currentModel} on:message />
6061
{/each}
61-
{#if pending}
62+
{#if pending && messages[messages.length - 1]?.from === "user"}
6263
<ChatMessage
6364
message={{ from: "assistant", content: "", id: randomUUID() }}
6465
model={currentModel}

src/lib/components/chat/ChatWindow.svelte

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import UploadBtn from "../UploadBtn.svelte";
2525
import file2base64 from "$lib/utils/file2base64";
2626
import { useSettingsStore } from "$lib/stores/settings";
27+
import ContinueBtn from "../ContinueBtn.svelte";
2728
2829
export let messages: Message[] = [];
2930
export let loading = false;
@@ -48,6 +49,7 @@
4849
share: void;
4950
stop: void;
5051
retry: { id: Message["id"]; content: string };
52+
continue: { id: Message["id"] };
5153
}>();
5254
5355
const handleSubmit = () => {
@@ -124,6 +126,7 @@
124126
}
125127
}}
126128
on:vote
129+
on:continue
127130
on:retry={(ev) => {
128131
if (!loading) dispatch("retry", ev.detail);
129132
}}
@@ -173,8 +176,20 @@
173176
content: messages[messages.length - 1].content,
174177
})}
175178
/>
176-
{:else if currentModel.multimodal}
177-
<UploadBtn bind:files classNames="ml-auto" />
179+
{:else}
180+
<div class="ml-auto gap-2">
181+
{#if currentModel.multimodal}
182+
<UploadBtn bind:files classNames="ml-auto" />
183+
{/if}
184+
{#if messages && messages[messages.length - 1]?.interrupted && !isReadOnly}
185+
<ContinueBtn
186+
on:click={() =>
187+
dispatch("continue", {
188+
id: messages[messages.length - 1].id,
189+
})}
190+
/>
191+
{/if}
192+
</div>
178193
{/if}
179194
</div>
180195
<form

src/lib/server/endpoints/endpoints.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ interface EndpointParameters {
1414
preprompt?: Conversation["preprompt"];
1515
_id?: Conversation["_id"];
1616
};
17+
continue?: boolean;
1718
}
1819

1920
interface CommonEndpoint {

src/lib/server/endpoints/tgi/endpointTgi.ts

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,26 @@ export const endpointTgiParametersSchema = z.object({
1515

1616
export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>): Endpoint {
1717
const { url, accessToken, model, authorization } = endpointTgiParametersSchema.parse(input);
18-
return async ({ conversation }) => {
19-
const prompt = await buildPrompt({
18+
19+
return async ({ conversation, continue: messageContinue }) => {
20+
let prompt = await buildPrompt({
2021
messages: conversation.messages,
2122
webSearch: conversation.messages[conversation.messages.length - 1].webSearch,
2223
preprompt: conversation.preprompt,
2324
model,
2425
id: conversation._id,
2526
});
2627

28+
if (messageContinue) {
29+
// start with the full prompt, and for each stop token, try to remove it from the end of the prompt
30+
prompt = model.parameters.stop.reduce((acc: string, curr: string) => {
31+
if (acc.endsWith(curr)) {
32+
return acc.slice(0, acc.length - curr.length);
33+
}
34+
return acc;
35+
}, prompt.trimEnd());
36+
}
37+
2738
return textGenerationStream(
2839
{
2940
parameters: { ...model.parameters, return_full_text: false },

src/lib/types/Message.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ export type Message = Partial<Timestamps> & {
1111
webSearch?: WebSearch;
1212
score?: -1 | 0 | 1;
1313
files?: string[]; // can contain either the hash of the file or the b64 encoded image data on the client side when uploading
14+
interrupted?: boolean;
1415
};

src/routes/conversation/[id]/+page.svelte

Lines changed: 76 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -64,23 +64,39 @@
6464
}
6565
}
6666
// this function is used to send new message to the backends
67-
async function writeMessage(message: string, messageId = randomUUID()) {
68-
if (!message.trim()) return;
69-
67+
async function writeMessage({
68+
prompt,
69+
messageId = randomUUID(),
70+
isRetry = false,
71+
isContinue = false,
72+
}: {
73+
prompt?: string;
74+
messageId?: ReturnType<typeof randomUUID>;
75+
isRetry?: boolean;
76+
isContinue?: boolean;
77+
}): Promise<void> {
7078
try {
7179
$isAborted = false;
7280
loading = true;
7381
pending = true;
7482
7583
// first we check if the messageId already exists, indicating a retry
7684
77-
let retryMessageIndex = messages.findIndex((msg) => msg.id === messageId);
78-
const isRetry = retryMessageIndex !== -1;
79-
// if it's not a retry we just use the whole array
80-
if (!isRetry) {
81-
retryMessageIndex = messages.length;
85+
let msgIndex = messages.findIndex((msg) => msg.id === messageId);
86+
87+
if (msgIndex === -1) {
88+
msgIndex = messages.length - 1;
89+
}
90+
if (isRetry && messages[msgIndex].from === "assistant") {
91+
throw new Error("Trying to retry a message that is not from user");
92+
}
93+
94+
if (isContinue && messages[msgIndex].from === "user") {
95+
throw new Error("Trying to continue a message that is not from assistant");
8296
}
8397
98+
// const isNewMessage = !isRetry && !isContinue;
99+
84100
const module = await import("browser-image-resizer");
85101
86102
// currently, only IDEFICS is supported by TGI
@@ -99,25 +115,42 @@
99115
);
100116
101117
// slice up to the point of the retry
102-
messages = [
103-
...messages.slice(0, retryMessageIndex),
104-
{
105-
from: "user",
106-
content: message,
107-
id: messageId,
108-
files: isRetry ? messages[retryMessageIndex].files : resizedImages,
109-
},
110-
];
118+
if (isRetry) {
119+
messages = [
120+
...messages.slice(0, msgIndex),
121+
{
122+
from: "user",
123+
content: messages[msgIndex].content,
124+
id: messageId,
125+
files: messages[msgIndex].files,
126+
},
127+
];
128+
} else if (!isContinue) {
129+
// or add a new message if its not a continue request
130+
if (!prompt) {
131+
throw new Error("Prompt is undefined");
132+
}
133+
messages = [
134+
...messages,
135+
{
136+
from: "user",
137+
content: prompt ?? "",
138+
id: messageId,
139+
files: resizedImages,
140+
},
141+
];
142+
}
111143
112144
files = [];
113145
114146
const response = await fetch(`${base}/conversation/${$page.params.id}`, {
115147
method: "POST",
116148
headers: { "Content-Type": "application/json" },
117149
body: JSON.stringify({
118-
inputs: message,
150+
inputs: prompt,
119151
id: messageId,
120152
is_retry: isRetry,
153+
is_continue: isContinue,
121154
web_search: $webSearchParameters.useSearch,
122155
files: isRetry ? undefined : resizedImages,
123156
}),
@@ -282,37 +315,54 @@
282315
// only used in case of creating new conversations (from the parent POST endpoint)
283316
if ($pendingMessage) {
284317
files = $pendingMessage.files;
285-
await writeMessage($pendingMessage.content);
318+
await writeMessage({ prompt: $pendingMessage.content });
286319
$pendingMessage = undefined;
287320
}
288321
});
289322
290323
async function onMessage(event: CustomEvent<string>) {
291324
if (!data.shared) {
292-
writeMessage(event.detail);
325+
await writeMessage({ prompt: event.detail });
293326
} else {
294-
convFromShared()
327+
await convFromShared()
295328
.then(async (convId) => {
296329
await goto(`${base}/conversation/${convId}`, { invalidateAll: true });
297330
})
298-
.then(() => writeMessage(event.detail))
331+
.then(async () => await writeMessage({ prompt: event.detail }))
299332
.finally(() => (loading = false));
300333
}
301334
}
302335
303336
async function onRetry(event: CustomEvent<{ id: Message["id"]; content: string }>) {
304337
if (!data.shared) {
305-
writeMessage(event.detail.content, event.detail.id);
338+
await writeMessage({
339+
prompt: event.detail.content,
340+
messageId: event.detail.id,
341+
isRetry: true,
342+
});
306343
} else {
307-
convFromShared()
344+
await convFromShared()
308345
.then(async (convId) => {
309346
await goto(`${base}/conversation/${convId}`, { invalidateAll: true });
310347
})
311-
.then(() => writeMessage(event.detail.content, event.detail.id))
348+
.then(
349+
async () =>
350+
await writeMessage({
351+
prompt: event.detail.content,
352+
messageId: event.detail.id,
353+
isRetry: true,
354+
})
355+
)
312356
.finally(() => (loading = false));
313357
}
314358
}
315359
360+
async function onContinue(event: CustomEvent<{ id: Message["id"] }>) {
361+
if (!data.shared) {
362+
writeMessage({ messageId: event.detail.id, isContinue: true });
363+
}
364+
}
365+
316366
$: $page.params.id, (($isAborted = true), (loading = false));
317367
$: title = data.conversations.find((conv) => conv.id === $page.params.id)?.title ?? data.title;
318368
</script>
@@ -337,6 +387,7 @@
337387
bind:files
338388
on:message={onMessage}
339389
on:retry={onRetry}
390+
on:continue={onContinue}
340391
on:vote={(event) => voteMessage(event.detail.score, event.detail.id)}
341392
on:share={() => shareConversation($page.params.id, data.title)}
342393
on:stop={() => (($isAborted = true), (loading = false))}

0 commit comments

Comments
 (0)