Skip to content

Commit 3a01622

Browse files
mikelfriedMishignsarrazinmishig25
authored
Add embedding models configurable, from both transformers.js and TEI (#646)
* Add embedding models configurable, from both Xenova and TEI * fix lint and format * Fix bug in sentenceSimilarity * Batches for TEI using /info route * Fix web search disapear when finish searching * Fix lint and format * Add more options for better embedding model usage * Fixing CR issues * Fix websearch disapear in later PR * Fix lint * Fix more minor code CR * Valiadate embeddingModelName field in model config * Add embeddingModel into shared conversation * Fix lint and format * Add default embedding model, and more readme explanation * Fix minor embedding model readme detailed * Update settings.json * Update README.md Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu> * Update README.md Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu> * Apply suggestions from code review Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu> * Resolved more issues * lint * Fix more issues * Fix format * fix small typo * lint * fix default model * Rn `maxSequenceLength` -> `chunkCharLength` * format * add "authorization" example * format --------- Co-authored-by: Mishig <mishig.davaadorj@coloradocollege.edu> Co-authored-by: Nathan Sarrazin <sarrazin.nathan@gmail.com> Co-authored-by: Mishig Davaadorj <dmishig@gmail.com>
1 parent 69c0464 commit 3a01622

File tree

18 files changed

+419
-66
lines changed

18 files changed

+419
-66
lines changed

.env

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,18 @@ CA_PATH=#
4646
CLIENT_KEY_PASSWORD=#
4747
REJECT_UNAUTHORIZED=true
4848

49+
TEXT_EMBEDDING_MODELS = `[
50+
{
51+
"name": "Xenova/gte-small",
52+
"displayName": "Xenova/gte-small",
53+
"description": "Local embedding model running on the server.",
54+
"chunkCharLength": 512,
55+
"endpoints": [
56+
{ "type": "transformersjs" }
57+
]
58+
}
59+
]`
60+
4961
# 'name', 'userMessageToken', 'assistantMessageToken' are required
5062
MODELS=`[
5163
{

.env.template

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ TASK_MODEL='mistralai/Mistral-7B-Instruct-v0.2'
204204
# "stop": ["</s>"]
205205
# }}`
206206

207-
208207
APP_BASE="/chat"
209208
PUBLIC_ORIGIN=https://huggingface.co
210209
PUBLIC_SHARE_PREFIX=https://hf.co/chat

README.md

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ A chat interface using open source models, eg OpenAssistant or Llama. It is a Sv
2020
1. [Setup](#setup)
2121
2. [Launch](#launch)
2222
3. [Web Search](#web-search)
23-
4. [Extra parameters](#extra-parameters)
24-
5. [Deploying to a HF Space](#deploying-to-a-hf-space)
25-
6. [Building](#building)
23+
4. [Text Embedding Models](#text-embedding-models)
24+
5. [Extra parameters](#extra-parameters)
25+
6. [Deploying to a HF Space](#deploying-to-a-hf-space)
26+
7. [Building](#building)
2627

2728
## No Setup Deploy
2829

@@ -78,10 +79,50 @@ Chat UI features a powerful Web Search feature. It works by:
7879

7980
1. Generating an appropriate search query from the user prompt.
8081
2. Performing web search and extracting content from webpages.
81-
3. Creating embeddings from texts using [transformers.js](https://huggingface.co/docs/transformers.js). Specifically, using [Xenova/gte-small](https://huggingface.co/Xenova/gte-small) model.
82+
3. Creating embeddings from texts using a text embedding model.
8283
4. From these embeddings, find the ones that are closest to the user query using a vector similarity search. Specifically, we use `inner product` distance.
8384
5. Get the corresponding texts to those closest embeddings and perform [Retrieval-Augmented Generation](https://huggingface.co/papers/2005.11401) (i.e. expand user prompt by adding those texts so that an LLM can use this information).
8485

86+
## Text Embedding Models
87+
88+
By default (for backward compatibility), when `TEXT_EMBEDDING_MODELS` environment variable is not defined, [transformers.js](https://huggingface.co/docs/transformers.js) embedding models will be used for embedding tasks, specifically, [Xenova/gte-small](https://huggingface.co/Xenova/gte-small) model.
89+
90+
You can customize the embedding model by setting `TEXT_EMBEDDING_MODELS` in your `.env.local` file. For example:
91+
92+
```env
93+
TEXT_EMBEDDING_MODELS = `[
94+
{
95+
"name": "Xenova/gte-small",
96+
"displayName": "Xenova/gte-small",
97+
"description": "locally running embedding",
98+
"chunkCharLength": 512,
99+
"endpoints": [
100+
{"type": "transformersjs"}
101+
]
102+
},
103+
{
104+
"name": "intfloat/e5-base-v2",
105+
"displayName": "intfloat/e5-base-v2",
106+
"description": "hosted embedding model",
107+
"chunkCharLength": 768,
108+
"preQuery": "query: ", # See https://huggingface.co/intfloat/e5-base-v2#faq
109+
"prePassage": "passage: ", # See https://huggingface.co/intfloat/e5-base-v2#faq
110+
"endpoints": [
111+
{
112+
"type": "tei",
113+
"url": "http://127.0.0.1:8080/",
114+
"authorization": "TOKEN_TYPE TOKEN" // optional authorization field. Example: "Basic VVNFUjpQQVNT"
115+
}
116+
]
117+
}
118+
]`
119+
```
120+
121+
The required fields are `name`, `chunkCharLength` and `endpoints`.
122+
Supported text embedding backends are: [`transformers.js`](https://huggingface.co/docs/transformers.js) and [`TEI`](https://github.com/huggingface/text-embeddings-inference). `transformers.js` models run locally as part of `chat-ui`, whereas `TEI` models run in a different environment & accessed through an API endpoint.
123+
124+
When more than one embedding models are supplied in `.env.local` file, the first will be used by default, and the others will only be used on LLM's which configured `embeddingModel` to the name of the model.
125+
85126
## Extra parameters
86127

87128
### OpenID connect
@@ -425,6 +466,45 @@ If you're using a certificate signed by a private CA, you will also need to add
425466

426467
If you're using a self-signed certificate, e.g. for testing or development purposes, you can set the `REJECT_UNAUTHORIZED` parameter to `false` in your `.env.local`. This will disable certificate validation, and allow Chat UI to connect to your custom endpoint.
427468

469+
#### Specific Embedding Model
470+
471+
A model can use any of the embedding models defined in `.env.local`, (currently used when web searching),
472+
by default it will use the first embedding model, but it can be changed with the field `embeddingModel`:
473+
474+
```env
475+
TEXT_EMBEDDING_MODELS = `[
476+
{
477+
"name": "Xenova/gte-small",
478+
"chunkCharLength": 512,
479+
"endpoints": [
480+
{"type": "transformersjs"}
481+
]
482+
},
483+
{
484+
"name": "intfloat/e5-base-v2",
485+
"chunkCharLength": 768,
486+
"endpoints": [
487+
{"type": "tei", "url": "http://127.0.0.1:8080/", "authorization": "Basic VVNFUjpQQVNT"},
488+
{"type": "tei", "url": "http://127.0.0.1:8081/"}
489+
]
490+
}
491+
]`
492+
493+
MODELS=`[
494+
{
495+
"name": "Ollama Mistral",
496+
"chatPromptTemplate": "...",
497+
"embeddingModel": "intfloat/e5-base-v2"
498+
"parameters": {
499+
...
500+
},
501+
"endpoints": [
502+
...
503+
]
504+
}
505+
]`
506+
```
507+
428508
## Deploying to a HF Space
429509

430510
Create a `DOTENV_LOCAL` secret to your HF space with the content of your .env.local, and they will be picked up automatically when you run.

src/lib/components/OpenWebSearchResults.svelte

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
{:else}
3131
<CarbonCheckmark class="my-auto text-gray-500" />
3232
{/if}
33-
<span class="px-2 font-medium" class:text-red-700={error} class:dark:text-red-500={error}
34-
>Web search
33+
<span class="px-2 font-medium" class:text-red-700={error} class:dark:text-red-500={error}>
34+
Web search
3535
</span>
3636
<div class="my-auto transition-all" class:rotate-90={detailsOpen}>
3737
<CarbonCaretRight />
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import { z } from "zod";
2+
import type { EmbeddingEndpoint, Embedding } from "$lib/types/EmbeddingEndpoints";
3+
import { chunk } from "$lib/utils/chunk";
4+
5+
export const embeddingEndpointTeiParametersSchema = z.object({
6+
weight: z.number().int().positive().default(1),
7+
model: z.any(),
8+
type: z.literal("tei"),
9+
url: z.string().url(),
10+
authorization: z.string().optional(),
11+
});
12+
13+
const getModelInfoByUrl = async (url: string, authorization?: string) => {
14+
const { origin } = new URL(url);
15+
16+
const response = await fetch(`${origin}/info`, {
17+
headers: {
18+
Accept: "application/json",
19+
"Content-Type": "application/json",
20+
...(authorization ? { Authorization: authorization } : {}),
21+
},
22+
});
23+
24+
const json = await response.json();
25+
return json;
26+
};
27+
28+
export async function embeddingEndpointTei(
29+
input: z.input<typeof embeddingEndpointTeiParametersSchema>
30+
): Promise<EmbeddingEndpoint> {
31+
const { url, model, authorization } = embeddingEndpointTeiParametersSchema.parse(input);
32+
33+
const { max_client_batch_size, max_batch_tokens } = await getModelInfoByUrl(url);
34+
const maxBatchSize = Math.min(
35+
max_client_batch_size,
36+
Math.floor(max_batch_tokens / model.chunkCharLength)
37+
);
38+
39+
return async ({ inputs }) => {
40+
const { origin } = new URL(url);
41+
42+
const batchesInputs = chunk(inputs, maxBatchSize);
43+
44+
const batchesResults = await Promise.all(
45+
batchesInputs.map(async (batchInputs) => {
46+
const response = await fetch(`${origin}/embed`, {
47+
method: "POST",
48+
headers: {
49+
Accept: "application/json",
50+
"Content-Type": "application/json",
51+
...(authorization ? { Authorization: authorization } : {}),
52+
},
53+
body: JSON.stringify({ inputs: batchInputs, normalize: true, truncate: true }),
54+
});
55+
56+
const embeddings: Embedding[] = await response.json();
57+
return embeddings;
58+
})
59+
);
60+
61+
const flatAllEmbeddings = batchesResults.flat();
62+
63+
return flatAllEmbeddings;
64+
};
65+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import { z } from "zod";
2+
import type { EmbeddingEndpoint } from "$lib/types/EmbeddingEndpoints";
3+
import type { Tensor, Pipeline } from "@xenova/transformers";
4+
import { pipeline } from "@xenova/transformers";
5+
6+
export const embeddingEndpointTransformersJSParametersSchema = z.object({
7+
weight: z.number().int().positive().default(1),
8+
model: z.any(),
9+
type: z.literal("transformersjs"),
10+
});
11+
12+
// Use the Singleton pattern to enable lazy construction of the pipeline.
13+
class TransformersJSModelsSingleton {
14+
static instances: Array<[string, Promise<Pipeline>]> = [];
15+
16+
static async getInstance(modelName: string): Promise<Pipeline> {
17+
const modelPipelineInstance = this.instances.find(([name]) => name === modelName);
18+
19+
if (modelPipelineInstance) {
20+
const [, modelPipeline] = modelPipelineInstance;
21+
return modelPipeline;
22+
}
23+
24+
const newModelPipeline = pipeline("feature-extraction", modelName);
25+
this.instances.push([modelName, newModelPipeline]);
26+
27+
return newModelPipeline;
28+
}
29+
}
30+
31+
export async function calculateEmbedding(modelName: string, inputs: string[]) {
32+
const extractor = await TransformersJSModelsSingleton.getInstance(modelName);
33+
const output: Tensor = await extractor(inputs, { pooling: "mean", normalize: true });
34+
35+
return output.tolist();
36+
}
37+
38+
export function embeddingEndpointTransformersJS(
39+
input: z.input<typeof embeddingEndpointTransformersJSParametersSchema>
40+
): EmbeddingEndpoint {
41+
const { model } = embeddingEndpointTransformersJSParametersSchema.parse(input);
42+
43+
return async ({ inputs }) => {
44+
return calculateEmbedding(model.name, inputs);
45+
};
46+
}

src/lib/server/embeddingModels.ts

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import { TEXT_EMBEDDING_MODELS } from "$env/static/private";
2+
3+
import { z } from "zod";
4+
import { sum } from "$lib/utils/sum";
5+
import {
6+
embeddingEndpoints,
7+
embeddingEndpointSchema,
8+
type EmbeddingEndpoint,
9+
} from "$lib/types/EmbeddingEndpoints";
10+
import { embeddingEndpointTransformersJS } from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints";
11+
12+
const modelConfig = z.object({
13+
/** Used as an identifier in DB */
14+
id: z.string().optional(),
15+
/** Used to link to the model page, and for inference */
16+
name: z.string().min(1),
17+
displayName: z.string().min(1).optional(),
18+
description: z.string().min(1).optional(),
19+
websiteUrl: z.string().url().optional(),
20+
modelUrl: z.string().url().optional(),
21+
endpoints: z.array(embeddingEndpointSchema).nonempty(),
22+
chunkCharLength: z.number().positive(),
23+
preQuery: z.string().default(""),
24+
prePassage: z.string().default(""),
25+
});
26+
27+
// Default embedding model for backward compatibility
28+
const rawEmbeddingModelJSON =
29+
TEXT_EMBEDDING_MODELS ||
30+
`[
31+
{
32+
"name": "Xenova/gte-small",
33+
"chunkCharLength": 512,
34+
"endpoints": [
35+
{ "type": "transformersjs" }
36+
]
37+
}
38+
]`;
39+
40+
const embeddingModelsRaw = z.array(modelConfig).parse(JSON.parse(rawEmbeddingModelJSON));
41+
42+
const processEmbeddingModel = async (m: z.infer<typeof modelConfig>) => ({
43+
...m,
44+
id: m.id || m.name,
45+
});
46+
47+
const addEndpoint = (m: Awaited<ReturnType<typeof processEmbeddingModel>>) => ({
48+
...m,
49+
getEndpoint: async (): Promise<EmbeddingEndpoint> => {
50+
if (!m.endpoints) {
51+
return embeddingEndpointTransformersJS({
52+
type: "transformersjs",
53+
weight: 1,
54+
model: m,
55+
});
56+
}
57+
58+
const totalWeight = sum(m.endpoints.map((e) => e.weight));
59+
60+
let random = Math.random() * totalWeight;
61+
62+
for (const endpoint of m.endpoints) {
63+
if (random < endpoint.weight) {
64+
const args = { ...endpoint, model: m };
65+
66+
switch (args.type) {
67+
case "tei":
68+
return embeddingEndpoints.tei(args);
69+
case "transformersjs":
70+
return embeddingEndpoints.transformersjs(args);
71+
}
72+
}
73+
74+
random -= endpoint.weight;
75+
}
76+
77+
throw new Error(`Failed to select embedding endpoint`);
78+
},
79+
});
80+
81+
export const embeddingModels = await Promise.all(
82+
embeddingModelsRaw.map((e) => processEmbeddingModel(e).then(addEndpoint))
83+
);
84+
85+
export const defaultEmbeddingModel = embeddingModels[0];
86+
87+
const validateEmbeddingModel = (_models: EmbeddingBackendModel[], key: "id" | "name") => {
88+
return z.enum([_models[0][key], ..._models.slice(1).map((m) => m[key])]);
89+
};
90+
91+
export const validateEmbeddingModelById = (_models: EmbeddingBackendModel[]) => {
92+
return validateEmbeddingModel(_models, "id");
93+
};
94+
95+
export const validateEmbeddingModelByName = (_models: EmbeddingBackendModel[]) => {
96+
return validateEmbeddingModel(_models, "name");
97+
};
98+
99+
export type EmbeddingBackendModel = typeof defaultEmbeddingModel;

src/lib/server/models.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import { z } from "zod";
1212
import endpoints, { endpointSchema, type Endpoint } from "./endpoints/endpoints";
1313
import endpointTgi from "./endpoints/tgi/endpointTgi";
1414
import { sum } from "$lib/utils/sum";
15+
import { embeddingModels, validateEmbeddingModelByName } from "./embeddingModels";
1516

1617
import JSON5 from "json5";
1718

@@ -68,6 +69,7 @@ const modelConfig = z.object({
6869
.optional(),
6970
multimodal: z.boolean().default(false),
7071
unlisted: z.boolean().default(false),
72+
embeddingModel: validateEmbeddingModelByName(embeddingModels).optional(),
7173
});
7274

7375
const modelsRaw = z.array(modelConfig).parse(JSON5.parse(MODELS));

0 commit comments

Comments
 (0)