From 3f478067194096eb9b1c5bc48042914a62bcace5 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 19 Aug 2025 02:45:26 -0400 Subject: [PATCH] Add MGPO to Flux network --- networks/lora_flux.py | 144 +++++++++++++++++++++++++- tests/networks/test_lora_flux_mgpo.py | 119 +++++++++++++++++++++ train_network.py | 8 +- 3 files changed, 268 insertions(+), 3 deletions(-) create mode 100644 tests/networks/test_lora_flux_mgpo.py diff --git a/networks/lora_flux.py b/networks/lora_flux.py index e9ad5f68d..f3ef301e6 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -48,6 +48,8 @@ def __init__( split_dims: Optional[List[int]] = None, ggpo_beta: Optional[float] = None, ggpo_sigma: Optional[float] = None, + mgpo_rho: float | None = None, + mgpo_beta: float | None = None, ): """ if alpha == 0 or None, alpha is rank (no scaling). @@ -117,6 +119,25 @@ def __init__( self.initialize_norm_cache(org_module.weight) self.org_module_shape: tuple[int] = org_module.weight.shape + self.ggpo_sigma = ggpo_sigma + self.ggpo_beta = ggpo_beta + + self.mgpo_rho = mgpo_rho + self.mgpo_beta = mgpo_beta + + # EMA of gradient magnitudes for adaptive normalization + self._grad_magnitude_ema_down = torch.nn.Parameter(torch.tensor(1.0), requires_grad=False) + self._grad_magnitude_ema_up = torch.nn.Parameter(torch.tensor(1.0), requires_grad=False) + + self.optimizer: torch.optim.Optimizer | None = None + + if self.ggpo_beta is not None and self.ggpo_sigma is not None: + self.combined_weight_norms = None + self.grad_norms = None + self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0]) + self.initialize_norm_cache(org_module.weight) + self.org_module_shape: tuple[int] = org_module.weight.shape + def apply_to(self): self.org_forward = self.org_module.forward self.org_module.forward = self.forward @@ -155,6 +176,18 @@ def forward(self, x): lx = self.lora_up(lx) + # LoRA Momentum-Guided Perturbation Optimization (MGPO) + if ( + self.training + and hasattr(self, "mgpo_rho") + and self.mgpo_rho is not None + and hasattr(self, "optimizer") + and self.optimizer is not None + ): + mgpo_perturbation_output = self.get_mgpo_output_perturbation(x) + if mgpo_perturbation_output is not None: + return org_forwarded + (self.multiplier * scale * lx) + mgpo_perturbation_output + # LoRA Gradient-Guided Perturbation Optimization if ( self.training @@ -301,6 +334,98 @@ def update_grad_norms(self): approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight)) self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True) + def update_gradient_ema(self): + """ + Update EMA of gradient magnitudes for adaptive perturbation normalization + + Formula: ḡₗ⁽ᵗ⁾ = β * ḡₗ⁽ᵗ⁻¹⁾ + (1 - β) * ||∇ΔWₗL||₂ + """ + if self.mgpo_beta is None: + return + + # Update EMA for lora_down gradient magnitude + if self.lora_down.weight.grad is not None: + current_grad_norm = torch.norm(self.lora_down.weight.grad, p=2) + self._grad_magnitude_ema_down.data = ( + self.mgpo_beta * self._grad_magnitude_ema_down.data + (1 - self.mgpo_beta) * current_grad_norm + ) + + # Update EMA for lora_up gradient magnitude + if self.lora_up.weight.grad is not None: + current_grad_norm = torch.norm(self.lora_up.weight.grad, p=2) + self._grad_magnitude_ema_up.data = ( + self.mgpo_beta * self._grad_magnitude_ema_up.data + (1 - self.mgpo_beta) * current_grad_norm + ) + + def get_mgpo_output_perturbation(self, x: Tensor) -> Tensor | None: + """ + Generate MGPO perturbation using both momentum direction and gradient magnitude normalization + + Full MGPO Formula: ε = -ρ · (vₜ / ||vₜ||₂) · (ḡₗ⁽ᵗ⁾)⁻¹ + Where: + - ε = perturbation vector + - ρ = perturbation radius (mgpo_rho) + - vₜ = momentum vector from optimizer (exp_avg) - provides DIRECTION + - ||vₜ||₂ = L2 norm of momentum for unit direction + - ḡₗ⁽ᵗ⁾ = EMA of gradient magnitude - provides ADAPTIVE SCALING + + Two separate EMAs: + 1. Momentum EMA (from Adam): vₜ = β₁ * vₜ₋₁ + (1 - β₁) * ∇L(Wₜ) + 2. Gradient Magnitude EMA: ḡₗ⁽ᵗ⁾ = β * ḡₗ⁽ᵗ⁻¹⁾ + (1 - β) * ||∇L(Wₜ)||₂ + """ + if self.optimizer is None or self.mgpo_rho is None or self.mgpo_beta is None: + return None + + total_perturbation_scale = 0.0 + valid_params = 0 + + # Handle both single and split dims cases + if self.split_dims is None: + params_and_emas = [ + (self.lora_down.weight, self._grad_magnitude_ema_down), + (self.lora_up.weight, self._grad_magnitude_ema_up), + ] + else: + # For split dims, use average EMA (or extend to per-param EMAs) + avg_ema = (self._grad_magnitude_ema_down + self._grad_magnitude_ema_up) / 2 + params_and_emas = [] + for lora_down in self.lora_down: + params_and_emas.append((lora_down.weight, avg_ema)) + for lora_up in self.lora_up: + params_and_emas.append((lora_up.weight, avg_ema)) + + for param, grad_ema in params_and_emas: + if param in self.optimizer.state and "exp_avg" in self.optimizer.state[param]: + # Get momentum direction: vₜ / ||vₜ||₂ + momentum = self.optimizer.state[param]["exp_avg"] + momentum_norm = torch.norm(momentum, p=2) + + if momentum_norm > 1e-8 and grad_ema > 1e-8: + # Apply full MGPO formula: ρ · (momentum_direction) · (1/grad_magnitude_ema) + direction_component = momentum_norm # We'll use this for scaling + adaptive_scale = 1.0 / grad_ema # Adaptive normalization + + perturbation_scale = self.mgpo_rho * direction_component * adaptive_scale + total_perturbation_scale += perturbation_scale.item() + valid_params += 1 + + if valid_params == 0: + return None + + # Average perturbation scale across all valid parameters + avg_perturbation_scale = total_perturbation_scale / valid_params + + with torch.no_grad(): + # Generate random perturbation scaled by MGPO formula + perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device) + perturbation.mul_(avg_perturbation_scale) + perturbation_output = x @ perturbation.T # Result: (batch × n) + + return perturbation_output + + def register_optimizer(self, optimizer): + self.optimizer = optimizer + @property def device(self): return next(self.parameters()).device @@ -571,6 +696,15 @@ def parse_block_selection(selection: str, total_blocks: int) -> List[bool]: if ggpo_sigma is not None: ggpo_sigma = float(ggpo_sigma) + mgpo_beta = kwargs.get("mgpo_beta", None) + mgpo_rho = kwargs.get("mgpo_rho", None) + + if mgpo_beta is not None: + mgpo_beta = float(mgpo_beta) + + if mgpo_rho is not None: + mgpo_rho = float(mgpo_rho) + # train T5XXL train_t5xxl = kwargs.get("train_t5xxl", False) if train_t5xxl is not None: @@ -639,6 +773,8 @@ def parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, float]: reg_dims=reg_dims, ggpo_beta=ggpo_beta, ggpo_sigma=ggpo_sigma, + mgpo_rho=mgpo_rho, + mgpo_beta=mgpo_beta, reg_lrs=reg_lrs, verbose=verbose, ) @@ -738,6 +874,8 @@ def __init__( reg_dims: Optional[Dict[str, int]] = None, ggpo_beta: Optional[float] = None, ggpo_sigma: Optional[float] = None, + mgpo_rho: Optional[float] = None, + mgpo_beta: Optional[float] = None, reg_lrs: Optional[Dict[str, float]] = None, verbose: Optional[bool] = False, ) -> None: @@ -783,6 +921,8 @@ def __init__( if ggpo_beta is not None and ggpo_sigma is not None: logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}") + if mgpo_beta is not None and mgpo_rho is not None: + logger.info(f"LoRA-MGPO training rho: {mgpo_rho} beta: {mgpo_beta}") if self.split_qkv: logger.info(f"split qkv for LoRA") if self.train_blocks is not None: @@ -842,7 +982,7 @@ def create_modules( break # if modules_dim is None, we use default lora_dim. if modules_dim is not None, we use the specified dim (no default) - if dim is None and modules_dim is None: + if dim is None and modules_dim is None: if is_linear or is_conv2d_1x1: dim = default_dim if default_dim is not None else self.lora_dim alpha = self.alpha @@ -917,6 +1057,8 @@ def create_modules( split_dims=split_dims, ggpo_beta=ggpo_beta, ggpo_sigma=ggpo_sigma, + mgpo_rho=mgpo_rho, + mgpo_beta=mgpo_beta, ) loras.append(lora) diff --git a/tests/networks/test_lora_flux_mgpo.py b/tests/networks/test_lora_flux_mgpo.py new file mode 100644 index 000000000..9404244bb --- /dev/null +++ b/tests/networks/test_lora_flux_mgpo.py @@ -0,0 +1,119 @@ +import pytest +import torch +import math +from networks.lora_flux import LoRAModule + + +class MockLinear(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(out_features, in_features)) + self.in_features = in_features + self.out_features = out_features + + def forward(self, x): + return torch.matmul(x, self.weight.t()) + + def state_dict(self): + return {"weight": self.weight} + + +class MockOptimizer: + def __init__(self, param): + self.state = {param: {"exp_avg": torch.randn_like(param)}} + + +@pytest.fixture +def lora_module(): + org_module = MockLinear(10, 20) + lora_module = LoRAModule(org_module, org_module, multiplier=1.0, lora_dim=4, alpha=1.0, mgpo_rho=0.1, mgpo_beta=0.9) + # Manually set org_module_shape to match the original module's weight + lora_module.org_module_shape = org_module.weight.shape + return lora_module + + +def test_mgpo_parameter_initialization(lora_module): + """Test MGPO-specific parameter initialization.""" + # Check MGPO-specific attributes + assert hasattr(lora_module, "mgpo_rho") + assert hasattr(lora_module, "mgpo_beta") + assert lora_module.mgpo_rho == 0.1 + assert lora_module.mgpo_beta == 0.9 + + # Check EMA parameters initialization + assert hasattr(lora_module, "_grad_magnitude_ema_down") + assert hasattr(lora_module, "_grad_magnitude_ema_up") + assert isinstance(lora_module._grad_magnitude_ema_down, torch.nn.Parameter) + assert isinstance(lora_module._grad_magnitude_ema_up, torch.nn.Parameter) + assert lora_module._grad_magnitude_ema_down.requires_grad == False + assert lora_module._grad_magnitude_ema_up.requires_grad == False + assert lora_module._grad_magnitude_ema_down.item() == 1.0 + assert lora_module._grad_magnitude_ema_up.item() == 1.0 + + +def test_update_gradient_ema(lora_module): + """Test gradient EMA update method.""" + # Ensure method works when mgpo_beta is set + lora_module.lora_down.weight.grad = torch.randn_like(lora_module.lora_down.weight) + lora_module.lora_up.weight.grad = torch.randn_like(lora_module.lora_up.weight) + + # Store initial EMA values + initial_down_ema = lora_module._grad_magnitude_ema_down.clone() + initial_up_ema = lora_module._grad_magnitude_ema_up.clone() + + # Update gradient EMA + lora_module.update_gradient_ema() + + # Check EMA update logic + down_grad_norm = torch.norm(lora_module.lora_down.weight.grad, p=2) + up_grad_norm = torch.norm(lora_module.lora_up.weight.grad, p=2) + + # Verify EMA calculation + expected_down_ema = lora_module.mgpo_beta * initial_down_ema + (1 - lora_module.mgpo_beta) * down_grad_norm + expected_up_ema = lora_module.mgpo_beta * initial_up_ema + (1 - lora_module.mgpo_beta) * up_grad_norm + + assert torch.allclose(lora_module._grad_magnitude_ema_down, expected_down_ema, rtol=1e-5) + assert torch.allclose(lora_module._grad_magnitude_ema_up, expected_up_ema, rtol=1e-5) + + # Test when mgpo_beta is None + lora_module.mgpo_beta = None + lora_module.update_gradient_ema() # Should not raise an exception + + +def test_get_mgpo_output_perturbation(lora_module): + """Test MGPO perturbation generation.""" + # Create a mock optimizer + mock_optimizer = MockOptimizer(lora_module.lora_down.weight) + lora_module.register_optimizer(mock_optimizer) + + # Prepare input + x = torch.randn(5, 10) # batch × input_dim + + # Ensure method works with valid conditions + perturbation = lora_module.get_mgpo_output_perturbation(x) + + # Verify perturbation characteristics + assert perturbation is not None + assert isinstance(perturbation, torch.Tensor) + assert perturbation.shape == (x.shape[0], lora_module.org_module.out_features) + + # Test when conditions are not met + lora_module.optimizer = None + lora_module.mgpo_rho = None + lora_module.mgpo_beta = None + + no_perturbation = lora_module.get_mgpo_output_perturbation(x) + assert no_perturbation is None + + +def test_register_optimizer(lora_module): + """Test optimizer registration method.""" + # Create a mock optimizer + mock_optimizer = MockOptimizer(lora_module.lora_down.weight) + + # Register optimizer + lora_module.register_optimizer(mock_optimizer) + + # Verify optimizer is correctly registered + assert hasattr(lora_module, "optimizer") + assert lora_module.optimizer == mock_optimizer diff --git a/train_network.py b/train_network.py index 7861e7404..bc21ad382 100644 --- a/train_network.py +++ b/train_network.py @@ -414,13 +414,12 @@ def process_batch( if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs - if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: - input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch['captions']) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), @@ -748,6 +747,9 @@ def train(self, args): optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) + if hasattr(network, "register_optimizer"): + network.register_optimizer(optimizer) + # prepare dataloader # strategies are set here because they cannot be referenced in another process. Copy them with the dataset # some strategies can be None @@ -1430,6 +1432,8 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen network.update_grad_norms() if hasattr(network, "update_norms"): network.update_norms() + if hasattr(network, "update_gradient_ema"): + network.update_gradient_ema() optimizer.step() lr_scheduler.step()