Skip to content

Commit 56cd839

Browse files
feat(ui): support for ref images for chatgpt on canvas
1 parent 7b446ee commit 56cd839

File tree

14 files changed

+228
-104
lines changed

14 files changed

+228
-104
lines changed

invokeai/frontend/web/public/locales/en.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1322,7 +1322,8 @@
13221322
"unableToCopyDesc": "Your browser does not support clipboard access. Firefox users may be able to fix this by following ",
13231323
"unableToCopyDesc_theseSteps": "these steps",
13241324
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks.",
1325-
"imagen3IncompatibleGenerationMode": "Imagen3 only supports Text to Image. Use other models for Image to Image, Inpainting and Outpainting tasks.",
1325+
"imagen3IncompatibleGenerationMode": "Google Imagen3 supports Text to Image only. Ensure the bounding box is empty, or use other models for Image to Image, Inpainting and Outpainting tasks.",
1326+
"chatGPT4oIncompatibleGenerationMode": "ChatGPT 4o supports Text to Image only. Ensure the bounding box is empty, or use other models for Image to Image, Inpainting and Outpainting tasks.",
13261327
"problemUnpublishingWorkflow": "Problem Unpublishing Workflow",
13271328
"problemUnpublishingWorkflowDescription": "There was a problem unpublishing the workflow. Please try again.",
13281329
"workflowUnpublished": "Workflow Unpublished"
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
2+
import { useAppSelector } from 'app/store/storeHooks';
3+
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
4+
import { selectBase } from 'features/controlLayers/store/paramsSlice';
5+
import { NavigateToModelManagerButton } from 'features/parameters/components/MainModel/NavigateToModelManagerButton';
6+
import { memo, useCallback, useMemo } from 'react';
7+
import { useTranslation } from 'react-i18next';
8+
import { useGlobalReferenceImageModels } from 'services/api/hooks/modelsByType';
9+
import type { AnyModelConfig, ApiModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types';
10+
11+
type Props = {
12+
modelKey: string | null;
13+
onChangeModel: (modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig) => void;
14+
};
15+
16+
export const GlobalReferenceImageModel = memo(({ modelKey, onChangeModel }: Props) => {
17+
const { t } = useTranslation();
18+
const currentBaseModel = useAppSelector(selectBase);
19+
const [modelConfigs, { isLoading }] = useGlobalReferenceImageModels();
20+
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
21+
22+
const _onChangeModel = useCallback(
23+
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig | null) => {
24+
if (!modelConfig) {
25+
return;
26+
}
27+
onChangeModel(modelConfig);
28+
},
29+
[onChangeModel]
30+
);
31+
32+
const getIsDisabled = useCallback(
33+
(model: AnyModelConfig): boolean => {
34+
const hasMainModel = Boolean(currentBaseModel);
35+
const hasSameBase = currentBaseModel === model.base;
36+
return !hasMainModel || !hasSameBase;
37+
},
38+
[currentBaseModel]
39+
);
40+
41+
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
42+
modelConfigs,
43+
onChange: _onChangeModel,
44+
selectedModel,
45+
getIsDisabled,
46+
isLoading,
47+
});
48+
49+
return (
50+
<Tooltip label={selectedModel?.description}>
51+
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full">
52+
<Combobox
53+
options={options}
54+
placeholder={t('common.placeholderSelectAModel')}
55+
value={value}
56+
onChange={onChange}
57+
noOptionsMessage={noOptionsMessage}
58+
/>
59+
<NavigateToModelManagerButton />
60+
</FormControl>
61+
</Tooltip>
62+
);
63+
});
64+
65+
GlobalReferenceImageModel.displayName = 'GlobalReferenceImageModel';

