Skip to content

Commit f542313

Browse files
author
Ubuntu
committed
added support for loading bria transformer
1 parent 9c9265c commit f542313

File tree

5 files changed

+1096
-0
lines changed

5 files changed

+1096
-0
lines changed

invokeai/backend/bria/__init__.py

Whitespace-only changes.

invokeai/backend/bria/bria_utils.py

Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
import math
2+
import os
3+
from typing import List, Optional, Union
4+
5+
import numpy as np
6+
import torch
7+
import torch.distributed as dist
8+
from diffusers.utils import logging
9+
from transformers import (
10+
CLIPTextModel,
11+
CLIPTextModelWithProjection,
12+
CLIPTokenizer,
13+
T5EncoderModel,
14+
T5TokenizerFast,
15+
)
16+
17+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18+
19+
20+
def get_t5_prompt_embeds(
21+
tokenizer: T5TokenizerFast,
22+
text_encoder: T5EncoderModel,
23+
prompt: Union[str, List[str], None] = None,
24+
num_images_per_prompt: int = 1,
25+
max_sequence_length: int = 128,
26+
device: Optional[torch.device] = None,
27+
):
28+
device = device or text_encoder.device
29+
30+
if prompt is None:
31+
prompt = ""
32+
33+
prompt = [prompt] if isinstance(prompt, str) else prompt
34+
batch_size = len(prompt)
35+
36+
text_inputs = tokenizer(
37+
prompt,
38+
# padding="max_length",
39+
max_length=max_sequence_length,
40+
truncation=True,
41+
add_special_tokens=True,
42+
return_tensors="pt",
43+
)
44+
text_input_ids = text_inputs.input_ids
45+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
46+
47+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
48+
removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
49+
logger.warning(
50+
"The following part of your input was truncated because `max_sequence_length` is set to "
51+
f" {max_sequence_length} tokens: {removed_text}"
52+
)
53+
54+
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
55+
56+
# Concat zeros to max_sequence
57+
b, seq_len, dim = prompt_embeds.shape
58+
if seq_len < max_sequence_length:
59+
padding = torch.zeros(
60+
(b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
61+
)
62+
prompt_embeds = torch.concat([prompt_embeds, padding], dim=1)
63+
64+
prompt_embeds = prompt_embeds.to(device=device)
65+
66+
_, seq_len, _ = prompt_embeds.shape
67+
68+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
69+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
70+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
71+
72+
return prompt_embeds
73+
74+
75+
# in order the get the same sigmas as in training and sample from them
76+
def get_original_sigmas(num_train_timesteps=1000, num_inference_steps=1000):
77+
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
78+
sigmas = timesteps / num_train_timesteps
79+
80+
inds = [int(ind) for ind in np.linspace(0, num_train_timesteps - 1, num_inference_steps)]
81+
new_sigmas = sigmas[inds]
82+
return new_sigmas
83+
84+
85+
def is_ng_none(negative_prompt):
86+
return (
87+
negative_prompt is None
88+
or negative_prompt == ""
89+
or (isinstance(negative_prompt, list) and negative_prompt[0] is None)
90+
or (type(negative_prompt) == list and negative_prompt[0] == "")
91+
)
92+
93+
94+
class CudaTimerContext:
95+
def __init__(self, times_arr):
96+
self.times_arr = times_arr
97+
98+
def __enter__(self):
99+
self.before_event = torch.cuda.Event(enable_timing=True)
100+
self.after_event = torch.cuda.Event(enable_timing=True)
101+
self.before_event.record()
102+
103+
def __exit__(self, type, value, traceback):
104+
self.after_event.record()
105+
torch.cuda.synchronize()
106+
elapsed_time = self.before_event.elapsed_time(self.after_event) / 1000
107+
self.times_arr.append(elapsed_time)
108+
109+
110+
def get_env_prefix():
111+
env = os.environ.get("CLOUD_PROVIDER", "AWS").upper()
112+
if env == "AWS":
113+
return "SM_CHANNEL"
114+
elif env == "AZURE":
115+
return "AZUREML_DATAREFERENCE"
116+
117+
raise Exception(f"Env {env} not supported")
118+
119+
120+
def compute_density_for_timestep_sampling(
121+
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
122+
):
123+
"""Compute the density for sampling the timesteps when doing SD3 training.
124+
125+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
126+
127+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
128+
"""
129+
if weighting_scheme == "logit_normal":
130+
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
131+
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
132+
u = torch.nn.functional.sigmoid(u)
133+
elif weighting_scheme == "mode":
134+
u = torch.rand(size=(batch_size,), device="cpu")
135+
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
136+
else:
137+
u = torch.rand(size=(batch_size,), device="cpu")
138+
return u
139+
140+
141+
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
142+
"""Computes loss weighting scheme for SD3 training.
143+
144+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
145+
146+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
147+
"""
148+
if weighting_scheme == "sigma_sqrt":
149+
weighting = (sigmas**-2.0).float()
150+
elif weighting_scheme == "cosmap":
151+
bot = 1 - 2 * sigmas + 2 * sigmas**2
152+
weighting = 2 / (math.pi * bot)
153+
else:
154+
weighting = torch.ones_like(sigmas)
155+
return weighting
156+
157+
158+
def initialize_distributed():
159+
# Initialize the process group for distributed training
160+
dist.init_process_group("nccl")
161+
162+
# Get the current process's rank (ID) and the total number of processes (world size)
163+
rank = dist.get_rank()
164+
world_size = dist.get_world_size()
165+
166+
print(f"Initialized distributed training: Rank {rank}/{world_size}")
167+
168+
169+
def get_clip_prompt_embeds(
170+
text_encoder: CLIPTextModel,
171+
text_encoder_2: CLIPTextModelWithProjection,
172+
tokenizer: CLIPTokenizer,
173+
tokenizer_2: CLIPTokenizer,
174+
prompt: Union[str, List[str]] = None,
175+
num_images_per_prompt: int = 1,
176+
max_sequence_length: int = 77,
177+
device: Optional[torch.device] = None,
178+
):
179+
device = device or text_encoder.device
180+
assert max_sequence_length == tokenizer.model_max_length
181+
prompt = [prompt] if isinstance(prompt, str) else prompt
182+
183+
# Define tokenizers and text encoders
184+
tokenizers = [tokenizer, tokenizer_2]
185+
text_encoders = [text_encoder, text_encoder_2]
186+
187+
# textual inversion: process multi-vector tokens if necessary
188+
prompt_embeds_list = []
189+
prompts = [prompt, prompt]
190+
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders, strict=False):
191+
text_inputs = tokenizer(
192+
prompt,
193+
padding="max_length",
194+
max_length=tokenizer.model_max_length,
195+
truncation=True,
196+
return_tensors="pt",
197+
)
198+
199+
text_input_ids = text_inputs.input_ids
200+
prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True)
201+
202+
# We are only ALWAYS interested in the pooled output of the final text encoder
203+
pooled_prompt_embeds = prompt_embeds[0]
204+
prompt_embeds = prompt_embeds.hidden_states[-2]
205+
206+
prompt_embeds_list.append(prompt_embeds)
207+
208+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
209+
210+
bs_embed, seq_len, _ = prompt_embeds.shape
211+
# duplicate text embeddings for each generation per prompt, using mps friendly method
212+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
213+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
214+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
215+
bs_embed * num_images_per_prompt, -1
216+
)
217+
218+
return prompt_embeds, pooled_prompt_embeds
219+
220+
221+
def get_1d_rotary_pos_embed(
222+
dim: int,
223+
pos: Union[np.ndarray, int],
224+
theta: float = 10000.0,
225+
use_real=False,
226+
linear_factor=1.0,
227+
ntk_factor=1.0,
228+
repeat_interleave_real=True,
229+
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
230+
):
231+
"""
232+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
233+
234+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
235+
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
236+
data type.
237+
238+
Args:
239+
dim (`int`): Dimension of the frequency tensor.
240+
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
241+
theta (`float`, *optional*, defaults to 10000.0):
242+
Scaling factor for frequency computation. Defaults to 10000.0.
243+
use_real (`bool`, *optional*):
244+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
245+
linear_factor (`float`, *optional*, defaults to 1.0):
246+
Scaling factor for the context extrapolation. Defaults to 1.0.
247+
ntk_factor (`float`, *optional*, defaults to 1.0):
248+
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
249+
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
250+
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
251+
Otherwise, they are concateanted with themselves.
252+
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
253+
the dtype of the frequency tensor.
254+
Returns:
255+
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
256+
"""
257+
assert dim % 2 == 0
258+
259+
if isinstance(pos, int):
260+
pos = torch.arange(pos)
261+
if isinstance(pos, np.ndarray):
262+
pos = torch.from_numpy(pos) # type: ignore # [S]
263+
264+
theta = theta * ntk_factor
265+
freqs = (
266+
1.0
267+
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
268+
/ linear_factor
269+
) # [D/2]
270+
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
271+
if use_real and repeat_interleave_real:
272+
# flux, hunyuan-dit, cogvideox
273+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
274+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
275+
return freqs_cos, freqs_sin
276+
elif use_real:
277+
# stable audio, allegro
278+
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
279+
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
280+
return freqs_cos, freqs_sin
281+
else:
282+
# lumina
283+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
284+
return freqs_cis
285+
286+
287+
class FluxPosEmbed(torch.nn.Module):
288+
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
289+
def __init__(self, theta: int, axes_dim: List[int]):
290+
super().__init__()
291+
self.theta = theta
292+
self.axes_dim = axes_dim
293+
294+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
295+
n_axes = ids.shape[-1]
296+
cos_out = []
297+
sin_out = []
298+
pos = ids.float()
299+
is_mps = ids.device.type == "mps"
300+
freqs_dtype = torch.float32 if is_mps else torch.float64
301+
for i in range(n_axes):
302+
cos, sin = get_1d_rotary_pos_embed(
303+
self.axes_dim[i],
304+
pos[:, i],
305+
theta=self.theta,
306+
repeat_interleave_real=True,
307+
use_real=True,
308+
freqs_dtype=freqs_dtype,
309+
)
310+
cos_out.append(cos)
311+
sin_out.append(sin)
312+
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
313+
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
314+
return freqs_cos, freqs_sin

0 commit comments

Comments
 (0)