Skip to content

Commit 2e0bf0a

Browse files
num_fused_loras as a property
Co-authored-by: BenjaminBossan <benjamin.bossan@gmail.com>
1 parent 4779750 commit 2e0bf0a

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ class LoraBaseMixin:
465465
"""Utility class for handling LoRAs."""
466466

467467
_lora_loadable_modules = []
468-
num_fused_loras = 0
468+
_merged_adapters = set()
469469

470470
def load_lora_weights(self, **kwargs):
471471
raise NotImplementedError("`load_lora_weights()` is not implemented.")
@@ -592,7 +592,6 @@ def fuse_lora(
592592
if len(components) == 0:
593593
raise ValueError("`components` cannot be an empty list.")
594594

595-
merged_adapters = set()
596595
for fuse_component in components:
597596
if fuse_component not in self._lora_loadable_modules:
598597
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
@@ -604,17 +603,15 @@ def fuse_lora(
604603
model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
605604
for module in model.modules():
606605
if isinstance(module, BaseTunerLayer):
607-
merged_adapters.update(set(module.merged_adapters))
606+
self._merged_adapters.update(set(module.merged_adapters))
608607
# handle transformers models.
609608
if issubclass(model.__class__, PreTrainedModel):
610609
fuse_text_encoder_lora(
611610
model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
612611
)
613612
for module in model.modules():
614613
if isinstance(module, BaseTunerLayer):
615-
merged_adapters.update(set(module.merged_adapters))
616-
617-
self.num_fused_loras += len(merged_adapters)
614+
self._merged_adapters.update(set(module.merged_adapters))
618615

619616
def unfuse_lora(self, components: List[str] = [], **kwargs):
620617
r"""
@@ -659,7 +656,6 @@ def unfuse_lora(self, components: List[str] = [], **kwargs):
659656
if len(components) == 0:
660657
raise ValueError("`components` cannot be an empty list.")
661658

662-
merged_adapters = set()
663659
for fuse_component in components:
664660
if fuse_component not in self._lora_loadable_modules:
665661
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
@@ -670,9 +666,11 @@ def unfuse_lora(self, components: List[str] = [], **kwargs):
670666
for module in model.modules():
671667
if isinstance(module, BaseTunerLayer):
672668
module.unmerge()
673-
merged_adapters.update(set(module.merged_adapters))
669+
self._merged_adapters.update(set(module.merged_adapters))
674670

675-
self.num_fused_loras = len(merged_adapters)
671+
@property
672+
def num_fused_loras(self):
673+
return len(self._merged_adapters)
676674

677675
def set_adapters(
678676
self,

0 commit comments

Comments
 (0)