Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 143 additions & 1 deletion networks/lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(
split_dims: Optional[List[int]] = None,
ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None,
mgpo_rho: float | None = None,
mgpo_beta: float | None = None,
):
"""
if alpha == 0 or None, alpha is rank (no scaling).
Expand Down Expand Up @@ -117,6 +119,25 @@ def __init__(
self.initialize_norm_cache(org_module.weight)
self.org_module_shape: tuple[int] = org_module.weight.shape

self.ggpo_sigma = ggpo_sigma
self.ggpo_beta = ggpo_beta

self.mgpo_rho = mgpo_rho
self.mgpo_beta = mgpo_beta

# EMA of gradient magnitudes for adaptive normalization
self._grad_magnitude_ema_down = torch.nn.Parameter(torch.tensor(1.0), requires_grad=False)
self._grad_magnitude_ema_up = torch.nn.Parameter(torch.tensor(1.0), requires_grad=False)

self.optimizer: torch.optim.Optimizer | None = None

if self.ggpo_beta is not None and self.ggpo_sigma is not None:
self.combined_weight_norms = None
self.grad_norms = None
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
self.initialize_norm_cache(org_module.weight)
self.org_module_shape: tuple[int] = org_module.weight.shape

def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
Expand Down Expand Up @@ -155,6 +176,18 @@ def forward(self, x):

lx = self.lora_up(lx)

# LoRA Momentum-Guided Perturbation Optimization (MGPO)
if (
self.training
and hasattr(self, "mgpo_rho")
and self.mgpo_rho is not None
and hasattr(self, "optimizer")
and self.optimizer is not None
):
mgpo_perturbation_output = self.get_mgpo_output_perturbation(x)
if mgpo_perturbation_output is not None:
return org_forwarded + (self.multiplier * scale * lx) + mgpo_perturbation_output

# LoRA Gradient-Guided Perturbation Optimization
if (
self.training
Expand Down Expand Up @@ -301,6 +334,98 @@ def update_grad_norms(self):
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)

def update_gradient_ema(self):
"""
Update EMA of gradient magnitudes for adaptive perturbation normalization

Formula: ḡₗ⁽ᵗ⁾ = β * ḡₗ⁽ᵗ⁻¹⁾ + (1 - β) * ||∇ΔWₗL||₂
"""
if self.mgpo_beta is None:
return

# Update EMA for lora_down gradient magnitude
if self.lora_down.weight.grad is not None:
current_grad_norm = torch.norm(self.lora_down.weight.grad, p=2)
self._grad_magnitude_ema_down.data = (
self.mgpo_beta * self._grad_magnitude_ema_down.data + (1 - self.mgpo_beta) * current_grad_norm
)

# Update EMA for lora_up gradient magnitude
if self.lora_up.weight.grad is not None:
current_grad_norm = torch.norm(self.lora_up.weight.grad, p=2)
self._grad_magnitude_ema_up.data = (
self.mgpo_beta * self._grad_magnitude_ema_up.data + (1 - self.mgpo_beta) * current_grad_norm
)

def get_mgpo_output_perturbation(self, x: Tensor) -> Tensor | None:
"""
Generate MGPO perturbation using both momentum direction and gradient magnitude normalization

Full MGPO Formula: ε = -ρ · (vₜ / ||vₜ||₂) · (ḡₗ⁽ᵗ⁾)⁻¹
Where:
- ε = perturbation vector
- ρ = perturbation radius (mgpo_rho)
- vₜ = momentum vector from optimizer (exp_avg) - provides DIRECTION
- ||vₜ||₂ = L2 norm of momentum for unit direction
- ḡₗ⁽ᵗ⁾ = EMA of gradient magnitude - provides ADAPTIVE SCALING

