Skip to content

Commit f9fd4b2

Browse files
jonstrutz11zacps
authored andcommitted
Add support for individual model permissions
1 parent 6dd3ae0 commit f9fd4b2

File tree

18 files changed

+499
-119
lines changed

18 files changed

+499
-119
lines changed

README.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,49 @@ We currently support [IDEFICS](https://huggingface.co/blog/idefics) (hosted on T
337337
}
338338
```
339339

340+
#### Group-based Model Permissions
341+
342+
If [logging in with OpenID](#openid-connect) via a supported provider, then user groups can be used in combination with the `allowed_groups` field for each model to show/hide models to users based on their group membership.
343+
344+
For all providers, see the following. Then, see additional instructions for your provider below.
345+
346+
1. Add `PROVIDER: "<provider-name-here>"` to your `.env.local` (you will enter the actual provider name later). Also, add `groups` to the `OPENID_CONFIG.SCOPES` field in your `.env.local` file:
347+
```env
348+
OPENID_CONFIG=`{
349+
// rest of OPENID_CONFIG here
350+
PROVIDER: "<provider-name-here>",
351+
SCOPES: "openid profile groups",
352+
// rest of OPENID_CONFIG here
353+
}`
354+
```
355+
356+
2. Use the `allowed_groups` parameter for each model to specify which group(s) should have access to that model. If not specified, all users will be able to access the model.
357+
358+
> [!WARNING]
359+
> The first model in your `.env.local` file is considered the "default" model and should be available to all users, so we strongly recommend against setting `allowed_groups` for this model.
360+
361+
> Note that during development, it is common to have `APP_BASE=""` in your `.env.local` - however, due to the cookies created by using a provider, this value should not be empty (e.g. setting `APP_BASE="/"` in `.env.local` would work).
362+
363+
#### Provider: Microsoft Entra
364+
365+
In order to enable use of [Microsoft Entra Security Groups](https://learn.microsoft.com/en-us/entra/fundamentals/concept-learn-about-groups) to show/hide models, do the following:
366+
367+
1. Replace `<provider-name-here>` with `entra` in `.env.local`.
368+
369+
2. `allowed_groups` for each model in `.env.local` should be a list of Microsoft Entra **Group IDs** (not group names), e.g.:
370+
371+
```env
372+
{
373+
// rest of the model config here
374+
"allowed_groups": ["123abcde-1234-abcd-cdef-1234567890ab", "abcde123-abcd-1234-cdef-abcdef123456"]
375+
}
376+
```
377+
378+
3. Finally, configure your app in Microsoft Entra so that the app can access user groups via the MS Graph API:
379+
- [Add groups claim](https://learn.microsoft.com/en-gb/entra/identity-platform/optional-claims?tabs=appui#configure-groups-optional-claims) to your app
380+
- [Enable ID Tokens](https://learn.microsoft.com/en-us/entra/identity-platform/v2-protocols-oidc#enable-id-tokens) for your app
381+
382+
340383
#### Running your own models using a custom endpoint
341384

342385
If you want to, instead of hitting models on the Hugging Face Inference API, you can run your own models locally.

src/app.d.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ declare global {
1010
// interface Error {}
1111
interface Locals {
1212
sessionId: string;
13-
user?: User & { logoutDisabled?: boolean };
13+
user?: User & { logoutDisabled?: boolean; groups?: string[] };
1414
}
1515

1616
interface Error {

src/hooks.server.ts

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import { sha256 } from "$lib/utils/sha256";
99
import { addWeeks } from "date-fns";
1010
import { checkAndRunMigrations } from "$lib/migrations/migrations";
1111
import { building } from "$app/environment";
12+
import { logout, OIDConfig, ProviderCookieNames } from "$lib/server/auth";
13+
import { type AccessToken, providers } from "$lib/server/providers/providers";
1214
import { logger } from "$lib/server/logger";
1315
import { AbortedGenerations } from "$lib/server/abortedGenerations";
1416
import { MetricsServer } from "$lib/server/metrics";
@@ -229,7 +231,11 @@ export const handle: Handle = async ({ event, resolve }) => {
229231
...(envPublic.PUBLIC_ORIGIN ? [new URL(envPublic.PUBLIC_ORIGIN).host] : []),
230232
];
231233

232-
if (!validOrigins.includes(new URL(origin).host)) {
234+
// origin is null for some reason when the POST request callback comes from an auth provider like MS entra so we skip this check (CSRF token is still validated)
235+
if (
236+
event.url.pathname !== `${base}/login/callback` &&
237+
!validOrigins.includes(new URL(origin).host)
238+
) {
233239
return errorResponse(403, "Invalid referer for POST request");
234240
}
235241
}
@@ -278,6 +284,55 @@ export const handle: Handle = async ({ event, resolve }) => {
278284
}
279285
}
280286

287+
// Get user groups for allowed models
288+
if (OIDConfig.PROVIDER && OIDConfig.SCOPES.includes("groups")) {
289+
const provider = providers[OIDConfig.PROVIDER];
290+
const session_exists = event.cookies.get(env.COOKIE_NAME) !== undefined;
291+
292+
let accessToken: AccessToken = JSON.parse(
293+
event.cookies.get(ProviderCookieNames.ACCESS_TOKEN)?.toString() || "{}"
294+
);
295+
let providerParameters = JSON.parse(
296+
event.cookies.get(ProviderCookieNames.PROVIDER_PARAMS)?.toString() || "{}"
297+
);
298+
299+
// If user is logged in, get/refresh access token and use it to retrieve user groups
300+
if (event.locals.user) {
301+
// Get access token upon login with id token
302+
if (accessToken && providerParameters.idToken) {
303+
[accessToken, providerParameters] = await provider.getAccessToken(
304+
event.cookies,
305+
providerParameters
306+
);
307+
event.locals.user.groups = await provider.getUserGroups(accessToken, providerParameters);
308+
}
309+
// Refresh access token on subsequent requests
310+
else if (accessToken.refreshToken && providerParameters.userTid) {
311+
accessToken = await provider.refreshAccessToken(
312+
event.cookies,
313+
accessToken,
314+
providerParameters
315+
);
316+
event.locals.user.groups = await provider.getUserGroups(accessToken, providerParameters);
317+
}
318+
// Logout user automatically if session exists but access token and/or provider params cookies have expired
319+
else if (session_exists) {
320+
event.locals.user.groups = undefined;
321+
await logout(event.cookies, event.locals);
322+
}
323+
}
324+
} else if (OIDConfig.SCOPES.includes("groups")) {
325+
return errorResponse(
326+
500,
327+
"'groups' has been set in OPENID_CONFIG.SCOPES, but OPENID_CONFIG.PROVIDER is undefined in .env file"
328+
);
329+
} else if (OIDConfig.PROVIDER) {
330+
return errorResponse(
331+
500,
332+
"OPENID_CONFIG.PROVIDER has been set, but 'groups' scope not set in OPENID_CONFIG.SCOPES in .env file"
333+
);
334+
}
335+
281336
let replaced = false;
282337

283338
const response = await resolve(event, {

src/lib/components/AssistantSettings.svelte

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@
265265
class="w-full rounded-lg border-2 border-gray-200 bg-gray-100 p-2"
266266
bind:value={modelId}
267267
>
268-
{#each models.filter((model) => !model.unlisted) as model}
268+
{#each models as model}
269269
<option value={model.id}>{model.displayName}</option>
270270
{/each}
271271
<p class="text-xs text-red-500">{getError("modelId", form)}</p>

src/lib/components/NavMenu.svelte

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
older: "Older",
4444
} as const;
4545
46-
const nModels: number = $page.data.models.filter((el: Model) => !el.unlisted).length;
46+
const nModels: number = $page.data.models.length;
4747
</script>
4848

4949
<div class="sticky top-0 flex flex-none items-center justify-between px-1.5 py-3.5 max-sm:pt-0">

src/lib/server/auth.ts

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,23 @@ import {
55
type TokenSet,
66
custom,
77
} from "openid-client";
8+
import { redirect } from "@sveltejs/kit";
89
import { addHours, addWeeks } from "date-fns";
910
import { env } from "$env/dynamic/private";
1011
import { sha256 } from "$lib/utils/sha256";
1112
import { z } from "zod";
1213
import { dev } from "$app/environment";
14+
import { base } from "$app/paths";
1315
import type { Cookies } from "@sveltejs/kit";
1416
import { collections } from "$lib/server/database";
1517
import JSON5 from "json5";
1618
import { logger } from "$lib/server/logger";
1719

1820
export interface OIDCSettings {
1921
redirectURI: string;
22+
response_type?: string;
23+
response_mode?: string | undefined;
24+
nonce?: string | undefined;
2025
}
2126

2227
export interface OIDCUserInfo {
@@ -34,6 +39,7 @@ export const OIDConfig = z
3439
.object({
3540
CLIENT_ID: stringWithDefault(env.OPENID_CLIENT_ID),
3641
CLIENT_SECRET: stringWithDefault(env.OPENID_CLIENT_SECRET),
42+
PROVIDER: stringWithDefault(env.OPENID_PROVIDER || ""),
3743
PROVIDER_URL: stringWithDefault(env.OPENID_PROVIDER_URL),
3844
SCOPES: stringWithDefault(env.OPENID_SCOPES),
3945
NAME_CLAIM: stringWithDefault(env.OPENID_NAME_CLAIM).refine(
@@ -46,8 +52,15 @@ export const OIDConfig = z
4652
})
4753
.parse(JSON5.parse(env.OPENID_CONFIG || "{}"));
4854

55+
export const ProviderCookieNames = {
56+
ACCESS_TOKEN: OIDConfig.PROVIDER !== "" ? OIDConfig.PROVIDER + "-access-token" : "",
57+
PROVIDER_PARAMS: OIDConfig.PROVIDER !== "" ? OIDConfig.PROVIDER + "-params" : "",
58+
};
59+
4960
export const requiresUser = !!OIDConfig.CLIENT_ID && !!OIDConfig.CLIENT_SECRET;
5061

62+
export const responseType = OIDConfig.SCOPES.includes("groups") ? "code id_token" : "code";
63+
5164
const sameSite = z
5265
.enum(["lax", "none", "strict"])
5366
.default(dev || env.ALLOW_INSECURE_COOKIES === "true" ? "lax" : "none")
@@ -108,7 +121,7 @@ async function getOIDCClient(settings: OIDCSettings): Promise<BaseClient> {
108121
client_id: OIDConfig.CLIENT_ID,
109122
client_secret: OIDConfig.CLIENT_SECRET,
110123
redirect_uris: [settings.redirectURI],
111-
response_types: ["code"],
124+
response_types: ["code", "id_token"],
112125
[custom.clock_tolerance]: OIDConfig.TOLERANCE || undefined,
113126
id_token_signed_response_alg: OIDConfig.ID_TOKEN_SIGNED_RESPONSE_ALG || undefined,
114127
};
@@ -131,8 +144,13 @@ export async function getOIDCAuthorizationUrl(
131144

132145
return client.authorizationUrl({
133146
scope: OIDConfig.SCOPES,
134-
state: csrfToken,
147+
state: Buffer.from(JSON.stringify({ csrfToken, sessionId: params.sessionId })).toString(
148+
"base64"
149+
),
135150
resource: OIDConfig.RESOURCE || undefined,
151+
response_type: settings.response_type,
152+
response_mode: settings.response_mode,
153+
nonce: settings.nonce,
136154
});
137155
}
138156

@@ -142,7 +160,11 @@ export async function getOIDCUserData(
142160
iss?: string
143161
): Promise<OIDCUserInfo> {
144162
const client = await getOIDCClient(settings);
145-
const token = await client.callback(settings.redirectURI, { code, iss });
163+
const token = await client.callback(
164+
settings.redirectURI,
165+
{ code, iss },
166+
{ nonce: settings.nonce }
167+
);
146168
const userData = await client.userinfo(token);
147169

148170
return { token, userData };
@@ -175,3 +197,26 @@ export async function validateAndParseCsrfToken(
175197
}
176198
return null;
177199
}
200+
201+
export async function logout(cookies: Cookies, locals: App.Locals) {
202+
await collections.sessions.deleteOne({ sessionId: locals.sessionId });
203+
204+
const cookie_names = [env.COOKIE_NAME];
205+
if (ProviderCookieNames.ACCESS_TOKEN) {
206+
cookie_names.push(ProviderCookieNames.ACCESS_TOKEN);
207+
}
208+
if (ProviderCookieNames.PROVIDER_PARAMS) {
209+
cookie_names.push(ProviderCookieNames.PROVIDER_PARAMS);
210+
}
211+
212+
for (const cookie_name of cookie_names) {
213+
cookies.delete(cookie_name, {
214+
path: env.APP_BASE,
215+
// So that it works inside the space's iframe
216+
sameSite: dev || env.ALLOW_INSECURE_COOKIES === "true" ? "lax" : "none",
217+
secure: !dev && !(env.ALLOW_INSECURE_COOKIES === "true"),
218+
httpOnly: true,
219+
});
220+
}
221+
redirect(303, `${base}/`);
222+
}

src/lib/server/models.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ const modelConfig = z.object({
8181
multimodal: z.boolean().default(false),
8282
multimodalAcceptedMimetypes: z.array(z.string()).optional(),
8383
tools: z.boolean().default(false),
84+
allowed_groups: z.array(z.string()).optional(),
8485
unlisted: z.boolean().default(false),
8586
embeddingModel: validateEmbeddingModelByName(embeddingModels).optional(),
8687
/** Used to enable/disable system prompt usage */

0 commit comments

Comments
 (0)