Skip to content

Commit 9e5e1ec

Browse files
author
Ubuntu
committed
addded bria nodes for bria3.1 and bria3.2
1 parent a139885 commit 9e5e1ec

File tree

6 files changed

+409
-0
lines changed

6 files changed

+409
-0
lines changed

invokeai/nodes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .bria_nodes import *
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
3+
from PIL import Image
4+
5+
from invokeai.app.invocations.model import VAEField
6+
from invokeai.app.invocations.primitives import FieldDescriptions, Input, InputField, LatentsField
7+
from invokeai.app.services.shared.invocation_context import InvocationContext
8+
from invokeai.invocation_api import BaseInvocation, Classification, ImageOutput, invocation
9+
10+
11+
@invocation(
12+
"bria_decoder",
13+
title="Bria Decoder",
14+
tags=["image", "bria"],
15+
category="image",
16+
version="1.0.0",
17+
classification=Classification.Prototype,
18+
)
19+
class BriaDecoderInvocation(BaseInvocation):
20+
latents: LatentsField = InputField(
21+
description=FieldDescriptions.latents,
22+
input=Input.Connection,
23+
)
24+
vae: VAEField = InputField(
25+
description=FieldDescriptions.vae,
26+
input=Input.Connection,
27+
)
28+
29+
@torch.no_grad()
30+
def invoke(self, context: InvocationContext) -> ImageOutput:
31+
latents = context.tensors.load(self.latents.latents_name)
32+
latents = latents.view(1, 64, 64, 4, 2, 2).permute(0, 3, 1, 4, 2, 5).reshape(1, 4, 128, 128)
33+
34+
with context.models.load(self.vae.vae) as vae:
35+
assert isinstance(vae, AutoencoderKL)
36+
latents = (latents / vae.config.scaling_factor)
37+
latents = latents.to(device=vae.device, dtype=vae.dtype)
38+
39+
decoded_output = vae.decode(latents)
40+
image = decoded_output.sample
41+
42+
# Convert to numpy with proper gradient handling
43+
image = ((image.clamp(-1, 1) + 1) / 2 * 255).cpu().detach().permute(0, 2, 3, 1).numpy().astype("uint8")[0]
44+
img = Image.fromarray(image)
45+
image_dto = context.images.save(image=img)
46+
return ImageOutput.build(image_dto)
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import torch
2+
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
3+
4+
from invokeai.app.invocations.fields import Input, InputField
5+
from invokeai.app.invocations.model import SubModelType, TransformerField
6+
from invokeai.app.invocations.primitives import (
7+
BaseInvocationOutput,
8+
FieldDescriptions,
9+
Input,
10+
InputField,
11+
LatentsField,
12+
OutputField,
13+
)
14+
from invokeai.app.services.shared.invocation_context import InvocationContext
15+
from invokeai.invocation_api import BaseInvocation, Classification, InputField, invocation, invocation_output
16+
17+
from invokeai.backend.bria.pipeline import get_original_sigmas, retrieve_timesteps
18+
from invokeai.backend.bria.transformer_bria import BriaTransformer2DModel
19+
20+
@invocation_output("bria_denoise_output")
21+
class BriaDenoiseInvocationOutput(BaseInvocationOutput):
22+
latents: LatentsField = OutputField(description=FieldDescriptions.latents)
23+
24+
25+
@invocation(
26+
"bria_denoise",
27+
title="Denoise - Bria",
28+
tags=["image", "bria"],
29+
category="image",
30+
version="1.0.0",
31+
classification=Classification.Prototype,
32+
)
33+
class BriaDenoiseInvocation(BaseInvocation):
34+
num_steps: int = InputField(
35+
default=30, title="Number of Steps", description="The number of steps to use for the denoiser"
36+
)
37+
guidance_scale: float = InputField(
38+
default=5.0, title="Guidance Scale", description="The guidance scale to use for the denoiser"
39+
)
40+
41+
transformer: TransformerField = InputField(
42+
description="Bria model (Transformer) to load",
43+
input=Input.Connection,
44+
title="Transformer",
45+
)
46+
latents: LatentsField = InputField(
47+
description="Latents to denoise",
48+
input=Input.Connection,
49+
title="Latents",
50+
)
51+
latent_image_ids: LatentsField = InputField(
52+
description="Latent Image IDs to denoise",
53+
input=Input.Connection,
54+
title="Latent Image IDs",
55+
)
56+
pos_embeds: LatentsField = InputField(
57+
description="Positive Prompt Embeds",
58+
input=Input.Connection,
59+
title="Positive Prompt Embeds",
60+
)
61+
neg_embeds: LatentsField = InputField(
62+
description="Negative Prompt Embeds",
63+
input=Input.Connection,
64+
title="Negative Prompt Embeds",
65+
)
66+
text_ids: LatentsField = InputField(
67+
description="Text IDs",
68+
input=Input.Connection,
69+
title="Text IDs",
70+
)
71+
72+
@torch.no_grad()
73+
def invoke(self, context: InvocationContext) -> BriaDenoiseInvocationOutput:
74+
latents = context.tensors.load(self.latents.latents_name)
75+
pos_embeds = context.tensors.load(self.pos_embeds.latents_name)
76+
neg_embeds = context.tensors.load(self.neg_embeds.latents_name)
77+
text_ids = context.tensors.load(self.text_ids.latents_name)
78+
latent_image_ids = context.tensors.load(self.latent_image_ids.latents_name)
79+
scheduler_identifier = self.transformer.transformer.model_copy(update={"submodel_type": SubModelType.Scheduler})
80+
81+
device = None
82+
dtype = None
83+
with (
84+
context.models.load(self.transformer.transformer) as transformer,
85+
context.models.load(scheduler_identifier) as scheduler,
86+
):
87+
assert isinstance(transformer, BriaTransformer2DModel)
88+
assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler)
89+
dtype = transformer.dtype
90+
device = transformer.device
91+
latents, pos_embeds, neg_embeds = map(lambda x: x.to(device, dtype), (latents, pos_embeds, neg_embeds))
92+
prompt_embeds = torch.cat([neg_embeds, pos_embeds]) if self.guidance_scale > 1 else pos_embeds
93+
94+
sigmas = get_original_sigmas(1000, self.num_steps)
95+
timesteps, _ = retrieve_timesteps(scheduler, self.num_steps, device, None, sigmas, mu=0.0)
96+
97+
for t in timesteps:
98+
# Prepare model input efficiently
99+
if self.guidance_scale > 1:
100+
latent_model_input = torch.cat([latents] * 2)
101+
else:
102+
latent_model_input = latents
103+
104+
# Prepare timestep tensor efficiently
105+
if isinstance(t, torch.Tensor):
106+
timestep_tensor = t.expand(latent_model_input.shape[0])
107+
else:
108+
timestep_tensor = torch.tensor([t] * latent_model_input.shape[0], device=device, dtype=torch.float32)
109+
110+
noise_pred = transformer(
111+
latent_model_input,
112+
encoder_hidden_states=prompt_embeds,
113+
timestep=timestep_tensor,
114+
img_ids=latent_image_ids,
115+
txt_ids=text_ids,
116+
guidance=None,
117+
return_dict=False,
118+
)[0]
119+
120+
if self.guidance_scale > 1:
121+
noise_uncond, noise_text = noise_pred.chunk(2)
122+
noise_pred = noise_uncond + self.guidance_scale * (noise_text - noise_uncond)
123+
124+
# Convert timestep for scheduler
125+
t_step = float(t.item()) if isinstance(t, torch.Tensor) else float(t)
126+
127+
# Use scheduler step with proper dtypes
128+
latents = scheduler.step(noise_pred, t_step, latents, return_dict=False)[0]
129+
130+
assert isinstance(latents, torch.Tensor)
131+
saved_input_latents_tensor = context.tensors.save(latents)
132+
latents_output = LatentsField(latents_name=saved_input_latents_tensor)
133+
return BriaDenoiseInvocationOutput(latents=latents_output)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import torch
2+
3+
from invokeai.app.invocations.fields import Input, InputField
4+
from invokeai.app.invocations.model import TransformerField
5+
from invokeai.app.invocations.primitives import (
6+
BaseInvocationOutput,
7+
FieldDescriptions,
8+
Input,
9+
LatentsField,
10+
OutputField,
11+
)
12+
from invokeai.backend.model_manager.config import MainDiffusersConfig
13+
from invokeai.invocation_api import (
14+
BaseInvocation,
15+
Classification,
16+
InputField,
17+
InvocationContext,
18+
invocation,
19+
invocation_output,
20+
)
21+
22+
23+
@invocation_output("bria_latent_sampler_output")
24+
class BriaLatentSamplerInvocationOutput(BaseInvocationOutput):
25+
"""Base class for nodes that output a CogView text conditioning tensor."""
26+
27+
latents: LatentsField = OutputField(description=FieldDescriptions.cond)
28+
latent_image_ids: LatentsField = OutputField(description=FieldDescriptions.cond)
29+
30+
31+
@invocation(
32+
"bria_latent_sampler",
33+
title="Latent Sampler - Bria",
34+
tags=["image", "bria"],
35+
category="image",
36+
version="1.0.0",
37+
classification=Classification.Prototype,
38+
)
39+
class BriaLatentSamplerInvocation(BaseInvocation):
40+
seed: int = InputField(
41+
default=42,
42+
title="Seed",
43+
description="The seed to use for the latent sampler",
44+
)
45+
transformer: TransformerField = InputField(
46+
description="Bria model (Transformer) to load",
47+
input=Input.Connection,
48+
title="Transformer",
49+
)
50+
51+
def invoke(self, context: InvocationContext) -> BriaLatentSamplerInvocationOutput:
52+
device = torch.device("cuda")
53+
transformer_config = context.models.get_config(self.transformer.transformer)
54+
if not isinstance(transformer_config, MainDiffusersConfig):
55+
raise ValueError("Transformer config is not a MainDiffusersConfig")
56+
# TODO: get latent channels from transformer config
57+
latent_channels = 16
58+
latent_height, latent_width = 128, 128
59+
shrunk = latent_channels // 4
60+
gen = torch.Generator(device=device).manual_seed(self.seed)
61+
62+
noise4d = torch.randn((1, shrunk, latent_height, latent_width), device=device, generator=gen)
63+
latents = noise4d.view(1, shrunk, latent_height // 2, 2, latent_width // 2, 2).permute(0, 2, 4, 1, 3, 5)
64+
latents = latents.reshape(1, (latent_height // 2) * (latent_width // 2), shrunk * 4)
65+
66+
latent_image_ids = torch.zeros((latent_height // 2, latent_width // 2, 3), device=device, dtype=torch.long)
67+
latent_image_ids[..., 1] = torch.arange(latent_height // 2, device=device)[:, None]
68+
latent_image_ids[..., 2] = torch.arange(latent_width // 2, device=device)[None, :]
69+
latent_image_ids = latent_image_ids.view(-1, 3)
70+
71+
saved_latents_tensor = context.tensors.save(latents)
72+
saved_latent_image_ids_tensor = context.tensors.save(latent_image_ids)
73+
latents_output = LatentsField(latents_name=saved_latents_tensor)
74+
latent_image_ids_output = LatentsField(latents_name=saved_latent_image_ids_tensor)
75+
76+
return BriaLatentSamplerInvocationOutput(
77+
latents=latents_output,
78+
latent_image_ids=latent_image_ids_output,
79+
)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
2+
from invokeai.app.invocations.model import (
3+
ModelIdentifierField,
4+
SubModelType,
5+
T5EncoderField,
6+
TransformerField,
7+
VAEField,
8+
)
9+
from invokeai.invocation_api import (
10+
BaseInvocation,
11+
BaseInvocationOutput,
12+
Classification,
13+
InputField,
14+
InvocationContext,
15+
OutputField,
16+
invocation,
17+
invocation_output,
18+
)
19+
20+
21+
@invocation_output("bria_model_loader_output")
22+
class BriaModelLoaderOutput(BaseInvocationOutput):
23+
"""Bria base model loader output"""
24+
25+
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
26+
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
27+
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
28+
29+
30+
@invocation(
31+
"bria_model_loader",
32+
title="Main Model - Bria",
33+
tags=["model", "bria"],
34+
version="1.0.0",
35+
classification=Classification.Prototype,
36+
)
37+
class BriaModelLoaderInvocation(BaseInvocation):
38+
"""Loads a bria base model, outputting its submodels."""
39+
40+
model: ModelIdentifierField = InputField(
41+
description="Bria model (Transformer) to load",
42+
ui_type=UIType.BriaMainModel,
43+
input=Input.Direct,
44+
)
45+
46+
def invoke(self, context: InvocationContext) -> BriaModelLoaderOutput:
47+
for key in [self.model.key]:
48+
if not context.models.exists(key):
49+
raise ValueError(f"Unknown model: {key}")
50+
51+
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
52+
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
53+
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
54+
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
55+
56+
return BriaModelLoaderOutput(
57+
transformer=TransformerField(transformer=transformer, loras=[]),
58+
t5_encoder=T5EncoderField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[]),
59+
vae=VAEField(vae=vae),
60+
)

0 commit comments

Comments
 (0)