Two separate EMAs:
1. Momentum EMA (from Adam): vₜ = β₁ * vₜ₋₁ + (1 - β₁) * ∇L(Wₜ)
2. Gradient Magnitude EMA: ḡₗ⁽ᵗ⁾ = β * ḡₗ⁽ᵗ⁻¹⁾ + (1 - β) * ||∇L(Wₜ)||₂
"""
if self.optimizer is None or self.mgpo_rho is None or self.mgpo_beta is None:
return None

total_perturbation_scale = 0.0
valid_params = 0

# Handle both single and split dims cases
if self.split_dims is None:
params_and_emas = [
(self.lora_down.weight, self._grad_magnitude_ema_down),
(self.lora_up.weight, self._grad_magnitude_ema_up),
]
else:
# For split dims, use average EMA (or extend to per-param EMAs)
avg_ema = (self._grad_magnitude_ema_down + self._grad_magnitude_ema_up) / 2
params_and_emas = []
for lora_down in self.lora_down:
params_and_emas.append((lora_down.weight, avg_ema))
for lora_up in self.lora_up:
params_and_emas.append((lora_up.weight, avg_ema))

for param, grad_ema in params_and_emas:
if param in self.optimizer.state and "exp_avg" in self.optimizer.state[param]:
# Get momentum direction: vₜ / ||vₜ||₂
momentum = self.optimizer.state[param]["exp_avg"]
momentum_norm = torch.norm(momentum, p=2)

if momentum_norm > 1e-8 and grad_ema > 1e-8:
# Apply full MGPO formula: ρ · (momentum_direction) · (1/grad_magnitude_ema)
direction_component = momentum_norm # We'll use this for scaling
adaptive_scale = 1.0 / grad_ema # Adaptive normalization

perturbation_scale = self.mgpo_rho * direction_component * adaptive_scale
total_perturbation_scale += perturbation_scale.item()
valid_params += 1

if valid_params == 0:
return None

# Average perturbation scale across all valid parameters
avg_perturbation_scale = total_perturbation_scale / valid_params

with torch.no_grad():
# Generate random perturbation scaled by MGPO formula
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
perturbation.mul_(avg_perturbation_scale)
perturbation_output = x @ perturbation.T # Result: (batch × n)

return perturbation_output

def register_optimizer(self, optimizer):
self.optimizer = optimizer

@property
def device(self):
return next(self.parameters()).device
Expand Down Expand Up @@ -571,6 +696,15 @@ def parse_block_selection(selection: str, total_blocks: int) -> List[bool]:
if ggpo_sigma is not None:
ggpo_sigma = float(ggpo_sigma)

mgpo_beta = kwargs.get("mgpo_beta", None)
mgpo_rho = kwargs.get("mgpo_rho", None)

if mgpo_beta is not None:
mgpo_beta = float(mgpo_beta)

if mgpo_rho is not None:
mgpo_rho = float(mgpo_rho)

# train T5XXL
train_t5xxl = kwargs.get("train_t5xxl", False)
if train_t5xxl is not None:
Expand Down Expand Up @@ -639,6 +773,8 @@ def parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, float]:
reg_dims=reg_dims,
ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma,
mgpo_rho=mgpo_rho,
mgpo_beta=mgpo_beta,
reg_lrs=reg_lrs,
verbose=verbose,
)
Expand Down Expand Up @@ -738,6 +874,8 @@ def __init__(
reg_dims: Optional[Dict[str, int]] = None,
ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None,
mgpo_rho: Optional[float] = None,
mgpo_beta: Optional[float] = None,
reg_lrs: Optional[Dict[str, float]] = None,
verbose: Optional[bool] = False,
) -> None:
Expand Down Expand Up @@ -783,6 +921,8 @@ def __init__(
if ggpo_beta is not None and ggpo_sigma is not None:
logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}")

if mgpo_beta is not None and mgpo_rho is not None:
logger.info(f"LoRA-MGPO training rho: {mgpo_rho} beta: {mgpo_beta}")
if self.split_qkv:
logger.info(f"split qkv for LoRA")
if self.train_blocks is not None:
Expand Down Expand Up @@ -842,7 +982,7 @@ def create_modules(
break

# if modules_dim is None, we use default lora_dim. if modules_dim is not None, we use the specified dim (no default)
if dim is None and modules_dim is None:
if dim is None and modules_dim is None:
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha
Expand Down Expand Up @@ -917,6 +1057,8 @@ def create_modules(
split_dims=split_dims,
ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma,
mgpo_rho=mgpo_rho,
mgpo_beta=mgpo_beta,
)
loras.append(lora)

Expand Down
119 changes: 119 additions & 0 deletions tests/networks/test_lora_flux_mgpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import pytest
import torch
import math
from networks.lora_flux import LoRAModule


class MockLinear(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = torch.nn.Parameter(torch.randn(out_features, in_features))
self.in_features = in_features
self.out_features = out_features

def forward(self, x):
return torch.matmul(x, self.weight.t())

def state_dict(self):
return {"weight": self.weight}


class MockOptimizer:
def __init__(self, param):
self.state = {param: {"exp_avg": torch.randn_like(param)}}


@pytest.fixture
def lora_module():
org_module = MockLinear(10, 20)
lora_module = LoRAModule(org_module, org_module, multiplier=1.0, lora_dim=4, alpha=1.0, mgpo_rho=0.1, mgpo_beta=0.9)
# Manually set org_module_shape to match the original module's weight
lora_module.org_module_shape = org_module.weight.shape
return lora_module


def test_mgpo_parameter_initialization(lora_module):
"""Test MGPO-specific parameter initialization."""
# Check MGPO-specific attributes
assert hasattr(lora_module, "mgpo_rho")
assert hasattr(lora_module, "mgpo_beta")
assert lora_module.mgpo_rho == 0.1
assert lora_module.mgpo_beta == 0.9

# Check EMA parameters initialization
assert hasattr(lora_module, "_grad_magnitude_ema_down")
assert hasattr(lora_module, "_grad_magnitude_ema_up")
assert isinstance(lora_module._grad_magnitude_ema_down, torch.nn.Parameter)
assert isinstance(lora_module._grad_magnitude_ema_up, torch.nn.Parameter)
assert lora_module._grad_magnitude_ema_down.requires_grad == False
assert lora_module._grad_magnitude_ema_up.requires_grad == False
assert lora_module._grad_magnitude_ema_down.item() == 1.0
assert lora_module._grad_magnitude_ema_up.item() == 1.0


def test_update_gradient_ema(lora_module):
"""Test gradient EMA update method."""
# Ensure method works when mgpo_beta is set
lora_module.lora_down.weight.grad = torch.randn_like(lora_module.lora_down.weight)
lora_module.lora_up.weight.grad = torch.randn_like(lora_module.lora_up.weight)

# Store initial EMA values
initial_down_ema = lora_module._grad_magnitude_ema_down.clone()
initial_up_ema = lora_module._grad_magnitude_ema_up.clone()

# Update gradient EMA
lora_module.update_gradient_ema()

# Check EMA update logic
down_grad_norm = torch.norm(lora_module.lora_down.weight.grad, p=2)
up_grad_norm = torch.norm(lora_module.lora_up.weight.grad, p=2)

# Verify EMA calculation
expected_down_ema = lora_module.mgpo_beta * initial_down_ema + (1 - lora_module.mgpo_beta) * down_grad_norm
expected_up_ema = lora_module.mgpo_beta * initial_up_ema + (1 - lora_module.mgpo_beta) * up_grad_norm

assert torch.allclose(lora_module._grad_magnitude_ema_down, expected_down_ema, rtol=1e-5)
assert torch.allclose(lora_module._grad_magnitude_ema_up, expected_up_ema, rtol=1e-5)

# Test when mgpo_beta is None
lora_module.mgpo_beta = None
lora_module.update_gradient_ema() # Should not raise an exception


def test_get_mgpo_output_perturbation(lora_module):
"""Test MGPO perturbation generation."""
# Create a mock optimizer
mock_optimizer = MockOptimizer(lora_module.lora_down.weight)
lora_module.register_optimizer(mock_optimizer)

# Prepare input
x = torch.randn(5, 10) # batch × input_dim

# Ensure method works with valid conditions
perturbation = lora_module.get_mgpo_output_perturbation(x)

# Verify perturbation characteristics
assert perturbation is not None
assert isinstance(perturbation, torch.Tensor)
assert perturbation.shape == (x.shape[0], lora_module.org_module.out_features)

# Test when conditions are not met
lora_module.optimizer = None
lora_module.mgpo_rho = None
lora_module.mgpo_beta = None

no_perturbation = lora_module.get_mgpo_output_perturbation(x)
assert no_perturbation is None


def test_register_optimizer(lora_module):
"""Test optimizer registration method."""
# Create a mock optimizer
mock_optimizer = MockOptimizer(lora_module.lora_down.weight)

# Register optimizer
lora_module.register_optimizer(mock_optimizer)

# Verify optimizer is correctly registered
assert hasattr(lora_module, "optimizer")
assert lora_module.optimizer == mock_optimizer
8 changes: 6 additions & 2 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,13 +414,12 @@ def process_batch(
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs


if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch['captions'])
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
Expand Down Expand Up @@ -748,6 +747,9 @@ def train(self, args):
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)

if hasattr(network, "register_optimizer"):
network.register_optimizer(optimizer)

# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
Expand Down Expand Up @@ -1430,6 +1432,8 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen
network.update_grad_norms()
if hasattr(network, "update_norms"):
network.update_norms()
if hasattr(network, "update_gradient_ema"):
network.update_gradient_ema()

optimizer.step()
lr_scheduler.step()
Expand Down