Skip to content

Commit f6c61eb

Browse files
nsarrazinmikelfriedMishig
authored
[Assistants] Use textToImage task for avatar generation (#662)
* Generate assistants avatar using stablediffusion * wording * Update +page.server.ts Co-authored-by: Michael Fried <mikelfried@gmail.com> * Add timeout & controls to avatar generation * Add controls for avatar generation in .env * Update src/routes/+layout.server.ts Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu> * Update src/lib/components/AssistantSettings.svelte Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu> * Fix avatar gen feature flag * Can only upload avatar if generate is unchecked --------- Co-authored-by: Michael Fried <mikelfried@gmail.com> Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu>
1 parent 8eb6b5b commit f6c61eb

File tree

7 files changed

+139
-5
lines changed

7 files changed

+139
-5
lines changed

.env

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,4 +128,6 @@ LLM_SUMMERIZATION=true
128128
# PUBLIC_APP_DATA_SHARING=1
129129
# PUBLIC_APP_DISCLAIMER=1
130130

131-
ENABLE_ASSISTANTS=false #set to true to enable assistants feature
131+
ENABLE_ASSISTANTS=false #set to true to enable assistants feature
132+
ASSISTANTS_GENERATE_AVATAR=true #requires an hf token, uses the model description and name to generate an avatar using a text to image model
133+
TEXT_TO_IMAGE_MODEL="runwayml/stable-diffusion-v1-5"

src/lib/components/AssistantSettings.svelte

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
import type { Assistant } from "$lib/types/Assistant";
55
66
import { onMount } from "svelte";
7-
import { enhance } from "$app/forms";
7+
import { applyAction, enhance } from "$app/forms";
88
import { base } from "$app/paths";
99
import CarbonPen from "~icons/carbon/pen";
1010
import { useSettingsStore } from "$lib/stores/settings";
11+
import { page } from "$app/stores";
12+
import IconLoading from "./icons/IconLoading.svelte";
1113
1214
type ActionData = {
1315
error: boolean;
@@ -49,13 +51,18 @@
4951
function getError(field: string, returnForm: ActionData) {
5052
return returnForm?.errors.find((error) => error.field === field)?.message ?? "";
5153
}
54+
55+
let loading = false;
56+
57+
let generateAvatar = false;
5258
</script>
5359

5460
<form
5561
method="POST"
5662
class="h-full w-full overflow-x-clip"
5763
enctype="multipart/form-data"
5864
use:enhance={async ({ formData }) => {
65+
loading = true;
5966
const avatar = formData.get("avatar");
6067

6168
if (avatar && typeof avatar !== "string" && avatar.size > 0 && compress) {
@@ -67,6 +74,11 @@
6774
formData.set("avatar", resizedImage);
6875
});
6976
}
77+
78+
return async ({ result }) => {
79+
loading = false;
80+
await applyAction(result);
81+
};
7082
}}
7183
>
7284
{#if assistant}
@@ -90,6 +102,7 @@
90102
accept="image/*"
91103
name="avatar"
92104
class="invisible z-10 block h-0 w-0"
105+
disabled={generateAvatar}
93106
on:change={onFilesChange}
94107
/>
95108
{#if (files && files[0]) || assistant?.avatar}
@@ -122,9 +135,24 @@
122135
Reset
123136
</button>
124137
{:else}
125-
<span class="text-xs text-gray-500 hover:underline">Click to upload</span>
138+
<span
139+
class="text-xs text-gray-500"
140+
class:hover:underline={!generateAvatar}
141+
class:cursor-pointer={!generateAvatar}>Click to upload</span
142+
>
126143
{/if}
127144
<p class="text-xs text-red-500">{getError("avatar", form)}</p>
145+
{#if !files?.[0] && $page.data.avatarGeneration && !assistant?.avatar}
146+
<label class="text-xs text-gray-500">
147+
<input
148+
type="checkbox"
149+
name="generateAvatar"
150+
class="text-xs text-gray-500"
151+
bind:checked={generateAvatar}
152+
/>
153+
Generate avatar from description
154+
</label>
155+
{/if}
128156
</label>
129157

130158
<label>
@@ -220,8 +248,19 @@
220248
class="rounded-full bg-gray-200 px-8 py-2 font-semibold text-gray-600">Cancel</a
221249
>
222250

223-
<button type="submit" class="rounded-full bg-black px-8 py-2 font-semibold text-white md:px-20"
224-
>{assistant ? "Save" : "Create"}</button
251+
<button
252+
type="submit"
253+
disabled={loading}
254+
aria-disabled={loading}
255+
class="rounded-full bg-black px-8 py-2 font-semibold md:px-20"
256+
class:bg-gray-200={loading}
257+
class:text-gray-600={loading}
258+
class:text-white={!loading}
225259
>
260+
{assistant ? "Save" : "Create"}
261+
{#if loading}
262+
<IconLoading classNames="ml-2 h-min" />
263+
{/if}
264+
</button>
226265
</div>
227266
</form>

src/lib/utils/generateAvatar.ts

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import { HF_TOKEN, TEXT_TO_IMAGE_MODEL } from "$env/static/private";
2+
import { generateFromDefaultEndpoint } from "$lib/server/generateFromDefaultEndpoint";
3+
import { HfInference } from "@huggingface/inference";
4+
5+
export async function generateAvatar(description?: string, name?: string): Promise<File> {
6+
const queryPrompt = `Generate a prompt for an image-generation model for the following:
7+
Name: ${name}
8+
Description: ${description}
9+
`;
10+
const imagePrompt = await generateFromDefaultEndpoint({
11+
messages: [{ from: "user", content: queryPrompt }],
12+
preprompt:
13+
"You are an assistant tasked with generating simple image descriptions. The user will ask you for an image, based on the name and a description of what they want, and you should reply with a short, concise, safe, descriptive sentence.",
14+
});
15+
16+
const hf = new HfInference(HF_TOKEN);
17+
18+
const blob = await hf.textToImage({
19+
inputs: imagePrompt,
20+
model: TEXT_TO_IMAGE_MODEL,
21+
});
22+
23+
return new File([blob], "avatar.png");
24+
}

src/lib/utils/timeout.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
export const timeout = <T>(prom: Promise<T>, time: number): Promise<T> => {
2+
let timer: NodeJS.Timeout;
3+
return Promise.race([prom, new Promise<T>((_r, rej) => (timer = setTimeout(rej, time)))]).finally(
4+
() => clearTimeout(timer)
5+
);
6+
};

src/routes/+layout.server.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import {
1313
YDC_API_KEY,
1414
USE_LOCAL_WEBSEARCH,
1515
ENABLE_ASSISTANTS,
16+
ASSISTANTS_GENERATE_AVATAR,
17+
TEXT_TO_IMAGE_MODEL,
1618
} from "$env/static/private";
1719
import { ObjectId } from "mongodb";
1820
import type { ConvSidebar } from "$lib/types/ConvSidebar";
@@ -161,6 +163,7 @@ export const load: LayoutServerLoad = async ({ locals, depends }) => {
161163
email: locals.user.email,
162164
},
163165
assistant,
166+
avatarGeneration: ASSISTANTS_GENERATE_AVATAR === "true" && !!TEXT_TO_IMAGE_MODEL,
164167
enableAssistants,
165168
loginRequired,
166169
loginEnabled: requiresUser,

src/routes/settings/assistants/[assistantId]/edit/+page.server.ts

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ import { ObjectId } from "mongodb";
77
import { z } from "zod";
88
import sizeof from "image-size";
99
import { sha256 } from "$lib/utils/sha256";
10+
import { ASSISTANTS_GENERATE_AVATAR, HF_TOKEN } from "$env/static/private";
11+
import { generateAvatar } from "$lib/utils/generateAvatar";
12+
import { timeout } from "$lib/utils/timeout";
1013

1114
const newAsssistantSchema = z.object({
1215
name: z.string().min(1),
@@ -18,6 +21,10 @@ const newAsssistantSchema = z.object({
1821
exampleInput3: z.string().optional(),
1922
exampleInput4: z.string().optional(),
2023
avatar: z.instanceof(File).optional(),
24+
generateAvatar: z
25+
.literal("on")
26+
.optional()
27+
.transform((el) => !!el),
2128
});
2229

2330
const uploadAvatar = async (avatar: File, assistantId: ObjectId): Promise<string> => {
@@ -99,6 +106,29 @@ export const actions: Actions = {
99106
}
100107

101108
hash = await uploadAvatar(parse.data.avatar, assistant._id);
109+
} else if (
110+
ASSISTANTS_GENERATE_AVATAR === "true" &&
111+
HF_TOKEN !== "" &&
112+
parse.data.generateAvatar
113+
) {
114+
try {
115+
const avatar = await timeout(
116+
generateAvatar(parse.data.description, parse.data.name),
117+
30000
118+
);
119+
120+
hash = await uploadAvatar(avatar, assistant._id);
121+
} catch (err) {
122+
return fail(400, {
123+
error: true,
124+
errors: [
125+
{
126+
field: "avatar",
127+
message: "Avatar generation failed. Try again or disable the feature.",
128+
},
129+
],
130+
});
131+
}
102132
}
103133

104134
const { acknowledged } = await collections.assistants.replaceOne(

src/routes/settings/assistants/new/+page.server.ts

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ import { ObjectId } from "mongodb";
77
import { z } from "zod";
88
import sizeof from "image-size";
99
import { sha256 } from "$lib/utils/sha256";
10+
import { ASSISTANTS_GENERATE_AVATAR, HF_TOKEN } from "$env/static/private";
11+
import { timeout } from "$lib/utils/timeout";
12+
import { generateAvatar } from "$lib/utils/generateAvatar";
1013

1114
const newAsssistantSchema = z.object({
1215
name: z.string().min(1),
@@ -18,6 +21,10 @@ const newAsssistantSchema = z.object({
1821
exampleInput3: z.string().optional(),
1922
exampleInput4: z.string().optional(),
2023
avatar: z.instanceof(File).optional(),
24+
generateAvatar: z
25+
.literal("on")
26+
.optional()
27+
.transform((el) => !!el),
2128
});
2229

2330
const uploadAvatar = async (avatar: File, assistantId: ObjectId): Promise<string> => {
@@ -88,6 +95,29 @@ export const actions: Actions = {
8895
}
8996

9097
hash = await uploadAvatar(parse.data.avatar, newAssistantId);
98+
} else if (
99+
ASSISTANTS_GENERATE_AVATAR === "true" &&
100+
HF_TOKEN !== "" &&
101+
parse.data.generateAvatar
102+
) {
103+
try {
104+
const avatar = await timeout(
105+
generateAvatar(parse.data.description, parse.data.name),
106+
30000
107+
);
108+
109+
hash = await uploadAvatar(avatar, newAssistantId);
110+
} catch (err) {
111+
return fail(400, {
112+
error: true,
113+
errors: [
114+
{
115+
field: "avatar",
116+
message: "Avatar generation failed. Try again or disable the feature.",
117+
},
118+
],
119+
});
120+
}
91121
}
92122

93123
const { insertedId } = await collections.assistants.insertOne({

0 commit comments

Comments
 (0)