From 8ed7c6e45a10a950b55945a0d2f1c74ecd9e1847 Mon Sep 17 00:00:00 2001 From: Nathan Sarrazin Date: Tue, 23 Jan 2024 13:15:53 +0100 Subject: [PATCH 01/12] Add rate-limited image generating endpoint --- .env | 1 + .env.template | 1 + src/lib/server/database.ts | 1 + src/lib/types/MessageEvent.ts | 1 + src/routes/conversation/[id]/+server.ts | 5 +- src/routes/generate/+server.ts | 62 +++++++++++++++++++++++++ 6 files changed, 69 insertions(+), 2 deletions(-) create mode 100644 src/routes/generate/+server.ts diff --git a/.env b/.env index 421fce63526..46da5c3eeff 100644 --- a/.env +++ b/.env @@ -110,6 +110,7 @@ PARQUET_EXPORT_HF_TOKEN= PARQUET_EXPORT_SECRET= RATE_LIMIT= # requests per minute +IMAGE_RATE_LIMIT= MESSAGES_BEFORE_LOGIN=# how many messages a user can send in a conversation before having to login. set to 0 to force login right away APP_BASE="" # base path of the app, e.g. /chat, left blank as default diff --git a/.env.template b/.env.template index 8fe36ccac71..bc40030e692 100644 --- a/.env.template +++ b/.env.template @@ -246,6 +246,7 @@ PUBLIC_APP_DATA_SHARING=1 PUBLIC_APP_DISCLAIMER=1 RATE_LIMIT=16 +IMAGE_RATE_LIMIT=6 MESSAGES_BEFORE_LOGIN=5# how many messages a user can send in a conversation before having to login. set to 0 to force login right away PUBLIC_GOOGLE_ANALYTICS_ID=G-8Q63TH4CSL diff --git a/src/lib/server/database.ts b/src/lib/server/database.ts index 7facc7ada38..543c29092e8 100644 --- a/src/lib/server/database.ts +++ b/src/lib/server/database.ts @@ -70,6 +70,7 @@ client.on("open", () => { users.createIndex({ hfUserId: 1 }, { unique: true }).catch(console.error); users.createIndex({ sessionId: 1 }, { unique: true, sparse: true }).catch(console.error); messageEvents.createIndex({ createdAt: 1 }, { expireAfterSeconds: 60 }).catch(console.error); + messageEvents.createIndex({ type: 1 }).catch(console.error); sessions.createIndex({ expiresAt: 1 }, { expireAfterSeconds: 0 }).catch(console.error); sessions.createIndex({ sessionId: 1 }, { unique: true }).catch(console.error); assistants.createIndex({ createdBy: 1 }).catch(console.error); diff --git a/src/lib/types/MessageEvent.ts b/src/lib/types/MessageEvent.ts index 9843cb29d8c..32810c2d349 100644 --- a/src/lib/types/MessageEvent.ts +++ b/src/lib/types/MessageEvent.ts @@ -5,4 +5,5 @@ import type { User } from "./User"; export interface MessageEvent extends Pick { userId: User["_id"] | Session["sessionId"]; ip?: string; + type: "message" | "image"; } diff --git a/src/routes/conversation/[id]/+server.ts b/src/routes/conversation/[id]/+server.ts index db0b2a9eec9..d6a22ee778a 100644 --- a/src/routes/conversation/[id]/+server.ts +++ b/src/routes/conversation/[id]/+server.ts @@ -40,6 +40,7 @@ export async function POST({ request, locals, params, getClientAddress }) { // register the event for ratelimiting await collections.messageEvents.insertOne({ userId, + type: "message", createdAt: new Date(), ip: getClientAddress(), }); @@ -70,8 +71,8 @@ export async function POST({ request, locals, params, getClientAddress }) { // check if the user is rate limited const nEvents = Math.max( - await collections.messageEvents.countDocuments({ userId }), - await collections.messageEvents.countDocuments({ ip: getClientAddress() }) + await collections.messageEvents.countDocuments({ userId, type: "message" }), + await collections.messageEvents.countDocuments({ ip: getClientAddress(), type: "message" }) ); if (RATE_LIMIT != "" && nEvents > parseInt(RATE_LIMIT)) { diff --git a/src/routes/generate/+server.ts b/src/routes/generate/+server.ts new file mode 100644 index 00000000000..befae4c03e0 --- /dev/null +++ b/src/routes/generate/+server.ts @@ -0,0 +1,62 @@ +import { ASSISTANTS_GENERATE_AVATAR, HF_TOKEN, RATE_LIMIT } from "$env/static/private"; +import { requiresUser } from "$lib/server/auth"; +import { collections } from "$lib/server/database.js"; +import { ERROR_MESSAGES } from "$lib/stores/errors.js"; +import { generateAvatar } from "$lib/utils/generateAvatar.js"; +import { timeout } from "$lib/utils/timeout.js"; +import { error } from "@sveltejs/kit"; +import { z } from "zod"; + +const avatarSchema = z.object({ + name: z.string().min(1), + description: z.string().optional(), +}); + +export async function POST({ request, locals, getClientAddress }) { + if (ASSISTANTS_GENERATE_AVATAR === "true" && HF_TOKEN !== "") { + throw new Error("ASSISTANTS_GENERATE_AVATAR is not true, or HF_TOKEN is not set"); + } + + const userId = locals.user?._id ?? locals.sessionId; + + // rate limit check + await collections.messageEvents.insertOne({ + userId, + type: "image", + createdAt: new Date(), + ip: getClientAddress(), + }); + + const nEvents = Math.max( + await collections.messageEvents.countDocuments({ userId, type: "image" }), + await collections.messageEvents.countDocuments({ ip: getClientAddress(), type: "image" }) + ); + + if (RATE_LIMIT != "" && nEvents > parseInt(RATE_LIMIT)) { + throw error(429, ERROR_MESSAGES.rateLimited); + } + + const formData = Object.fromEntries(await request.formData()); + + // can only create assistants when logged in, IF login is setup + if (!locals.user && requiresUser) { + throw error(400, "Must be logged in. Unauthorized"); + } + + const parse = avatarSchema.safeParse(formData); + + if (!parse.success) { + throw error(400, "Avatar generation failed. Input validation failed."); + } + + try { + const avatar = await timeout(generateAvatar(parse.data.description, parse.data.name), 30000); + return new Response(avatar, { + headers: { + "Content-Type": "image/png", + }, + }); + } catch (e) { + throw error(400, "Avatar generation failed. Try again or disable the feature."); + } +} From 35acebcea3fbf91545e6372c059fef4aa3597578 Mon Sep 17 00:00:00 2001 From: Nathan Sarrazin Date: Tue, 23 Jan 2024 15:03:48 +0100 Subject: [PATCH 02/12] Add generate avatar button --- src/lib/components/AssistantSettings.svelte | 131 +++++++++++++------- src/lib/components/GenerateAvatarBtn.svelte | 23 ++++ src/lib/stores/errors.ts | 2 +- src/routes/generate/+server.ts | 17 ++- 4 files changed, 124 insertions(+), 49 deletions(-) create mode 100644 src/lib/components/GenerateAvatarBtn.svelte diff --git a/src/lib/components/AssistantSettings.svelte b/src/lib/components/AssistantSettings.svelte index 6511ee7bee0..dd0e3477f10 100644 --- a/src/lib/components/AssistantSettings.svelte +++ b/src/lib/components/AssistantSettings.svelte @@ -7,9 +7,11 @@ import { applyAction, enhance } from "$app/forms"; import { base } from "$app/paths"; import CarbonPen from "~icons/carbon/pen"; + import CarbonUpload from "~icons/carbon/upload"; import { useSettingsStore } from "$lib/stores/settings"; import { page } from "$app/stores"; import IconLoading from "./icons/IconLoading.svelte"; + import GenerateAvatarBtn from "./GenerateAvatarBtn.svelte"; type ActionData = { error: boolean; @@ -45,6 +47,7 @@ const inputEl = e.target as HTMLInputElement; if (inputEl.files?.length) { files = inputEl.files; + form = null; } } @@ -53,8 +56,48 @@ } let loading = false; + let generatingAvatar = false; - let generateAvatar = false; + async function onGenerate() { + generatingAvatar = true; + form = null; + + const generatedAvatar = await fetch(base + "/generate", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + description, + name, + }), + }) + .then(async (res) => { + if (!res.ok) { + throw new Error((await res.json()).message); + } else { + return res.blob(); + } + }) + .catch((error) => { + form = { + error: true, + errors: [{ field: "avatar", message: error.message }], + }; + generatingAvatar = false; + }); + + if (generatedAvatar) { + files = [ + new File([generatedAvatar], "avatar.png", { type: "image/png" }), + ] as unknown as FileList; + } + + generatingAvatar = false; + } + + let name = assistant?.name ?? ""; + let description = assistant?.description ?? "";
-