|
| 1 | +import math |
| 2 | +import os |
| 3 | +from typing import List, Optional, Union |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import torch |
| 7 | +import torch.distributed as dist |
| 8 | +from diffusers.utils import logging |
| 9 | +from transformers import ( |
| 10 | + CLIPTextModel, |
| 11 | + CLIPTextModelWithProjection, |
| 12 | + CLIPTokenizer, |
| 13 | + T5EncoderModel, |
| 14 | + T5TokenizerFast, |
| 15 | +) |
| 16 | + |
| 17 | +logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
| 18 | + |
| 19 | + |
| 20 | +def get_t5_prompt_embeds( |
| 21 | + tokenizer: T5TokenizerFast, |
| 22 | + text_encoder: T5EncoderModel, |
| 23 | + prompt: Union[str, List[str], None] = None, |
| 24 | + num_images_per_prompt: int = 1, |
| 25 | + max_sequence_length: int = 128, |
| 26 | + device: Optional[torch.device] = None, |
| 27 | +): |
| 28 | + device = device or text_encoder.device |
| 29 | + |
| 30 | + if prompt is None: |
| 31 | + prompt = "" |
| 32 | + |
| 33 | + prompt = [prompt] if isinstance(prompt, str) else prompt |
| 34 | + batch_size = len(prompt) |
| 35 | + |
| 36 | + text_inputs = tokenizer( |
| 37 | + prompt, |
| 38 | + # padding="max_length", |
| 39 | + max_length=max_sequence_length, |
| 40 | + truncation=True, |
| 41 | + add_special_tokens=True, |
| 42 | + return_tensors="pt", |
| 43 | + ) |
| 44 | + text_input_ids = text_inputs.input_ids |
| 45 | + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids |
| 46 | + |
| 47 | + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): |
| 48 | + removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) |
| 49 | + logger.warning( |
| 50 | + "The following part of your input was truncated because `max_sequence_length` is set to " |
| 51 | + f" {max_sequence_length} tokens: {removed_text}" |
| 52 | + ) |
| 53 | + |
| 54 | + prompt_embeds = text_encoder(text_input_ids.to(device))[0] |
| 55 | + |
| 56 | + # Concat zeros to max_sequence |
| 57 | + b, seq_len, dim = prompt_embeds.shape |
| 58 | + if seq_len < max_sequence_length: |
| 59 | + padding = torch.zeros( |
| 60 | + (b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device |
| 61 | + ) |
| 62 | + prompt_embeds = torch.concat([prompt_embeds, padding], dim=1) |
| 63 | + |
| 64 | + prompt_embeds = prompt_embeds.to(device=device) |
| 65 | + |
| 66 | + _, seq_len, _ = prompt_embeds.shape |
| 67 | + |
| 68 | + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method |
| 69 | + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
| 70 | + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) |
| 71 | + |
| 72 | + return prompt_embeds |
| 73 | + |
| 74 | + |
| 75 | +# in order the get the same sigmas as in training and sample from them |
| 76 | +def get_original_sigmas(num_train_timesteps=1000, num_inference_steps=1000): |
| 77 | + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() |
| 78 | + sigmas = timesteps / num_train_timesteps |
| 79 | + |
| 80 | + inds = [int(ind) for ind in np.linspace(0, num_train_timesteps - 1, num_inference_steps)] |
| 81 | + new_sigmas = sigmas[inds] |
| 82 | + return new_sigmas |
| 83 | + |
| 84 | + |
| 85 | +def is_ng_none(negative_prompt): |
| 86 | + return ( |
| 87 | + negative_prompt is None |
| 88 | + or negative_prompt == "" |
| 89 | + or (isinstance(negative_prompt, list) and negative_prompt[0] is None) |
| 90 | + or (type(negative_prompt) == list and negative_prompt[0] == "") |
| 91 | + ) |
| 92 | + |
| 93 | + |
| 94 | +class CudaTimerContext: |
| 95 | + def __init__(self, times_arr): |
| 96 | + self.times_arr = times_arr |
| 97 | + |
| 98 | + def __enter__(self): |
| 99 | + self.before_event = torch.cuda.Event(enable_timing=True) |
| 100 | + self.after_event = torch.cuda.Event(enable_timing=True) |
| 101 | + self.before_event.record() |
| 102 | + |
| 103 | + def __exit__(self, type, value, traceback): |
| 104 | + self.after_event.record() |
| 105 | + torch.cuda.synchronize() |
| 106 | + elapsed_time = self.before_event.elapsed_time(self.after_event) / 1000 |
| 107 | + self.times_arr.append(elapsed_time) |
| 108 | + |
| 109 | + |
| 110 | +def get_env_prefix(): |
| 111 | + env = os.environ.get("CLOUD_PROVIDER", "AWS").upper() |
| 112 | + if env == "AWS": |
| 113 | + return "SM_CHANNEL" |
| 114 | + elif env == "AZURE": |
| 115 | + return "AZUREML_DATAREFERENCE" |
| 116 | + |
| 117 | + raise Exception(f"Env {env} not supported") |
| 118 | + |
| 119 | + |
| 120 | +def compute_density_for_timestep_sampling( |
| 121 | + weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None |
| 122 | +): |
| 123 | + """Compute the density for sampling the timesteps when doing SD3 training. |
| 124 | +
|
| 125 | + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. |
| 126 | +
|
| 127 | + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. |
| 128 | + """ |
| 129 | + if weighting_scheme == "logit_normal": |
| 130 | + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). |
| 131 | + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") |
| 132 | + u = torch.nn.functional.sigmoid(u) |
| 133 | + elif weighting_scheme == "mode": |
| 134 | + u = torch.rand(size=(batch_size,), device="cpu") |
| 135 | + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) |
| 136 | + else: |
| 137 | + u = torch.rand(size=(batch_size,), device="cpu") |
| 138 | + return u |
| 139 | + |
| 140 | + |
| 141 | +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): |
| 142 | + """Computes loss weighting scheme for SD3 training. |
| 143 | +
|
| 144 | + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. |
| 145 | +
|
| 146 | + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. |
| 147 | + """ |
| 148 | + if weighting_scheme == "sigma_sqrt": |
| 149 | + weighting = (sigmas**-2.0).float() |
| 150 | + elif weighting_scheme == "cosmap": |
| 151 | + bot = 1 - 2 * sigmas + 2 * sigmas**2 |
| 152 | + weighting = 2 / (math.pi * bot) |
| 153 | + else: |
| 154 | + weighting = torch.ones_like(sigmas) |
| 155 | + return weighting |
| 156 | + |
| 157 | + |
| 158 | +def initialize_distributed(): |
| 159 | + # Initialize the process group for distributed training |
| 160 | + dist.init_process_group("nccl") |
| 161 | + |
| 162 | + # Get the current process's rank (ID) and the total number of processes (world size) |
| 163 | + rank = dist.get_rank() |
| 164 | + world_size = dist.get_world_size() |
| 165 | + |
| 166 | + print(f"Initialized distributed training: Rank {rank}/{world_size}") |
| 167 | + |
| 168 | + |
| 169 | +def get_clip_prompt_embeds( |
| 170 | + text_encoder: CLIPTextModel, |
| 171 | + text_encoder_2: CLIPTextModelWithProjection, |
| 172 | + tokenizer: CLIPTokenizer, |
| 173 | + tokenizer_2: CLIPTokenizer, |
| 174 | + prompt: Union[str, List[str]] = None, |
| 175 | + num_images_per_prompt: int = 1, |
| 176 | + max_sequence_length: int = 77, |
| 177 | + device: Optional[torch.device] = None, |
| 178 | +): |
| 179 | + device = device or text_encoder.device |
| 180 | + assert max_sequence_length == tokenizer.model_max_length |
| 181 | + prompt = [prompt] if isinstance(prompt, str) else prompt |
| 182 | + |
| 183 | + # Define tokenizers and text encoders |
| 184 | + tokenizers = [tokenizer, tokenizer_2] |
| 185 | + text_encoders = [text_encoder, text_encoder_2] |
| 186 | + |
| 187 | + # textual inversion: process multi-vector tokens if necessary |
| 188 | + prompt_embeds_list = [] |
| 189 | + prompts = [prompt, prompt] |
| 190 | + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders, strict=False): |
| 191 | + text_inputs = tokenizer( |
| 192 | + prompt, |
| 193 | + padding="max_length", |
| 194 | + max_length=tokenizer.model_max_length, |
| 195 | + truncation=True, |
| 196 | + return_tensors="pt", |
| 197 | + ) |
| 198 | + |
| 199 | + text_input_ids = text_inputs.input_ids |
| 200 | + prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), output_hidden_states=True) |
| 201 | + |
| 202 | + # We are only ALWAYS interested in the pooled output of the final text encoder |
| 203 | + pooled_prompt_embeds = prompt_embeds[0] |
| 204 | + prompt_embeds = prompt_embeds.hidden_states[-2] |
| 205 | + |
| 206 | + prompt_embeds_list.append(prompt_embeds) |
| 207 | + |
| 208 | + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) |
| 209 | + |
| 210 | + bs_embed, seq_len, _ = prompt_embeds.shape |
| 211 | + # duplicate text embeddings for each generation per prompt, using mps friendly method |
| 212 | + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
| 213 | + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) |
| 214 | + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( |
| 215 | + bs_embed * num_images_per_prompt, -1 |
| 216 | + ) |
| 217 | + |
| 218 | + return prompt_embeds, pooled_prompt_embeds |
| 219 | + |
| 220 | + |
| 221 | +def get_1d_rotary_pos_embed( |
| 222 | + dim: int, |
| 223 | + pos: Union[np.ndarray, int], |
| 224 | + theta: float = 10000.0, |
| 225 | + use_real=False, |
| 226 | + linear_factor=1.0, |
| 227 | + ntk_factor=1.0, |
| 228 | + repeat_interleave_real=True, |
| 229 | + freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux) |
| 230 | +): |
| 231 | + """ |
| 232 | + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. |
| 233 | +
|
| 234 | + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end |
| 235 | + index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 |
| 236 | + data type. |
| 237 | +
|
| 238 | + Args: |
| 239 | + dim (`int`): Dimension of the frequency tensor. |
| 240 | + pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar |
| 241 | + theta (`float`, *optional*, defaults to 10000.0): |
| 242 | + Scaling factor for frequency computation. Defaults to 10000.0. |
| 243 | + use_real (`bool`, *optional*): |
| 244 | + If True, return real part and imaginary part separately. Otherwise, return complex numbers. |
| 245 | + linear_factor (`float`, *optional*, defaults to 1.0): |
| 246 | + Scaling factor for the context extrapolation. Defaults to 1.0. |
| 247 | + ntk_factor (`float`, *optional*, defaults to 1.0): |
| 248 | + Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. |
| 249 | + repeat_interleave_real (`bool`, *optional*, defaults to `True`): |
| 250 | + If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. |
| 251 | + Otherwise, they are concateanted with themselves. |
| 252 | + freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): |
| 253 | + the dtype of the frequency tensor. |
| 254 | + Returns: |
| 255 | + `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] |
| 256 | + """ |
| 257 | + assert dim % 2 == 0 |
| 258 | + |
| 259 | + if isinstance(pos, int): |
| 260 | + pos = torch.arange(pos) |
| 261 | + if isinstance(pos, np.ndarray): |
| 262 | + pos = torch.from_numpy(pos) # type: ignore # [S] |
| 263 | + |
| 264 | + theta = theta * ntk_factor |
| 265 | + freqs = ( |
| 266 | + 1.0 |
| 267 | + / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) |
| 268 | + / linear_factor |
| 269 | + ) # [D/2] |
| 270 | + freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] |
| 271 | + if use_real and repeat_interleave_real: |
| 272 | + # flux, hunyuan-dit, cogvideox |
| 273 | + freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] |
| 274 | + freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] |
| 275 | + return freqs_cos, freqs_sin |
| 276 | + elif use_real: |
| 277 | + # stable audio, allegro |
| 278 | + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] |
| 279 | + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] |
| 280 | + return freqs_cos, freqs_sin |
| 281 | + else: |
| 282 | + # lumina |
| 283 | + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] |
| 284 | + return freqs_cis |
| 285 | + |
| 286 | + |
| 287 | +class FluxPosEmbed(torch.nn.Module): |
| 288 | + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 |
| 289 | + def __init__(self, theta: int, axes_dim: List[int]): |
| 290 | + super().__init__() |
| 291 | + self.theta = theta |
| 292 | + self.axes_dim = axes_dim |
| 293 | + |
| 294 | + def forward(self, ids: torch.Tensor) -> torch.Tensor: |
| 295 | + n_axes = ids.shape[-1] |
| 296 | + cos_out = [] |
| 297 | + sin_out = [] |
| 298 | + pos = ids.float() |
| 299 | + is_mps = ids.device.type == "mps" |
| 300 | + freqs_dtype = torch.float32 if is_mps else torch.float64 |
| 301 | + for i in range(n_axes): |
| 302 | + cos, sin = get_1d_rotary_pos_embed( |
| 303 | + self.axes_dim[i], |
| 304 | + pos[:, i], |
| 305 | + theta=self.theta, |
| 306 | + repeat_interleave_real=True, |
| 307 | + use_real=True, |
| 308 | + freqs_dtype=freqs_dtype, |
| 309 | + ) |
| 310 | + cos_out.append(cos) |
| 311 | + sin_out.append(sin) |
| 312 | + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) |
| 313 | + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) |
| 314 | + return freqs_cos, freqs_sin |
0 commit comments