-
Notifications
You must be signed in to change notification settings - Fork 69
add vacab parallel embedding #315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
from einops import rearrange | ||
from torch import Tensor, nn | ||
|
||
from internlm.core.context import ParallelMode | ||
from internlm.core.context import global_context as gpc | ||
from internlm.model.ops.rotary_emb import apply_rotary_emb | ||
from internlm.utils.parallel import is_using_isp | ||
|
@@ -33,6 +34,7 @@ def __init__( | |
*args, | ||
padding_idx: int = None, | ||
dtype: torch.dtype = None, | ||
vocab_parallel: bool = False, | ||
**kwargs, | ||
): | ||
super().__init__() | ||
|
@@ -42,14 +44,54 @@ def __init__( | |
self.padding_idx = padding_idx | ||
self.embed_args = args | ||
self.embed_kwargs = kwargs | ||
self.vocab_parallel = vocab_parallel | ||
|
||
if is_using_isp(): | ||
# isp: split vocab_size to support the sharing of parameters between embedding and head. | ||
assert ( | ||
num_embeddings % gpc.weight_parallel_size == 0 | ||
), f"{num_embeddings} is not divisible by {gpc.weight_parallel_size}" | ||
self.num_embeddings_per_partition = num_embeddings // gpc.weight_parallel_size | ||
|
||
self.embed_dim_per_partition = embedding_dim | ||
elif vocab_parallel: | ||
|
||
assert ( | ||
num_embeddings % gpc.tensor_parallel_size == 0 | ||
), f"{num_embeddings} is not divisible by {gpc.tensor_parallel_size}" | ||
|
||
self.num_embeddings_per_partition = num_embeddings // gpc.tensor_parallel_size | ||
self.embed_dim_per_partition = embedding_dim | ||
self.vocab_start_index = gpc.get_local_rank(ParallelMode.TENSOR) * self.num_embeddings_per_partition | ||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition | ||
else: | ||
# mtp/msp/fsp: do not support the sharing of parameters between embedding and head, | ||
# use VocabParallelEmbedding1D instead. | ||
assert ( | ||
embedding_dim % gpc.tensor_parallel_size == 0 | ||
), f"{embedding_dim} is not divisible by {gpc.tensor_parallel_size}" | ||
self.num_embeddings_per_partition = num_embeddings | ||
self.embed_dim_per_partition = embedding_dim // gpc.tensor_parallel_size | ||
|
||
self.weight = nn.Parameter( | ||
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), dtype=dtype) | ||
) | ||
|
||
def forward(self, input_: Tensor) -> Tensor: | ||
if self.vocab_parallel: | ||
# Build the mask. | ||
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 前面is_using_isp切分的是self.num_embeddings_per_partition = num_embeddings // gpc.weight_parallel_size,但是并没有构建vocab_start_index等 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isp是聚合参数,所以不能走vocab_parallel的代码,需要走原来的逻辑 |
||
# Mask the input. | ||
masked_input = input_.clone() - self.vocab_start_index | ||
masked_input[input_mask] = 0 | ||
else: | ||
masked_input = input_ | ||
|
||
_parallel_size = gpc.weight_parallel_size if is_using_isp() else gpc.tensor_parallel_size | ||
output = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) | ||
|
||
embed_dim_per_partition = embedding_dim // _parallel_size | ||
self.weight = nn.Parameter(torch.empty((num_embeddings, embed_dim_per_partition), dtype=dtype)) | ||
if self.vocab_parallel: | ||
output[input_mask, :] = 0.0 | ||
|
||
def forward(self, input_: Tensor) -> Tensor: | ||
return F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) | ||
return output | ||
|
||
|
||
class RotaryEmbedding(torch.nn.Module): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里麻烦的一点是要是ISP模型保持原来默认切分维度,又有了vocab切的可选项,要怎么办
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可能是register_module_hook的时候,判断一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok