From 9b209b09ca5198a6d2d28cd559620b8783d4f453 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 13 Jun 2025 15:30:34 +0530 Subject: [PATCH 1/3] show how metadata stuff should be incorporated in training scripts. --- .../dreambooth/test_dreambooth_lora_flux.py | 45 +++++++++++++++++++ .../dreambooth/train_dreambooth_lora_flux.py | 22 ++++++--- src/diffusers/training_utils.py | 8 ++++ 3 files changed, 70 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/test_dreambooth_lora_flux.py b/examples/dreambooth/test_dreambooth_lora_flux.py index a76825e29448..837a537b5a4e 100644 --- a/examples/dreambooth/test_dreambooth_lora_flux.py +++ b/examples/dreambooth/test_dreambooth_lora_flux.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging import os import sys @@ -20,6 +21,8 @@ import safetensors +from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY + sys.path.append("..") from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 @@ -234,3 +237,45 @@ def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_mult run_command(self._launch_args + resume_run_args) self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) + + def test_dreambooth_lora_with_metadata(self): + # Use a `lora_alpha` that is different from `rank`. + lora_alpha = 8 + rank = 4 + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --lora_alpha={lora_alpha} + --rank={rank} + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") + self.assertTrue(os.path.isfile(state_dict_file)) + + # Check if the metadata was properly serialized. + with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f: + metadata = f.metadata() or {} + + metadata.pop("format", None) + raw = metadata.get(LORA_ADAPTER_METADATA_KEY) + if raw: + raw = json.loads(raw) + + loaded_lora_alpha = raw["transformer.lora_alpha"] + self.assertTrue(loaded_lora_alpha == lora_alpha) + loaded_lora_rank = raw["transformer.r"] + self.assertTrue(loaded_lora_rank == rank) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 1caf9c62d79b..f27af490b1f1 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -27,7 +27,6 @@ import numpy as np import torch -import torch.utils.checkpoint import transformers from accelerate import Accelerator from accelerate.logging import get_logger @@ -53,6 +52,7 @@ ) from diffusers.optimization import get_scheduler from diffusers.training_utils import ( + _collate_lora_metadata, _set_state_dict_into_text_encoder, cast_training_params, compute_density_for_timestep_sampling, @@ -358,7 +358,12 @@ def parse_args(input_args=None): default=4, help=("The dimension of the LoRA update matrices."), ) - + parser.add_argument( + "--lora_alpha", + type=int, + default=4, + help="LoRA alpha to be used for additional scaling.", + ) parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers") parser.add_argument( @@ -1238,7 +1243,7 @@ def main(args): # now we will add new LoRA weights the transformer layers transformer_lora_config = LoraConfig( r=args.rank, - lora_alpha=args.rank, + lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, init_lora_weights="gaussian", target_modules=target_modules, @@ -1247,7 +1252,7 @@ def main(args): if args.train_text_encoder: text_lora_config = LoraConfig( r=args.rank, - lora_alpha=args.rank, + lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], @@ -1264,12 +1269,14 @@ def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: transformer_lora_layers_to_save = None text_encoder_one_lora_layers_to_save = None - + modules_to_save = [] for model in models: if isinstance(model, type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model) + modules_to_save["transformer"] = model elif isinstance(model, type(unwrap_model(text_encoder_one))): text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) + modules_to_save["text_encoder"] = model else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1280,6 +1287,7 @@ def save_model_hook(models, weights, output_dir): output_dir, transformer_lora_layers=transformer_lora_layers_to_save, text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, + **_collate_lora_metadata(modules_to_save), ) def load_model_hook(models, input_dir): @@ -1889,16 +1897,19 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: + modules_to_save = {} transformer = unwrap_model(transformer) if args.upcast_before_saving: transformer.to(torch.float32) else: transformer = transformer.to(weight_dtype) transformer_lora_layers = get_peft_model_state_dict(transformer) + modules_to_save["transformer"] = transformer if args.train_text_encoder: text_encoder_one = unwrap_model(text_encoder_one) text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) + modules_to_save["text_encoder"] = text_encoder_one else: text_encoder_lora_layers = None @@ -1906,6 +1917,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers, text_encoder_lora_layers=text_encoder_lora_layers, + **_collate_lora_metadata(modules_to_save), ) # Final inference diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 43bf0010d7b3..211fbf7448f7 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -247,6 +247,14 @@ def _set_state_dict_into_text_encoder( set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default") +def _collate_lora_metadata(modules_to_save: Dict[str, torch.nn.Module]) -> dict[str, Any]: + metadatas = {} + for module_name, module in modules_to_save.items(): + if module is not None: + metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict() + return metadatas + + def compute_density_for_timestep_sampling( weighting_scheme: str, batch_size: int, From b08733dda27a1a1fd47c7258260c05a1d117a654 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 13 Jun 2025 15:33:19 +0530 Subject: [PATCH 2/3] typing --- src/diffusers/training_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 211fbf7448f7..bc30411d8726 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -247,7 +247,7 @@ def _set_state_dict_into_text_encoder( set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default") -def _collate_lora_metadata(modules_to_save: Dict[str, torch.nn.Module]) -> dict[str, Any]: +def _collate_lora_metadata(modules_to_save: Dict[str, torch.nn.Module]) -> Dict[str, Any]: metadatas = {} for module_name, module in modules_to_save.items(): if module is not None: From e8559bed73d8015fbb0ca5a5377ad5ad3fe5652d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 13 Jun 2025 15:51:56 +0530 Subject: [PATCH 3/3] fix --- examples/dreambooth/train_dreambooth_lora_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index f27af490b1f1..9c529cbb92ca 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1269,7 +1269,7 @@ def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: transformer_lora_layers_to_save = None text_encoder_one_lora_layers_to_save = None - modules_to_save = [] + modules_to_save = {} for model in models: if isinstance(model, type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model)