Skip to content

Commit c259899

Browse files
feat(ui): support for FLUX Redux in canvas
User facing: When a FLUX main model is selected, users may now add Regional Reference Image layers. When switching between FLUX Redux and FLUX IP Adapter, the settings will change to match the model type. (IP Adapter has weight, begin/end step, but Redux does not.) The image will be retained when switching between the two. Otherwise it works the same way as IP Adapter - both in Global and Regional Reference Image layers. --- Internal state handling: Slightly awkward, but it was easiest to make FLUX Redux a second type of IP Adapter in redux state. Global and regional reference images still have a single `ipAdapter` field, but it can have a type of `ip_adapter` or `flux_redux`. Ideally, this field is called `config` or `settings` or something, but we are past that point. We _could_ do a migration to rename it, but I don't think it's worth the effort. --- Other changes: - Updated canvas layer validators to handle FLUX Redux. - Updated model list loading logic to un-set FLUX Redux models in Canvas if they are not in the list (e.g. if the user deletes the model in the main app). - Updated graph builders - new `addFLUXRedux` util & updated `addRegions` util. - Updated the `buildModelsHook` util to return a hook that accepts a filter callback. This handles a discrepancy: FLUX IP Adapter does not support regional guidance, but FLUX Redux does. The Regional Guidance settings provide the filter to filter out FLUX IP Adapter models from the combined list of IP Adapter ahd Redux models.
1 parent f62b9ad commit c259899

File tree

20 files changed

+494
-189
lines changed

20 files changed

+494
-189
lines changed

invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import type { AnyModelConfig } from 'services/api/types';
3131
import {
3232
isCLIPEmbedModelConfig,
3333
isControlLayerModelConfig,
34+
isFluxReduxModelConfig,
3435
isFluxVAEModelConfig,
3536
isIPAdapterModelConfig,
3637
isLoRAModelConfig,
@@ -77,6 +78,7 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
7778
handleT5EncoderModels(models, state, dispatch, log);
7879
handleCLIPEmbedModels(models, state, dispatch, log);
7980
handleFLUXVAEModels(models, state, dispatch, log);
81+
handleFLUXReduxModels(models, state, dispatch, log);
8082
},
8183
});
8284
};
@@ -209,6 +211,10 @@ const handleControlAdapterModels: ModelHandler = (models, state, dispatch, log)
209211
const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
210212
const ipaModels = models.filter(isIPAdapterModelConfig);
211213
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
214+
if (entity.ipAdapter.type !== 'ip_adapter') {
215+
return;
216+
}
217+
212218
const selectedIPAdapterModel = entity.ipAdapter.model;
213219
// `null` is a valid IP adapter model - no need to do anything.
214220
if (!selectedIPAdapterModel) {
@@ -224,6 +230,10 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
224230

225231
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
226232
entity.referenceImages.forEach(({ id: referenceImageId, ipAdapter }) => {
233+
if (ipAdapter.type !== 'ip_adapter') {
234+
return;
235+
}
236+
227237
const selectedIPAdapterModel = ipAdapter.model;
228238
// `null` is a valid IP adapter model - no need to do anything.
229239
if (!selectedIPAdapterModel) {
@@ -241,6 +251,49 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
241251
});
242252
};
243253

254+
const handleFLUXReduxModels: ModelHandler = (models, state, dispatch, log) => {
255+
const fluxReduxModels = models.filter(isFluxReduxModelConfig);
256+
257+
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
258+
if (entity.ipAdapter.type !== 'flux_redux') {
259+
return;
260+
}
261+
const selectedFLUXReduxModel = entity.ipAdapter.model;
262+
// `null` is a valid FLUX Redux model - no need to do anything.
263+
if (!selectedFLUXReduxModel) {
264+
return;
265+
}
266+
const isModelAvailable = fluxReduxModels.some((m) => m.key === selectedFLUXReduxModel.key);
267+
if (isModelAvailable) {
268+
return;
269+
}
270+
log.debug({ selectedFLUXReduxModel }, 'Selected FLUX Redux model is not available, clearing');
271+
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), modelConfig: null }));
272+
});
273+
274+
selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
275+
entity.referenceImages.forEach(({ id: referenceImageId, ipAdapter }) => {
276+
if (ipAdapter.type !== 'flux_redux') {
277+
return;
278+
}
279+
280+
const selectedFLUXReduxModel = ipAdapter.model;
281+
// `null` is a valid FLUX Redux model - no need to do anything.
282+
if (!selectedFLUXReduxModel) {
283+
return;
284+
}
285+
const isModelAvailable = fluxReduxModels.some((m) => m.key === selectedFLUXReduxModel.key);
286+
if (isModelAvailable) {
287+
return;
288+
}
289+
log.debug({ selectedFLUXReduxModel }, 'Selected FLUX Redux model is not available, clearing');
290+
dispatch(
291+
rgIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), referenceImageId, modelConfig: null })
292+
);
293+
});
294+
});
295+
};
296+
244297
const handlePostProcessingModel: ModelHandler = (models, state, dispatch, log) => {
245298
const selectedPostProcessingModel = state.upscale.postProcessingModel;
246299
const allSpandrelModels = models.filter(isSpandrelImageToImageModelConfig);

invokeai/frontend/web/src/features/controlLayers/components/CanvasAddEntityButtons.tsx

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import {
99
useAddRegionalGuidance,
1010
useAddRegionalReferenceImage,
1111
} from 'features/controlLayers/hooks/addLayerHooks';
12-
import { selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
12+
import { selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
1313
import { memo } from 'react';
1414
import { useTranslation } from 'react-i18next';
1515
import { PiPlusBold } from 'react-icons/pi';
@@ -22,7 +22,6 @@ export const CanvasAddEntityButtons = memo(() => {
2222
const addControlLayer = useAddControlLayer();
2323
const addGlobalReferenceImage = useAddGlobalReferenceImage();
2424
const addRegionalReferenceImage = useAddRegionalReferenceImage();
25-
const isFLUX = useAppSelector(selectIsFLUX);
2625
const isSD3 = useAppSelector(selectIsSD3);
2726

2827
return (
@@ -75,7 +74,7 @@ export const CanvasAddEntityButtons = memo(() => {
7574
justifyContent="flex-start"
7675
leftIcon={<PiPlusBold />}
7776
onClick={addRegionalReferenceImage}
78-
isDisabled={isFLUX || isSD3}
77+
isDisabled={isSD3}
7978
>
8079
{t('controlLayers.regionalReferenceImage')}
8180
</Button>

invokeai/frontend/web/src/features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarAddLayerMenu.tsx

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import {
99
useAddRegionalReferenceImage,
1010
} from 'features/controlLayers/hooks/addLayerHooks';
1111
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
12-
import { selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
12+
import { selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
1313
import { memo } from 'react';
1414
import { useTranslation } from 'react-i18next';
1515
import { PiPlusBold } from 'react-icons/pi';
@@ -23,7 +23,6 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
2323
const addRegionalReferenceImage = useAddRegionalReferenceImage();
2424
const addRasterLayer = useAddRasterLayer();
2525
const addControlLayer = useAddControlLayer();
26-
const isFLUX = useAppSelector(selectIsFLUX);
2726
const isSD3 = useAppSelector(selectIsSD3);
2827

2928
return (
@@ -52,7 +51,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
5251
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isSD3}>
5352
{t('controlLayers.regionalGuidance')}
5453
</MenuItem>
55-
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isFLUX || isSD3}>
54+
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isSD3}>
5655
{t('controlLayers.regionalReferenceImage')}
5756
</MenuItem>
5857
</MenuGroup>
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
2+
import { Combobox, FormControl } from '@invoke-ai/ui-library';
3+
import { useAppSelector } from 'app/store/storeHooks';
4+
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
5+
import type { CLIPVisionModelV2 } from 'features/controlLayers/store/types';
6+
import { isCLIPVisionModelV2 } from 'features/controlLayers/store/types';
7+
import { memo, useCallback, useMemo } from 'react';
8+
import { useTranslation } from 'react-i18next';
9+
import { assert } from 'tsafe';
10+
11+
// at this time, ViT-L is the only supported clip model for FLUX IP adapter
12+
const FLUX_CLIP_VISION = 'ViT-L';
13+
14+
const CLIP_VISION_OPTIONS = [
15+
{ label: 'ViT-H', value: 'ViT-H' },
16+
{ label: 'ViT-G', value: 'ViT-G' },
17+
{ label: FLUX_CLIP_VISION, value: FLUX_CLIP_VISION },
18+
];
19+
20+
type Props = {
21+
model: CLIPVisionModelV2;
22+
onChange: (clipVisionModel: CLIPVisionModelV2) => void;
23+
};
24+
25+
export const CLIPVisionModel = memo(({ model, onChange }: Props) => {
26+
const { t } = useTranslation();
27+
28+
const _onChangeCLIPVisionModel = useCallback<ComboboxOnChange>(
29+
(v) => {
30+
assert(isCLIPVisionModelV2(v?.value));
31+
onChange(v.value);
32+
},
33+
[onChange]
34+
);
35+
36+
const isFLUX = useAppSelector(selectIsFLUX);
37+
38+
const clipVisionOptions = useMemo(() => {
39+
return CLIP_VISION_OPTIONS.map((option) => ({
40+
...option,
41+
isDisabled: isFLUX && option.value !== FLUX_CLIP_VISION,
42+
}));
43+
}, [isFLUX]);
44+
45+
const clipVisionModelValue = useMemo(() => {
46+
return CLIP_VISION_OPTIONS.find((o) => o.value === model);
47+
}, [model]);
48+
49+
return (
50+
<FormControl width="max-content" minWidth={28}>
51+
<Combobox
52+
options={clipVisionOptions}
53+
placeholder={t('common.placeholderSelectAModel')}
54+
value={clipVisionModelValue}
55+
onChange={_onChangeCLIPVisionModel}
56+
/>
57+
</FormControl>
58+
);
59+
});
60+
61+
CLIPVisionModel.displayName = 'CLIPVisionModel';

invokeai/frontend/web/src/features/controlLayers/components/IPAdapter/IPAdapterModel.tsx

Lines changed: 32 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,36 @@
1-
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
2-
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
1+
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
32
import { useAppSelector } from 'app/store/storeHooks';
43
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
5-
import { selectBase, selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
6-
import type { CLIPVisionModelV2 } from 'features/controlLayers/store/types';
7-
import { isCLIPVisionModelV2 } from 'features/controlLayers/store/types';
4+
import { selectBase } from 'features/controlLayers/store/paramsSlice';
85
import { memo, useCallback, useMemo } from 'react';
96
import { useTranslation } from 'react-i18next';
10-
import { useIPAdapterModels } from 'services/api/hooks/modelsByType';
11-
import type { AnyModelConfig, IPAdapterModelConfig } from 'services/api/types';
12-
import { assert } from 'tsafe';
13-
14-
// at this time, ViT-L is the only supported clip model for FLUX IP adapter
15-
const FLUX_CLIP_VISION = 'ViT-L';
16-
17-
const CLIP_VISION_OPTIONS = [
18-
{ label: 'ViT-H', value: 'ViT-H' },
19-
{ label: 'ViT-G', value: 'ViT-G' },
20-
{ label: FLUX_CLIP_VISION, value: FLUX_CLIP_VISION },
21-
];
7+
import { useIPAdapterOrFLUXReduxModels } from 'services/api/hooks/modelsByType';
8+
import type { AnyModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types';
229

2310
type Props = {
11+
isRegionalGuidance: boolean;
2412
modelKey: string | null;
25-
onChangeModel: (modelConfig: IPAdapterModelConfig) => void;
26-
clipVisionModel: CLIPVisionModelV2;
27-
onChangeCLIPVisionModel: (clipVisionModel: CLIPVisionModelV2) => void;
13+
onChangeModel: (modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => void;
2814
};
2915

30-
export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel, onChangeCLIPVisionModel }: Props) => {
16+
export const IPAdapterModel = memo(({ isRegionalGuidance, modelKey, onChangeModel }: Props) => {
3117
const { t } = useTranslation();
3218
const currentBaseModel = useAppSelector(selectBase);
33-
const [modelConfigs, { isLoading }] = useIPAdapterModels();
19+
const filter = useCallback(
20+
(config: IPAdapterModelConfig | FLUXReduxModelConfig) => {
21+
// FLUX supports regional guidance for FLUX Redux models only - not IP Adapter models.
22+
if (isRegionalGuidance && config.base === 'flux' && config.type === 'ip_adapter') {
23+
return false;
24+
}
25+
return true;
26+
},
27+
[isRegionalGuidance]
28+
);
29+
const [modelConfigs, { isLoading }] = useIPAdapterOrFLUXReduxModels(filter);
3430
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
3531

3632
const _onChangeModel = useCallback(
37-
(modelConfig: IPAdapterModelConfig | null) => {
33+
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | null) => {
3834
if (!modelConfig) {
3935
return;
4036
}
@@ -43,21 +39,11 @@ export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel,
4339
[onChangeModel]
4440
);
4541

46-
const _onChangeCLIPVisionModel = useCallback<ComboboxOnChange>(
47-
(v) => {
48-
assert(isCLIPVisionModelV2(v?.value));
49-
onChangeCLIPVisionModel(v.value);
50-
},
51-
[onChangeCLIPVisionModel]
52-
);
53-
54-
const isFLUX = useAppSelector(selectIsFLUX);
55-
5642
const getIsDisabled = useCallback(
5743
(model: AnyModelConfig): boolean => {
58-
const isCompatible = currentBaseModel === model.base;
5944
const hasMainModel = Boolean(currentBaseModel);
60-
return !hasMainModel || !isCompatible;
45+
const hasSameBase = currentBaseModel === model.base;
46+
return !hasMainModel || !hasSameBase;
6147
},
6248
[currentBaseModel]
6349
);
@@ -70,41 +56,18 @@ export const IPAdapterModel = memo(({ modelKey, onChangeModel, clipVisionModel,
7056
isLoading,
7157
});
7258

73-
const clipVisionOptions = useMemo(() => {
74-
return CLIP_VISION_OPTIONS.map((option) => ({
75-
...option,
76-
isDisabled: isFLUX && option.value !== FLUX_CLIP_VISION,
77-
}));
78-
}, [isFLUX]);
79-
80-
const clipVisionModelValue = useMemo(() => {
81-
return CLIP_VISION_OPTIONS.find((o) => o.value === clipVisionModel);
82-
}, [clipVisionModel]);
83-
8459
return (
85-
<Flex gap={2}>
86-
<Tooltip label={selectedModel?.description}>
87-
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full">
88-
<Combobox
89-
options={options}
90-
placeholder={t('common.placeholderSelectAModel')}
91-
value={value}
92-
onChange={onChange}
93-
noOptionsMessage={noOptionsMessage}
94-
/>
95-
</FormControl>
96-
</Tooltip>
97-
{selectedModel?.format === 'checkpoint' && (
98-
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} width="max-content" minWidth={28}>
99-
<Combobox
100-
options={clipVisionOptions}
101-
placeholder={t('common.placeholderSelectAModel')}
102-
value={clipVisionModelValue}
103-
onChange={_onChangeCLIPVisionModel}
104-
/>
105-
</FormControl>
106-
)}
107-
</Flex>
60+
<Tooltip label={selectedModel?.description}>
61+
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full">
62+
<Combobox
63+
options={options}
64+
placeholder={t('common.placeholderSelectAModel')}
65+
value={value}
66+
onChange={onChange}
67+
noOptionsMessage={noOptionsMessage}
68+
/>
69+
</FormControl>
70+
</Tooltip>
10871
);
10972
});
11073

0 commit comments

Comments
 (0)