@@ -124,7 +124,6 @@ def __init__(
124
124
for degree , chan in fiber :
125
125
self .transform [str (degree )] = nn .ParameterDict ({
126
126
'scale' : nn .Parameter (torch .ones (1 , 1 , chan )) if not gated_scale else None ,
127
- 'bias' : nn .Parameter (rand_uniform ((1 , 1 , chan ), - 1e-3 , 1e-3 )),
128
127
'w_gate' : nn .Parameter (rand_uniform ((chan , chan ), - 1e-3 , 1e-3 )) if gated_scale else None
129
128
})
130
129
@@ -137,14 +136,14 @@ def forward(self, features):
137
136
138
137
# Transform on norms
139
138
parameters = self .transform [degree ]
140
- gate_weights , bias , scale = parameters ['w_gate' ], parameters [ 'bias ' ], parameters ['scale' ]
139
+ gate_weights , scale = parameters ['w_gate' ], parameters ['scale' ]
141
140
142
141
transformed = rearrange (norm , '... () -> ...' )
143
142
144
143
if not exists (scale ):
145
144
scale = einsum ('b n d, d e -> b n e' , transformed , gate_weights )
146
145
147
- transformed = self .nonlin (transformed * scale + bias )
146
+ transformed = self .nonlin (transformed * scale )
148
147
transformed = rearrange (transformed , '... -> ... ()' )
149
148
150
149
# Nonlinearity on norm
0 commit comments