Skip to content

Commit 98e5f79

Browse files
committed
Setup flux model loading in the UI
1 parent 45792cc commit 98e5f79

File tree

22 files changed

+814
-137
lines changed

22 files changed

+814
-137
lines changed

invokeai/app/invocations/fields.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
4040

4141
# region Model Field Types
4242
MainModel = "MainModelField"
43+
FluxMainModel = "FluxMainModelField"
4344
SDXLMainModel = "SDXLMainModelField"
4445
SDXLRefinerModel = "SDXLRefinerModelField"
4546
ONNXModel = "ONNXModelField"
@@ -126,12 +127,14 @@ class FieldDescriptions:
126127
noise = "Noise tensor"
127128
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
128129
unet = "UNet (scheduler, LoRAs)"
130+
transformer = "Transformer"
129131
vae = "VAE"
130132
cond = "Conditioning tensor"
131133
controlnet_model = "ControlNet model to load"
132134
vae_model = "VAE model to load"
133135
lora_model = "LoRA model to load"
134136
main_model = "Main model (UNet, VAE, CLIP) to load"
137+
flux_model = "Flux model (Transformer, VAE, CLIP) to load"
135138
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
136139
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
137140
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"

invokeai/app/invocations/flux_text_to_image.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from pathlib import Path
22
from typing import Literal
3+
from pydantic import Field
34

45
import torch
56
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
67
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
78
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
9+
from invokeai.app.invocations.model import ModelIdentifierField
810
from optimum.quanto import qfloat8
911
from PIL import Image
1012
from transformers.models.auto import AutoModelForTextEncoding
@@ -17,6 +19,7 @@
1719
InputField,
1820
WithBoard,
1921
WithMetadata,
22+
UIType,
2023
)
2124
from invokeai.app.invocations.primitives import ImageOutput
2225
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -46,6 +49,11 @@ class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
4649
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
4750
"""Text-to-image generation using a FLUX model."""
4851

