Skip to content

Commit 37d9440

Browse files
committed
Fixup casting issues for weights/bias in fp32 norm layers
1 parent ea4f940 commit 37d9440

File tree

3 files changed

+40
-13
lines changed

3 files changed

+40
-13
lines changed

timm/layers/norm.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def __init__(
104104
super().__init__(num_channels, eps=eps, elementwise_affine=affine, **kwargs)
105105

106106
def forward(self, x: torch.Tensor) -> torch.Tensor:
107-
x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype)
107+
weight = self.weight.float() if self.weight is not None else None
108+
bias = self.bias.float() if self.bias is not None else None
109+
x = F.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).to(x.dtype)
108110
return x
109111

110112

@@ -146,7 +148,9 @@ def __init__(
146148

147149
def forward(self, x: torch.Tensor) -> torch.Tensor:
148150
x = x.permute(0, 2, 3, 1)
149-
x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype)
151+
weight = self.weight.float() if self.weight is not None else None
152+
bias = self.bias.float() if self.bias is not None else None
153+
x = F.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).to(x.dtype)
150154
x = x.permute(0, 3, 1, 2)
151155
return x
152156

@@ -282,7 +286,8 @@ def reset_parameters(self) -> None:
282286
nn.init.ones_(self.weight)
283287

284288
def forward(self, x: torch.Tensor) -> torch.Tensor:
285-
x = rms_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype)
289+
weight = self.weight.float() if self.weight is not None else None
290+
x = rms_norm(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype)
286291
return x
287292

288293

@@ -381,7 +386,8 @@ def reset_parameters(self) -> None:
381386
nn.init.ones_(self.weight)
382387

383388
def forward(self, x: torch.Tensor) -> torch.Tensor:
384-
x = rms_norm2d(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype)
389+
weight = self.weight.float() if self.weight is not None else None
390+
x = rms_norm2d(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype)
385391
return x
386392

387393

@@ -470,7 +476,8 @@ def reset_parameters(self) -> None:
470476
nn.init.ones_(self.weight)
471477

472478
def forward(self, x: torch.Tensor) -> torch.Tensor:
473-
x = simple_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype)
479+
weight = self.weight.float() if self.weight is not None else None
480+
x = simple_norm(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype)
474481
return x
475482

476483

@@ -562,6 +569,7 @@ def reset_parameters(self) -> None:
562569

563570
def forward(self, x: torch.Tensor) -> torch.Tensor:
564571
x = x.permute(0, 2, 3, 1)
565-
x = simple_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype)
572+
weight = self.weight.float() if self.weight is not None else None
573+
x = simple_norm(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype)
566574
x = x.permute(0, 3, 1, 2)
567575
return x

timm/layers/norm_act.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,9 @@ def __init__(
482482
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
483483

484484
def forward(self, x):
485-
x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype)
485+
weight = self.weight.float() if self.weight is not None else None
486+
bias = self.bias.float() if self.bias is not None else None
487+
x = F.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).to(x.dtype)
486488
x = self.drop(x)
487489
x = self.act(x)
488490
return x
@@ -540,7 +542,9 @@ def __init__(
540542

541543
def forward(self, x):
542544
x = x.permute(0, 2, 3, 1)
543-
x = F.layer_norm(x.float(), self.normalized_shape, self.weight, self.bias, self.eps).to(x.dtype)
545+
weight = self.weight.float() if self.weight is not None else None
546+
bias = self.bias.float() if self.bias is not None else None
547+
x = F.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).to(x.dtype)
544548
x = x.permute(0, 3, 1, 2)
545549
x = self.drop(x)
546550
x = self.act(x)
@@ -605,7 +609,8 @@ def __init__(
605609
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
606610

607611
def forward(self, x: torch.Tensor) -> torch.Tensor:
608-
x = rms_norm(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype)
612+
weight = self.weight.float() if self.weight is not None else None
613+
x = rms_norm(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype)
609614
x = self.drop(x)
610615
x = self.act(x)
611616
return x
@@ -667,7 +672,8 @@ def __init__(
667672
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
668673

669674
def forward(self, x: torch.Tensor) -> torch.Tensor:
670-
x = rms_norm2d(x.float(), self.normalized_shape, self.weight, self.eps).to(x.dtype)
675+
weight = self.weight.float() if self.weight is not None else None
676+
x = rms_norm2d(x.float(), self.normalized_shape, weight, self.eps).to(x.dtype)
671677
x = self.drop(x)
672678
x = self.act(x)
673679
return x

timm/models/efficientnet.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2881,22 +2881,35 @@ def test_efficientnet(pretrained=False, **kwargs) -> EfficientNet:
28812881

28822882
@register_model
28832883
def test_efficientnet_gn(pretrained=False, **kwargs) -> EfficientNet:
2884+
28842885
model = _gen_test_efficientnet(
2885-
'test_efficientnet_gn', pretrained=pretrained, norm_layer=partial(GroupNormAct, group_size=8), **kwargs)
2886+
'test_efficientnet_gn',
2887+
pretrained=pretrained,
2888+
norm_layer=kwargs.pop('norm_layer', partial(GroupNormAct, group_size=8)),
2889+
**kwargs
2890+
)
28862891
return model
28872892

28882893

28892894
@register_model
28902895
def test_efficientnet_ln(pretrained=False, **kwargs) -> EfficientNet:
28912896
model = _gen_test_efficientnet(
2892-
'test_efficientnet_ln', pretrained=pretrained, norm_layer=LayerNormAct2d, **kwargs)
2897+
'test_efficientnet_ln',
2898+
pretrained=pretrained,
2899+
norm_layer=kwargs.pop('norm_layer', LayerNormAct2d),
2900+
**kwargs
2901+
)
28932902
return model
28942903

28952904

28962905
@register_model
28972906
def test_efficientnet_evos(pretrained=False, **kwargs) -> EfficientNet:
28982907
model = _gen_test_efficientnet(
2899-
'test_efficientnet_evos', pretrained=pretrained, norm_layer=partial(EvoNorm2dS0, group_size=8), **kwargs)
2908+
'test_efficientnet_evos',
2909+
pretrained=pretrained,
2910+
norm_layer=kwargs.pop('norm_layer', partial(EvoNorm2dS0, group_size=8)),
2911+
**kwargs
2912+
)
29002913
return model
29012914

29022915

0 commit comments

Comments
 (0)