Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
f34b865
Patchification research to condition the model
ArEnSc Jan 14, 2025
b341b79
Adding Conditioning Blocks to the Diffusion Transformer.
ArEnSc Jan 15, 2025
315431b
Recommit the conditioning code.
ArEnSc Jan 16, 2025
cd66e6e
fixed the training pipeline to address new changes from upstream
ArEnSc Jan 17, 2025
47356cd
Testing the original code path.
ArEnSc Jan 17, 2025
b73d61a
saving more changes related to args.
ArEnSc Jan 17, 2025
59afd55
Retested changes without conditioning it works.
ArEnSc Jan 17, 2025
44df9e3
reformulate the conditioning and patchification
ArEnSc Jan 18, 2025
2720d24
Moved the norm back to prepare latents step.
ArEnSc Jan 18, 2025
fa422c4
Adding conditioned transformer layer change.
ArEnSc Jan 18, 2025
3c605a1
Adding in the conditioned transformer 3d model
ArEnSc Jan 18, 2025
48d2b61
Need to test saving and loading.
ArEnSc Jan 18, 2025
34996c1
replaced the pipeline functions.
ArEnSc Jan 21, 2025
9fde660
Fixing weights copy.
ArEnSc Jan 22, 2025
07e755c
modifcations to the loading of the weights.
ArEnSc Jan 22, 2025
9817671
modify the forward pass
ArEnSc Jan 22, 2025
731f326
Added residual
ArEnSc Jan 22, 2025
d631737
Post patchify the latents.
ArEnSc Jan 22, 2025
d6b3615
Ok adapter trains need to get the pass for validation using a skeleton.
ArEnSc Jan 22, 2025
b7f26bd
removed some code made things clearer to replace the pipeline inferen…
ArEnSc Jan 23, 2025
64898ef
Test the conditioned pipeline and ensure it loads.
ArEnSc Jan 24, 2025
440738e
Add conditioned pipeline ...
ArEnSc Jan 24, 2025
305a4a9
Need to pass in image ref video.
ArEnSc Jan 24, 2025
c6490c0
Reducing cogload on the framework.
ArEnSc Jan 24, 2025
06e26ec
test and trace the changes.
ArEnSc Jan 24, 2025
e99d70f
Fixed some issues with imports and the original lora.py code.
ArEnSc Jan 25, 2025
80f653f
Passed data in for validation.
ArEnSc Jan 25, 2025
dfa03f1
Fix before trying again.
ArEnSc Jan 27, 2025
8ec35d5
trying training again but broke something so going back.
ArEnSc Jan 28, 2025
4ab3b25
Make a fix to the video tensors.
ArEnSc Jan 28, 2025
707e001
The encoder hidden states and encode attention mask sizes do not matc…
ArEnSc Jan 29, 2025
a162851
Noted place that requires intervention.
ArEnSc Jan 29, 2025
536dfc8
Figure out why classifier free guidence breaks the loop. I suspect it…
ArEnSc Jan 29, 2025
bca44f6
Remove some comments.
ArEnSc Jan 29, 2025
e4a14af
Add back in the pose for both channels.
ArEnSc Jan 29, 2025
41eec16
Renamed a few variables to reduce cog load.
ArEnSc Jan 29, 2025
8fe0ed2
Conditioned Adapter make residual optional.
ArEnSc Jan 29, 2025
be1cb10
make residual optional.
ArEnSc Jan 30, 2025
09e3f3d
move residual arg.
ArEnSc Jan 30, 2025
e399684
Make residuals none.
ArEnSc Jan 30, 2025
ea3cd5c
Full Finetune should work with this adapter removed the residual_x
ArEnSc Jan 30, 2025
28165e6
Adapter weights can be loaded.
ArEnSc Feb 4, 2025
3fe4de9
Modifed the inference loop to try something. also added a copy of the…
ArEnSc Feb 4, 2025
d65958b
Checking inputs.
ArEnSc Feb 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,14 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

ltx-video/samples
# manually added
wandb/
*.txt
dump*
outputs*
*.slurm
.vscode/

ltx-video/ltxv_strip
!requirements.txt
video-dataset-pose-test
40 changes: 40 additions & 0 deletions finetrainers/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ class Args:
dataset_file: Optional[str] = None
video_column: str = None
caption_column: str = None
pose_column:str = None

id_token: Optional[str] = None
image_resolution_buckets: List[Tuple[int, int]] = None
video_resolution_buckets: List[Tuple[int, int, int]] = None
Expand Down Expand Up @@ -314,6 +316,11 @@ class Args:
validation_prompts: List[str] = None
validation_images: List[str] = None
validation_videos: List[str] = None

# Condition
validation_pose_videos: List[str] = None
validation_img_ref_videos: List[str] = None

