52
52
)
53
53
from diffusers .optimization import get_scheduler
54
54
from diffusers .training_utils import (
55
+ _collate_lora_metadata ,
55
56
cast_training_params ,
56
57
compute_density_for_timestep_sampling ,
57
58
compute_loss_weighting_for_sd3 ,
@@ -323,9 +324,13 @@ def parse_args(input_args=None):
323
324
default = 4 ,
324
325
help = ("The dimension of the LoRA update matrices." ),
325
326
)
326
-
327
+ parser .add_argument (
328
+ "--lora_alpha" ,
329
+ type = int ,
330
+ default = 4 ,
331
+ help = "LoRA alpha to be used for additional scaling." ,
332
+ )
327
333
parser .add_argument ("--lora_dropout" , type = float , default = 0.0 , help = "Dropout probability for LoRA layers" )
328
-
329
334
parser .add_argument (
330
335
"--with_prior_preservation" ,
331
336
default = False ,
@@ -1023,7 +1028,7 @@ def main(args):
1023
1028
# now we will add new LoRA weights the transformer layers
1024
1029
transformer_lora_config = LoraConfig (
1025
1030
r = args .rank ,
1026
- lora_alpha = args .rank ,
1031
+ lora_alpha = args .lora_alpha ,
1027
1032
lora_dropout = args .lora_dropout ,
1028
1033
init_lora_weights = "gaussian" ,
1029
1034
target_modules = target_modules ,
@@ -1039,10 +1044,11 @@ def unwrap_model(model):
1039
1044
def save_model_hook (models , weights , output_dir ):
1040
1045
if accelerator .is_main_process :
1041
1046
transformer_lora_layers_to_save = None
1042
-
1047
+ modules_to_save = {}
1043
1048
for model in models :
1044
1049
if isinstance (model , type (unwrap_model (transformer ))):
1045
1050
transformer_lora_layers_to_save = get_peft_model_state_dict (model )
1051
+ modules_to_save ["transformer" ] = model
1046
1052
else :
1047
1053
raise ValueError (f"unexpected save model: { model .__class__ } " )
1048
1054
@@ -1052,6 +1058,7 @@ def save_model_hook(models, weights, output_dir):
1052
1058
SanaPipeline .save_lora_weights (
1053
1059
output_dir ,
1054
1060
transformer_lora_layers = transformer_lora_layers_to_save ,
1061
+ ** _collate_lora_metadata (modules_to_save ),
1055
1062
)
1056
1063
1057
1064
def load_model_hook (models , input_dir ):
@@ -1507,15 +1514,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1507
1514
accelerator .wait_for_everyone ()
1508
1515
if accelerator .is_main_process :
1509
1516
transformer = unwrap_model (transformer )
1517
+ modules_to_save = {}
1510
1518
if args .upcast_before_saving :
1511
1519
transformer .to (torch .float32 )
1512
1520
else :
1513
1521
transformer = transformer .to (weight_dtype )
1514
1522
transformer_lora_layers = get_peft_model_state_dict (transformer )
1523
+ modules_to_save ["transformer" ] = transformer
1515
1524
1516
1525
SanaPipeline .save_lora_weights (
1517
1526
save_directory = args .output_dir ,
1518
1527
transformer_lora_layers = transformer_lora_layers ,
1528
+ ** _collate_lora_metadata (modules_to_save ),
1519
1529
)
1520
1530
1521
1531
# Final inference
0 commit comments