52+
flux_model: ModelIdentifierField = InputField(
53+
description="The Flux model",
54+
input=Input.Any,
55+
ui_type=UIType.FluxMainModel
56+
)
4957
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
5058
use_8bit: bool = InputField(
5159
default=False, description="Whether to quantize the transformer model to 8-bit precision."

invokeai/app/invocations/model.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ class CLIPField(BaseModel):
6060
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
6161

6262

63+
64+
class TransformerField(BaseModel):
65+
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
66+
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
67+
68+
6369
class VAEField(BaseModel):
6470
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
6571
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@@ -122,6 +128,49 @@ def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
122128
return ModelIdentifierOutput(model=self.model)
123129

124130

131+
@invocation_output("flux_model_loader_output")
132+
class FluxModelLoaderOutput(BaseInvocationOutput):
133+
"""Flux base model loader output"""
134+
135+
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
136+
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
137+
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
138+
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
139+
140+
141+
@invocation("flux_model_loader", title="Flux Main Model", tags=["model", "flux"], category="model", version="1.0.3")
142+
class FluxModelLoaderInvocation(BaseInvocation):
143+
"""Loads a flux base model, outputting its submodels."""
144+
145+
model: ModelIdentifierField = InputField(
146+
description=FieldDescriptions.flux_model,
147+
ui_type=UIType.FluxMainModel,
148+
input=Input.Direct,
149+
)
150+
151+
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
152+
model_key = self.model.key
153+
154+
# TODO: not found exceptions
155+
if not context.models.exists(model_key):
156+
raise Exception(f"Unknown model: {model_key}")
157+
158+
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
159+
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
160+
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
161+
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
162+
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
163+
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
164+
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
165+
166+
return FluxModelLoaderOutput(
167+
transformer=TransformerField(transformer=transformer, scheduler=scheduler),
168+
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
169+
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
170+
vae=VAEField(vae=vae),
171+
)
172+
173+
125174
@invocation(
126175
"main_model_loader",
127176
title="Main Model",

invokeai/backend/model_manager/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class BaseModelType(str, Enum):
5252
StableDiffusion2 = "sd-2"
5353
StableDiffusionXL = "sdxl"
5454
StableDiffusionXLRefiner = "sdxl-refiner"
55+
Flux = "flux"
5556
# Kandinsky2_1 = "kandinsky-2.1"
5657

5758

@@ -74,6 +75,7 @@ class SubModelType(str, Enum):
7475
"""Submodel type."""
7576

7677
UNet = "unet"
78+
Transformer = "transformer"
7779
TextEncoder = "text_encoder"
7880
TextEncoder2 = "text_encoder_2"
7981
Tokenizer = "tokenizer"

invokeai/backend/model_manager/probe.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class ModelProbe(object):
9595
}
9696

9797
CLASS2TYPE = {
98+
"FluxPipeline": ModelType.Main,
9899
"StableDiffusionPipeline": ModelType.Main,
99100
"StableDiffusionInpaintPipeline": ModelType.Main,
100101
"StableDiffusionXLPipeline": ModelType.Main,
@@ -626,6 +627,10 @@ def get_repo_variant(self) -> ModelRepoVariant:
626627

627628
class PipelineFolderProbe(FolderProbeBase):
628629
def get_base_type(self) -> BaseModelType:
630+
with open(f"{self.model_path}/model_index.json", "r") as file:
631+
conf = json.load(file)
632+
if "_class_name" in conf and conf.get("_class_name") == "FluxPipeline":
633+
return BaseModelType.Flux
629634
with open(self.model_path / "unet" / "config.json", "r") as file:
630635
unet_conf = json.load(file)
631636
if unet_conf["cross_attention_dim"] == 768:

invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
1313
'sd-2': 'teal',
1414
sdxl: 'invokeBlue',
1515
'sdxl-refiner': 'invokeBlue',
16+
flux: 'invokeBlue',
1617
};
1718

1819
const ModelBaseBadge = ({ base }: Props) => {

invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import {
1414
isEnumFieldInputTemplate,
1515
isFloatFieldInputInstance,
1616
isFloatFieldInputTemplate,
17+
isFluxMainModelFieldInputInstance,
18+
isFluxMainModelFieldInputTemplate,
1719
isImageFieldInputInstance,
1820
isImageFieldInputTemplate,
1921
isIntegerFieldInputInstance,
@@ -48,6 +50,7 @@ import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
4850
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
4951
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
5052
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
53+
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
5154
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
5255
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
5356
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
@@ -69,6 +72,7 @@ type InputFieldProps = {
6972
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
7073
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
7174
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
75+
window.console.log("Hit 0")
7276

7377
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
7478
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
@@ -145,6 +149,9 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
145149
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
146150
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
147151
}
152+
if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) {
153+
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
154+
}
148155

149156
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
150157
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
2+
import { useAppDispatch } from 'app/store/storeHooks';
3+
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
4+
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
5+
import type { FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate } from 'features/nodes/types/field';
6+
import { memo, useCallback } from 'react';
7+
import { useFluxModels } from 'services/api/hooks/modelsByType';
8+
import type { MainModelConfig } from 'services/api/types';
9+
10+
import type { FieldComponentProps } from './types';
11+
12+
type Props = FieldComponentProps<FluxMainModelFieldInputInstance, FluxMainModelFieldInputTemplate>;
13+
14+
const FluxMainModelFieldInputComponent = (props: Props) => {
15+
const { nodeId, field } = props;
16+
const dispatch = useAppDispatch();
17+
const [modelConfigs, { isLoading }] = useFluxModels();
18+
const _onChange = useCallback(
19+
(value: MainModelConfig | null) => {
20+
if (!value) {
21+
return;
22+
}
23+
dispatch(
24+
fieldMainModelValueChanged({
25+
nodeId,
26+
fieldName: field.name,
27+
value,
28+
})
29+
);
30+
},
31+
[dispatch, field.name, nodeId]
32+
);
33+
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
34+
modelConfigs,
35+
onChange: _onChange,
36+
isLoading,
37+
selectedModel: field.value,
38+
});
39+
40+
return (
41+
<Flex w="full" alignItems="center" gap={2}>
42+
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
43+
<Combobox
44+
value={value}
45+
placeholder={placeholder}
46+
options={options}
47+
onChange={onChange}
48+
noOptionsMessage={noOptionsMessage}
49+
/>
50+
</FormControl>
51+
</Flex>
52+
);
53+
};
54+
55+
export default memo(FluxMainModelFieldInputComponent);

invokeai/frontend/web/src/features/nodes/types/common.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
6161
// #endregion
6262

6363
// #region Model-related schemas
64-
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
64+
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner', 'flux']);
6565
const zModelType = z.enum([
6666
'main',
6767
'vae',
@@ -76,6 +76,7 @@ const zModelType = z.enum([
7676
]);
7777
const zSubModelType = z.enum([
7878
'unet',
79+
'transformer',
7980
'text_encoder',
8081
'text_encoder_2',
8182
'tokenizer',

invokeai/frontend/web/src/features/nodes/types/constants.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ export const MODEL_TYPES = [
3131
'ControlNetModelField',
3232
'LoRAModelField',
3333
'MainModelField',
34+
'FluxMainModelField',
3435
'SDXLMainModelField',
3536
'SDXLRefinerModelField',
3637
'VaeModelField',
@@ -61,13 +62,15 @@ export const FIELD_COLORS: { [key: string]: string } = {
6162
LatentsField: 'pink.500',
6263
LoRAModelField: 'teal.500',
6364
MainModelField: 'teal.500',
65+
FluxMainModelField: 'teal.500',
6466
SDXLMainModelField: 'teal.500',
6567
SDXLRefinerModelField: 'teal.500',
6668
SpandrelImageToImageModelField: 'teal.500',
6769
StringField: 'yellow.500',
6870
T2IAdapterField: 'teal.500',
6971
T2IAdapterModelField: 'teal.500',
7072
UNetField: 'red.500',
73+
TransformerField: 'red.500',
7174
VAEField: 'blue.500',
7275
VAEModelField: 'teal.500',
7376
};

0 commit comments

Comments
 (0)