Skip to content

Commit 7140f2e

Browse files
Ilan TchenakUbuntu
authored andcommitted
Setup Probe and UI to accept bria controlnet models
1 parent 9e5e1ec commit 7140f2e

File tree

12 files changed

+162
-4
lines changed

12 files changed

+162
-4
lines changed

invokeai/app/invocations/fields.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
4343
CogView4MainModel = "CogView4MainModelField"
4444
FluxMainModel = "FluxMainModelField"
4545
BriaMainModel = "BriaMainModelField"
46+
BriaControlNetModel = "BriaControlNetModelField"
4647
SD3MainModel = "SD3MainModelField"
4748
SDXLMainModel = "SDXLMainModelField"
4849
SDXLRefinerModel = "SDXLRefinerModelField"

invokeai/backend/model_manager/legacy_probe.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class ModelProbe(object):
126126

127127
CLASS2TYPE = {
128128
"BriaPipeline": ModelType.Main,
129+
"BriaControlNetModel": ModelType.ControlNet,
129130
"FluxPipeline": ModelType.Main,
130131
"StableDiffusionPipeline": ModelType.Main,
131132
"StableDiffusionInpaintPipeline": ModelType.Main,
@@ -1013,6 +1014,9 @@ def get_base_type(self) -> BaseModelType:
10131014
if config.get("_class_name", None) == "FluxControlNetModel":
10141015
return BaseModelType.Flux
10151016

1017+
if config.get("_class_name", None) == "BriaControlNetModel":
1018+
return BaseModelType.Bria
1019+
10161020
# no obvious way to distinguish between sd2-base and sd2-768
10171021
dimension = config["cross_attention_dim"]
10181022
if dimension == 768:

invokeai/backend/model_manager/load/model_loaders/bria.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
AnyModelConfig,
66
CheckpointConfigBase,
77
DiffusersConfigBase,
8+
ControlNetDiffusersConfig,
9+
ControlNetCheckpointConfig,
810
)
911
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
1012
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
@@ -17,6 +19,45 @@
1719
)
1820

1921

