-
Notifications
You must be signed in to change notification settings - Fork 47
Open
Description
Hi, thanks for your wonderful work.
I encountered a question when reading class EquivariantLayerNormV2
in /nets/layer_norm.py
.
On computing the field mean with
field_mean = torch.mean(field, dim=1, keepdim=True) # [batch, mul, 1]]
,
Should dim here be actually -1
?
Since we also compute field_norm
withdim==-1
in next few lines.
Related codes:
for mul, ir in self.irreps: # mul is the multiplicity (number of copies) of some irrep type (ir)
d = ir.dim
field = node_input.narrow(1, ix, mul*d)
ix += mul * d
# [batch * sample, mul, repr]
field = field.reshape(-1, mul, d)
# For scalars first compute and subtract the mean
if ir.l == 0 and ir.p == 1:
# TODO: here the dim should be -1?
field_mean = torch.mean(field, dim=1, keepdim=True) # [batch, mul, 1]]
field = field - field_mean
# Then compute the rescaling factor (norm of each feature vector)
# Rescaling of the norms themselves based on the option "normalization"
if self.normalization == 'norm':
field_norm = field.pow(2).sum(-1) # [batch * sample, mul]
elif self.normalization == 'component':
field_norm = field.pow(2).mean(-1) # [batch * sample, mul]
Metadata
Metadata
Assignees
Labels
No labels