Skip to content

[QUESTION]] about EquivariantLayerNormV2 #13

@kzhoa

Description

@kzhoa

Hi, thanks for your wonderful work.
I encountered a question when reading class EquivariantLayerNormV2 in /nets/layer_norm.py .
On computing the field mean with
field_mean = torch.mean(field, dim=1, keepdim=True) # [batch, mul, 1]] ,
Should dim here be actually -1 ?
Since we also compute field_norm withdim==-1 in next few lines.

Related codes:

for mul, ir in self.irreps:  # mul is the multiplicity (number of copies) of some irrep type (ir)
            d = ir.dim
            field = node_input.narrow(1, ix, mul*d)
            ix += mul * d

            # [batch * sample, mul, repr]
            field = field.reshape(-1, mul, d)

            # For scalars first compute and subtract the mean
            if ir.l == 0 and ir.p == 1:
                # TODO:  here the dim should be -1?
                field_mean = torch.mean(field, dim=1, keepdim=True) # [batch, mul, 1]]
                field = field - field_mean
                
            # Then compute the rescaling factor (norm of each feature vector)
            # Rescaling of the norms themselves based on the option "normalization"
            if self.normalization == 'norm':
                field_norm = field.pow(2).sum(-1)  # [batch * sample, mul]
            elif self.normalization == 'component':
                field_norm = field.pow(2).mean(-1)  # [batch * sample, mul]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions