@@ -465,7 +465,7 @@ class LoraBaseMixin:
465
465
"""Utility class for handling LoRAs."""
466
466
467
467
_lora_loadable_modules = []
468
- num_fused_loras = 0
468
+ _merged_adapters = set ()
469
469
470
470
def load_lora_weights (self , ** kwargs ):
471
471
raise NotImplementedError ("`load_lora_weights()` is not implemented." )
@@ -592,7 +592,6 @@ def fuse_lora(
592
592
if len (components ) == 0 :
593
593
raise ValueError ("`components` cannot be an empty list." )
594
594
595
- merged_adapters = set ()
596
595
for fuse_component in components :
597
596
if fuse_component not in self ._lora_loadable_modules :
598
597
raise ValueError (f"{ fuse_component } is not found in { self ._lora_loadable_modules = } ." )
@@ -604,17 +603,15 @@ def fuse_lora(
604
603
model .fuse_lora (lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names )
605
604
for module in model .modules ():
606
605
if isinstance (module , BaseTunerLayer ):
607
- merged_adapters .update (set (module .merged_adapters ))
606
+ self . _merged_adapters .update (set (module .merged_adapters ))
608
607
# handle transformers models.
609
608
if issubclass (model .__class__ , PreTrainedModel ):
610
609
fuse_text_encoder_lora (
611
610
model , lora_scale = lora_scale , safe_fusing = safe_fusing , adapter_names = adapter_names
612
611
)
613
612
for module in model .modules ():
614
613
if isinstance (module , BaseTunerLayer ):
615
- merged_adapters .update (set (module .merged_adapters ))
616
-
617
- self .num_fused_loras += len (merged_adapters )
614
+ self ._merged_adapters .update (set (module .merged_adapters ))
618
615
619
616
def unfuse_lora (self , components : List [str ] = [], ** kwargs ):
620
617
r"""
@@ -659,7 +656,6 @@ def unfuse_lora(self, components: List[str] = [], **kwargs):
659
656
if len (components ) == 0 :
660
657
raise ValueError ("`components` cannot be an empty list." )
661
658
662
- merged_adapters = set ()
663
659
for fuse_component in components :
664
660
if fuse_component not in self ._lora_loadable_modules :
665
661
raise ValueError (f"{ fuse_component } is not found in { self ._lora_loadable_modules = } ." )
@@ -670,9 +666,11 @@ def unfuse_lora(self, components: List[str] = [], **kwargs):
670
666
for module in model .modules ():
671
667
if isinstance (module , BaseTunerLayer ):
672
668
module .unmerge ()
673
- merged_adapters .update (set (module .merged_adapters ))
669
+ self . _merged_adapters .update (set (module .merged_adapters ))
674
670
675
- self .num_fused_loras = len (merged_adapters )
671
+ @property
672
+ def num_fused_loras (self ):
673
+ return len (self ._merged_adapters )
676
674
677
675
def set_adapters (
678
676
self ,
0 commit comments