@@ -104,7 +104,9 @@ def __init__(
104
104
super ().__init__ (num_channels , eps = eps , elementwise_affine = affine , ** kwargs )
105
105
106
106
def forward (self , x : torch .Tensor ) -> torch .Tensor :
107
- x = F .layer_norm (x .float (), self .normalized_shape , self .weight , self .bias , self .eps ).to (x .dtype )
107
+ weight = self .weight .float () if self .weight is not None else None
108
+ bias = self .bias .float () if self .bias is not None else None
109
+ x = F .layer_norm (x .float (), self .normalized_shape , weight , bias , self .eps ).to (x .dtype )
108
110
return x
109
111
110
112
@@ -146,7 +148,9 @@ def __init__(
146
148
147
149
def forward (self , x : torch .Tensor ) -> torch .Tensor :
148
150
x = x .permute (0 , 2 , 3 , 1 )
149
- x = F .layer_norm (x .float (), self .normalized_shape , self .weight , self .bias , self .eps ).to (x .dtype )
151
+ weight = self .weight .float () if self .weight is not None else None
152
+ bias = self .bias .float () if self .bias is not None else None
153
+ x = F .layer_norm (x .float (), self .normalized_shape , weight , bias , self .eps ).to (x .dtype )
150
154
x = x .permute (0 , 3 , 1 , 2 )
151
155
return x
152
156
@@ -282,7 +286,8 @@ def reset_parameters(self) -> None:
282
286
nn .init .ones_ (self .weight )
283
287
284
288
def forward (self , x : torch .Tensor ) -> torch .Tensor :
285
- x = rms_norm (x .float (), self .normalized_shape , self .weight , self .eps ).to (x .dtype )
289
+ weight = self .weight .float () if self .weight is not None else None
290
+ x = rms_norm (x .float (), self .normalized_shape , weight , self .eps ).to (x .dtype )
286
291
return x
287
292
288
293
@@ -381,7 +386,8 @@ def reset_parameters(self) -> None:
381
386
nn .init .ones_ (self .weight )
382
387
383
388
def forward (self , x : torch .Tensor ) -> torch .Tensor :
384
- x = rms_norm2d (x .float (), self .normalized_shape , self .weight , self .eps ).to (x .dtype )
389
+ weight = self .weight .float () if self .weight is not None else None
390
+ x = rms_norm2d (x .float (), self .normalized_shape , weight , self .eps ).to (x .dtype )
385
391
return x
386
392
387
393
@@ -470,7 +476,8 @@ def reset_parameters(self) -> None:
470
476
nn .init .ones_ (self .weight )
471
477
472
478
def forward (self , x : torch .Tensor ) -> torch .Tensor :
473
- x = simple_norm (x .float (), self .normalized_shape , self .weight , self .eps ).to (x .dtype )
479
+ weight = self .weight .float () if self .weight is not None else None
480
+ x = simple_norm (x .float (), self .normalized_shape , weight , self .eps ).to (x .dtype )
474
481
return x
475
482
476
483
@@ -562,6 +569,7 @@ def reset_parameters(self) -> None:
562
569
563
570
def forward (self , x : torch .Tensor ) -> torch .Tensor :
564
571
x = x .permute (0 , 2 , 3 , 1 )
565
- x = simple_norm (x .float (), self .normalized_shape , self .weight , self .eps ).to (x .dtype )
572
+ weight = self .weight .float () if self .weight is not None else None
573
+ x = simple_norm (x .float (), self .normalized_shape , weight , self .eps ).to (x .dtype )
566
574
x = x .permute (0 , 3 , 1 , 2 )
567
575
return x
0 commit comments