invokeai/frontend/web/src/features/controlLayers/components/IPAdapter/IPAdapterSettings.tsx

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/c
66
import { Weight } from 'features/controlLayers/components/common/Weight';
77
import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLIPVisionModel';
88
import { FLUXReduxImageInfluence } from 'features/controlLayers/components/IPAdapter/FLUXReduxImageInfluence';
9+
import { GlobalReferenceImageModel } from 'features/controlLayers/components/IPAdapter/GlobalReferenceImageModel';
910
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
1011
import { IPAdapterSettingsEmptyState } from 'features/controlLayers/components/IPAdapter/IPAdapterSettingsEmptyState';
1112
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
@@ -33,10 +34,9 @@ import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
3334
import { memo, useCallback, useMemo } from 'react';
3435
import { useTranslation } from 'react-i18next';
3536
import { PiBoundingBoxBold } from 'react-icons/pi';
36-
import type { FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
37+
import type { ApiModelConfig, FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
3738

3839
import { IPAdapterImagePreview } from './IPAdapterImagePreview';
39-
import { IPAdapterModel } from './IPAdapterModel';
4040

4141
const buildSelectIPAdapter = (entityIdentifier: CanvasEntityIdentifier<'reference_image'>) =>
4242
createSelector(
@@ -80,7 +80,7 @@ const IPAdapterSettingsContent = memo(() => {
8080
);
8181

8282
const onChangeModel = useCallback(
83-
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => {
83+
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig) => {
8484
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier, modelConfig }));
8585
},
8686
[dispatch, entityIdentifier]
@@ -113,11 +113,7 @@ const IPAdapterSettingsContent = memo(() => {
113113
<CanvasEntitySettingsWrapper>
114114
<Flex flexDir="column" gap={2} position="relative" w="full">
115115
<Flex gap={2} alignItems="center" w="full">
116-
<IPAdapterModel
117-
isRegionalGuidance={false}
118-
modelKey={ipAdapter.model?.key ?? null}
119-
onChangeModel={onChangeModel}
120-
/>
116+
<GlobalReferenceImageModel modelKey={ipAdapter.model?.key ?? null} onChangeModel={onChangeModel} />
121117
{ipAdapter.type === 'ip_adapter' && (
122118
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
123119
)}

invokeai/frontend/web/src/features/controlLayers/components/IPAdapter/IPAdapterModel.tsx renamed to invokeai/frontend/web/src/features/controlLayers/components/IPAdapter/RegionalReferenceImageModel.tsx

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,26 @@ import { selectBase } from 'features/controlLayers/store/paramsSlice';
55
import { NavigateToModelManagerButton } from 'features/parameters/components/MainModel/NavigateToModelManagerButton';
66
import { memo, useCallback, useMemo } from 'react';
77
import { useTranslation } from 'react-i18next';
8-
import { useIPAdapterOrFLUXReduxModels } from 'services/api/hooks/modelsByType';
8+
import { useRegionalReferenceImageModels } from 'services/api/hooks/modelsByType';
99
import type { AnyModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types';
1010

1111
type Props = {
12-
isRegionalGuidance: boolean;
1312
modelKey: string | null;
1413
onChangeModel: (modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig) => void;
1514
};
1615

17-
export const IPAdapterModel = memo(({ isRegionalGuidance, modelKey, onChangeModel }: Props) => {
16+
const filter = (config: IPAdapterModelConfig | FLUXReduxModelConfig) => {
17+
// FLUX supports regional guidance for FLUX Redux models only - not IP Adapter models.
18+
if (config.base === 'flux' && config.type === 'ip_adapter') {
19+
return false;
20+
}
21+
return true;
22+
};
23+
24+
export const RegionalReferenceImageModel = memo(({ modelKey, onChangeModel }: Props) => {
1825
const { t } = useTranslation();
1926
const currentBaseModel = useAppSelector(selectBase);
20-
const filter = useCallback(
21-
(config: IPAdapterModelConfig | FLUXReduxModelConfig) => {
22-
// FLUX supports regional guidance for FLUX Redux models only - not IP Adapter models.
23-
if (isRegionalGuidance && config.base === 'flux' && config.type === 'ip_adapter') {
24-
return false;
25-
}
26-
return true;
27-
},
28-
[isRegionalGuidance]
29-
);
30-
const [modelConfigs, { isLoading }] = useIPAdapterOrFLUXReduxModels(filter);
27+
const [modelConfigs, { isLoading }] = useRegionalReferenceImageModels(filter);
3128
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
3229

3330
const _onChangeModel = useCallback(
@@ -73,4 +70,4 @@ export const IPAdapterModel = memo(({ isRegionalGuidance, modelKey, onChangeMode
7370
);
7471
});
7572

76-
IPAdapterModel.displayName = 'IPAdapterModel';
73+
RegionalReferenceImageModel.displayName = 'RegionalReferenceImageModel';

invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettings.tsx

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import { CLIPVisionModel } from 'features/controlLayers/components/IPAdapter/CLI
77
import { FLUXReduxImageInfluence } from 'features/controlLayers/components/IPAdapter/FLUXReduxImageInfluence';
88
import { IPAdapterImagePreview } from 'features/controlLayers/components/IPAdapter/IPAdapterImagePreview';
99
import { IPAdapterMethod } from 'features/controlLayers/components/IPAdapter/IPAdapterMethod';
10-
import { IPAdapterModel } from 'features/controlLayers/components/IPAdapter/IPAdapterModel';
10+
import { RegionalReferenceImageModel } from 'features/controlLayers/components/IPAdapter/RegionalReferenceImageModel';
1111
import { RegionalGuidanceIPAdapterSettingsEmptyState } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState';
1212
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
1313
import { usePullBboxIntoRegionalGuidanceReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
@@ -140,11 +140,7 @@ const RegionalGuidanceIPAdapterSettingsContent = memo(({ referenceImageId }: Pro
140140
</Flex>
141141
<Flex flexDir="column" gap={2} position="relative" w="full">
142142
<Flex gap={2} alignItems="center" w="full">
143-
<IPAdapterModel
144-
isRegionalGuidance={true}
145-
modelKey={ipAdapter.model?.key ?? null}
146-
onChangeModel={onChangeModel}
147-
/>
143+
<RegionalReferenceImageModel modelKey={ipAdapter.model?.key ?? null} onChangeModel={onChangeModel} />
148144
{ipAdapter.type === 'ip_adapter' && (
149145
<CLIPVisionModel model={ipAdapter.clipVisionModel} onChange={onChangeCLIPVisionModel} />
150146
)}

invokeai/frontend/web/src/features/controlLayers/hooks/addLayerHooks.ts

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,26 @@ import { selectBase } from 'features/controlLayers/store/paramsSlice';
1717
import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/selectors';
1818
import type {
1919
CanvasEntityIdentifier,
20+
CanvasReferenceImageState,
2021
CanvasRegionalGuidanceState,
2122
ControlLoRAConfig,
2223
ControlNetConfig,
2324
IPAdapterConfig,
2425
T2IAdapterConfig,
2526
} from 'features/controlLayers/store/types';
26-
import { initialControlNet, initialIPAdapter, initialT2IAdapter } from 'features/controlLayers/store/util';
27+
import {
28+
initialChatGPT4oReferenceImage,
29+
initialControlNet,
30+
initialIPAdapter,
31+
initialT2IAdapter,
32+
} from 'features/controlLayers/store/util';
2733
import { zModelIdentifierField } from 'features/nodes/types/common';
2834
import { useCallback } from 'react';
29-
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
35+
import {
36+
modelConfigsAdapterSelectors,
37+
selectMainModelConfig,
38+
selectModelConfigsQuery,
39+
} from 'services/api/endpoints/models';
3040
import type {
3141
ControlLoRAModelConfig,
3242
ControlNetModelConfig,
@@ -64,6 +74,35 @@ export const selectDefaultControlAdapter = createSelector(
6474
}
6575
);
6676

77+
const selectDefaultRefImageConfig = createSelector(
78+
selectMainModelConfig,
79+
selectModelConfigsQuery,
80+
selectBase,
81+
(selectedMainModel, query, base): CanvasReferenceImageState['ipAdapter'] => {
82+
if (selectedMainModel?.base === 'chatgpt-4o') {
83+
const referenceImage = deepClone(initialChatGPT4oReferenceImage);
84+
referenceImage.model = zModelIdentifierField.parse(selectedMainModel);
85+
return referenceImage;
86+
}
87+
88+
const { data } = query;
89+
let model: IPAdapterModelConfig | null = null;
90+
if (data) {
91+
const modelConfigs = modelConfigsAdapterSelectors.selectAll(data).filter(isIPAdapterModelConfig);
92+
const compatibleModels = modelConfigs.filter((m) => (base ? m.base === base : true));
93+
model = compatibleModels[0] ?? modelConfigs[0] ?? null;
94+
}
95+
const ipAdapter = deepClone(initialIPAdapter);
96+
if (model) {
97+
ipAdapter.model = zModelIdentifierField.parse(model);
98+
if (model.base === 'flux') {
99+
ipAdapter.clipVisionModel = 'ViT-L';
100+
}
101+
}
102+
return ipAdapter;
103+
}
104+
);
105+
67106
/**
68107
* Selects the default IP adapter configuration based on the model configurations and the base.
69108
*
@@ -146,11 +185,11 @@ export const useAddRegionalReferenceImage = () => {
146185

147186
export const useAddGlobalReferenceImage = () => {
148187
const dispatch = useAppDispatch();
149-
const defaultIPAdapter = useAppSelector(selectDefaultIPAdapter);
188+
const defaultRefImage = useAppSelector(selectDefaultRefImageConfig);
150189
const func = useCallback(() => {
151-
const overrides = { ipAdapter: deepClone(defaultIPAdapter) };
190+
const overrides = { ipAdapter: deepClone(defaultRefImage) };
152191
dispatch(referenceImageAdded({ isSelected: true, overrides }));
153-
}, [defaultIPAdapter, dispatch]);
192+
}, [defaultRefImage, dispatch]);
154193

155194
return func;
156195
};

invokeai/frontend/web/src/features/controlLayers/hooks/useIsEntityTypeEnabled.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ export const useIsEntityTypeEnabled = (entityType: CanvasEntityType) => {
1919
const isEntityTypeEnabled = useMemo<boolean>(() => {
2020
switch (entityType) {
2121
case 'reference_image':
22-
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
22+
return !isSD3 && !isCogView4 && !isImagen3;
2323
case 'regional_guidance':
2424
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
2525
case 'control_layer':

invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@ import { isMainModelBase, zModelIdentifierField } from 'features/nodes/types/com
3434
import { ASPECT_RATIO_MAP } from 'features/parameters/components/Bbox/constants';
3535
import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
3636
import type { IRect } from 'konva/lib/types';
37-
import { merge } from 'lodash-es';
37+
import { isEqual, merge } from 'lodash-es';
3838
import type { UndoableOptions } from 'redux-undo';
3939
import type {
40+
ApiModelConfig,
4041
ControlLoRAModelConfig,
4142
ControlNetModelConfig,
4243
FLUXReduxModelConfig,
@@ -76,6 +77,7 @@ import {
7677
getReferenceImageState,
7778
getRegionalGuidanceState,
7879
imageDTOToImageWithDims,
80+
initialChatGPT4oReferenceImage,
7981
initialControlLoRA,
8082
initialControlNet,
8183
initialFLUXRedux,
@@ -644,48 +646,70 @@ export const canvasSlice = createSlice({
644646
referenceImageIPAdapterModelChanged: (
645647
state,
646648
action: PayloadAction<
647-
EntityIdentifierPayload<{ modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | null }, 'reference_image'>
649+
EntityIdentifierPayload<
650+
{ modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig | null },
651+
'reference_image'
652+
>
648653
>
649654
) => {
650655
const { entityIdentifier, modelConfig } = action.payload;
651656
const entity = selectEntity(state, entityIdentifier);
652657
if (!entity) {
653658
return;
654659
}
660+
661+
const oldModel = entity.ipAdapter.model;
662+
663+
// First set the new model
655664
entity.ipAdapter.model = modelConfig ? zModelIdentifierField.parse(modelConfig) : null;
656665

657666
if (!entity.ipAdapter.model) {
658667
return;
659668
}
660669

661-
if (entity.ipAdapter.type === 'ip_adapter' && entity.ipAdapter.model.type === 'flux_redux') {
662-
// Switching from ip_adapter to flux_redux
670+
if (isEqual(oldModel, entity.ipAdapter.model)) {
671+
// Nothing changed, so we don't need to do anything
672+
return;
673+
}
674+
675+
// The type of ref image depends on the model. When the user switches the model, we rebuild the ref image.
676+
// When we switch the model, we keep the image the same, but change the other parameters.
677+
678+
if (entity.ipAdapter.model.base === 'chatgpt-4o') {
679+
// Switching to chatgpt-4o ref image
663680
entity.ipAdapter = {
664-
...initialFLUXRedux,
681+
...initialChatGPT4oReferenceImage,
665682
image: entity.ipAdapter.image,
666683
model: entity.ipAdapter.model,
667684
};
668685
return;
669686
}
670687

671-
if (entity.ipAdapter.type === 'flux_redux' && entity.ipAdapter.model.type === 'ip_adapter') {
672-
// Switching from flux_redux to ip_adapter
688+
if (entity.ipAdapter.model.type === 'flux_redux') {
689+
// Switching to flux_redux
673690
entity.ipAdapter = {
674-
...initialIPAdapter,
691+
...initialFLUXRedux,
675692
image: entity.ipAdapter.image,
676693
model: entity.ipAdapter.model,
677694
};
678695
return;
679696
}
680697

681-
if (entity.ipAdapter.type === 'ip_adapter') {
698+
if (entity.ipAdapter.model.type === 'ip_adapter') {
699+
// Switching to ip_adapter
700+
entity.ipAdapter = {
701+
...initialIPAdapter,
702+
image: entity.ipAdapter.image,
703+
model: entity.ipAdapter.model,
704+
};
682705
// Ensure that the IP Adapter model is compatible with the CLIP Vision model
683706
if (entity.ipAdapter.model?.base === 'flux') {
684707
entity.ipAdapter.clipVisionModel = 'ViT-L';
685708
} else if (entity.ipAdapter.clipVisionModel === 'ViT-L') {
686709
// Fall back to ViT-H (ViT-G would also work)
687710
entity.ipAdapter.clipVisionModel = 'ViT-H';
688711
}
712+
return;
689713
}
690714
},
691715
referenceImageIPAdapterCLIPVisionModelChanged: (

0 commit comments

Comments
 (0)