Skip to content

Commit 88b789c

Browse files
committed
latest transformer research shows bias is not needed and could be harmful
1 parent f8e2f2d commit 88b789c

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

se3_transformer_pytorch/se3_transformer_pytorch.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def __init__(
124124
for degree, chan in fiber:
125125
self.transform[str(degree)] = nn.ParameterDict({
126126
'scale': nn.Parameter(torch.ones(1, 1, chan)) if not gated_scale else None,
127-
'bias': nn.Parameter(rand_uniform((1, 1, chan), -1e-3, 1e-3)),
128127
'w_gate': nn.Parameter(rand_uniform((chan, chan), -1e-3, 1e-3)) if gated_scale else None
129128
})
130129

@@ -137,14 +136,14 @@ def forward(self, features):
137136

138137
# Transform on norms
139138
parameters = self.transform[degree]
140-
gate_weights, bias, scale = parameters['w_gate'], parameters['bias'], parameters['scale']
139+
gate_weights, scale = parameters['w_gate'], parameters['scale']
141140

142141
transformed = rearrange(norm, '... () -> ...')
143142

144143
if not exists(scale):
145144
scale = einsum('b n d, d e -> b n e', transformed, gate_weights)
146145

147-
transformed = self.nonlin(transformed * scale + bias)
146+
transformed = self.nonlin(transformed * scale)
148147
transformed = rearrange(transformed, '... -> ... ()')
149148

150149
# Nonlinearity on norm

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'se3-transformer-pytorch',
55
packages = find_packages(),
66
include_package_data = True,
7-
version = '0.8.13',
7+
version = '0.9.0',
88
license='MIT',
99
description = 'SE3 Transformer - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)