Skip to content

Commit 6760300

Browse files
imbr92linoytsaban
andauthored
Add --lora_alpha and metadata handling to train_dreambooth_lora_sana.py (#11744)
Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
1 parent 798265f commit 6760300

File tree

2 files changed

+56
-4
lines changed

2 files changed

+56
-4
lines changed

examples/dreambooth/test_dreambooth_lora_sana.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import json
1617
import logging
1718
import os
1819
import sys
1920
import tempfile
2021

2122
import safetensors
2223

24+
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
25+
2326

2427
sys.path.append("..")
2528
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
@@ -204,3 +207,42 @@ def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit_removes_mult
204207
run_command(self._launch_args + resume_run_args)
205208

206209
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
210+
211+
def test_dreambooth_lora_sana_with_metadata(self):
212+
lora_alpha = 8
213+
rank = 4
214+
with tempfile.TemporaryDirectory() as tmpdir:
215+
test_args = f"""
216+
{self.script_path}
217+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
218+
--instance_data_dir={self.instance_data_dir}
219+
--output_dir={tmpdir}
220+
--resolution=32
221+
--train_batch_size=1
222+
--gradient_accumulation_steps=1
223+
--max_train_steps=4
224+
--lora_alpha={lora_alpha}
225+
--rank={rank}
226+
--checkpointing_steps=2
227+
--max_sequence_length 166
228+
""".split()
229+
230+
test_args.extend(["--instance_prompt", ""])
231+
run_command(self._launch_args + test_args)
232+
233+
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
234+
self.assertTrue(os.path.isfile(state_dict_file))
235+
236+
# Check if the metadata was properly serialized.
237+
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
238+
metadata = f.metadata() or {}
239+
240+
metadata.pop("format", None)
241+
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
242+
if raw:
243+
raw = json.loads(raw)
244+
245+
loaded_lora_alpha = raw["transformer.lora_alpha"]
246+
self.assertTrue(loaded_lora_alpha == lora_alpha)
247+
loaded_lora_rank = raw["transformer.r"]
248+
self.assertTrue(loaded_lora_rank == rank)

examples/dreambooth/train_dreambooth_lora_sana.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
)
5353
from diffusers.optimization import get_scheduler
5454
from diffusers.training_utils import (
55+
_collate_lora_metadata,
5556
cast_training_params,
5657
compute_density_for_timestep_sampling,
5758
compute_loss_weighting_for_sd3,
@@ -323,9 +324,13 @@ def parse_args(input_args=None):
323324
default=4,
324325
help=("The dimension of the LoRA update matrices."),
325326
)
326-
327+
parser.add_argument(
328+
"--lora_alpha",
329+
type=int,
330+
default=4,
331+
help="LoRA alpha to be used for additional scaling.",
332+
)
327333
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
328-
329334
parser.add_argument(
330335
"--with_prior_preservation",
331336
default=False,
@@ -1023,7 +1028,7 @@ def main(args):
10231028
# now we will add new LoRA weights the transformer layers
10241029
transformer_lora_config = LoraConfig(
10251030
r=args.rank,
1026-
lora_alpha=args.rank,
1031+
lora_alpha=args.lora_alpha,
10271032
lora_dropout=args.lora_dropout,
10281033
init_lora_weights="gaussian",
10291034
target_modules=target_modules,
@@ -1039,10 +1044,11 @@ def unwrap_model(model):
10391044
def save_model_hook(models, weights, output_dir):
10401045
if accelerator.is_main_process:
10411046
transformer_lora_layers_to_save = None
1042-
1047+
modules_to_save = {}
10431048
for model in models:
10441049
if isinstance(model, type(unwrap_model(transformer))):
10451050
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
1051+
modules_to_save["transformer"] = model
10461052
else:
10471053
raise ValueError(f"unexpected save model: {model.__class__}")
10481054

@@ -1052,6 +1058,7 @@ def save_model_hook(models, weights, output_dir):
10521058
SanaPipeline.save_lora_weights(
10531059
output_dir,
10541060
transformer_lora_layers=transformer_lora_layers_to_save,
1061+
**_collate_lora_metadata(modules_to_save),
10551062
)
10561063

10571064
def load_model_hook(models, input_dir):
@@ -1507,15 +1514,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15071514
accelerator.wait_for_everyone()
15081515
if accelerator.is_main_process:
15091516
transformer = unwrap_model(transformer)
1517+
modules_to_save = {}
15101518
if args.upcast_before_saving:
15111519
transformer.to(torch.float32)
15121520
else:
15131521
transformer = transformer.to(weight_dtype)
15141522
transformer_lora_layers = get_peft_model_state_dict(transformer)
1523+
modules_to_save["transformer"] = transformer
15151524

15161525
SanaPipeline.save_lora_weights(
15171526
save_directory=args.output_dir,
15181527
transformer_lora_layers=transformer_lora_layers,
1528+
**_collate_lora_metadata(modules_to_save),
15191529
)
15201530

15211531
# Final inference

0 commit comments

Comments
 (0)