diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 4799b5f35..231a20942 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -192,7 +192,17 @@ """ parallel = dict( zero1=dict(size=-1), - tensor=dict(size=1, mode="mtp"), + tensor=dict( + size=1, + mode="mtp", + tp_overlap=False, + tp_overlap_cfg=dict( + tp_comm_overlap_ag=True, + tp_comm_overlap_rs=True, + tp_comm_bulk_wgrad=True, + tp_comm_bulk_dgrad=True, + ), + ), pipeline=dict(size=1, interleaved_overlap=True), weight=dict(size=1, overlap=True), ) diff --git a/internlm/core/parallel/shard.py b/internlm/core/parallel/shard.py index 4dc6a1f5b..46cdde3ee 100644 --- a/internlm/core/parallel/shard.py +++ b/internlm/core/parallel/shard.py @@ -161,13 +161,19 @@ def get_parallel_strategies_split_mode(linear_name: str) -> str: if linear_name in ("gate"): return "gate" # for MoE model elif linear_name in ("wqkv", "wq", "wk", "wv", "wkv", "w1", "w3", "w13"): - return "column" + if gpc.config.parallel["tensor"].get("tp_overlap", False): + return "tecolumn" + else: + return "column" elif linear_name in ("fc1", "fc2", "linear_1", "linear_2"): # for vit model return "column" elif linear_name in ("wo", "out_proj", "w2") and tp_mode == TensorParallelMode.isp.name: return "column" elif linear_name in ("wo", "out_proj", "w2"): - return "row" + if gpc.config.parallel["tensor"].get("tp_overlap", False): + return "terow" + else: + return "row" elif linear_name in ("grouped_w1", "grouped_w2", "grouped_w3") and tp_mode == "isp": return "grouped_wp" elif linear_name in ("grouped_w1", "grouped_w3"): diff --git a/internlm/core/scheduler/pipeline_scheduler_zb.py b/internlm/core/scheduler/pipeline_scheduler_zb.py index 75cf18448..9cd2041ad 100644 --- a/internlm/core/scheduler/pipeline_scheduler_zb.py +++ b/internlm/core/scheduler/pipeline_scheduler_zb.py @@ -901,7 +901,6 @@ def _run_steady_loop( else: next_unit_chunk_id = 1 - # import pdb; pdb.set_trace() if unit_step == num_units_stage1 - 1: chunk0_B_need_recv_prev_chunk0_output = False else: diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index 71c30d00d..afa31d412 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -19,6 +19,7 @@ from internlm.initialize.initialize_trainer import initialize_trainer from internlm.model.losses.ce_loss import InternLoss from internlm.model.metrics import AccPerplex +from internlm.model.modules.utils import is_te_min_version from internlm.monitor.monitor import send_alert_message from internlm.train.pipeline import ( get_scheduler_hooks, @@ -154,6 +155,9 @@ def __init__( scheduler_hooks=get_scheduler_hooks(self.metric, optimizer, isp_communicator), ) + if gpc.config.parallel["tensor"].get("tp_overlap", False): + self._initialize_tp_comm_ub() + # set attributes self._set_attributes( kwargs["profiling"], train_dl, val_dls, train_state, optimizer, beta2_scheduler, isp_communicator @@ -249,6 +253,35 @@ def _initialize_batch_skipper(self, train_state) -> BatchSkipper: skip_batches = streaming_simple_resume(train_state) return BatchSkipper(skip_batches) + def _initialize_tp_comm_ub(self): + """initializing the communicators with user buffers for high-performance tensor-model-parallel + communication overlap""" + try: + from transformer_engine.pytorch import module as te_module + + except ImportError: + raise RuntimeError( + "Tensor Parallel Communication/GEMM Overlap optimization needs 'transformer_engine' package" + ) + + input_shape = [gpc.config.data["seq_len"] * gpc.config.data["micro_bsz"], gpc.config.model["hidden_size"]] + + if is_te_min_version("1.9.0"): + # The process group with the target bootstrap backend is created in Transformer Engine. + te_module.base.initialize_ub( + shape=input_shape, + tp_size=gpc.config.parallel["tensor"]["size"], + use_fp8=False, + bootstrap_backend="nccl", + ) + else: + # Create a MPI process group to help with TP communication overlap bootstrap. + torch.distributed.new_group(backend="mpi") + + te_module.base.initialize_ub( + shape=input_shape, tp_size=gpc.config.parallel["tensor"]["size"], use_fp8=False + ) + def _set_attributes(self, profiling, train_dl, val_dls, train_state, optimizer, beta2_scheduler, isp_communicator): self.profiling = profiling self.train_dl = train_dl diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index e1cb2f0d2..831fe3b0e 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -101,7 +101,9 @@ def args_sanity_check(): gpc.config.parallel.pipeline._add_item("mode", "1F1B") if "tensor" not in gpc.config.parallel: - gpc.config.parallel._add_item("tensor", dict(size=1, mode=TensorParallelMode.mtp.name)) + gpc.config.parallel._add_item( + "tensor", dict(size=1, mode=TensorParallelMode.mtp.name, tp_overlap=False, tp_overlap_cfg=None) + ) if "weight" not in gpc.config.parallel: gpc.config.parallel._add_item( @@ -398,7 +400,9 @@ def args_sanity_check(): # set default value for tensor parallel if isinstance(gpc.config.parallel["tensor"], int): - gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode=TensorParallelMode.mtp.name) + gpc.config.parallel["tensor"] = dict( + size=gpc.config.parallel["tensor"], mode=TensorParallelMode.mtp.name, tp_overlap=False, tp_overlap_cfg=None + ) if gpc.config.parallel["tensor"].get("mode", None) is None: gpc.config.parallel["tensor"]["mode"] = TensorParallelMode.mtp.name if gpc.config.parallel["tensor"]["mode"] == TensorParallelMode.isp.name: @@ -455,6 +459,22 @@ def args_sanity_check(): if gpc.config.model.get("parallel_output", False) is False: logger.warning("When enable sequence parallel, it recommend to enable parallel_output") + if gpc.config.parallel["tensor"].get("tp_overlap", None) is None: + gpc.config.parallel["tensor"]["tp_overlap"] = False + elif gpc.config.parallel["tensor"].get("tp_overlap", None) is True: + assert gpc.config.parallel["tensor"].get("mode", None) in [ + TensorParallelMode.msp.name, + TensorParallelMode.fsp.name, + ], "tp_overlap can be set to true only in msp and fsp mode" + + if gpc.config.parallel["tensor"].get("tp_overlap_cfg", None) is None: + gpc.config.parallel["tensor"]["tp_overlap_cfg"] = dict( + tp_comm_overlap_ag=True, + tp_comm_overlap_rs=True, + tp_comm_bulk_wgrad=True, + tp_comm_bulk_dgrad=True, + ) + # set default value for weight parallel if gpc.config.parallel["weight"].get("overlap", None) is None: gpc.config.parallel["weight"]["overlap"] = False diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/modeling_internlm2.py index 0453b9dcb..bee7e4f98 100644 --- a/internlm/model/modeling_internlm2.py +++ b/internlm/model/modeling_internlm2.py @@ -535,11 +535,30 @@ def load_hf_weights(folder: str, model: nn.Module) -> None: dim=0, )[local_rank] else: - state_dict[f"layers.{i}.attention.wqkv.weight"] = torch.chunk( - state_dict.pop(f"model.layers.{layer_ids}.attention.wqkv.weight"), - split_size, - dim=0, - )[local_rank] + key = f"model.layers.{layer_ids}.attention.wqkv.weight" + if key in state_dict: + state_dict[f"layers.{i}.attention.wqkv.weight"] = torch.chunk( + state_dict.pop(key), + split_size, + dim=0, + )[local_rank] + else: + wq = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.attention.wq.weight"), + split_size, + dim=0, + )[local_rank] + wk = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.attention.wk.weight"), + split_size, + dim=0, + )[local_rank] + wv = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.attention.wv.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wqkv.weight"] = torch.cat([wq, wk, wv], dim=0) wo_name = "self_attn.o_proj" if is_internlm3 else "attention.wo" state_dict[f"layers.{i}.attention.wo.weight"] = torch.chunk( state_dict.pop(f"model.layers.{layer_ids}.{wo_name}.weight"), diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 6f1902689..a32fe3386 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -9,6 +9,14 @@ import torch import torch.distributed as dist + +try: + import transformer_engine as te + + has_te = True +except (ModuleNotFoundError, ImportError): + has_te = False + from torch import nn from internlm.accelerator import get_accelerator @@ -19,6 +27,7 @@ get_parallel_strategies_split_mode, get_tensor_split_parallel_mode, ) +from internlm.model.modules.utils import is_te_min_version from internlm.model.ops.linear import ( gmm_backward_op, gmm_forward_op, @@ -1009,6 +1018,137 @@ def __init__( self.full_weight_shape = torch.Size((num_groups, in_features, out_features)) +if has_te: + + class TELinear(te.pytorch.Linear): + """ + Wrapper for the Transformer-Engine's `Linear` layer. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + skip_bias_add: bool, + is_expert: bool, + split_mode: str = "none", + tp_comm_buffer_name: str = None, + ): + if is_expert: + raise ValueError("Transformer Engine linear layers do not yet support MoE") + + # TE returns a zero length Tensor when bias=False and + # return_bias=True. Here we need a single Tensor + self.te_return_bias = skip_bias_add and bias + self.is_first_microbatch = True + + extra_kwargs = {"params_dtype": gpc.config.model.dtype} + extra_kwargs["device"] = torch.cuda.current_device() + + if gpc.config.parallel["tensor"].get("tp_overlap", False): + if is_te_min_version("1.5.0", check_equality=False): + extra_kwargs["ub_overlap_ag"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( + "tp_comm_overlap_ag", True + ) + if split_mode == "column": + extra_kwargs["ub_bulk_wgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( + "tp_comm_bulk_wgrad", True + ) + extra_kwargs["ub_bulk_dgrad"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( + "tp_comm_bulk_dgrad", True + ) + elif split_mode == "row": + extra_kwargs["ub_overlap_rs"] = gpc.config.parallel["tensor"]["tp_overlap_cfg"].get( + "tp_comm_overlap_rs", True + ) + else: + raise NotImplementedError("tp overlap is supported only when transformer_engine version >= 1.5.0") + assert ( + tp_comm_buffer_name is not None + ), "Buffer name should be set to configure communication overlap settings" + extra_kwargs["ub_name"] = tp_comm_buffer_name + + parallel_mode = get_tensor_split_parallel_mode(is_expert=is_expert) + tp_size = gpc.get_world_size(parallel_mode) + tp_group = gpc.get_group(parallel_mode) + + super().__init__( + in_features=in_features, + out_features=out_features, + sequence_parallel=gpc.config.parallel.sequence_parallel, + tp_group=tp_group, + tp_size=tp_size, + bias=bias, + return_bias=self.te_return_bias, + parallel_mode=split_mode, + **extra_kwargs, + ) + + def forward(self, x): + """Forward.""" + _is_first_microbatch = self.is_first_microbatch + x = x.transpose(0, 1) + out = super().forward(x, is_first_microbatch=_is_first_microbatch) + out = out.transpose(0, 1) + self.is_first_microbatch = False + + return out + + class TEColumnParallelLinear(TELinear): + """ + Wrapper for the TELinear layer. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + skip_bias_add: bool, + is_expert: bool = False, + tp_comm_buffer_name: str = None, + ): + super().__init__( + in_features, + out_features, + bias=bias, + skip_bias_add=skip_bias_add, + is_expert=is_expert, + split_mode="column", + tp_comm_buffer_name=tp_comm_buffer_name, + ) + + class TERowParallelLinear(TELinear): + """ + Wrapper for the TELinear layer. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + skip_bias_add: bool, + is_expert: bool = False, + tp_comm_buffer_name: str = None, + ): + super().__init__( + in_features, + out_features, + bias=bias, + skip_bias_add=skip_bias_add, + is_expert=is_expert, + split_mode="row", + tp_comm_buffer_name=tp_comm_buffer_name, + ) + +else: + TELinear = ParallelLinearWithCommExt + TEColumnParallelLinear = ColumnParallelLinear + TERowParallelLinear = RowParallelLinear + + def new_linear( name: str, in_features: int, @@ -1021,6 +1161,7 @@ def new_linear( weight_scale: int = 1, norm_head: bool = False, is_expert: bool = False, + tp_comm_buffer_name: str = None, **kwargs, ) -> nn.Linear: @@ -1057,7 +1198,7 @@ def new_linear( weight_scale=weight_scale, norm_head=norm_head, ) - elif split_mode == "column": + elif split_mode == "column" or (split_mode == "tecolumn" and not has_te): return ColumnParallelLinear( in_features, out_features, @@ -1067,7 +1208,16 @@ def new_linear( dtype, is_expert, ) - elif split_mode == "row": + elif split_mode == "tecolumn": + return TEColumnParallelLinear( + in_features, + out_features, + bias, + False, + is_expert, + tp_comm_buffer_name, + ) + elif split_mode == "row" or (split_mode == "terow" and not has_te): return RowParallelLinear( in_features, out_features, @@ -1077,6 +1227,15 @@ def new_linear( dtype, is_expert, ) + elif split_mode == "terow": + return TERowParallelLinear( + in_features, + out_features, + bias, + False, + is_expert, + tp_comm_buffer_name, + ) elif split_mode == "grouped_wp": return GroupedWPLinear( in_features, diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index 42418a212..db2ae9c13 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -157,8 +157,16 @@ def __init__( if self.enable_qkv_fusion: # bias=True is according to https://spaces.ac.cn/archives/9577 - self.wqkv = new_linear("wqkv", embed_dim, 3 * embed_dim, bias, **factory_kwargs) + if gpc.config.parallel["tensor"].get("tp_overlap", False): + self.wqkv = new_linear( + "wqkv", embed_dim, 3 * embed_dim, bias, tp_comm_buffer_name="qkv", **factory_kwargs + ) + else: + self.wqkv = new_linear("wqkv", embed_dim, 3 * embed_dim, bias, **factory_kwargs) else: + assert ( + gpc.config.parallel["tensor"].get("tp_overlap", False) is False + ), "tp overlap currently only support fused wqkv." self.wq = new_linear("wq", embed_dim, embed_dim, bias, **factory_kwargs) self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs) self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs) @@ -167,7 +175,12 @@ def __init__( self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) # output projection always have the bias (for now) (except for baichuan2 model) - self.out_proj = new_linear("out_proj", embed_dim, embed_dim, bias=out_bias, **factory_kwargs) + if gpc.config.parallel["tensor"].get("tp_overlap", False): + self.out_proj = new_linear( + "out_proj", embed_dim, embed_dim, bias=out_bias, tp_comm_buffer_name="proj", **factory_kwargs + ) + else: + self.out_proj = new_linear("out_proj", embed_dim, embed_dim, bias=out_bias, **factory_kwargs) def register_checkpoint_compatibility_hooks( self, pre_load_hook: Optional[Callable] = None, pre_save_hook: Optional[Callable] = None @@ -193,7 +206,6 @@ def _training(self, x, **kwargs): if self.enable_qkv_fusion: qkv = self.wqkv(x) qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) - q = qkv[:, :, 0].squeeze(2) k = qkv[:, :, 1].squeeze(2) v = qkv[:, :, 2].squeeze(2) @@ -214,8 +226,8 @@ def _training(self, x, **kwargs): if gpc.config.data.use_packed_dataset is False or self.training is False: kwargs.pop("max_seqlen_q", None) kwargs.pop("max_seqlen_k", None) - context = self.inner_attn(q, k, v, **kwargs) + context = self.inner_attn(q, k, v, **kwargs) # wo return self.out_proj(rearrange(context, "b s h d -> b s (h d)")) @@ -461,12 +473,20 @@ def __init__( if enable_qkv_fusion: assert bias is False, "Fuesd wqkv only support bias is False." - self.wqkv = new_linear("wqkv", embed_dim, q_dim + 2 * self.kv_dim, bias, **factory_kwargs) + if gpc.config.parallel["tensor"].get("tp_overlap", False): + self.wqkv = new_linear( + "wqkv", embed_dim, q_dim + 2 * self.kv_dim, bias, tp_comm_buffer_name="qkv", **factory_kwargs + ) + else: + self.wqkv = new_linear("wqkv", embed_dim, q_dim + 2 * self.kv_dim, bias, **factory_kwargs) self._register_load_state_dict_pre_hook( partial(_qkv_pre_load_convert, q_dim=q_dim, kv_dim=self.kv_dim), with_module=True ) self._register_state_dict_hook(partial(_qkv_save_convert, q_dim=q_dim, kv_dim=self.kv_dim)) else: + assert ( + gpc.config.parallel["tensor"].get("tp_overlap", False) is False + ), "tp overlap currently only support fused wqkv." self.wq = new_linear("wq", embed_dim, q_dim, bias, **factory_kwargs) self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs) self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs) @@ -478,7 +498,10 @@ def __init__( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout, layer_idx=layer_idx ) - self.wo = new_linear("wo", q_dim, embed_dim, bias, **factory_kwargs) + if gpc.config.parallel["tensor"].get("tp_overlap", False): + self.wo = new_linear("wo", q_dim, embed_dim, bias, tp_comm_buffer_name="proj", **factory_kwargs) + else: + self.wo = new_linear("wo", q_dim, embed_dim, bias, **factory_kwargs) def register_checkpoint_compatibility_hooks( self, pre_load_hook: Optional[Callable] = None, pre_save_hook: Optional[Callable] = None diff --git a/internlm/model/modules/mlp.py b/internlm/model/modules/mlp.py index e51e5897f..835d649e4 100644 --- a/internlm/model/modules/mlp.py +++ b/internlm/model/modules/mlp.py @@ -6,6 +6,7 @@ import torch from torch import nn +from internlm.core.context import global_context as gpc from internlm.model.modules.linear import new_linear from internlm.model.modules.utils import Gelu, Silu from internlm.utils.logger import get_logger @@ -86,16 +87,41 @@ def __init__( if self.mlp_layer_fusion: assert bias is False, "Fuesd FeedForward only support bias is False." - self.fused_w1_w3 = new_linear( - "w13", in_features, hidden_features * 2, bias, device=device, dtype=dtype, is_expert=is_expert - ) - self.w2 = new_linear( - "w2", hidden_features, out_features, bias, device=device, dtype=dtype, is_expert=is_expert - ) + if gpc.config.parallel["tensor"].get("tp_overlap", False): + self.fused_w1_w3 = new_linear( + "w13", + in_features, + hidden_features * 2, + bias, + device=device, + dtype=dtype, + is_expert=is_expert, + tp_comm_buffer_name="fc1", + ) + self.w2 = new_linear( + "w2", + hidden_features, + out_features, + bias, + device=device, + dtype=dtype, + is_expert=is_expert, + tp_comm_buffer_name="fc2", + ) + else: + self.fused_w1_w3 = new_linear( + "w13", in_features, hidden_features * 2, bias, device=device, dtype=dtype, is_expert=is_expert + ) + self.w2 = new_linear( + "w2", hidden_features, out_features, bias, device=device, dtype=dtype, is_expert=is_expert + ) self._register_load_state_dict_pre_hook(_mlp_pre_load_convert, with_module=True) self._register_state_dict_hook(_mlp_save_convert) else: + assert ( + gpc.config.parallel["tensor"].get("tp_overlap", False) is False + ), "tp overlap currently only support fused mlp." self.w1 = new_linear( "w1", in_features, hidden_features, bias, device=device, dtype=dtype, is_expert=is_expert ) diff --git a/internlm/model/modules/utils.py b/internlm/model/modules/utils.py index 9d43fbf00..9bc0eec17 100644 --- a/internlm/model/modules/utils.py +++ b/internlm/model/modules/utils.py @@ -1,9 +1,12 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from importlib.metadata import version + import torch import torch.nn.functional as F from einops import rearrange +from packaging.version import Version as PkgVersion from internlm.utils.logger import get_logger @@ -91,3 +94,29 @@ def update_kv_cache(kv, inference_params, layer_idx): ) v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(kv[:, :, 1], "b s h d -> b h s d") return kv + + +def get_te_version(): + """Get TE version from __version__; if not available use pip's. Use caching.""" + + def get_te_version_str(): + try: + import transformer_engine as te + except (ModuleNotFoundError, ImportError): + return None + + if hasattr(te, "__version__"): + return str(te.__version__) + else: + return version("transformer-engine") + + _te_version = PkgVersion(get_te_version_str()) + return _te_version + + +def is_te_min_version(version, check_equality=True): + """Check if minimum version of `transformer-engine` is installed.""" + ver = get_te_version() + if check_equality: + return ver is not None and ver >= PkgVersion(version) + return ver is not None and ver > PkgVersion(version) diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 784a5305a..814af9b54 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -59,6 +59,7 @@ RewardModelLinear, RowParallelLinear, ScaleColumnParallelLinear, + TELinear, new_linear, ) from internlm.model.modules.norm import new_layer_norm @@ -207,7 +208,7 @@ def _check_module(name, module): elif gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp(): setattr(param, IS_WEIGHT_EXPERT_DATA_PARALLEL, True) # for non-moe linear module - elif isinstance(module, ParallelLinearWithCommExt): + elif isinstance(module, (ParallelLinearWithCommExt, TELinear)): for param in module.parameters(): if gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp(): setattr(param, IS_TENSOR_ZERO_PARALLEL, True)