Skip to content

Commit 68f9939

Browse files
StAlKeR7779hipsterusername
authored andcommitted
Add support for norm layer
1 parent 7da6120 commit 68f9939

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

invokeai/backend/lora.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,39 @@ def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype]
378378
self.on_input = self.on_input.to(device=device, dtype=dtype)
379379

380380

381-
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
381+
class NormLayer(LoRALayerBase):
382+
# bias handled in LoRALayerBase(calc_size, to)
383+
# weight: torch.Tensor
384+
# bias: Optional[torch.Tensor]
385+
386+
def __init__(
387+
self,
388+
layer_key: str,
389+
values: Dict[str, torch.Tensor],
390+
):
391+
super().__init__(layer_key, values)
392+
393+
self.weight = values["w_norm"]
394+
self.bias = values.get("b_norm", None)
395+
396+
self.rank = None # unscaled
397+
self.check_keys(values, {"w_norm", "b_norm"})
398+
399+
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
400+
return self.weight
401+
402+
def calc_size(self) -> int:
403+
model_size = super().calc_size()
404+
model_size += self.weight.nelement() * self.weight.element_size()
405+
return model_size
406+
407+
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
408+
super().to(device=device, dtype=dtype)
409+
410+
self.weight = self.weight.to(device=device, dtype=dtype)
411+
412+
413+
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
382414

383415

384416
class LoRAModelRaw(RawModel): # (torch.nn.Module):
@@ -519,6 +551,10 @@ def from_checkpoint(
519551
elif "on_input" in values:
520552
layer = IA3Layer(layer_key, values)
521553

554+
# norms
555+
elif "w_norm" in values:
556+
layer = NormLayer(layer_key, values)
557+
522558
else:
523559
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
524560
raise Exception("Unknown lora format!")

0 commit comments

Comments
 (0)