Skip to content

Commit adf4cc7

Browse files
fix(ui): Fix LoRA picker to default to current base model architecture (#8135)
Enhance LoRA picker to default filter by current base model architecture ## Summary Fixes new LoRA picker to auto select the architecture filter for the current model group ## Related Issues / Discussions N/A ## QA Instructions Open LoRA menu with any model group selected. The right models should be filtered. ## Merge Plan Merge when ready. ## Checklist - [X] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
2 parents 1320a2c + 9f1ea9d commit adf4cc7

File tree

3 files changed

+35
-6
lines changed

3 files changed

+35
-6
lines changed

invokeai/frontend/web/src/common/components/Picker/Picker.tsx

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ type PickerProps<T extends object> = {
198198
* Whether the picker should be searchable. If true, renders a search input.
199199
*/
200200
searchable?: boolean;
201+
/**
202+
* Initial state for group toggles. If provided, groups will start with these states instead of all being disabled.
203+
*/
204+
initialGroupStates?: GroupStatusMap;
201205
};
202206

203207
export type PickerContextState<T extends object> = {
@@ -310,9 +314,9 @@ const flattenOptions = <T extends object>(options: OptionOrGroup<T>[]): T[] => {
310314
return flattened;
311315
};
312316

313-
type GroupStatusMap = Record<string, boolean>;
317+
export type GroupStatusMap = Record<string, boolean>;
314318

315-
const useTogglableGroups = <T extends object>(options: OptionOrGroup<T>[]) => {
319+
const useTogglableGroups = <T extends object>(options: OptionOrGroup<T>[], initialGroupStates?: GroupStatusMap) => {
316320
const groupsWithOptions = useMemo(() => {
317321
const ids: string[] = [];
318322
for (const optionOrGroup of options) {
@@ -332,14 +336,16 @@ const useTogglableGroups = <T extends object>(options: OptionOrGroup<T>[]) => {
332336
const groupStatusMap = $groupStatusMap.get();
333337
const newMap: GroupStatusMap = {};
334338
for (const id of groupsWithOptions) {
335-
if (newMap[id] === undefined) {
336-
newMap[id] = false;
339+
if (initialGroupStates && initialGroupStates[id] !== undefined) {
340+
newMap[id] = initialGroupStates[id];
337341
} else if (groupStatusMap[id] !== undefined) {
338342
newMap[id] = groupStatusMap[id];
343+
} else {
344+
newMap[id] = false;
339345
}
340346
}
341347
$groupStatusMap.set(newMap);
342-
}, [groupsWithOptions, $groupStatusMap]);
348+
}, [groupsWithOptions, $groupStatusMap, initialGroupStates]);
343349

344350
const toggleGroup = useCallback(
345351
(idToToggle: string) => {
@@ -511,10 +517,14 @@ export const Picker = typedMemo(<T extends object>(props: PickerProps<T>) => {
511517
OptionComponent = DefaultOptionComponent,
512518
NextToSearchBar,
513519
searchable,
520+
initialGroupStates,
514521
} = props;
515522
const rootRef = useRef<HTMLDivElement>(null);
516523
const inputRef = useRef<HTMLInputElement>(null);
517-
const { $groupStatusMap, $areAllGroupsDisabled, toggleGroup } = useTogglableGroups(optionsOrGroups);
524+
const { $groupStatusMap, $areAllGroupsDisabled, toggleGroup } = useTogglableGroups(
525+
optionsOrGroups,
526+
initialGroupStates
527+
);
518528
const $activeOptionId = useAtom(getFirstOptionId(optionsOrGroups, getOptionId));
519529
const $compactView = useAtom(true);
520530
const $optionsOrGroups = useAtom(optionsOrGroups);

invokeai/frontend/web/src/features/lora/components/LoRASelect.tsx

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ import { FormControl, FormLabel } from '@invoke-ai/ui-library';
22
import { createSelector } from '@reduxjs/toolkit';
33
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
44
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
5+
import type { GroupStatusMap } from 'common/components/Picker/Picker';
56
import { useRelatedGroupedModelCombobox } from 'common/hooks/useRelatedGroupedModelCombobox';
67
import { loraAdded, selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
78
import { selectBase } from 'features/controlLayers/store/paramsSlice';
89
import { ModelPicker } from 'features/parameters/components/ModelPicker';
10+
import { API_BASE_MODELS } from 'features/parameters/types/constants';
911
import { memo, useCallback, useMemo } from 'react';
1012
import { useTranslation } from 'react-i18next';
1113
import { useLoRAModels } from 'services/api/hooks/modelsByType';
@@ -58,6 +60,19 @@ const LoRASelect = () => {
5860
return t('models.addLora');
5961
}, [isLoading, options.length, t]);
6062

63+
// Calculate initial group states to default to the current base model architecture
64+
const initialGroupStates = useMemo(() => {
65+
if (!currentBaseModel) {
66+
return undefined;
67+
}
68+
69+
// Determine the group ID for the current base model
70+
const groupId = API_BASE_MODELS.includes(currentBaseModel) ? 'api' : currentBaseModel;
71+
72+
// Return a map with only the current base model group enabled
73+
return { [groupId]: true } satisfies GroupStatusMap;
74+
}, [currentBaseModel]);
75+
6176
return (
6277
<FormControl gap={2}>
6378
<InformationalPopover feature="lora">
@@ -72,6 +87,7 @@ const LoRASelect = () => {
7287
placeholder={placeholder}
7388
getIsOptionDisabled={getIsDisabled}
7489
noOptionsText={t('models.noLoRAsInstalled')}
90+
initialGroupStates={initialGroupStates}
7591
/>
7692
</FormControl>
7793
);

invokeai/frontend/web/src/features/parameters/components/ModelPicker.tsx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ export const ModelPicker = typedMemo(
125125
isInvalid,
126126
className,
127127
noOptionsText,
128+
initialGroupStates,
128129
}: {
129130
modelConfigs: T[];
130131
selectedModelConfig: T | undefined;
@@ -137,6 +138,7 @@ export const ModelPicker = typedMemo(
137138
isInvalid?: boolean;
138139
className?: string;
139140
noOptionsText?: string;
141+
initialGroupStates?: Record<string, boolean>;
140142
}) => {
141143
const { t } = useTranslation();
142144
const options = useMemo<T[] | Group<T>[]>(() => {
@@ -244,6 +246,7 @@ export const ModelPicker = typedMemo(
244246
NextToSearchBar={<NavigateToModelManagerButton />}
245247
getIsOptionDisabled={getIsOptionDisabled}
246248
searchable
249+
initialGroupStates={initialGroupStates}
247250
/>
248251
</PopoverBody>
249252
</PopoverContent>

0 commit comments

Comments
 (0)