diff --git a/megatron/model/__init__.py b/megatron/model/__init__.py index 64158ef99a..fcd1253e0f 100644 --- a/megatron/model/__init__.py +++ b/megatron/model/__init__.py @@ -13,7 +13,7 @@ from .fused_rmsnorm import RMSNorm else: from .rmsnorm import RMSNorm - from torch.nn import LayerNorm + from .layernorm import LayerNorm from .distributed import DistributedDataParallel from .bert_model import BertModel diff --git a/megatron/model/layernorm.py b/megatron/model/layernorm.py new file mode 100644 index 0000000000..f560e6bf84 --- /dev/null +++ b/megatron/model/layernorm.py @@ -0,0 +1,21 @@ +import numbers + +import torch +from torch.nn.parameter import Parameter + +class LayerNorm(torch.nn.Module): + def __init__(self, normalized_shape, eps: float = 1e-5, sequence_parallel=False): + super(LayerNorm, self).__init__() + + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.eps = eps + self.weight = Parameter(torch.ones(normalized_shape)) + self.bias = Parameter(torch.zeros(normalized_shape)) + self.sequence_parallel = sequence_parallel + setattr(self.weight, 'sequence_parallel', self.sequence_parallel) + + def forward(self, x): + output = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return output