Skip to content

Commit 44619fa

Browse files
committed
add sequence_parallel in layernorm init to enable 3D parallelism can run successfully with DeepSpeed
1 parent 3e1da1f commit 44619fa

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

megatron/model/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .fused_rmsnorm import RMSNorm
1414
else:
1515
from .rmsnorm import RMSNorm
16-
from torch.nn import LayerNorm
16+
from .layernorm import LayerNorm
1717

1818
from .distributed import DistributedDataParallel
1919
from .bert_model import BertModel

megatron/model/layernorm.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import numbers
2+
3+
import torch
4+
from torch.nn.parameter import Parameter
5+
6+
class LayerNorm(torch.nn.Module):
7+
def __init__(self, normalized_shape, eps: float = 1e-5, sequence_parallel=False):
8+
super(LayerNorm, self).__init__()
9+
10+
if isinstance(normalized_shape, numbers.Integral):
11+
normalized_shape = (normalized_shape,)
12+
self.normalized_shape = torch.Size(normalized_shape)
13+
self.eps = eps
14+
self.weight = Parameter(torch.ones(normalized_shape))
15+
self.bias = Parameter(torch.zeros(normalized_shape))
16+
self.sequence_parallel = sequence_parallel
17+
setattr(self.weight, 'sequence_parallel', self.sequence_parallel)
18+
19+
def forward(self, x):
20+
output = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
21+
return output

0 commit comments

Comments
 (0)