@@ -378,7 +378,39 @@ def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype]
378
378
self .on_input = self .on_input .to (device = device , dtype = dtype )
379
379
380
380
381
- AnyLoRALayer = Union [LoRALayer , LoHALayer , LoKRLayer , FullLayer , IA3Layer ]
381
+ class NormLayer (LoRALayerBase ):
382
+ # bias handled in LoRALayerBase(calc_size, to)
383
+ # weight: torch.Tensor
384
+ # bias: Optional[torch.Tensor]
385
+
386
+ def __init__ (
387
+ self ,
388
+ layer_key : str ,
389
+ values : Dict [str , torch .Tensor ],
390
+ ):
391
+ super ().__init__ (layer_key , values )
392
+
393
+ self .weight = values ["w_norm" ]
394
+ self .bias = values .get ("b_norm" , None )
395
+
396
+ self .rank = None # unscaled
397
+ self .check_keys (values , {"w_norm" , "b_norm" })
398
+
399
+ def get_weight (self , orig_weight : torch .Tensor ) -> torch .Tensor :
400
+ return self .weight
401
+
402
+ def calc_size (self ) -> int :
403
+ model_size = super ().calc_size ()
404
+ model_size += self .weight .nelement () * self .weight .element_size ()
405
+ return model_size
406
+
407
+ def to (self , device : Optional [torch .device ] = None , dtype : Optional [torch .dtype ] = None ) -> None :
408
+ super ().to (device = device , dtype = dtype )
409
+
410
+ self .weight = self .weight .to (device = device , dtype = dtype )
411
+
412
+
413
+ AnyLoRALayer = Union [LoRALayer , LoHALayer , LoKRLayer , FullLayer , IA3Layer , NormLayer ]
382
414
383
415
384
416
class LoRAModelRaw (RawModel ): # (torch.nn.Module):
@@ -519,6 +551,10 @@ def from_checkpoint(
519
551
elif "on_input" in values :
520
552
layer = IA3Layer (layer_key , values )
521
553
554
+ # norms
555
+ elif "w_norm" in values :
556
+ layer = NormLayer (layer_key , values )
557
+
522
558
else :
523
559
print (f">> Encountered unknown lora layer module in { model .name } : { layer_key } - { list (values .keys ())} " )
524
560
raise Exception ("Unknown lora format!" )
0 commit comments