22+
@ModelLoaderRegistry.register(base=BaseModelType.Bria, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
23+
class BriaControlNetDiffusersModel(GenericDiffusersLoader):
24+
"""Class to load Bria control net models."""
25+
26+
def _load_model(
27+
self,
28+
config: AnyModelConfig,
29+
submodel_type: Optional[SubModelType] = None,
30+
) -> AnyModel:
31+
if isinstance(config, ControlNetCheckpointConfig):
32+
raise NotImplementedError("CheckpointConfigBase is not implemented for Bria models.")
33+
34+
if submodel_type is None:
35+
raise Exception("A submodel type must be provided when loading control net pipelines.")
36+
37+
model_path = Path(config.path)
38+
load_class = self.get_hf_load_class(model_path, submodel_type)
39+
repo_variant = config.repo_variant if isinstance(config, ControlNetDiffusersConfig) else None
40+
variant = repo_variant.value if repo_variant else None
41+
model_path = model_path / submodel_type.value
42+
43+
dtype = self._torch_dtype
44+
45+
try:
46+
result: AnyModel = load_class.from_pretrained(
47+
model_path,
48+
torch_dtype=dtype,
49+
variant=variant,
50+
)
51+
except OSError as e:
52+
if variant and "no file named" in str(
53+
e
54+
): # try without the variant, just in case user's preferences changed
55+
result = load_class.from_pretrained(model_path, torch_dtype=dtype)
56+
else:
57+
raise e
58+
59+
return result
60+
2061
@ModelLoaderRegistry.register(base=BaseModelType.Bria, type=ModelType.Main, format=ModelFormat.Diffusers)
2162
class BriaDiffusersModel(GenericDiffusersLoader):
2263
"""Class to load Bria main models."""

invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@ import {
2727
isBoardFieldInputTemplate,
2828
isBooleanFieldInputInstance,
2929
isBooleanFieldInputTemplate,
30-
isChatGPT4oModelFieldInputInstance,
31-
isChatGPT4oModelFieldInputTemplate,
30+
isBriaControlNetModelFieldInputInstance,
31+
isBriaControlNetModelFieldInputTemplate,
3232
isBriaMainModelFieldInputInstance,
3333
isBriaMainModelFieldInputTemplate,
34+
isChatGPT4oModelFieldInputInstance,
35+
isChatGPT4oModelFieldInputTemplate,
3436
isCLIPEmbedModelFieldInputInstance,
3537
isCLIPEmbedModelFieldInputTemplate,
3638
isCLIPGEmbedModelFieldInputInstance,
@@ -119,6 +121,7 @@ import { assert } from 'tsafe';
119121

120122
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
121123
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
124+
import BriaControlNetModelFieldInputComponent from './inputs/BriaControlNetModelFieldInputComponent';
122125
import BriaMainModelFieldInputComponent from './inputs/BriaMainModelFieldInputComponent';
123126
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
124127
import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInputComponent';
@@ -458,6 +461,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
458461
return <BriaMainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
459462
}
460463

464+
if (isBriaControlNetModelFieldInputTemplate(template)) {
465+
if (!isBriaControlNetModelFieldInputInstance(field)) {
466+
return null;
467+
}
468+
return <BriaControlNetModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
469+
}
470+
461471
if (isSD3MainModelFieldInputTemplate(template)) {
462472
if (!isSD3MainModelFieldInputInstance(field)) {
463473
return null;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import { useAppDispatch } from 'app/store/storeHooks';
2+
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
3+
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
4+
import type {
5+
BriaControlNetModelFieldInputInstance,
6+
BriaControlNetModelFieldInputTemplate,
7+
} from 'features/nodes/types/field';
8+
import { memo, useCallback } from 'react';
9+
import { useBriaModels } from 'services/api/hooks/modelsByType';
10+
import type { MainModelConfig } from 'services/api/types';
11+
12+
import type { FieldComponentProps } from './types';
13+
14+
type Props = FieldComponentProps<BriaControlNetModelFieldInputInstance, BriaControlNetModelFieldInputTemplate>;
15+
16+
const BriaControlNetModelFieldInputComponent = (props: Props) => {
17+
const { nodeId, field } = props;
18+
const dispatch = useAppDispatch();
19+
const [modelConfigs, { isLoading }] = useBriaModels();
20+
const onChange = useCallback(
21+
(value: MainModelConfig | null) => {
22+
if (!value) {
23+
return;
24+
}
25+
dispatch(
26+
fieldMainModelValueChanged({
27+
nodeId,
28+
fieldName: field.name,
29+
value,
30+
})
31+
);
32+
},
33+
[dispatch, field.name, nodeId]
34+
);
35+
36+
return (
37+
<ModelFieldCombobox
38+
value={field.value}
39+
modelConfigs={modelConfigs}
40+
isLoadingConfigs={isLoading}
41+
onChange={onChange}
42+
required={props.fieldTemplate.required}
43+
/>
44+
);
45+
};
46+
47+
export default memo(BriaControlNetModelFieldInputComponent);

invokeai/frontend/web/src/features/nodes/types/constants.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
5353
MainModelField: 'teal.500',
5454
FluxMainModelField: 'teal.500',
5555
BriaMainModelField: 'teal.500',
56+
BriaControlNetModelField: 'teal.500',
5657
SD3MainModelField: 'teal.500',
5758
CogView4MainModelField: 'teal.500',
5859
SDXLMainModelField: 'teal.500',

invokeai/frontend/web/src/features/nodes/types/field.ts

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,10 @@ const zBriaMainModelFieldType = zFieldTypeBase.extend({
189189
name: z.literal('BriaMainModelField'),
190190
originalType: zStatelessFieldType.optional(),
191191
});
192+
const zBriaControlNetModelFieldType = zFieldTypeBase.extend({
193+
name: z.literal('BriaControlNetModelField'),
194+
originalType: zStatelessFieldType.optional(),
195+
});
192196
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
193197
name: z.literal('SDXLRefinerModelField'),
194198
originalType: zStatelessFieldType.optional(),
@@ -330,6 +334,7 @@ const zStatefulFieldType = z.union([
330334
zStringGeneratorFieldType,
331335
zImageGeneratorFieldType,
332336
zBriaMainModelFieldType,
337+
zBriaControlNetModelFieldType,
333338
]);
334339
export type StatefulFieldType = z.infer<typeof zStatefulFieldType>;
335340
const statefulFieldTypeNames = zStatefulFieldType.options.map((o) => o.shape.name.value);
@@ -347,6 +352,7 @@ const modelFieldTypeNames = [
347352
zCogView4MainModelFieldType.shape.name.value,
348353
zFluxMainModelFieldType.shape.name.value,
349354
zBriaMainModelFieldType.shape.name.value,
355+
zBriaControlNetModelFieldType.shape.name.value,
350356
zSDXLRefinerModelFieldType.shape.name.value,
351357
zVAEModelFieldType.shape.name.value,
352358
zLoRAModelFieldType.shape.name.value,
@@ -914,6 +920,26 @@ export const isBriaMainModelFieldInputTemplate =
914920
buildTemplateTypeGuard<BriaMainModelFieldInputTemplate>('BriaMainModelField');
915921
// #endregion
916922

923+
// #region BriaControlNetModelField
924+
const zBriaControlNetModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
925+
const zBriaControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({
926+
value: zBriaControlNetModelFieldValue,
927+
});
928+
const zBriaControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({
929+
type: zBriaControlNetModelFieldType,
930+
originalType: zFieldType.optional(),
931+
default: zBriaControlNetModelFieldValue,
932+
});
933+
const zBriaControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
934+
type: zBriaControlNetModelFieldType,
935+
});
936+
export type BriaControlNetModelFieldInputInstance = z.infer<typeof zBriaControlNetModelFieldInputInstance>;
937+
export type BriaControlNetModelFieldInputTemplate = z.infer<typeof zBriaControlNetModelFieldInputTemplate>;
938+
export const isBriaControlNetModelFieldInputInstance = buildInstanceTypeGuard(zBriaControlNetModelFieldInputInstance);
939+
export const isBriaControlNetModelFieldInputTemplate =
940+
buildTemplateTypeGuard<BriaControlNetModelFieldInputTemplate>('BriaControlNetModelField');
941+
// #endregion
942+
917943
// #region SDXLRefinerModelField
918944
/** @alias */ // tells knip to ignore this duplicate export
919945
export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only.
@@ -1914,6 +1940,7 @@ export const zStatefulFieldValue = z.union([
19141940
zSDXLMainModelFieldValue,
19151941
zFluxMainModelFieldValue,
19161942
zBriaMainModelFieldValue,
1943+
zBriaControlNetModelFieldValue,
19171944
zSD3MainModelFieldValue,
19181945
zCogView4MainModelFieldValue,
19191946
zSDXLRefinerModelFieldValue,
@@ -1966,6 +1993,7 @@ const zStatefulFieldInputInstance = z.union([
19661993
zMainModelFieldInputInstance,
19671994
zFluxMainModelFieldInputInstance,
19681995
zBriaMainModelFieldInputInstance,
1996+
zBriaControlNetModelFieldInputInstance,
19691997
zSD3MainModelFieldInputInstance,
19701998
zCogView4MainModelFieldInputInstance,
19711999
zSDXLMainModelFieldInputInstance,
@@ -2009,6 +2037,7 @@ const zStatefulFieldInputTemplate = z.union([
20092037
zMainModelFieldInputTemplate,
20102038
zFluxMainModelFieldInputTemplate,
20112039
zBriaMainModelFieldInputTemplate,
2040+
zBriaControlNetModelFieldInputTemplate,
20122041
zSD3MainModelFieldInputTemplate,
20132042
zCogView4MainModelFieldInputTemplate,
20142043
zSDXLMainModelFieldInputTemplate,
@@ -2062,6 +2091,7 @@ const zStatefulFieldOutputTemplate = z.union([
20622091
zMainModelFieldOutputTemplate,
20632092
zFluxMainModelFieldOutputTemplate,
20642093
zBriaMainModelFieldOutputTemplate,
2094+
zBriaControlNetModelFieldOutputTemplate,
20652095
zSD3MainModelFieldOutputTemplate,
20662096
zCogView4MainModelFieldOutputTemplate,
20672097
zSDXLMainModelFieldOutputTemplate,

invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
1818
SDXLMainModelField: undefined,
1919
FluxMainModelField: undefined,
2020
BriaMainModelField: undefined,
21+
BriaControlNetModelField: undefined,
2122
SD3MainModelField: undefined,
2223
CogView4MainModelField: undefined,
2324
SDXLRefinerModelField: undefined,

invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ import { FieldParseError } from 'features/nodes/types/error';
33
import type {
44
BoardFieldInputTemplate,
55
BooleanFieldInputTemplate,
6-
ChatGPT4oModelFieldInputTemplate,
6+
BriaControlNetModelFieldInputTemplate,
77
BriaMainModelFieldInputTemplate,
8+
ChatGPT4oModelFieldInputTemplate,
89
CLIPEmbedModelFieldInputTemplate,
910
CLIPGEmbedModelFieldInputTemplate,
1011
CLIPLEmbedModelFieldInputTemplate,
@@ -357,6 +358,20 @@ const buildBriaMainModelFieldInputTemplate: FieldInputTemplateBuilder<BriaMainMo
357358
return template;
358359
};
359360

361+
const buildBriaControlNetModelFieldInputTemplate: FieldInputTemplateBuilder<BriaControlNetModelFieldInputTemplate> = ({
362+
schemaObject,
363+
baseField,
364+
fieldType,
365+
}) => {
366+
const template: BriaControlNetModelFieldInputTemplate = {
367+
...baseField,
368+
type: fieldType,
369+
default: schemaObject.default ?? undefined,
370+
};
371+
372+
return template;
373+
};
374+
360375
const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder<SD3MainModelFieldInputTemplate> = ({
361376
schemaObject,
362377
baseField,
@@ -850,6 +865,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
850865
CogView4MainModelField: buildCogView4MainModelFieldInputTemplate,
851866
FluxMainModelField: buildFluxMainModelFieldInputTemplate,
852867
BriaMainModelField: buildBriaMainModelFieldInputTemplate,
868+
BriaControlNetModelField: buildBriaControlNetModelFieldInputTemplate,
853869
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
854870
StringField: buildStringFieldInputTemplate,
855871
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,

invokeai/frontend/web/src/services/api/hooks/modelsByType.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ import {
99
} from 'services/api/endpoints/models';
1010
import type { AnyModelConfig } from 'services/api/types';
1111
import {
12+
isBriaControlNetModelConfig,
1213
isBriaMainModelModelConfig,
14+
isChatGPT4oModelConfig,
1315
isCLIPEmbedModelConfig,
1416
isCLIPVisionModelConfig,
1517
isCogView4MainModelModelConfig,
@@ -66,6 +68,7 @@ export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
6668
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
6769
export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
6870
export const useBriaModels = buildModelsHook(isBriaMainModelModelConfig);
71+
export const useBriaControlNetModels = buildModelsHook(isBriaControlNetModelConfig);
6972
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
7073
export const useCogView4Models = buildModelsHook(isCogView4MainModelModelConfig);
7174
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);

0 commit comments

Comments
 (0)