Skip to content

Commit a1afcb6

Browse files
galen-ftnsarrazin
andauthored
Add support for passing an API key or any other custom token in the authorization header (#579)
* Add support for passing an API key or any other custom token in the authorization header * Make linter happy * Fix README as per linter suggestions * Refactor endpoints to actually parse zod config * Remove top level env var and simplify header addition * Skip section on API key or other, remove obsolete comment in endpointTgi.ts and remote CUSTOM_AUTHORIZATION_TOKEN from .env --------- Co-authored-by: Nathan Sarrazin <sarrazin.nathan@gmail.com>
1 parent 2da78f5 commit a1afcb6

File tree

6 files changed

+35
-30
lines changed

6 files changed

+35
-30
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,8 @@ You can then add the generated information and the `authorization` parameter to
397397
]
398398
```
399399

400+
Please note that if `HF_ACCESS_TOKEN` is also set or not empty, it will take precedence.
401+
400402
#### Models hosted on multiple custom endpoints
401403

402404
If the model being hosted will be available on multiple servers/instances add the `weight` parameter to your `.env.local`. The `weight` will be used to determine the probability of requesting a particular endpoint.

src/lib/server/endpoints/aws/endpointAws.ts

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,19 @@ export const endpointAwsParametersSchema = z.object({
1515
region: z.string().optional(),
1616
});
1717

18-
export async function endpointAws({
19-
url,
20-
accessKey,
21-
secretKey,
22-
sessionToken,
23-
model,
24-
region,
25-
service,
26-
}: z.infer<typeof endpointAwsParametersSchema>): Promise<Endpoint> {
18+
export async function endpointAws(
19+
input: z.input<typeof endpointAwsParametersSchema>
20+
): Promise<Endpoint> {
2721
let AwsClient;
2822
try {
2923
AwsClient = (await import("aws4fetch")).AwsClient;
3024
} catch (e) {
3125
throw new Error("Failed to import aws4fetch");
3226
}
3327

28+
const { url, accessKey, secretKey, sessionToken, model, region, service } =
29+
endpointAwsParametersSchema.parse(input);
30+
3431
const aws = new AwsClient({
3532
accessKeyId: accessKey,
3633
secretAccessKey: secretKey,

src/lib/server/endpoints/llamacpp/endpointLlamacpp.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ export const endpointLlamacppParametersSchema = z.object({
1212
accessToken: z.string().min(1).default(HF_ACCESS_TOKEN),
1313
});
1414

15-
export function endpointLlamacpp({
16-
url,
17-
model,
18-
}: z.infer<typeof endpointLlamacppParametersSchema>): Endpoint {
15+
export function endpointLlamacpp(
16+
input: z.input<typeof endpointLlamacppParametersSchema>
17+
): Endpoint {
18+
const { url, model } = endpointLlamacppParametersSchema.parse(input);
1919
return async ({ conversation }) => {
2020
const prompt = await buildPrompt({
2121
messages: conversation.messages,

src/lib/server/endpoints/ollama/endpointOllama.ts

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,9 @@ export const endpointOllamaParametersSchema = z.object({
1111
ollamaName: z.string().min(1).optional(),
1212
});
1313

14-
export function endpointOllama({
15-
url,
16-
model,
17-
ollamaName,
18-
}: z.infer<typeof endpointOllamaParametersSchema>): Endpoint {
14+
export function endpointOllama(input: z.input<typeof endpointOllamaParametersSchema>): Endpoint {
15+
const { url, model, ollamaName } = endpointOllamaParametersSchema.parse(input);
16+
1917
return async ({ conversation }) => {
2018
const prompt = await buildPrompt({
2119
messages: conversation.messages,

src/lib/server/endpoints/openai/endpointOai.ts

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,10 @@ export const endpointOAIParametersSchema = z.object({
1616
.default("chat_completions"),
1717
});
1818

19-
export async function endpointOai({
20-
baseURL,
21-
apiKey,
22-
completion,
23-
model,
24-
}: z.infer<typeof endpointOAIParametersSchema>): Promise<Endpoint> {
19+
export async function endpointOai(
20+
input: z.input<typeof endpointOAIParametersSchema>
21+
): Promise<Endpoint> {
22+
const { baseURL, apiKey, completion, model } = endpointOAIParametersSchema.parse(input);
2523
let OpenAI;
2624
try {
2725
OpenAI = (await import("openai")).OpenAI;

src/lib/server/endpoints/tgi/endpointTgi.ts

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,11 @@ export const endpointTgiParametersSchema = z.object({
1010
type: z.literal("tgi"),
1111
url: z.string().url(),
1212
accessToken: z.string().default(HF_ACCESS_TOKEN),
13+
authorization: z.string().optional(),
1314
});
1415

15-
export function endpointTgi({
16-
url,
17-
accessToken,
18-
model,
19-
}: z.infer<typeof endpointTgiParametersSchema>): Endpoint {
16+
export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>): Endpoint {
17+
const { url, accessToken, model, authorization } = endpointTgiParametersSchema.parse(input);
2018
return async ({ conversation }) => {
2119
const prompt = await buildPrompt({
2220
messages: conversation.messages,
@@ -33,7 +31,19 @@ export function endpointTgi({
3331
inputs: prompt,
3432
accessToken,
3533
},
36-
{ use_cache: false }
34+
{
35+
use_cache: false,
36+
fetch: async (endpointUrl, info) => {
37+
if (info && authorization && !accessToken) {
38+
// Set authorization header if it is defined and HF_ACCESS_TOKEN is empty
39+
info.headers = {
40+
...info.headers,
41+
Authorization: authorization,
42+
};
43+
}
44+
return fetch(endpointUrl, info);
45+
},
46+
}
3747
);
3848
};
3949
}

0 commit comments

Comments
 (0)