Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions internlm/core/parallel/comm/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class EmbeddingWeightParallelCommunicator:

def __init__(self, parallel_mode: ParallelMode) -> None:
self.parallel_mode = parallel_mode
self.emb_column = 1
self.vocab_dim = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里麻烦的一点是要是ISP模型保持原来默认切分维度,又有了vocab切的可选项,要怎么办

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能是register_module_hook的时候,判断一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok


self._cur_micro_step = 0
self._num_micro_step = gpc.config.data.micro_num
Expand All @@ -165,7 +165,7 @@ def forward(ctx, inputs: torch.Tensor): # pylint: disable=W0613
if module.weight.evo_tensor is None:
module.weight.evo_tensor = module.weight.data

module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.emb_column)
module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.vocab_dim)
inputs = inputs.detach()
return inputs

Expand All @@ -188,7 +188,7 @@ def forward(ctx, output: torch.Tensor): # pylint: disable=W0613

@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor: # pylint: disable=W0613
module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.emb_column)
module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.vocab_dim)
return grad_output

def _pre_forward_hook(module, inputs): # pylint: disable=W0613
Expand All @@ -205,7 +205,7 @@ def _post_forward_hook(module, inputs, output): # pylint: disable=W0613
def grad_reduce_hook(self, param: torch.Tensor):

_grad, _ = reduce_scatter_raw(
param.grad, gpc.get_group(self.parallel_mode), op=dist.ReduceOp.AVG, reduce_dim=self.emb_column
param.grad, gpc.get_group(self.parallel_mode), op=dist.ReduceOp.AVG, reduce_dim=self.vocab_dim
)
if param.evo_tensor.grad is None:
param.evo_tensor.grad = _grad
Expand Down
15 changes: 13 additions & 2 deletions internlm/core/parallel/comm/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
all_gather_raw,
all_reduce_raw,
gather_forward_split_backward,
reduce_forward,
reduce_scatter_raw,
split_forward_gather_backward,
)
Expand Down Expand Up @@ -341,7 +342,12 @@ def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tup
"""
_emb_dim = 2 # [bsz, seqlen, emb_dim]

return gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim)
if module.vocab_parallel:
output = reduce_forward(output, self._parallel_mode)
else:
output = gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim)

return output


class EmbeddingSequenceParallelCommunicator:
Expand All @@ -363,7 +369,12 @@ def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tup
"""
_emb_dim, _seq_dim = 2, 1 # [bsz, seqlen, emb_dim]

output = gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim)
# tp:
if module.vocab_parallel:
output = reduce_forward(output, self._parallel_mode)
else:
output = gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim)
# sp:
output = split_forward_gather_backward(output, self._parallel_mode, dim=_seq_dim)

return output
37 changes: 37 additions & 0 deletions internlm/core/parallel/comm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ def _gather(input_, parallel_mode, dim=-1):
return output


def _reduce(input_, parallel_mode):
# skip if only one rank involved
if gpc.get_world_size(parallel_mode) == 1:
return input_

group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode)
dist.all_reduce(input_, group=group)

return input_


class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate.

Expand Down Expand Up @@ -174,6 +185,32 @@ def split_forward_gather_backward(input_, parallel_mode, dim):
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)


class _ReduceForward(torch.autograd.Function):
"""
All-reduce the input from the model parallel region.

Args:
input_: input matrix.
parallel_mode: parallel mode.
"""

@staticmethod
def symbolic(input_):
return _reduce(input_, parallel_mode=None)

@staticmethod
def forward(ctx, input_, parallel_mode): # pylint: disable=W0613
return _reduce(input_, parallel_mode)

@staticmethod
def backward(ctx, grad_output): # pylint: disable=W0613
return grad_output, None


def reduce_forward(input_, parallel_mode):
return _ReduceForward.apply(input_, parallel_mode)


def all_gather_raw(
input_: Tensor,
process_group: ProcessGroup,
Expand Down
52 changes: 47 additions & 5 deletions internlm/model/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from einops import rearrange
from torch import Tensor, nn

from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.ops.rotary_emb import apply_rotary_emb
from internlm.utils.parallel import is_using_isp
Expand All @@ -33,6 +34,7 @@ def __init__(
*args,
padding_idx: int = None,
dtype: torch.dtype = None,
vocab_parallel: bool = False,
**kwargs,
):
super().__init__()
Expand All @@ -42,14 +44,54 @@ def __init__(
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.vocab_parallel = vocab_parallel

if is_using_isp():
# isp: split vocab_size to support the sharing of parameters between embedding and head.
assert (
num_embeddings % gpc.weight_parallel_size == 0
), f"{num_embeddings} is not divisible by {gpc.weight_parallel_size}"
self.num_embeddings_per_partition = num_embeddings // gpc.weight_parallel_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里默认了ISP模式下就采用了vocab_parallel。是不是可以用vocab_parallel作为一个统一的控制,它只和是否用embedding和head共享权重相关,如果用户在Modeling文件里,需要共享权重,则手动设定vocab_parallel为true即可。其他情况下,默认走之前的切分emb的逻辑。避免之前的代码出现BC,特别是有一些llama模型加载HF权重的设计,都是走的切分emb维度

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

比如CI的那个错误

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

self.embed_dim_per_partition = embedding_dim
elif vocab_parallel:

assert (
num_embeddings % gpc.tensor_parallel_size == 0
), f"{num_embeddings} is not divisible by {gpc.tensor_parallel_size}"

self.num_embeddings_per_partition = num_embeddings // gpc.tensor_parallel_size
self.embed_dim_per_partition = embedding_dim
self.vocab_start_index = gpc.get_local_rank(ParallelMode.TENSOR) * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
else:
# mtp/msp/fsp: do not support the sharing of parameters between embedding and head,
# use VocabParallelEmbedding1D instead.
assert (
embedding_dim % gpc.tensor_parallel_size == 0
), f"{embedding_dim} is not divisible by {gpc.tensor_parallel_size}"
self.num_embeddings_per_partition = num_embeddings
self.embed_dim_per_partition = embedding_dim // gpc.tensor_parallel_size

self.weight = nn.Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), dtype=dtype)
)

def forward(self, input_: Tensor) -> Tensor:
if self.vocab_parallel:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

前面is_using_isp切分的是self.num_embeddings_per_partition = num_embeddings // gpc.weight_parallel_size,但是并没有构建vocab_start_index等

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isp是聚合参数,所以不能走vocab_parallel的代码,需要走原来的逻辑

# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_

_parallel_size = gpc.weight_parallel_size if is_using_isp() else gpc.tensor_parallel_size
output = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)

embed_dim_per_partition = embedding_dim // _parallel_size
self.weight = nn.Parameter(torch.empty((num_embeddings, embed_dim_per_partition), dtype=dtype))
if self.vocab_parallel:
output[input_mask, :] = 0.0

def forward(self, input_: Tensor) -> Tensor:
return F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
return output


class RotaryEmbedding(torch.nn.Module):
Expand Down
Loading