Skip to content

Commit e9fda39

Browse files
authored
remove F.rms_norm for now (#11126)
up
1 parent 2c1ed50 commit e9fda39

File tree

1 file changed

+0
-10
lines changed

1 file changed

+0
-10
lines changed

src/diffusers/models/normalization.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -550,16 +550,6 @@ def forward(self, hidden_states):
550550
hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
551551
if self.bias is not None:
552552
hidden_states = hidden_states + self.bias
553-
elif is_torch_version(">=", "2.4"):
554-
if self.weight is not None:
555-
# convert into half-precision if necessary
556-
if self.weight.dtype in [torch.float16, torch.bfloat16]:
557-
hidden_states = hidden_states.to(self.weight.dtype)
558-
hidden_states = nn.functional.rms_norm(
559-
hidden_states, normalized_shape=(hidden_states.shape[-1],), weight=self.weight, eps=self.eps
560-
)
561-
if self.bias is not None:
562-
hidden_states = hidden_states + self.bias
563553
else:
564554
input_dtype = hidden_states.dtype
565555
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)

0 commit comments

Comments
 (0)