From cd2798770c8b9f8cf63499442f998f0199d6cec0 Mon Sep 17 00:00:00 2001 From: sallyjunjun Date: Mon, 10 Feb 2025 11:00:33 +0800 Subject: [PATCH 1/5] add tp overlap feature --- configs/7B_sft.py | 10 +- internlm/core/parallel/shard.py | 10 +- .../core/scheduler/pipeline_scheduler_zb.py | 1 - internlm/core/trainer_builder.py | 27 +++ internlm/initialize/launch.py | 23 ++- internlm/model/modeling_internlm2.py | 29 ++- internlm/model/modules/linear.py | 187 ++++++++++++++++++ internlm/model/modules/mha.py | 25 ++- internlm/model/modules/mlp.py | 22 ++- internlm/model/modules/utils.py | 25 ++- internlm/train/pipeline.py | 4 +- 11 files changed, 338 insertions(+), 25 deletions(-) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 4799b5f35..8cc8e4fd1 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -192,7 +192,15 @@ """ parallel = dict( zero1=dict(size=-1), - tensor=dict(size=1, mode="mtp"), + tensor=dict( + size=1, + mode="mtp", + tp_overlap=False, + tp_overlap_cfg=dict( + tp_comm_overlap_ag=True, + tp_comm_overlap_rs=True, + ), + ), pipeline=dict(size=1, interleaved_overlap=True), weight=dict(size=1, overlap=True), ) diff --git a/internlm/core/parallel/shard.py b/internlm/core/parallel/shard.py index 4dc6a1f5b..845db7908 100644 --- a/internlm/core/parallel/shard.py +++ b/internlm/core/parallel/shard.py @@ -161,13 +161,19 @@ def get_parallel_strategies_split_mode(linear_name: str) -> str: if linear_name in ("gate"): return "gate" # for MoE model elif linear_name in ("wqkv", "wq", "wk", "wv", "wkv", "w1", "w3", "w13"): - return "column" + if gpc.config.parallel.tensor.tp_overlap: + return "tecolumn" + else: + return "column" elif linear_name in ("fc1", "fc2", "linear_1", "linear_2"): # for vit model return "column" elif linear_name in ("wo", "out_proj", "w2") and tp_mode == TensorParallelMode.isp.name: return "column" elif linear_name in ("wo", "out_proj", "w2"): - return "row" + if gpc.config.parallel.tensor.tp_overlap: + return "terow" + else: + return "row" elif linear_name in ("grouped_w1", "grouped_w2", "grouped_w3") and tp_mode == "isp": return "grouped_wp" elif linear_name in ("grouped_w1", "grouped_w3"): diff --git a/internlm/core/scheduler/pipeline_scheduler_zb.py b/internlm/core/scheduler/pipeline_scheduler_zb.py index 75cf18448..9cd2041ad 100644 --- a/internlm/core/scheduler/pipeline_scheduler_zb.py +++ b/internlm/core/scheduler/pipeline_scheduler_zb.py @@ -901,7 +901,6 @@ def _run_steady_loop( else: next_unit_chunk_id = 1 - # import pdb; pdb.set_trace() if unit_step == num_units_stage1 - 1: chunk0_B_need_recv_prev_chunk0_output = False else: diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index 71c30d00d..b0f744f01 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -19,6 +19,7 @@ from internlm.initialize.initialize_trainer import initialize_trainer from internlm.model.losses.ce_loss import InternLoss from internlm.model.metrics import AccPerplex +from internlm.model.modules.utils import is_te_min_version from internlm.monitor.monitor import send_alert_message from internlm.train.pipeline import ( get_scheduler_hooks, @@ -154,6 +155,9 @@ def __init__( scheduler_hooks=get_scheduler_hooks(self.metric, optimizer, isp_communicator), ) + if gpc.config.parallel["tensor"]["tp_overlap"]: + self._initialize_tp_comm_ub() + # set attributes self._set_attributes( kwargs["profiling"], train_dl, val_dls, train_state, optimizer, beta2_scheduler, isp_communicator @@ -249,6 +253,29 @@ def _initialize_batch_skipper(self, train_state) -> BatchSkipper: skip_batches = streaming_simple_resume(train_state) return BatchSkipper(skip_batches) + def _initialize_tp_comm_ub(self): + """ initializing the communicators with user buffers for high-performance tensor-model-parallel + communication overlap """ + try: + import transformer_engine + from transformer_engine.pytorch import module as te_module + + except ImportError: + raise RuntimeError("Tensor Parallel Communication/GEMM Overlap optimization needs 'transformer_engine' package") + + input_shape = [gpc.config.data["seq_len"] * gpc.config.data["micro_bsz"], gpc.config.model["hidden_size"]] + + if is_te_min_version("1.9.0"): + # The process group with the target bootstrap backend is created in Transformer Engine. + te_module.base.initialize_ub(shape = input_shape, tp_size = gpc.config.parallel["tensor"]["size"], + use_fp8 = False, bootstrap_backend = 'nccl') + else: + # Create a MPI process group to help with TP communication overlap bootstrap. + torch.distributed.new_group(backend='mpi') + + te_module.base.initialize_ub(shape = input_shape, tp_size = gpc.config.parallel["tensor"]["size"], + use_fp8 = False) + def _set_attributes(self, profiling, train_dl, val_dls, train_state, optimizer, beta2_scheduler, isp_communicator): self.profiling = profiling self.train_dl = train_dl diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index e1cb2f0d2..6a049bab8 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -101,7 +101,7 @@ def args_sanity_check(): gpc.config.parallel.pipeline._add_item("mode", "1F1B") if "tensor" not in gpc.config.parallel: - gpc.config.parallel._add_item("tensor", dict(size=1, mode=TensorParallelMode.mtp.name)) + gpc.config.parallel._add_item("tensor", dict(size=1, mode=TensorParallelMode.mtp.name, tp_overlap=False, tp_overlap_cfg=None)) if "weight" not in gpc.config.parallel: gpc.config.parallel._add_item( @@ -398,7 +398,7 @@ def args_sanity_check(): # set default value for tensor parallel if isinstance(gpc.config.parallel["tensor"], int): - gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode=TensorParallelMode.mtp.name) + gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode=TensorParallelMode.mtp.name, tp_overlap=False, tp_overlap_cfg=None) if gpc.config.parallel["tensor"].get("mode", None) is None: gpc.config.parallel["tensor"]["mode"] = TensorParallelMode.mtp.name if gpc.config.parallel["tensor"]["mode"] == TensorParallelMode.isp.name: @@ -455,6 +455,25 @@ def args_sanity_check(): if gpc.config.model.get("parallel_output", False) is False: logger.warning("When enable sequence parallel, it recommend to enable parallel_output") + if gpc.config.parallel["tensor"].get("tp_overlap", None) is None: + gpc.config.parallel["tensor"]["tp_overlap"] = False + elif gpc.config.parallel["tensor"].get("tp_overlap", None) is True: + assert ( + gpc.config.parallel["tensor"].get("mode", None) in [ + TensorParallelMode.msp.name, + TensorParallelMode.fsp.name, + ] + ), "tp_overlap can be set to true only in msp and fsp mode" + + if gpc.config.parallel["tensor"].get("tp_overlap_cfg", None) is None: + gpc.config.parallel["tensor"]["tp_overlap_cfg"] = dict( + tp_comm_overlap_ag=True, + tp_comm_overlap_rs=True, + tp_comm_bulk_wgrad=True, + tp_comm_bulk_dgrad=True, + tp_comm_overlap_rs_dgrad=False + ) + # set default value for weight parallel if gpc.config.parallel["weight"].get("overlap", None) is None: gpc.config.parallel["weight"]["overlap"] = False diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/modeling_internlm2.py index 0453b9dcb..bee7e4f98 100644 --- a/internlm/model/modeling_internlm2.py +++ b/internlm/model/modeling_internlm2.py @@ -535,11 +535,30 @@ def load_hf_weights(folder: str, model: nn.Module) -> None: dim=0, )[local_rank] else: - state_dict[f"layers.{i}.attention.wqkv.weight"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.attention.wqkv.weight"), - split_size, - dim=0, - )[local_rank] + key = f"model.layers.{layer_ids}.attention.wqkv.weight" + if key in state_dict: + state_dict[f"layers.{i}.attention.wqkv.weight"] = torch.chunk( + state_dict.pop(key), + split_size, + dim=0, + )[local_rank] + else: + wq = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.attention.wq.weight"), + split_size, + dim=0, + )[local_rank] + wk = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.attention.wk.weight"), + split_size, + dim=0, + )[local_rank] + wv = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.attention.wv.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wqkv.weight"] = torch.cat([wq, wk, wv], dim=0) wo_name = "self_attn.o_proj" if is_internlm3 else "attention.wo" state_dict[f"layers.{i}.attention.wo.weight"] = torch.chunk( state_dict.pop(f"model.layers.{layer_ids}.{wo_name}.weight"), diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 6f1902689..043825812 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -10,6 +10,7 @@ import torch import torch.distributed as dist from torch import nn +import transformer_engine as te from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode @@ -25,6 +26,7 @@ linear_backward_op, linear_forward_op, ) +from internlm.model.modules.utils import is_te_min_version from internlm.utils.logger import get_logger if TYPE_CHECKING: @@ -1009,6 +1011,172 @@ def __init__( self.full_weight_shape = torch.Size((num_groups, in_features, out_features)) +class TEColumnParallelLinear(te.pytorch.LayerNormLinear): + """ + Wrapper for the Transformer-Engine's `LayerNormLinear` layer. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + skip_bias_add: bool, + is_expert: bool, + tp_comm_buffer_name: str = None, + ): + if is_expert: + raise ValueError('Transformer Engine linear layers do not yet support MoE') + + # TE returns a zero length Tensor when bias=False and + # return_bias=True, but we prefer None. So in that case we + # tell TE to not return the bias, and return None + # ourselves. This way our forward always returns two values + # and we don't have to deal with the zero length Tensor. + self.te_return_bias = skip_bias_add and bias + self.is_first_microbatch = True + + extra_kwargs = {"params_dtype": gpc.config.model.dtype} + if is_te_min_version("0.12.0"): + extra_kwargs["device"] = torch.cuda.current_device() + + # Only Transformer-Engine version >= 0.11.0 supports `RMSNorm` + # if is_te_min_version("0.11.0"): + # extra_kwargs["normalization"] = self.config.normalization + # elif self.config.normalization != "LayerNorm": + # te_version = get_te_version() + # raise ValueError( + # f"Transformer Engine v{te_version} does not support {self.config.normalization}." + # ) + + if gpc.config.parallel["tensor"]["tp_overlap"]: + extra_kwargs["ub_bulk_wgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get("tp_comm_bulk_wgrad", True) + extra_kwargs["ub_bulk_dgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get("tp_comm_bulk_dgrad", True) + if is_te_min_version("1.5.0", check_equality=False): + extra_kwargs["ub_overlap_ag"] = ( + gpc.config.parallel["tensor"]["tp_overlap_cfg"].get("tp_comm_overlap_ag", True) + ) + if is_te_min_version("1.6.0.dev0", check_equality=False): + extra_kwargs["ub_overlap_rs_dgrad"] = ( + gpc.config.parallel["tensor"]["tp_overlap_cfg"].get("tp_comm_overlap_rs_dgrad", False) + ) + else: + raise NotImplementedError('tp overlap is supported only when transformer_engine version >= 1.5.0') + assert ( + tp_comm_buffer_name is not None + ), "Buffer name should be set to configure communication overlap settings" + extra_kwargs["ub_name"] = tp_comm_buffer_name + + parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert) + tp_size = gpc.get_world_size(parallel_mode) + tp_group = gpc.get_group(parallel_mode) + super().__init__( + in_features=in_features, + out_features=out_features, + sequence_parallel=gpc.config.parallel.sequence_parallel, + tp_group=tp_group, + tp_size=tp_size, + bias=bias, + return_bias=self.te_return_bias, + parallel_mode="column", + return_layernorm_output=False, + **extra_kwargs, + ) + + def forward(self, x): + """Forward.""" + _is_first_microbatch = self.is_first_microbatch + x = x.transpose(0, 1) + out = super().forward(x, is_first_microbatch=_is_first_microbatch) + out = out.transpose(0, 1) + + self.is_first_microbatch = False + + return out + + +class TERowParallelLinear(te.pytorch.Linear): + """ + Wrapper for the Transformer-Engine's `Linear` layer. + """ + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + skip_bias_add: bool, + is_expert: bool = False, + tp_comm_buffer_name: str = None, + ): + # TE returns a zero length Tensor when bias=False and + # return_bias=True. Here we need a single Tensor + self.te_return_bias = skip_bias_add and bias + self.is_first_microbatch = True + + extra_kwargs = {"params_dtype": gpc.config.model.dtype} + if is_te_min_version("0.12.0"): + extra_kwargs["device"] = torch.cuda.current_device() + + if gpc.config.parallel["tensor"]["tp_overlap"]: + if is_te_min_version("1.5.0"): + extra_kwargs["ub_overlap_ag"] = ( + gpc.config.parallel["tensor"]["tp_overlap_cfg"].get("tp_comm_overlap_ag", True) + ) + extra_kwargs["ub_overlap_rs"] = ( + gpc.config.parallel["tensor"]["tp_overlap_cfg"].get("tp_comm_overlap_rs", True) + ) + # Disable ub overlap for experts. + if is_expert: + extra_kwargs["ub_overlap_ag"] = False + extra_kwargs["ub_overlap_rs"] = False + else: + raise NotImplementedError('tp overlap is supported only when transformer_engine version >= 1.5.0') + assert ( + tp_comm_buffer_name is not None + ), "Buffer name should be set to configure communication overlap settings" + extra_kwargs["ub_name"] = tp_comm_buffer_name + + self.expert_parallel = gpc.config.parallel["expert"].get("size", 1) > 1 + parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert) + # Disable communications in TE when using TP or EP by making TE agnostic of model parallel. + tp_size = gpc.get_world_size(parallel_mode) + tp_group = gpc.get_group(parallel_mode) + explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel) + + split_mode = "row" + if explicit_expert_comm: + assert in_features % tp_size == 0, "{} is not divisible by {}".format(in_features, tp_size) + in_features = in_features // tp_size + split_mode = None + tp_size = 1 + tp_group = None + + super().__init__( + in_features=in_features, + out_features=out_features, + sequence_parallel=gpc.config.parallel.sequence_parallel, + tp_group=tp_group, + tp_size=tp_size, + bias=bias, + return_bias=self.te_return_bias, + parallel_mode=split_mode, + **extra_kwargs, + ) + + for param in self.parameters(): + setattr(param, 'allreduce', not (is_expert and self.expert_parallel)) + + def forward(self, x): + """Forward.""" + _is_first_microbatch = self.is_first_microbatch + x = x.transpose(0, 1) + out = super().forward(x, is_first_microbatch=_is_first_microbatch) + out = out.transpose(0, 1) + self.is_first_microbatch = False + + return out + + def new_linear( name: str, in_features: int, @@ -1021,6 +1189,7 @@ def new_linear( weight_scale: int = 1, norm_head: bool = False, is_expert: bool = False, + tp_comm_buffer_name: str = None, **kwargs, ) -> nn.Linear: @@ -1067,6 +1236,15 @@ def new_linear( dtype, is_expert, ) + elif split_mode == "tecolumn": + return TEColumnParallelLinear( + in_features, + out_features, + bias, + False, + is_expert, + tp_comm_buffer_name, + ) elif split_mode == "row": return RowParallelLinear( in_features, @@ -1077,6 +1255,15 @@ def new_linear( dtype, is_expert, ) + elif split_mode == "terow": + return TERowParallelLinear( + in_features, + out_features, + bias, + False, + is_expert, + tp_comm_buffer_name, + ) elif split_mode == "grouped_wp": return GroupedWPLinear( in_features, diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index 42418a212..f814944d7 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -157,8 +157,12 @@ def __init__( if self.enable_qkv_fusion: # bias=True is according to https://spaces.ac.cn/archives/9577 - self.wqkv = new_linear("wqkv", embed_dim, 3 * embed_dim, bias, **factory_kwargs) + if gpc.config.parallel["tensor"]["tp_overlap"]: + self.wqkv = new_linear("wqkv", embed_dim, 3 * embed_dim, bias, tp_comm_buffer_name="qkv", **factory_kwargs) + else: + self.wqkv = new_linear("wqkv", embed_dim, 3 * embed_dim, bias, **factory_kwargs) else: + assert gpc.config.parallel["tensor"]["tp_overlap"] is False, "tp overlap currently only support fused wqkv." self.wq = new_linear("wq", embed_dim, embed_dim, bias, **factory_kwargs) self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs) self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs) @@ -167,7 +171,10 @@ def __init__( self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) # output projection always have the bias (for now) (except for baichuan2 model) - self.out_proj = new_linear("out_proj", embed_dim, embed_dim, bias=out_bias, **factory_kwargs) + if gpc.config.parallel["tensor"]["tp_overlap"]: + self.out_proj = new_linear("out_proj", embed_dim, embed_dim, bias=out_bias, tp_comm_buffer_name="proj", **factory_kwargs) + else: + self.out_proj = new_linear("out_proj", embed_dim, embed_dim, bias=out_bias, **factory_kwargs) def register_checkpoint_compatibility_hooks( self, pre_load_hook: Optional[Callable] = None, pre_save_hook: Optional[Callable] = None @@ -193,7 +200,6 @@ def _training(self, x, **kwargs): if self.enable_qkv_fusion: qkv = self.wqkv(x) qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) - q = qkv[:, :, 0].squeeze(2) k = qkv[:, :, 1].squeeze(2) v = qkv[:, :, 2].squeeze(2) @@ -214,8 +220,8 @@ def _training(self, x, **kwargs): if gpc.config.data.use_packed_dataset is False or self.training is False: kwargs.pop("max_seqlen_q", None) kwargs.pop("max_seqlen_k", None) - context = self.inner_attn(q, k, v, **kwargs) + context = self.inner_attn(q, k, v, **kwargs) # wo return self.out_proj(rearrange(context, "b s h d -> b s (h d)")) @@ -461,12 +467,16 @@ def __init__( if enable_qkv_fusion: assert bias is False, "Fuesd wqkv only support bias is False." - self.wqkv = new_linear("wqkv", embed_dim, q_dim + 2 * self.kv_dim, bias, **factory_kwargs) + if gpc.config.parallel["tensor"]["tp_overlap"]: + self.wqkv = new_linear("wqkv", embed_dim, q_dim + 2 * self.kv_dim, bias, tp_comm_buffer_name="qkv", **factory_kwargs) + else: + self.wqkv = new_linear("wqkv", embed_dim, q_dim + 2 * self.kv_dim, bias, **factory_kwargs) self._register_load_state_dict_pre_hook( partial(_qkv_pre_load_convert, q_dim=q_dim, kv_dim=self.kv_dim), with_module=True ) self._register_state_dict_hook(partial(_qkv_save_convert, q_dim=q_dim, kv_dim=self.kv_dim)) else: + assert gpc.config.parallel["tensor"]["tp_overlap"] is False, "tp overlap currently only support fused wqkv." self.wq = new_linear("wq", embed_dim, q_dim, bias, **factory_kwargs) self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs) self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs) @@ -478,7 +488,10 @@ def __init__( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout, layer_idx=layer_idx ) - self.wo = new_linear("wo", q_dim, embed_dim, bias, **factory_kwargs) + if gpc.config.parallel["tensor"]["tp_overlap"]: + self.wo = new_linear("wo", q_dim, embed_dim, bias, tp_comm_buffer_name="proj", **factory_kwargs) + else: + self.wo = new_linear("wo", q_dim, embed_dim, bias, **factory_kwargs) def register_checkpoint_compatibility_hooks( self, pre_load_hook: Optional[Callable] = None, pre_save_hook: Optional[Callable] = None diff --git a/internlm/model/modules/mlp.py b/internlm/model/modules/mlp.py index e51e5897f..74bc91bee 100644 --- a/internlm/model/modules/mlp.py +++ b/internlm/model/modules/mlp.py @@ -10,6 +10,7 @@ from internlm.model.modules.utils import Gelu, Silu from internlm.utils.logger import get_logger from internlm.utils.utils import ActivationType +from internlm.core.context import global_context as gpc logger = get_logger(__file__) @@ -86,16 +87,25 @@ def __init__( if self.mlp_layer_fusion: assert bias is False, "Fuesd FeedForward only support bias is False." - self.fused_w1_w3 = new_linear( - "w13", in_features, hidden_features * 2, bias, device=device, dtype=dtype, is_expert=is_expert - ) - self.w2 = new_linear( - "w2", hidden_features, out_features, bias, device=device, dtype=dtype, is_expert=is_expert - ) + if gpc.config.parallel["tensor"]["tp_overlap"]: + self.fused_w1_w3 = new_linear( + "w13", in_features, hidden_features * 2, bias, device=device, dtype=dtype, is_expert=is_expert, tp_comm_buffer_name="fc1" + ) + self.w2 = new_linear( + "w2", hidden_features, out_features, bias, device=device, dtype=dtype, is_expert=is_expert, tp_comm_buffer_name="fc2" + ) + else: + self.fused_w1_w3 = new_linear( + "w13", in_features, hidden_features * 2, bias, device=device, dtype=dtype, is_expert=is_expert + ) + self.w2 = new_linear( + "w2", hidden_features, out_features, bias, device=device, dtype=dtype, is_expert=is_expert + ) self._register_load_state_dict_pre_hook(_mlp_pre_load_convert, with_module=True) self._register_state_dict_hook(_mlp_save_convert) else: + assert gpc.config.parallel["tensor"]["tp_overlap"] is False, "tp overlap currently only support fused mlp." self.w1 = new_linear( "w1", in_features, hidden_features, bias, device=device, dtype=dtype, is_expert=is_expert ) diff --git a/internlm/model/modules/utils.py b/internlm/model/modules/utils.py index 9d43fbf00..5e7766762 100644 --- a/internlm/model/modules/utils.py +++ b/internlm/model/modules/utils.py @@ -4,7 +4,8 @@ import torch import torch.nn.functional as F from einops import rearrange - +from importlib.metadata import version +from packaging.version import Version as PkgVersion from internlm.utils.logger import get_logger logger = get_logger(__file__) @@ -91,3 +92,25 @@ def update_kv_cache(kv, inference_params, layer_idx): ) v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(kv[:, :, 1], "b s h d -> b h s d") return kv + + +def get_te_version(): + """Get TE version from __version__; if not available use pip's. Use caching.""" + + def get_te_version_str(): + import transformer_engine as te + + if hasattr(te, '__version__'): + return str(te.__version__) + else: + return version("transformer-engine") + + _te_version = PkgVersion(get_te_version_str()) + return _te_version + + +def is_te_min_version(version, check_equality=True): + """Check if minimum version of `transformer-engine` is installed.""" + if check_equality: + return get_te_version() >= PkgVersion(version) + return get_te_version() > PkgVersion(version) \ No newline at end of file diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 784a5305a..74d659ff0 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -59,6 +59,8 @@ RewardModelLinear, RowParallelLinear, ScaleColumnParallelLinear, + TERowParallelLinear, + TEColumnParallelLinear, new_linear, ) from internlm.model.modules.norm import new_layer_norm @@ -207,7 +209,7 @@ def _check_module(name, module): elif gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp(): setattr(param, IS_WEIGHT_EXPERT_DATA_PARALLEL, True) # for non-moe linear module - elif isinstance(module, ParallelLinearWithCommExt): + elif isinstance(module, (ParallelLinearWithCommExt, TERowParallelLinear, TEColumnParallelLinear)): for param in module.parameters(): if gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp(): setattr(param, IS_TENSOR_ZERO_PARALLEL, True) From 35f65019c82addcc67ed34310422756ae2914cd8 Mon Sep 17 00:00:00 2001 From: sallyjunjun Date: Wed, 19 Feb 2025 11:06:10 +0800 Subject: [PATCH 2/5] change te.pytorch.LayerNormLinear to te.pytorch.Linear in TEColumnParallelLinear, which requires transformer_engine v2.0 --- internlm/model/modules/linear.py | 18 ++---------------- internlm/model/modules/utils.py | 10 +++++++--- 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 043825812..a1c6cff87 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -1011,9 +1011,9 @@ def __init__( self.full_weight_shape = torch.Size((num_groups, in_features, out_features)) -class TEColumnParallelLinear(te.pytorch.LayerNormLinear): +class TEColumnParallelLinear(te.pytorch.Linear): """ - Wrapper for the Transformer-Engine's `LayerNormLinear` layer. + Wrapper for the Transformer-Engine's `Linear` layer. """ def __init__( @@ -1040,15 +1040,6 @@ def __init__( if is_te_min_version("0.12.0"): extra_kwargs["device"] = torch.cuda.current_device() - # Only Transformer-Engine version >= 0.11.0 supports `RMSNorm` - # if is_te_min_version("0.11.0"): - # extra_kwargs["normalization"] = self.config.normalization - # elif self.config.normalization != "LayerNorm": - # te_version = get_te_version() - # raise ValueError( - # f"Transformer Engine v{te_version} does not support {self.config.normalization}." - # ) - if gpc.config.parallel["tensor"]["tp_overlap"]: extra_kwargs["ub_bulk_wgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get("tp_comm_bulk_wgrad", True) extra_kwargs["ub_bulk_dgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get("tp_comm_bulk_dgrad", True) @@ -1056,10 +1047,6 @@ def __init__( extra_kwargs["ub_overlap_ag"] = ( gpc.config.parallel["tensor"]["tp_overlap_cfg"].get("tp_comm_overlap_ag", True) ) - if is_te_min_version("1.6.0.dev0", check_equality=False): - extra_kwargs["ub_overlap_rs_dgrad"] = ( - gpc.config.parallel["tensor"]["tp_overlap_cfg"].get("tp_comm_overlap_rs_dgrad", False) - ) else: raise NotImplementedError('tp overlap is supported only when transformer_engine version >= 1.5.0') assert ( @@ -1079,7 +1066,6 @@ def __init__( bias=bias, return_bias=self.te_return_bias, parallel_mode="column", - return_layernorm_output=False, **extra_kwargs, ) diff --git a/internlm/model/modules/utils.py b/internlm/model/modules/utils.py index 5e7766762..2576b7b05 100644 --- a/internlm/model/modules/utils.py +++ b/internlm/model/modules/utils.py @@ -98,7 +98,10 @@ def get_te_version(): """Get TE version from __version__; if not available use pip's. Use caching.""" def get_te_version_str(): - import transformer_engine as te + try: + import transformer_engine as te + except (ModuleNotFoundError, ImportError): + return None if hasattr(te, '__version__'): return str(te.__version__) @@ -111,6 +114,7 @@ def get_te_version_str(): def is_te_min_version(version, check_equality=True): """Check if minimum version of `transformer-engine` is installed.""" + ver = get_te_version() if check_equality: - return get_te_version() >= PkgVersion(version) - return get_te_version() > PkgVersion(version) \ No newline at end of file + return ver is not None and ver >= PkgVersion(version) + return ver is not None and ver > PkgVersion(version) From a71091c76d06c4fcc075f117b846495f1c8d0211 Mon Sep 17 00:00:00 2001 From: sallyjunjun Date: Wed, 19 Feb 2025 11:48:55 +0800 Subject: [PATCH 3/5] fix lint check --- internlm/core/trainer_builder.py | 24 ++++++++++++++--------- internlm/initialize/launch.py | 22 +++++++++++---------- internlm/model/modules/linear.py | 33 ++++++++++++++++++-------------- internlm/model/modules/mha.py | 12 +++++++++--- internlm/model/modules/mlp.py | 20 ++++++++++++++++--- internlm/model/modules/utils.py | 6 ++++-- internlm/train/pipeline.py | 2 +- 7 files changed, 77 insertions(+), 42 deletions(-) diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index b0f744f01..98d13f6b1 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -254,27 +254,33 @@ def _initialize_batch_skipper(self, train_state) -> BatchSkipper: return BatchSkipper(skip_batches) def _initialize_tp_comm_ub(self): - """ initializing the communicators with user buffers for high-performance tensor-model-parallel - communication overlap """ + """initializing the communicators with user buffers for high-performance tensor-model-parallel + communication overlap""" try: - import transformer_engine from transformer_engine.pytorch import module as te_module except ImportError: - raise RuntimeError("Tensor Parallel Communication/GEMM Overlap optimization needs 'transformer_engine' package") + raise RuntimeError( + "Tensor Parallel Communication/GEMM Overlap optimization needs 'transformer_engine' package" + ) input_shape = [gpc.config.data["seq_len"] * gpc.config.data["micro_bsz"], gpc.config.model["hidden_size"]] if is_te_min_version("1.9.0"): # The process group with the target bootstrap backend is created in Transformer Engine. - te_module.base.initialize_ub(shape = input_shape, tp_size = gpc.config.parallel["tensor"]["size"], - use_fp8 = False, bootstrap_backend = 'nccl') + te_module.base.initialize_ub( + shape=input_shape, + tp_size=gpc.config.parallel["tensor"]["size"], + use_fp8=False, + bootstrap_backend="nccl", + ) else: # Create a MPI process group to help with TP communication overlap bootstrap. - torch.distributed.new_group(backend='mpi') + torch.distributed.new_group(backend="mpi") - te_module.base.initialize_ub(shape = input_shape, tp_size = gpc.config.parallel["tensor"]["size"], - use_fp8 = False) + te_module.base.initialize_ub( + shape=input_shape, tp_size=gpc.config.parallel["tensor"]["size"], use_fp8=False + ) def _set_attributes(self, profiling, train_dl, val_dls, train_state, optimizer, beta2_scheduler, isp_communicator): self.profiling = profiling diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 6a049bab8..f4e50d78c 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -101,7 +101,9 @@ def args_sanity_check(): gpc.config.parallel.pipeline._add_item("mode", "1F1B") if "tensor" not in gpc.config.parallel: - gpc.config.parallel._add_item("tensor", dict(size=1, mode=TensorParallelMode.mtp.name, tp_overlap=False, tp_overlap_cfg=None)) + gpc.config.parallel._add_item( + "tensor", dict(size=1, mode=TensorParallelMode.mtp.name, tp_overlap=False, tp_overlap_cfg=None) + ) if "weight" not in gpc.config.parallel: gpc.config.parallel._add_item( @@ -398,7 +400,9 @@ def args_sanity_check(): # set default value for tensor parallel if isinstance(gpc.config.parallel["tensor"], int): - gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode=TensorParallelMode.mtp.name, tp_overlap=False, tp_overlap_cfg=None) + gpc.config.parallel["tensor"] = dict( + size=gpc.config.parallel["tensor"], mode=TensorParallelMode.mtp.name, tp_overlap=False, tp_overlap_cfg=None + ) if gpc.config.parallel["tensor"].get("mode", None) is None: gpc.config.parallel["tensor"]["mode"] = TensorParallelMode.mtp.name if gpc.config.parallel["tensor"]["mode"] == TensorParallelMode.isp.name: @@ -458,20 +462,18 @@ def args_sanity_check(): if gpc.config.parallel["tensor"].get("tp_overlap", None) is None: gpc.config.parallel["tensor"]["tp_overlap"] = False elif gpc.config.parallel["tensor"].get("tp_overlap", None) is True: - assert ( - gpc.config.parallel["tensor"].get("mode", None) in [ - TensorParallelMode.msp.name, - TensorParallelMode.fsp.name, - ] - ), "tp_overlap can be set to true only in msp and fsp mode" - + assert gpc.config.parallel["tensor"].get("mode", None) in [ + TensorParallelMode.msp.name, + TensorParallelMode.fsp.name, + ], "tp_overlap can be set to true only in msp and fsp mode" + if gpc.config.parallel["tensor"].get("tp_overlap_cfg", None) is None: gpc.config.parallel["tensor"]["tp_overlap_cfg"] = dict( tp_comm_overlap_ag=True, tp_comm_overlap_rs=True, tp_comm_bulk_wgrad=True, tp_comm_bulk_dgrad=True, - tp_comm_overlap_rs_dgrad=False + tp_comm_overlap_rs_dgrad=False, ) # set default value for weight parallel diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index a1c6cff87..dd692505d 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -9,8 +9,8 @@ import torch import torch.distributed as dist -from torch import nn import transformer_engine as te +from torch import nn from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode @@ -20,13 +20,13 @@ get_parallel_strategies_split_mode, get_tensor_split_parallel_mode, ) +from internlm.model.modules.utils import is_te_min_version from internlm.model.ops.linear import ( gmm_backward_op, gmm_forward_op, linear_backward_op, linear_forward_op, ) -from internlm.model.modules.utils import is_te_min_version from internlm.utils.logger import get_logger if TYPE_CHECKING: @@ -1026,7 +1026,7 @@ def __init__( tp_comm_buffer_name: str = None, ): if is_expert: - raise ValueError('Transformer Engine linear layers do not yet support MoE') + raise ValueError("Transformer Engine linear layers do not yet support MoE") # TE returns a zero length Tensor when bias=False and # return_bias=True, but we prefer None. So in that case we @@ -1041,14 +1041,18 @@ def __init__( extra_kwargs["device"] = torch.cuda.current_device() if gpc.config.parallel["tensor"]["tp_overlap"]: - extra_kwargs["ub_bulk_wgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get("tp_comm_bulk_wgrad", True) - extra_kwargs["ub_bulk_dgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get("tp_comm_bulk_dgrad", True) + extra_kwargs["ub_bulk_wgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( + "tp_comm_bulk_wgrad", True + ) + extra_kwargs["ub_bulk_dgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( + "tp_comm_bulk_dgrad", True + ) if is_te_min_version("1.5.0", check_equality=False): - extra_kwargs["ub_overlap_ag"] = ( - gpc.config.parallel["tensor"]["tp_overlap_cfg"].get("tp_comm_overlap_ag", True) + extra_kwargs["ub_overlap_ag"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( + "tp_comm_overlap_ag", True ) else: - raise NotImplementedError('tp overlap is supported only when transformer_engine version >= 1.5.0') + raise NotImplementedError("tp overlap is supported only when transformer_engine version >= 1.5.0") assert ( tp_comm_buffer_name is not None ), "Buffer name should be set to configure communication overlap settings" @@ -1085,6 +1089,7 @@ class TERowParallelLinear(te.pytorch.Linear): """ Wrapper for the Transformer-Engine's `Linear` layer. """ + def __init__( self, in_features: int, @@ -1105,18 +1110,18 @@ def __init__( if gpc.config.parallel["tensor"]["tp_overlap"]: if is_te_min_version("1.5.0"): - extra_kwargs["ub_overlap_ag"] = ( - gpc.config.parallel["tensor"]["tp_overlap_cfg"].get("tp_comm_overlap_ag", True) + extra_kwargs["ub_overlap_ag"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( + "tp_comm_overlap_ag", True ) - extra_kwargs["ub_overlap_rs"] = ( - gpc.config.parallel["tensor"]["tp_overlap_cfg"].get("tp_comm_overlap_rs", True) + extra_kwargs["ub_overlap_rs"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( + "tp_comm_overlap_rs", True ) # Disable ub overlap for experts. if is_expert: extra_kwargs["ub_overlap_ag"] = False extra_kwargs["ub_overlap_rs"] = False else: - raise NotImplementedError('tp overlap is supported only when transformer_engine version >= 1.5.0') + raise NotImplementedError("tp overlap is supported only when transformer_engine version >= 1.5.0") assert ( tp_comm_buffer_name is not None ), "Buffer name should be set to configure communication overlap settings" @@ -1150,7 +1155,7 @@ def __init__( ) for param in self.parameters(): - setattr(param, 'allreduce', not (is_expert and self.expert_parallel)) + setattr(param, "allreduce", not (is_expert and self.expert_parallel)) def forward(self, x): """Forward.""" diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index f814944d7..8212402c7 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -158,7 +158,9 @@ def __init__( if self.enable_qkv_fusion: # bias=True is according to https://spaces.ac.cn/archives/9577 if gpc.config.parallel["tensor"]["tp_overlap"]: - self.wqkv = new_linear("wqkv", embed_dim, 3 * embed_dim, bias, tp_comm_buffer_name="qkv", **factory_kwargs) + self.wqkv = new_linear( + "wqkv", embed_dim, 3 * embed_dim, bias, tp_comm_buffer_name="qkv", **factory_kwargs + ) else: self.wqkv = new_linear("wqkv", embed_dim, 3 * embed_dim, bias, **factory_kwargs) else: @@ -172,7 +174,9 @@ def __init__( # output projection always have the bias (for now) (except for baichuan2 model) if gpc.config.parallel["tensor"]["tp_overlap"]: - self.out_proj = new_linear("out_proj", embed_dim, embed_dim, bias=out_bias, tp_comm_buffer_name="proj", **factory_kwargs) + self.out_proj = new_linear( + "out_proj", embed_dim, embed_dim, bias=out_bias, tp_comm_buffer_name="proj", **factory_kwargs + ) else: self.out_proj = new_linear("out_proj", embed_dim, embed_dim, bias=out_bias, **factory_kwargs) @@ -468,7 +472,9 @@ def __init__( if enable_qkv_fusion: assert bias is False, "Fuesd wqkv only support bias is False." if gpc.config.parallel["tensor"]["tp_overlap"]: - self.wqkv = new_linear("wqkv", embed_dim, q_dim + 2 * self.kv_dim, bias, tp_comm_buffer_name="qkv", **factory_kwargs) + self.wqkv = new_linear( + "wqkv", embed_dim, q_dim + 2 * self.kv_dim, bias, tp_comm_buffer_name="qkv", **factory_kwargs + ) else: self.wqkv = new_linear("wqkv", embed_dim, q_dim + 2 * self.kv_dim, bias, **factory_kwargs) self._register_load_state_dict_pre_hook( diff --git a/internlm/model/modules/mlp.py b/internlm/model/modules/mlp.py index 74bc91bee..1476c6f9c 100644 --- a/internlm/model/modules/mlp.py +++ b/internlm/model/modules/mlp.py @@ -6,11 +6,11 @@ import torch from torch import nn +from internlm.core.context import global_context as gpc from internlm.model.modules.linear import new_linear from internlm.model.modules.utils import Gelu, Silu from internlm.utils.logger import get_logger from internlm.utils.utils import ActivationType -from internlm.core.context import global_context as gpc logger = get_logger(__file__) @@ -89,10 +89,24 @@ def __init__( if gpc.config.parallel["tensor"]["tp_overlap"]: self.fused_w1_w3 = new_linear( - "w13", in_features, hidden_features * 2, bias, device=device, dtype=dtype, is_expert=is_expert, tp_comm_buffer_name="fc1" + "w13", + in_features, + hidden_features * 2, + bias, + device=device, + dtype=dtype, + is_expert=is_expert, + tp_comm_buffer_name="fc1", ) self.w2 = new_linear( - "w2", hidden_features, out_features, bias, device=device, dtype=dtype, is_expert=is_expert, tp_comm_buffer_name="fc2" + "w2", + hidden_features, + out_features, + bias, + device=device, + dtype=dtype, + is_expert=is_expert, + tp_comm_buffer_name="fc2", ) else: self.fused_w1_w3 = new_linear( diff --git a/internlm/model/modules/utils.py b/internlm/model/modules/utils.py index 2576b7b05..9bc0eec17 100644 --- a/internlm/model/modules/utils.py +++ b/internlm/model/modules/utils.py @@ -1,11 +1,13 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from importlib.metadata import version + import torch import torch.nn.functional as F from einops import rearrange -from importlib.metadata import version from packaging.version import Version as PkgVersion + from internlm.utils.logger import get_logger logger = get_logger(__file__) @@ -103,7 +105,7 @@ def get_te_version_str(): except (ModuleNotFoundError, ImportError): return None - if hasattr(te, '__version__'): + if hasattr(te, "__version__"): return str(te.__version__) else: return version("transformer-engine") diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 74d659ff0..9e7d55528 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -59,8 +59,8 @@ RewardModelLinear, RowParallelLinear, ScaleColumnParallelLinear, - TERowParallelLinear, TEColumnParallelLinear, + TERowParallelLinear, new_linear, ) from internlm.model.modules.norm import new_layer_norm From c392ba5bdfb91d04f45f93832b2fa07972cb410f Mon Sep 17 00:00:00 2001 From: sallyjunjun Date: Wed, 19 Feb 2025 15:13:01 +0800 Subject: [PATCH 4/5] fix transformer_engine import error and tp_overlap key error --- configs/7B_sft.py | 2 + internlm/core/parallel/shard.py | 4 +- internlm/core/trainer_builder.py | 2 +- internlm/initialize/launch.py | 1 - internlm/model/modules/linear.py | 305 ++++++++++++++++--------------- internlm/model/modules/mha.py | 16 +- internlm/model/modules/mlp.py | 6 +- 7 files changed, 176 insertions(+), 160 deletions(-) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 8cc8e4fd1..231a20942 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -199,6 +199,8 @@ tp_overlap_cfg=dict( tp_comm_overlap_ag=True, tp_comm_overlap_rs=True, + tp_comm_bulk_wgrad=True, + tp_comm_bulk_dgrad=True, ), ), pipeline=dict(size=1, interleaved_overlap=True), diff --git a/internlm/core/parallel/shard.py b/internlm/core/parallel/shard.py index 845db7908..46cdde3ee 100644 --- a/internlm/core/parallel/shard.py +++ b/internlm/core/parallel/shard.py @@ -161,7 +161,7 @@ def get_parallel_strategies_split_mode(linear_name: str) -> str: if linear_name in ("gate"): return "gate" # for MoE model elif linear_name in ("wqkv", "wq", "wk", "wv", "wkv", "w1", "w3", "w13"): - if gpc.config.parallel.tensor.tp_overlap: + if gpc.config.parallel["tensor"].get("tp_overlap", False): return "tecolumn" else: return "column" @@ -170,7 +170,7 @@ def get_parallel_strategies_split_mode(linear_name: str) -> str: elif linear_name in ("wo", "out_proj", "w2") and tp_mode == TensorParallelMode.isp.name: return "column" elif linear_name in ("wo", "out_proj", "w2"): - if gpc.config.parallel.tensor.tp_overlap: + if gpc.config.parallel["tensor"].get("tp_overlap", False): return "terow" else: return "row" diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index 98d13f6b1..afa31d412 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -155,7 +155,7 @@ def __init__( scheduler_hooks=get_scheduler_hooks(self.metric, optimizer, isp_communicator), ) - if gpc.config.parallel["tensor"]["tp_overlap"]: + if gpc.config.parallel["tensor"].get("tp_overlap", False): self._initialize_tp_comm_ub() # set attributes diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index f4e50d78c..831fe3b0e 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -473,7 +473,6 @@ def args_sanity_check(): tp_comm_overlap_rs=True, tp_comm_bulk_wgrad=True, tp_comm_bulk_dgrad=True, - tp_comm_overlap_rs_dgrad=False, ) # set default value for weight parallel diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index dd692505d..60128e225 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -9,7 +9,14 @@ import torch import torch.distributed as dist -import transformer_engine as te + +try: + import transformer_engine as te + + has_te = True +except (ModuleNotFoundError, ImportError): + has_te = False + from torch import nn from internlm.accelerator import get_accelerator @@ -1011,161 +1018,163 @@ def __init__( self.full_weight_shape = torch.Size((num_groups, in_features, out_features)) -class TEColumnParallelLinear(te.pytorch.Linear): - """ - Wrapper for the Transformer-Engine's `Linear` layer. - """ +if has_te: - def __init__( - self, - in_features: int, - out_features: int, - bias: bool, - skip_bias_add: bool, - is_expert: bool, - tp_comm_buffer_name: str = None, - ): - if is_expert: - raise ValueError("Transformer Engine linear layers do not yet support MoE") - - # TE returns a zero length Tensor when bias=False and - # return_bias=True, but we prefer None. So in that case we - # tell TE to not return the bias, and return None - # ourselves. This way our forward always returns two values - # and we don't have to deal with the zero length Tensor. - self.te_return_bias = skip_bias_add and bias - self.is_first_microbatch = True - - extra_kwargs = {"params_dtype": gpc.config.model.dtype} - if is_te_min_version("0.12.0"): - extra_kwargs["device"] = torch.cuda.current_device() - - if gpc.config.parallel["tensor"]["tp_overlap"]: - extra_kwargs["ub_bulk_wgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( - "tp_comm_bulk_wgrad", True - ) - extra_kwargs["ub_bulk_dgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( - "tp_comm_bulk_dgrad", True - ) - if is_te_min_version("1.5.0", check_equality=False): - extra_kwargs["ub_overlap_ag"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( - "tp_comm_overlap_ag", True - ) - else: - raise NotImplementedError("tp overlap is supported only when transformer_engine version >= 1.5.0") - assert ( - tp_comm_buffer_name is not None - ), "Buffer name should be set to configure communication overlap settings" - extra_kwargs["ub_name"] = tp_comm_buffer_name - - parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert) - tp_size = gpc.get_world_size(parallel_mode) - tp_group = gpc.get_group(parallel_mode) - super().__init__( - in_features=in_features, - out_features=out_features, - sequence_parallel=gpc.config.parallel.sequence_parallel, - tp_group=tp_group, - tp_size=tp_size, - bias=bias, - return_bias=self.te_return_bias, - parallel_mode="column", - **extra_kwargs, - ) - - def forward(self, x): - """Forward.""" - _is_first_microbatch = self.is_first_microbatch - x = x.transpose(0, 1) - out = super().forward(x, is_first_microbatch=_is_first_microbatch) - out = out.transpose(0, 1) - - self.is_first_microbatch = False + class TEColumnParallelLinear(te.pytorch.Linear): + """ + Wrapper for the Transformer-Engine's `Linear` layer. + """ - return out + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + skip_bias_add: bool, + is_expert: bool, + tp_comm_buffer_name: str = None, + ): + if is_expert: + raise ValueError("Transformer Engine linear layers do not yet support MoE") + + # TE returns a zero length Tensor when bias=False and + # return_bias=True, but we prefer None. So in that case we + # tell TE to not return the bias, and return None + # ourselves. This way our forward always returns two values + # and we don't have to deal with the zero length Tensor. + self.te_return_bias = skip_bias_add and bias + self.is_first_microbatch = True + + extra_kwargs = {"params_dtype": gpc.config.model.dtype} + if is_te_min_version("0.12.0"): + extra_kwargs["device"] = torch.cuda.current_device() + + if gpc.config.parallel["tensor"].get("tp_overlap", False): + extra_kwargs["ub_bulk_wgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( + "tp_comm_bulk_wgrad", True + ) + extra_kwargs["ub_bulk_dgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( + "tp_comm_bulk_dgrad", True + ) + if is_te_min_version("1.5.0", check_equality=False): + extra_kwargs["ub_overlap_ag"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( + "tp_comm_overlap_ag", True + ) + else: + raise NotImplementedError("tp overlap is supported only when transformer_engine version >= 1.5.0") + assert ( + tp_comm_buffer_name is not None + ), "Buffer name should be set to configure communication overlap settings" + extra_kwargs["ub_name"] = tp_comm_buffer_name + + parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert) + tp_size = gpc.get_world_size(parallel_mode) + tp_group = gpc.get_group(parallel_mode) + super().__init__( + in_features=in_features, + out_features=out_features, + sequence_parallel=gpc.config.parallel.sequence_parallel, + tp_group=tp_group, + tp_size=tp_size, + bias=bias, + return_bias=self.te_return_bias, + parallel_mode="column", + **extra_kwargs, + ) + def forward(self, x): + """Forward.""" + _is_first_microbatch = self.is_first_microbatch + x = x.transpose(0, 1) + out = super().forward(x, is_first_microbatch=_is_first_microbatch) + out = out.transpose(0, 1) -class TERowParallelLinear(te.pytorch.Linear): - """ - Wrapper for the Transformer-Engine's `Linear` layer. - """ + self.is_first_microbatch = False - def __init__( - self, - in_features: int, - out_features: int, - bias: bool, - skip_bias_add: bool, - is_expert: bool = False, - tp_comm_buffer_name: str = None, - ): - # TE returns a zero length Tensor when bias=False and - # return_bias=True. Here we need a single Tensor - self.te_return_bias = skip_bias_add and bias - self.is_first_microbatch = True - - extra_kwargs = {"params_dtype": gpc.config.model.dtype} - if is_te_min_version("0.12.0"): - extra_kwargs["device"] = torch.cuda.current_device() - - if gpc.config.parallel["tensor"]["tp_overlap"]: - if is_te_min_version("1.5.0"): - extra_kwargs["ub_overlap_ag"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( - "tp_comm_overlap_ag", True - ) - extra_kwargs["ub_overlap_rs"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( - "tp_comm_overlap_rs", True - ) - # Disable ub overlap for experts. - if is_expert: - extra_kwargs["ub_overlap_ag"] = False - extra_kwargs["ub_overlap_rs"] = False - else: - raise NotImplementedError("tp overlap is supported only when transformer_engine version >= 1.5.0") - assert ( - tp_comm_buffer_name is not None - ), "Buffer name should be set to configure communication overlap settings" - extra_kwargs["ub_name"] = tp_comm_buffer_name + return out - self.expert_parallel = gpc.config.parallel["expert"].get("size", 1) > 1 - parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert) - # Disable communications in TE when using TP or EP by making TE agnostic of model parallel. - tp_size = gpc.get_world_size(parallel_mode) - tp_group = gpc.get_group(parallel_mode) - explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel) - - split_mode = "row" - if explicit_expert_comm: - assert in_features % tp_size == 0, "{} is not divisible by {}".format(in_features, tp_size) - in_features = in_features // tp_size - split_mode = None - tp_size = 1 - tp_group = None + class TERowParallelLinear(te.pytorch.Linear): + """ + Wrapper for the Transformer-Engine's `Linear` layer. + """ - super().__init__( - in_features=in_features, - out_features=out_features, - sequence_parallel=gpc.config.parallel.sequence_parallel, - tp_group=tp_group, - tp_size=tp_size, - bias=bias, - return_bias=self.te_return_bias, - parallel_mode=split_mode, - **extra_kwargs, - ) + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + skip_bias_add: bool, + is_expert: bool = False, + tp_comm_buffer_name: str = None, + ): + # TE returns a zero length Tensor when bias=False and + # return_bias=True. Here we need a single Tensor + self.te_return_bias = skip_bias_add and bias + self.is_first_microbatch = True + + extra_kwargs = {"params_dtype": gpc.config.model.dtype} + if is_te_min_version("0.12.0"): + extra_kwargs["device"] = torch.cuda.current_device() + + if gpc.config.parallel["tensor"].get("tp_overlap", False): + if is_te_min_version("1.5.0"): + extra_kwargs["ub_overlap_ag"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( + "tp_comm_overlap_ag", True + ) + extra_kwargs["ub_overlap_rs"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( + "tp_comm_overlap_rs", True + ) + # Disable ub overlap for experts. + if is_expert: + extra_kwargs["ub_overlap_ag"] = False + extra_kwargs["ub_overlap_rs"] = False + else: + raise NotImplementedError("tp overlap is supported only when transformer_engine version >= 1.5.0") + assert ( + tp_comm_buffer_name is not None + ), "Buffer name should be set to configure communication overlap settings" + extra_kwargs["ub_name"] = tp_comm_buffer_name + + self.expert_parallel = gpc.config.parallel["expert"].get("size", 1) > 1 + parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert) + # Disable communications in TE when using TP or EP by making TE agnostic of model parallel. + tp_size = gpc.get_world_size(parallel_mode) + tp_group = gpc.get_group(parallel_mode) + explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel) + + split_mode = "row" + if explicit_expert_comm: + assert in_features % tp_size == 0, "{} is not divisible by {}".format(in_features, tp_size) + in_features = in_features // tp_size + split_mode = None + tp_size = 1 + tp_group = None + + super().__init__( + in_features=in_features, + out_features=out_features, + sequence_parallel=gpc.config.parallel.sequence_parallel, + tp_group=tp_group, + tp_size=tp_size, + bias=bias, + return_bias=self.te_return_bias, + parallel_mode=split_mode, + **extra_kwargs, + ) - for param in self.parameters(): - setattr(param, "allreduce", not (is_expert and self.expert_parallel)) + def forward(self, x): + """Forward.""" + _is_first_microbatch = self.is_first_microbatch + x = x.transpose(0, 1) + out = super().forward(x, is_first_microbatch=_is_first_microbatch) + out = out.transpose(0, 1) + self.is_first_microbatch = False - def forward(self, x): - """Forward.""" - _is_first_microbatch = self.is_first_microbatch - x = x.transpose(0, 1) - out = super().forward(x, is_first_microbatch=_is_first_microbatch) - out = out.transpose(0, 1) - self.is_first_microbatch = False + return out - return out +else: + TEColumnParallelLinear = ColumnParallelLinear + TERowParallelLinear = RowParallelLinear def new_linear( @@ -1217,7 +1226,7 @@ def new_linear( weight_scale=weight_scale, norm_head=norm_head, ) - elif split_mode == "column": + elif split_mode == "column" or (split_mode == "tecolumn" and not has_te): return ColumnParallelLinear( in_features, out_features, @@ -1236,7 +1245,7 @@ def new_linear( is_expert, tp_comm_buffer_name, ) - elif split_mode == "row": + elif split_mode == "row" or (split_mode == "terow" and not has_te): return RowParallelLinear( in_features, out_features, diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index 8212402c7..db2ae9c13 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -157,14 +157,16 @@ def __init__( if self.enable_qkv_fusion: # bias=True is according to https://spaces.ac.cn/archives/9577 - if gpc.config.parallel["tensor"]["tp_overlap"]: + if gpc.config.parallel["tensor"].get("tp_overlap", False): self.wqkv = new_linear( "wqkv", embed_dim, 3 * embed_dim, bias, tp_comm_buffer_name="qkv", **factory_kwargs ) else: self.wqkv = new_linear("wqkv", embed_dim, 3 * embed_dim, bias, **factory_kwargs) else: - assert gpc.config.parallel["tensor"]["tp_overlap"] is False, "tp overlap currently only support fused wqkv." + assert ( + gpc.config.parallel["tensor"].get("tp_overlap", False) is False + ), "tp overlap currently only support fused wqkv." self.wq = new_linear("wq", embed_dim, embed_dim, bias, **factory_kwargs) self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs) self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs) @@ -173,7 +175,7 @@ def __init__( self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) # output projection always have the bias (for now) (except for baichuan2 model) - if gpc.config.parallel["tensor"]["tp_overlap"]: + if gpc.config.parallel["tensor"].get("tp_overlap", False): self.out_proj = new_linear( "out_proj", embed_dim, embed_dim, bias=out_bias, tp_comm_buffer_name="proj", **factory_kwargs ) @@ -471,7 +473,7 @@ def __init__( if enable_qkv_fusion: assert bias is False, "Fuesd wqkv only support bias is False." - if gpc.config.parallel["tensor"]["tp_overlap"]: + if gpc.config.parallel["tensor"].get("tp_overlap", False): self.wqkv = new_linear( "wqkv", embed_dim, q_dim + 2 * self.kv_dim, bias, tp_comm_buffer_name="qkv", **factory_kwargs ) @@ -482,7 +484,9 @@ def __init__( ) self._register_state_dict_hook(partial(_qkv_save_convert, q_dim=q_dim, kv_dim=self.kv_dim)) else: - assert gpc.config.parallel["tensor"]["tp_overlap"] is False, "tp overlap currently only support fused wqkv." + assert ( + gpc.config.parallel["tensor"].get("tp_overlap", False) is False + ), "tp overlap currently only support fused wqkv." self.wq = new_linear("wq", embed_dim, q_dim, bias, **factory_kwargs) self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs) self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs) @@ -494,7 +498,7 @@ def __init__( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout, layer_idx=layer_idx ) - if gpc.config.parallel["tensor"]["tp_overlap"]: + if gpc.config.parallel["tensor"].get("tp_overlap", False): self.wo = new_linear("wo", q_dim, embed_dim, bias, tp_comm_buffer_name="proj", **factory_kwargs) else: self.wo = new_linear("wo", q_dim, embed_dim, bias, **factory_kwargs) diff --git a/internlm/model/modules/mlp.py b/internlm/model/modules/mlp.py index 1476c6f9c..835d649e4 100644 --- a/internlm/model/modules/mlp.py +++ b/internlm/model/modules/mlp.py @@ -87,7 +87,7 @@ def __init__( if self.mlp_layer_fusion: assert bias is False, "Fuesd FeedForward only support bias is False." - if gpc.config.parallel["tensor"]["tp_overlap"]: + if gpc.config.parallel["tensor"].get("tp_overlap", False): self.fused_w1_w3 = new_linear( "w13", in_features, @@ -119,7 +119,9 @@ def __init__( self._register_load_state_dict_pre_hook(_mlp_pre_load_convert, with_module=True) self._register_state_dict_hook(_mlp_save_convert) else: - assert gpc.config.parallel["tensor"]["tp_overlap"] is False, "tp overlap currently only support fused mlp." + assert ( + gpc.config.parallel["tensor"].get("tp_overlap", False) is False + ), "tp overlap currently only support fused mlp." self.w1 = new_linear( "w1", in_features, hidden_features, bias, device=device, dtype=dtype, is_expert=is_expert ) From 7ac997e5c57d4dd1a47ef91f7ce22268b2f40844 Mon Sep 17 00:00:00 2001 From: sallyjunjun Date: Wed, 19 Feb 2025 17:38:55 +0800 Subject: [PATCH 5/5] optimize code --- internlm/model/modules/linear.py | 124 ++++++++++++------------------- internlm/train/pipeline.py | 5 +- 2 files changed, 50 insertions(+), 79 deletions(-) diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 60128e225..a32fe3386 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -1020,7 +1020,7 @@ def __init__( if has_te: - class TEColumnParallelLinear(te.pytorch.Linear): + class TELinear(te.pytorch.Linear): """ Wrapper for the Transformer-Engine's `Linear` layer. """ @@ -1032,34 +1032,36 @@ def __init__( bias: bool, skip_bias_add: bool, is_expert: bool, + split_mode: str = "none", tp_comm_buffer_name: str = None, ): if is_expert: raise ValueError("Transformer Engine linear layers do not yet support MoE") # TE returns a zero length Tensor when bias=False and - # return_bias=True, but we prefer None. So in that case we - # tell TE to not return the bias, and return None - # ourselves. This way our forward always returns two values - # and we don't have to deal with the zero length Tensor. + # return_bias=True. Here we need a single Tensor self.te_return_bias = skip_bias_add and bias self.is_first_microbatch = True extra_kwargs = {"params_dtype": gpc.config.model.dtype} - if is_te_min_version("0.12.0"): - extra_kwargs["device"] = torch.cuda.current_device() + extra_kwargs["device"] = torch.cuda.current_device() if gpc.config.parallel["tensor"].get("tp_overlap", False): - extra_kwargs["ub_bulk_wgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( - "tp_comm_bulk_wgrad", True - ) - extra_kwargs["ub_bulk_dgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( - "tp_comm_bulk_dgrad", True - ) if is_te_min_version("1.5.0", check_equality=False): extra_kwargs["ub_overlap_ag"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( "tp_comm_overlap_ag", True ) + if split_mode == "column": + extra_kwargs["ub_bulk_wgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( + "tp_comm_bulk_wgrad", True + ) + extra_kwargs["ub_bulk_dgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( + "tp_comm_bulk_dgrad", True + ) + elif split_mode == "row": + extra_kwargs["ub_overlap_rs"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( + "tp_comm_overlap_rs", True + ) else: raise NotImplementedError("tp overlap is supported only when transformer_engine version >= 1.5.0") assert ( @@ -1070,6 +1072,7 @@ def __init__( parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert) tp_size = gpc.get_world_size(parallel_mode) tp_group = gpc.get_group(parallel_mode) + super().__init__( in_features=in_features, out_features=out_features, @@ -1078,7 +1081,7 @@ def __init__( tp_size=tp_size, bias=bias, return_bias=self.te_return_bias, - parallel_mode="column", + parallel_mode=split_mode, **extra_kwargs, ) @@ -1088,14 +1091,13 @@ def forward(self, x): x = x.transpose(0, 1) out = super().forward(x, is_first_microbatch=_is_first_microbatch) out = out.transpose(0, 1) - self.is_first_microbatch = False return out - class TERowParallelLinear(te.pytorch.Linear): + class TEColumnParallelLinear(TELinear): """ - Wrapper for the Transformer-Engine's `Linear` layer. + Wrapper for the TELinear layer. """ def __init__( @@ -1107,72 +1109,42 @@ def __init__( is_expert: bool = False, tp_comm_buffer_name: str = None, ): - # TE returns a zero length Tensor when bias=False and - # return_bias=True. Here we need a single Tensor - self.te_return_bias = skip_bias_add and bias - self.is_first_microbatch = True - - extra_kwargs = {"params_dtype": gpc.config.model.dtype} - if is_te_min_version("0.12.0"): - extra_kwargs["device"] = torch.cuda.current_device() - - if gpc.config.parallel["tensor"].get("tp_overlap", False): - if is_te_min_version("1.5.0"): - extra_kwargs["ub_overlap_ag"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( - "tp_comm_overlap_ag", True - ) - extra_kwargs["ub_overlap_rs"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( - "tp_comm_overlap_rs", True - ) - # Disable ub overlap for experts. - if is_expert: - extra_kwargs["ub_overlap_ag"] = False - extra_kwargs["ub_overlap_rs"] = False - else: - raise NotImplementedError("tp overlap is supported only when transformer_engine version >= 1.5.0") - assert ( - tp_comm_buffer_name is not None - ), "Buffer name should be set to configure communication overlap settings" - extra_kwargs["ub_name"] = tp_comm_buffer_name - - self.expert_parallel = gpc.config.parallel["expert"].get("size", 1) > 1 - parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert) - # Disable communications in TE when using TP or EP by making TE agnostic of model parallel. - tp_size = gpc.get_world_size(parallel_mode) - tp_group = gpc.get_group(parallel_mode) - explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel) - - split_mode = "row" - if explicit_expert_comm: - assert in_features % tp_size == 0, "{} is not divisible by {}".format(in_features, tp_size) - in_features = in_features // tp_size - split_mode = None - tp_size = 1 - tp_group = None - super().__init__( - in_features=in_features, - out_features=out_features, - sequence_parallel=gpc.config.parallel.sequence_parallel, - tp_group=tp_group, - tp_size=tp_size, + in_features, + out_features, bias=bias, - return_bias=self.te_return_bias, - parallel_mode=split_mode, - **extra_kwargs, + skip_bias_add=skip_bias_add, + is_expert=is_expert, + split_mode="column", + tp_comm_buffer_name=tp_comm_buffer_name, ) - def forward(self, x): - """Forward.""" - _is_first_microbatch = self.is_first_microbatch - x = x.transpose(0, 1) - out = super().forward(x, is_first_microbatch=_is_first_microbatch) - out = out.transpose(0, 1) - self.is_first_microbatch = False + class TERowParallelLinear(TELinear): + """ + Wrapper for the TELinear layer. + """ - return out + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + skip_bias_add: bool, + is_expert: bool = False, + tp_comm_buffer_name: str = None, + ): + super().__init__( + in_features, + out_features, + bias=bias, + skip_bias_add=skip_bias_add, + is_expert=is_expert, + split_mode="row", + tp_comm_buffer_name=tp_comm_buffer_name, + ) else: + TELinear = ParallelLinearWithCommExt TEColumnParallelLinear = ColumnParallelLinear TERowParallelLinear = RowParallelLinear diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 9e7d55528..814af9b54 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -59,8 +59,7 @@ RewardModelLinear, RowParallelLinear, ScaleColumnParallelLinear, - TEColumnParallelLinear, - TERowParallelLinear, + TELinear, new_linear, ) from internlm.model.modules.norm import new_layer_norm @@ -209,7 +208,7 @@ def _check_module(name, module): elif gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp(): setattr(param, IS_WEIGHT_EXPERT_DATA_PARALLEL, True) # for non-moe linear module - elif isinstance(module, (ParallelLinearWithCommExt, TERowParallelLinear, TEColumnParallelLinear)): + elif isinstance(module, (ParallelLinearWithCommExt, TELinear)): for param in module.parameters(): if gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp(): setattr(param, IS_TENSOR_ZERO_PARALLEL, True)