validation_heights: List[int] = None
validation_widths: List[int] = None
validation_num_frames: List[int] = None
Expand Down Expand Up @@ -353,6 +360,7 @@ def to_dict(self) -> Dict[str, Any]:
"dataset_file": self.dataset_file,
"video_column": self.video_column,
"caption_column": self.caption_column,
"pose_column": self.pose_column,
"id_token": self.id_token,
"image_resolution_buckets": self.image_resolution_buckets,
"video_resolution_buckets": self.video_resolution_buckets,
Expand Down Expand Up @@ -550,6 +558,12 @@ def parse_video_resolution_bucket(resolution_bucket: str) -> Tuple[int, int, int
default="text",
help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--data_root` folder containing the line-separated instance prompts.",
)
parser.add_argument(
"--pose_column",
type=str,
default=None,
help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--data_root` folder containing the line-separated instance prompts.",
)
parser.add_argument(
"--id_token",
type=str,
Expand Down Expand Up @@ -870,6 +884,21 @@ def _add_validation_arguments(parser: argparse.ArgumentParser) -> None:
default=None,
help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
)

parser.add_argument(
"--validation_pose_videos",
type=str,
default=None,
help="One or more pose_videos path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
)

parser.add_argument(
"--validation_img_ref_videos",
type=str,
default=None,
help="One or more img_ref video or img will just take the first frame and create a video from it path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
)

parser.add_argument(
"--validation_videos",
type=str,
Expand Down Expand Up @@ -1006,6 +1035,7 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
result_args.dataset_file = args.dataset_file
result_args.video_column = args.video_column
result_args.caption_column = args.caption_column
result_args.pose_column = args.pose_column
result_args.id_token = args.id_token
result_args.image_resolution_buckets = args.image_resolution_buckets or DEFAULT_IMAGE_RESOLUTION_BUCKETS
result_args.video_resolution_buckets = args.video_resolution_buckets or DEFAULT_VIDEO_RESOLUTION_BUCKETS
Expand Down Expand Up @@ -1069,6 +1099,12 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
validation_prompts = args.validation_prompts.split(args.validation_separator) if args.validation_prompts else []
validation_images = args.validation_images.split(args.validation_separator) if args.validation_images else None
validation_videos = args.validation_videos.split(args.validation_separator) if args.validation_videos else None

# extend with pose conditioning ...
validation_pose_videos = args.validation_pose_videos.split(args.validation_separator) if args.validation_pose_videos else None
validation_img_ref_videos = args.validation_img_ref_videos.split(args.validation_separator) if args.validation_img_ref_videos else None


stripped_validation_prompts = []
validation_heights = []
validation_widths = []
Expand Down Expand Up @@ -1097,6 +1133,10 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
result_args.validation_images = validation_images
result_args.validation_videos = validation_videos

# extend with pose conditioning ...
result_args.validation_pose_videos = validation_pose_videos
result_args.validation_img_ref_videos = validation_img_ref_videos

result_args.num_validation_videos_per_prompt = args.num_validation_videos
result_args.validation_every_n_epochs = args.validation_epochs
result_args.validation_every_n_steps = args.validation_steps
Expand Down
196 changes: 196 additions & 0 deletions finetrainers/conditioning/LTXVideoConditionedTransformer3DModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import math
from typing import Any, Dict, Optional, Tuple
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import LTXVideoTransformer3DModel
from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import is_torch_version
from finetrainers.conditioning.conditioned_residual_adapter_bottleneck import ConditionedResidualAdapterBottleneck

@maybe_allow_in_graph
class LTXVideoConditionedTransformer3DModel(LTXVideoTransformer3DModel):
def __init__(self,
in_channels: int = 128,
out_channels: int = 128,
patch_size: int = 1,
patch_size_t: int = 1,
num_attention_heads: int = 32,
attention_head_dim: int = 64,
cross_attention_dim: int = 2048,
num_layers: int = 28,
activation_fn: str = "gelu-approximate",
qk_norm: str = "rms_norm_across_heads",
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
caption_channels: int = 4096,
attention_bias: bool = True,
attention_out_bias: bool = True,
adapter_in_dim:int = 256):

super().__init__(
in_channels=in_channels,
out_channels=out_channels,
patch_size=patch_size,
patch_size_t=patch_size_t,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
cross_attention_dim=cross_attention_dim,
num_layers=num_layers,
activation_fn=activation_fn,
qk_norm=qk_norm,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
caption_channels=caption_channels,
attention_bias=attention_bias,
attention_out_bias=attention_out_bias
)

# adapter.down_proj.weight
self.adapter = ConditionedResidualAdapterBottleneck(
input_dim=adapter_in_dim,
output_dim=128,
bottleneck_dim=64,
adapter_dropout=0.1,
adapter_init_scale=1e-3
)

@classmethod # rewrite this to ensure that if it doesn't have adapter weights it will load this way but if it has adapter weights it will not.
def from_pretrained(cls, pretrained_model_name_or_path,save_directory, **kwargs):
# First create an empty model with the desired architecture
print("Pretrain Loading")
model = cls()
model_dict = model.state_dict()
# # Then load the pretrained weights into it
pretrained = LTXVideoTransformer3DModel.from_pretrained(pretrained_model_name_or_path, **kwargs)

# Copy over the pretrained weights for the shared components
pretrained_dict = pretrained.state_dict()

# Filter out adapter weights from the pretrained dict
filtered_dict = {}
for k, v in pretrained_dict.items():
if k in model_dict:

filtered_dict[k] = v
else:
print(k)
# Update model with pretrained weights
model_dict.update(filtered_dict)

adapter_weights_path = os.path.join(save_directory, "adapter_weights.pth")
if os.path.exists(adapter_weights_path):
adapter_weights = torch.load(adapter_weights_path)
model.adapter.load_state_dict(adapter_weights)
print("Adapter weights loaded successfully.")
else:
print("No adapter weights file found. Using default adapter weights.")

model.load_state_dict(model_dict)
return model

def save_pretrained(self,save_directory):
model_state_dict = self.state_dict()
torch.save(model_state_dict, save_directory)

adapter_state_dict = self.adapter.state_dict()
save_directory = os.path.dirname(save_directory)
path = os.path.join(save_directory, "adapter_weights.pth")
torch.save(adapter_state_dict, path)
print("Saving Model with Adapter")

def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor,
num_frames: int,
height: int,
width: int,
rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
residual_x: torch.Tensor = None
) -> torch.Tensor:

image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)

# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)

batch_size = hidden_states.size(0)

# inject the condition and the residual then project it into the pretrained proj_in
hidden_states = self.adapter(residual_x=residual_x,
conditioned_x=hidden_states)
# whats this value ?
hidden_states = self.proj_in(hidden_states)

temb, embedded_timestep = self.time_embed(
timestep.flatten(),
batch_size=batch_size,
hidden_dtype=hidden_states.dtype,
)

temb = temb.view(batch_size, -1, temb.size(-1))
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))

encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))

for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint( # crashes here
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
encoder_attention_mask,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
encoder_attention_mask=encoder_attention_mask,
)

scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]

hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states * (1 + scale) + shift
output = self.proj_out(hidden_states)


if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)


def apply_rotary_emb(x, freqs):
cos, sin = freqs
x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out

1 change: 1 addition & 0 deletions finetrainers/conditioning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .condition_latents_prepare import post_conditioned_latent_patchify, prepare_latents_for_conditioning
81 changes: 81 additions & 0 deletions finetrainers/conditioning/condition_latents_prepare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import Optional
import torch
from diffusers import AutoencoderKLLTXVideo

def _normalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Normalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = (latents - latents_mean) * scaling_factor / latents_std
return latents

def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
batch_size, num_channels, num_frames, height, width = latents.shape
post_patch_num_frames = num_frames // patch_size_t
post_patch_height = height // patch_size
post_patch_width = width // patch_size

# to check math lol
# dim1 = num_frames // patch_size_t * height // patch_size * width // patch_size
# dim2 = num_channels * patch_size_t * patch_size * patch_size

latents = latents.reshape(
batch_size,
-1,
post_patch_num_frames,
patch_size_t,
post_patch_height,
patch_size,
post_patch_width,
patch_size,
)
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
return latents

def post_conditioned_latent_patchify(
latents: torch.Tensor,
num_frames: int,
height: int,
width: int,
patch_size: int = 1,
patch_size_t: int = 1,
**kwargs,
) -> torch.Tensor:
latents = _pack_latents(latents, patch_size, patch_size_t)
return {"latents": latents, "num_frames": num_frames, "height": height, "width": width}

def prepare_latents_for_conditioning(
vae: AutoencoderKLLTXVideo,
image_or_video: torch.Tensor,
patch_size: int = 1,
patch_size_t: int = 1,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
generator: Optional[torch.Generator] = None,

) -> torch.Tensor:
device = device or vae.device

if image_or_video.ndim == 4:
image_or_video = image_or_video.unsqueeze(2)
assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor"

image_or_video = image_or_video.to(device=device, dtype=vae.dtype)
image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, C, F, H, W] -> [B, F, C, H, W]

latents = vae.encode(image_or_video).latent_dist.sample(generator=generator)

latents = latents.to(dtype=dtype)
_, _, num_frames, height, width = latents.shape
latents = _normalize_latents(latents, vae.latents_mean, vae.latents_std)
# latents = _pack_latents(latents, patch_size, patch_size_t)
return {"latents": latents,
"num_frames": num_frames,
"height": height,
"width": width }
Loading
Loading