Skip to content

Commit b73ec69

Browse files
authored
Add missing chat inference snippets (#1637)
follow-up after #1636 Better to review commits individually. This PR: - adds a snippet using JS fetch for conversational model (e2f6c6f) - adds tests for existing snippets using "auto" provider (conversational) (4356373) - adds snippets for "auto" + "conversational" for cURL, Python openai, Python requests, JS openai, JS requests. (33d3094). Before that, only snippets for huggingface_hub/huggingface.js were displayed.
1 parent a4ca182 commit b73ec69

File tree

37 files changed

+956
-12
lines changed

37 files changed

+956
-12
lines changed

packages/inference/src/snippets/getInferenceSnippets.ts

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@ const CLIENTS: Record<InferenceSnippetLanguage, Client[]> = {
3636
sh: [...SH_CLIENTS],
3737
};
3838

39-
const CLIENTS_AUTO_POLICY: Partial<Record<InferenceSnippetLanguage, Client[]>> = {
39+
// The "auto"-provider policy is only available through the HF SDKs (huggingface.js / huggingface_hub)
40+
// except for conversational tasks for which we have https://router.huggingface.co/v1/chat/completions
41+
const CLIENTS_NON_CONVERSATIONAL_AUTO_POLICY: Partial<Record<InferenceSnippetLanguage, Client[]>> = {
4042
js: ["huggingface.js"],
41-
python: ["huggingface_hub", "openai"],
43+
python: ["huggingface_hub"],
4244
};
4345

4446
type InputPreparationFn = (model: ModelDataMinimal, opts?: Record<string, unknown>) => object;
@@ -206,11 +208,16 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
206208
// Inputs for the "auto" route is strictly the same as "inputs", except the model includes the provider
207209
// If not "auto" route, use the providerInputs
208210
const autoInputs =
209-
provider !== "auto" && !opts?.endpointUrl && !opts?.directRequest
210-
? {
211-
...inputs,
212-
model: `${model.id}:${provider}`,
213-
}
211+
!opts?.endpointUrl && !opts?.directRequest
212+
? provider !== "auto"
213+
? {
214+
...inputs,
215+
model: `${model.id}:${provider}`,
216+
}
217+
: {
218+
...inputs,
219+
model: `${model.id}`, // if no :provider => auto
220+
}
214221
: providerInputs;
215222

216223
/// Prepare template injection data
@@ -259,7 +266,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
259266
};
260267

261268
/// Iterate over clients => check if a snippet exists => generate
262-
const clients = provider === "auto" ? CLIENTS_AUTO_POLICY : CLIENTS;
269+
const clients = provider === "auto" && task !== "conversational" ? CLIENTS_NON_CONVERSATIONAL_AUTO_POLICY : CLIENTS;
263270
return inferenceSnippetLanguages
264271
.map((language) => {
265272
const langClients = clients[language] ?? [];
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
async function query(data) {
2+
const response = await fetch(
3+
"{{ fullUrl }}",
4+
{
5+
headers: {
6+
Authorization: "{{ authorizationHeader }}",
7+
"Content-Type": "application/json",
8+
{% if billTo %}
9+
"X-HF-Bill-To": "{{ billTo }}",
10+
{% endif %} },
11+
method: "POST",
12+
body: JSON.stringify(data),
13+
}
14+
);
15+
const result = await response.json();
16+
return result;
17+
}
18+
19+
query({
20+
{{ autoInputs.asTsString }}
21+
}).then((response) => {
22+
console.log(JSON.stringify(response));
23+
});

packages/tasks-gen/scripts/generate-snippets-fixtures.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ const TEST_CASES: {
5656
tags: ["conversational"],
5757
inference: "",
5858
},
59-
providers: ["hf-inference", "together"],
59+
providers: ["hf-inference", "together", "auto"],
6060
opts: { streaming: false },
6161
},
6262
{
@@ -68,7 +68,7 @@ const TEST_CASES: {
6868
tags: ["conversational"],
6969
inference: "",
7070
},
71-
providers: ["hf-inference", "together"],
71+
providers: ["hf-inference", "together", "auto"],
7272
opts: { streaming: true },
7373
},
7474
{
@@ -80,7 +80,7 @@ const TEST_CASES: {
8080
tags: ["conversational"],
8181
inference: "",
8282
},
83-
providers: ["hf-inference", "fireworks-ai"],
83+
providers: ["hf-inference", "fireworks-ai", "auto"],
8484
opts: { streaming: false },
8585
},
8686
{
@@ -92,7 +92,7 @@ const TEST_CASES: {
9292
tags: ["conversational"],
9393
inference: "",
9494
},
95-
providers: ["hf-inference", "fireworks-ai"],
95+
providers: ["hf-inference", "fireworks-ai", "auto"],
9696
opts: { streaming: true },
9797
},
9898
{
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
async function query(data) {
2+
const response = await fetch(
3+
"https://router.huggingface.co/v1/chat/completions",
4+
{
5+
headers: {
6+
Authorization: `Bearer ${process.env.HF_TOKEN}`,
7+
"Content-Type": "application/json",
8+
"X-HF-Bill-To": "huggingface",
9+
},
10+
method: "POST",
11+
body: JSON.stringify(data),
12+
}
13+
);
14+
const result = await response.json();
15+
return result;
16+
}
17+
18+
query({
19+
messages: [
20+
{
21+
role: "user",
22+
content: "What is the capital of France?",
23+
},
24+
],
25+
model: "meta-llama/Llama-3.1-8B-Instruct:hf-inference",
26+
}).then((response) => {
27+
console.log(JSON.stringify(response));
28+
});
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
async function query(data) {
2+
const response = await fetch(
3+
"http://localhost:8080/v1/chat/completions",
4+
{
5+
headers: {
6+
Authorization: `Bearer ${process.env.API_TOKEN}`,
7+
"Content-Type": "application/json",
8+
},
9+
method: "POST",
10+
body: JSON.stringify(data),
11+
}
12+
);
13+
const result = await response.json();
14+
return result;
15+
}
16+
17+
query({
18+
messages: [
19+
{
20+
role: "user",
21+
content: "What is the capital of France?",
22+
},
23+
],
24+
model: "meta-llama/Llama-3.1-8B-Instruct",
25+
}).then((response) => {
26+
console.log(JSON.stringify(response));
27+
});
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
async function query(data) {
2+
const response = await fetch(
3+
"https://router.huggingface.co/v1/chat/completions",
4+
{
5+
headers: {
6+
Authorization: `Bearer ${process.env.HF_TOKEN}`,
7+
"Content-Type": "application/json",
8+
},
9+
method: "POST",
10+
body: JSON.stringify(data),
11+
}
12+
);
13+
const result = await response.json();
14+
return result;
15+
}
16+
17+
query({
18+
messages: [
19+
{
20+
role: "user",
21+
content: "What is the capital of France?",
22+
},
23+
],
24+
model: "meta-llama/Llama-3.1-8B-Instruct",
25+
}).then((response) => {
26+
console.log(JSON.stringify(response));
27+
});
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
async function query(data) {
2+
const response = await fetch(
3+
"https://router.huggingface.co/v1/chat/completions",
4+
{
5+
headers: {
6+
Authorization: `Bearer ${process.env.HF_TOKEN}`,
7+
"Content-Type": "application/json",
8+
},
9+
method: "POST",
10+
body: JSON.stringify(data),
11+
}
12+
);
13+
const result = await response.json();
14+
return result;
15+
}
16+
17+
query({
18+
messages: [
19+
{
20+
role: "user",
21+
content: "What is the capital of France?",
22+
},
23+
],
24+
model: "meta-llama/Llama-3.1-8B-Instruct:hf-inference",
25+
}).then((response) => {
26+
console.log(JSON.stringify(response));
27+
});
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
async function query(data) {
2+
const response = await fetch(
3+
"https://router.huggingface.co/v1/chat/completions",
4+
{
5+
headers: {
6+
Authorization: `Bearer ${process.env.HF_TOKEN}`,
7+
"Content-Type": "application/json",
8+
},
9+
method: "POST",
10+
body: JSON.stringify(data),
11+
}
12+
);
13+
const result = await response.json();
14+
return result;
15+
}
16+
17+
query({
18+
messages: [
19+
{
20+
role: "user",
21+
content: "What is the capital of France?",
22+
},
23+
],
24+
model: "meta-llama/Llama-3.1-8B-Instruct:together",
25+
}).then((response) => {
26+
console.log(JSON.stringify(response));
27+
});
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import { InferenceClient } from "@huggingface/inference";
2+
3+
const client = new InferenceClient(process.env.HF_TOKEN);
4+
5+
const chatCompletion = await client.chatCompletion({
6+
provider: "auto",
7+
model: "meta-llama/Llama-3.1-8B-Instruct",
8+
messages: [
9+
{
10+
role: "user",
11+
content: "What is the capital of France?",
12+
},
13+
],
14+
});
15+
16+
console.log(chatCompletion.choices[0].message);
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import { OpenAI } from "openai";
2+
3+
const client = new OpenAI({
4+
baseURL: "https://router.huggingface.co/v1",
5+
apiKey: process.env.HF_TOKEN,
6+
});
7+
8+
const chatCompletion = await client.chat.completions.create({
9+
model: "meta-llama/Llama-3.1-8B-Instruct",
10+
messages: [
11+
{
12+
role: "user",
13+
content: "What is the capital of France?",
14+
},
15+
],
16+
});
17+
18+
console.log(chatCompletion.choices[0].message);

0 commit comments

Comments
 (0)