Skip to content

Commit 9c9265c

Browse files
brandonrisingUbuntu
authored andcommitted
Setup Probe and UI to accept bria main models
1 parent 0d67ee6 commit 9c9265c

File tree

14 files changed

+173
-3
lines changed

14 files changed

+173
-3
lines changed

invokeai/app/invocations/fields.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
4242
MainModel = "MainModelField"
4343
CogView4MainModel = "CogView4MainModelField"
4444
FluxMainModel = "FluxMainModelField"
45+
BriaMainModel = "BriaMainModelField"
4546
SD3MainModel = "SD3MainModelField"
4647
SDXLMainModel = "SDXLMainModelField"
4748
SDXLRefinerModel = "SDXLRefinerModelField"

invokeai/backend/model_manager/legacy_probe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ class ModelProbe(object):
125125
}
126126

127127
CLASS2TYPE = {
128+
"BriaPipeline": ModelType.Main,
128129
"FluxPipeline": ModelType.Main,
129130
"StableDiffusionPipeline": ModelType.Main,
130131
"StableDiffusionInpaintPipeline": ModelType.Main,
@@ -861,6 +862,8 @@ def get_base_type(self) -> BaseModelType:
861862
return BaseModelType.StableDiffusion3
862863
elif transformer_conf["_class_name"] == "CogView4Transformer2DModel":
863864
return BaseModelType.CogView4
865+
elif transformer_conf["_class_name"] == "BriaTransformer2DModel":
866+
return BaseModelType.Bria
864867
else:
865868
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
866869

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from pathlib import Path
2+
from typing import Optional
3+
4+
from invokeai.backend.model_manager.config import (
5+
AnyModelConfig,
6+
CheckpointConfigBase,
7+
DiffusersConfigBase,
8+
)
9+
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
10+
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
11+
from invokeai.backend.model_manager.taxonomy import (
12+
AnyModel,
13+
BaseModelType,
14+
ModelFormat,
15+
ModelType,
16+
SubModelType,
17+
)
18+
19+
20+
@ModelLoaderRegistry.register(base=BaseModelType.Bria, type=ModelType.Main, format=ModelFormat.Diffusers)
21+
class BriaDiffusersModel(GenericDiffusersLoader):
22+
"""Class to load Bria main models."""
23+
24+
def _load_model(
25+
self,
26+
config: AnyModelConfig,
27+
submodel_type: Optional[SubModelType] = None,
28+
) -> AnyModel:
29+
if isinstance(config, CheckpointConfigBase):
30+
raise NotImplementedError("CheckpointConfigBase is not implemented for Bria models.")
31+
32+
if submodel_type is None:
33+
raise Exception("A submodel type must be provided when loading main pipelines.")
34+
35+
model_path = Path(config.path)
36+
load_class = self.get_hf_load_class(model_path, submodel_type)
37+
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
38+
variant = repo_variant.value if repo_variant else None
39+
model_path = model_path / submodel_type.value
40+
41+
dtype = self._torch_dtype
42+
try:
43+
result: AnyModel = load_class.from_pretrained(
44+
model_path,
45+
torch_dtype=dtype,
46+
variant=variant,
47+
)
48+
except OSError as e:
49+
if variant and "no file named" in str(
50+
e
51+
): # try without the variant, just in case user's preferences changed
52+
result = load_class.from_pretrained(model_path, torch_dtype=dtype)
53+
else:
54+
raise e
55+
56+
return result

invokeai/backend/model_manager/taxonomy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class BaseModelType(str, Enum):
3030
Imagen4 = "imagen4"
3131
ChatGPT4o = "chatgpt-4o"
3232
FluxKontext = "flux-kontext"
33+
Bria = "bria"
3334

3435

3536
class ModelType(str, Enum):

invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ import {
2929
isBooleanFieldInputTemplate,
3030
isChatGPT4oModelFieldInputInstance,
3131
isChatGPT4oModelFieldInputTemplate,
32+
isBriaMainModelFieldInputInstance,
33+
isBriaMainModelFieldInputTemplate,
3234
isCLIPEmbedModelFieldInputInstance,
3335
isCLIPEmbedModelFieldInputTemplate,
3436
isCLIPGEmbedModelFieldInputInstance,
@@ -117,6 +119,7 @@ import { assert } from 'tsafe';
117119

118120
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
119121
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
122+
import BriaMainModelFieldInputComponent from './inputs/BriaMainModelFieldInputComponent';
120123
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
121124
import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInputComponent';
122125
import CLIPLEmbedModelFieldInputComponent from './inputs/CLIPLEmbedModelFieldInputComponent';
@@ -448,6 +451,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
448451
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
449452
}
450453

454+
if (isBriaMainModelFieldInputTemplate(template)) {
455+
if (!isBriaMainModelFieldInputInstance(field)) {
456+
return null;
457+
}
458+
return <BriaMainModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
459+
}
460+
451461
if (isSD3MainModelFieldInputTemplate(template)) {
452462
if (!isSD3MainModelFieldInputInstance(field)) {
453463
return null;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import { useAppDispatch } from 'app/store/storeHooks';
2+
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
3+
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
4+
import type { BriaMainModelFieldInputInstance, BriaMainModelFieldInputTemplate } from 'features/nodes/types/field';
5+
import { memo, useCallback } from 'react';
6+
import { useBriaModels } from 'services/api/hooks/modelsByType';
7+
import type { MainModelConfig } from 'services/api/types';
8+
9+
import type { FieldComponentProps } from './types';
10+
11+
type Props = FieldComponentProps<BriaMainModelFieldInputInstance, BriaMainModelFieldInputTemplate>;
12+
13+
const BriaMainModelFieldInputComponent = (props: Props) => {
14+
const { nodeId, field } = props;
15+
const dispatch = useAppDispatch();
16+
const [modelConfigs, { isLoading }] = useBriaModels();
17+
const onChange = useCallback(
18+
(value: MainModelConfig | null) => {
19+
if (!value) {
20+
return;
21+
}
22+
dispatch(
23+
fieldMainModelValueChanged({
24+
nodeId,
25+
fieldName: field.name,
26+
value,
27+
})
28+
);
29+
},
30+
[dispatch, field.name, nodeId]
31+
);
32+
33+
return (
34+
<ModelFieldCombobox
35+
value={field.value}
36+
modelConfigs={modelConfigs}
37+
isLoadingConfigs={isLoading}
38+
onChange={onChange}
39+
required={props.fieldTemplate.required}
40+
/>
41+
);
42+
};
43+
44+
export default memo(BriaMainModelFieldInputComponent);

invokeai/frontend/web/src/features/nodes/types/common.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ const zBaseModel = z.enum([
7676
'imagen4',
7777
'chatgpt-4o',
7878
'flux-kontext',
79+
'bria',
7980
]);
8081
export type BaseModelType = z.infer<typeof zBaseModel>;
8182
export const zMainModelBase = z.enum([
@@ -89,6 +90,7 @@ export const zMainModelBase = z.enum([
8990
'imagen4',
9091
'chatgpt-4o',
9192
'flux-kontext',
93+
'bria',
9294
]);
9395
type MainModelBase = z.infer<typeof zMainModelBase>;
9496
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;

invokeai/frontend/web/src/features/nodes/types/constants.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
5252
LoRAModelField: 'teal.500',
5353
MainModelField: 'teal.500',
5454
FluxMainModelField: 'teal.500',
55+
BriaMainModelField: 'teal.500',
5556
SD3MainModelField: 'teal.500',
5657
CogView4MainModelField: 'teal.500',
5758
SDXLMainModelField: 'teal.500',

invokeai/frontend/web/src/features/nodes/types/field.ts

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ const zFluxMainModelFieldType = zFieldTypeBase.extend({
185185
name: z.literal('FluxMainModelField'),
186186
originalType: zStatelessFieldType.optional(),
187187
});
188+
const zBriaMainModelFieldType = zFieldTypeBase.extend({
189+
name: z.literal('BriaMainModelField'),
190+
originalType: zStatelessFieldType.optional(),
191+
});
188192
const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
189193
name: z.literal('SDXLRefinerModelField'),
190194
originalType: zStatelessFieldType.optional(),
@@ -325,6 +329,7 @@ const zStatefulFieldType = z.union([
325329
zIntegerGeneratorFieldType,
326330
zStringGeneratorFieldType,
327331
zImageGeneratorFieldType,
332+
zBriaMainModelFieldType,
328333
]);
329334
export type StatefulFieldType = z.infer<typeof zStatefulFieldType>;
330335
const statefulFieldTypeNames = zStatefulFieldType.options.map((o) => o.shape.name.value);
@@ -341,6 +346,7 @@ const modelFieldTypeNames = [
341346
zSD3MainModelFieldType.shape.name.value,
342347
zCogView4MainModelFieldType.shape.name.value,
343348
zFluxMainModelFieldType.shape.name.value,
349+
zBriaMainModelFieldType.shape.name.value,
344350
zSDXLRefinerModelFieldType.shape.name.value,
345351
zVAEModelFieldType.shape.name.value,
346352
zLoRAModelFieldType.shape.name.value,
@@ -888,6 +894,26 @@ export const isFluxMainModelFieldInputTemplate =
888894
buildTemplateTypeGuard<FluxMainModelFieldInputTemplate>('FluxMainModelField');
889895
// #endregion
890896

897+
// #region BriaMainModelField
898+
const zBriaMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only.
899+
const zBriaMainModelFieldInputInstance = zFieldInputInstanceBase.extend({
900+
value: zBriaMainModelFieldValue,
901+
});
902+
const zBriaMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
903+
type: zBriaMainModelFieldType,
904+
originalType: zFieldType.optional(),
905+
default: zBriaMainModelFieldValue,
906+
});
907+
const zBriaMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
908+
type: zBriaMainModelFieldType,
909+
});
910+
export type BriaMainModelFieldInputInstance = z.infer<typeof zBriaMainModelFieldInputInstance>;
911+
export type BriaMainModelFieldInputTemplate = z.infer<typeof zBriaMainModelFieldInputTemplate>;
912+
export const isBriaMainModelFieldInputInstance = buildInstanceTypeGuard(zBriaMainModelFieldInputInstance);
913+
export const isBriaMainModelFieldInputTemplate =
914+
buildTemplateTypeGuard<BriaMainModelFieldInputTemplate>('BriaMainModelField');
915+
// #endregion
916+
891917
// #region SDXLRefinerModelField
892918
/** @alias */ // tells knip to ignore this duplicate export
893919
export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only.
@@ -1887,6 +1913,7 @@ export const zStatefulFieldValue = z.union([
18871913
zMainModelFieldValue,
18881914
zSDXLMainModelFieldValue,
18891915
zFluxMainModelFieldValue,
1916+
zBriaMainModelFieldValue,
18901917
zSD3MainModelFieldValue,
18911918
zCogView4MainModelFieldValue,
18921919
zSDXLRefinerModelFieldValue,
@@ -1938,6 +1965,7 @@ const zStatefulFieldInputInstance = z.union([
19381965
zModelIdentifierFieldInputInstance,
19391966
zMainModelFieldInputInstance,
19401967
zFluxMainModelFieldInputInstance,
1968+
zBriaMainModelFieldInputInstance,
19411969
zSD3MainModelFieldInputInstance,
19421970
zCogView4MainModelFieldInputInstance,
19431971
zSDXLMainModelFieldInputInstance,
@@ -1980,6 +2008,7 @@ const zStatefulFieldInputTemplate = z.union([
19802008
zModelIdentifierFieldInputTemplate,
19812009
zMainModelFieldInputTemplate,
19822010
zFluxMainModelFieldInputTemplate,
2011+
zBriaMainModelFieldInputTemplate,
19832012
zSD3MainModelFieldInputTemplate,
19842013
zCogView4MainModelFieldInputTemplate,
19852014
zSDXLMainModelFieldInputTemplate,
@@ -2032,6 +2061,7 @@ const zStatefulFieldOutputTemplate = z.union([
20322061
zModelIdentifierFieldOutputTemplate,
20332062
zMainModelFieldOutputTemplate,
20342063
zFluxMainModelFieldOutputTemplate,
2064+
zBriaMainModelFieldOutputTemplate,
20352065
zSD3MainModelFieldOutputTemplate,
20362066
zCogView4MainModelFieldOutputTemplate,
20372067
zSDXLMainModelFieldOutputTemplate,

invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
1717
SchedulerField: 'dpmpp_3m_k',
1818
SDXLMainModelField: undefined,
1919
FluxMainModelField: undefined,
20+
BriaMainModelField: undefined,
2021
SD3MainModelField: undefined,
2122
CogView4MainModelField: undefined,
2223
SDXLRefinerModelField: undefined,

0 commit comments

Comments
 (0)