diff --git a/configs/7B_internlm2.py b/configs/7B_internlm2.py index 52bd70fc3..97758bba4 100644 --- a/configs/7B_internlm2.py +++ b/configs/7B_internlm2.py @@ -174,6 +174,7 @@ 1. size: int, the size of pipeline parallel. 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, defaults to False. + 3. mode: str, the pipeline parallel mode, should be in ['1f1b', 'zbh1', 'zbv']. The defalut is 1f1b. weight parallel (dict): 1. size: int, the size of weight parallel. 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. @@ -181,7 +182,7 @@ parallel = dict( zero1=dict(size=-1), tensor=dict(size=2, mode="isp"), - pipeline=dict(size=1, interleaved_overlap=True), + pipeline=dict(size=1, interleaved_overlap=True, mode="1f1b"), weight=dict(size=2, overlap=True), ) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 44f94c1f3..4799b5f35 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -186,7 +186,6 @@ 1. size: int, the size of pipeline parallel (Default is 1F1B). 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, defaults to False. - 3. zero_bubble: bool, enable/disable zero bubble pipeline parallelism (ZB-H1), defaults to False. weight parallel (dict): 1. size: int, the size of weight parallel. 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. @@ -194,7 +193,7 @@ parallel = dict( zero1=dict(size=-1), tensor=dict(size=1, mode="mtp"), - pipeline=dict(size=1, interleaved_overlap=True, zero_bubble=False), + pipeline=dict(size=1, interleaved_overlap=True), weight=dict(size=1, overlap=True), ) diff --git a/doc/code-docs/locales/en/LC_MESSAGES/parallel.po b/doc/code-docs/locales/en/LC_MESSAGES/parallel.po index b948e4f9b..df6ff3f78 100644 --- a/doc/code-docs/locales/en/LC_MESSAGES/parallel.po +++ b/doc/code-docs/locales/en/LC_MESSAGES/parallel.po @@ -563,8 +563,8 @@ msgstr "" msgid "返回类型" msgstr "Return type" -#: internlm.core.scheduler.pipeline_scheduler.InterleavedPipelineScheduler.forward_backward_step:19 -#: internlm.core.scheduler.pipeline_scheduler.PipelineScheduler.forward_backward_step:19 +#: internlm.core.scheduler.pipeline_scheduler_1f1b.InterleavedPipelineScheduler.forward_backward_step:19 +#: internlm.core.scheduler.pipeline_scheduler_1f1b.PipelineScheduler.forward_backward_step:19 #: of msgid "Tuple[:class:`torch.Tensor`]" msgstr "" @@ -579,11 +579,11 @@ msgstr "" "To use interleaved pipeline scheduler, users need to set " "``model.num_chunks > 1`` in the config file." -#: internlm.core.scheduler.pipeline_scheduler.InterleavedPipelineScheduler:1 of +#: internlm.core.scheduler.pipeline_scheduler_1f1b.InterleavedPipelineScheduler:1 of msgid "Interleaved Pipeline Scheduler." msgstr "" -#: internlm.core.scheduler.pipeline_scheduler.InterleavedPipelineScheduler.forward_backward_step:1 +#: internlm.core.scheduler.pipeline_scheduler_1f1b.InterleavedPipelineScheduler.forward_backward_step:1 #: of msgid "" "Run interleaved 1F1B schedule (model split into model chunks), with " diff --git a/doc/code-docs/source/parallel.rst b/doc/code-docs/source/parallel.rst index 1ad9ff639..3f893af10 100644 --- a/doc/code-docs/source/parallel.rst +++ b/doc/code-docs/source/parallel.rst @@ -137,14 +137,14 @@ InternEvo 在流水线并行中使用 `1F1B 1`` 。 -.. autoclass:: internlm.core.scheduler.pipeline_scheduler.InterleavedPipelineScheduler +.. autoclass:: internlm.core.scheduler.pipeline_scheduler_1f1b.InterleavedPipelineScheduler :members: 值得注意的是,在使用交错式流水线调度器时可启用通信优化功能,即在 1F1B 阶段启用异步通信,以充分利用上行/下行带宽并实现通信与计算重叠。 diff --git a/doc/en/structure.md b/doc/en/structure.md index 7a37ef30c..8d9d726c4 100644 --- a/doc/en/structure.md +++ b/doc/en/structure.md @@ -12,7 +12,8 @@ The system code file structure is shown below: │ │ │ └── process_group_initializer.py │ │ ├── scheduler # Scheduling module, which manages schedulers for parallel training, including non-pipeline and pipeline parallel schedulers │ │ │ ├── no_pipeline_scheduler.py -│ │ │ └── pipeline_scheduler.py +│ │ │ ├── pipeline_scheduler_1f1b.py +│ │ │ └── pipeline_scheduler_zb.py │ │ ├── engine.py # Responsible for managing the training and evaluation process of the model │ │ └── trainer.py # Responsible for managing the training engine and scheduler │ ├── data # Data module, responsible for managing dataset generation and processing diff --git a/doc/structure.md b/doc/structure.md index 7c0c6f134..9de1016b1 100644 --- a/doc/structure.md +++ b/doc/structure.md @@ -12,7 +12,8 @@ │ │ │ └── process_group_initializer.py │ │ ├── scheduler # 调度模块,管理并行训练的调度器,包括非流水线并行调度器和流水线并行调度器 │ │ │ ├── no_pipeline_scheduler.py -│ │ │ └── pipeline_scheduler.py +│ │ │ ├── pipeline_scheduler_1f1b.py +│ │ │ └── pipeline_scheduler_zb.py │ │ ├── engine.py # 负责管理模型的训练和评估过程 │ │ └── trainer.py # 负责管理训练引擎和调度器 │ ├── data # 数据模块,负责管理数据集生成和处理 diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 2f34785ac..989b1c00f 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -163,6 +163,7 @@ def __init__(self): self.virtual_pipeline_parallel_rank = None self._expert_parallel_group_names = [] self.is_evaluating = False + self.v_shape = False @property def config(self): @@ -292,8 +293,13 @@ def is_rank_for_log(self): and self.is_first_rank(ParallelMode.WEIGHT) and self.is_first_rank(ParallelMode.DATA) and self.is_first_rank(ParallelMode.WEIGHT_DATA) - and self.is_last_rank(ParallelMode.PIPELINE) ) + + if not self.v_shape: + is_log_rank = is_log_rank and self.is_last_rank(ParallelMode.PIPELINE) + else: + is_log_rank = is_log_rank and self.is_first_rank(ParallelMode.PIPELINE) + return is_log_rank def is_last_rank(self, parallel_mode: ParallelMode): @@ -327,11 +333,17 @@ def is_pipeline_last_stage(self, ignore_virtual=False): and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1 ): return False - return self.is_last_rank(ParallelMode.PIPELINE) + if not self.v_shape: + return self.is_last_rank(ParallelMode.PIPELINE) + else: + return self.is_first_rank(ParallelMode.PIPELINE) def is_no_pp_or_last_stage(self): # NOTICE!!!, this will ignore virutal stage - return not self.is_initialized(ParallelMode.PIPELINE) or self.is_last_rank(ParallelMode.PIPELINE) + if not self.v_shape: + return not self.is_initialized(ParallelMode.PIPELINE) or self.is_last_rank(ParallelMode.PIPELINE) + else: + return not self.is_initialized(ParallelMode.PIPELINE) or self.is_first_rank(ParallelMode.PIPELINE) def get_world_size(self, parallel_mode: ParallelMode): """Returns the world size for `parallel_mode`. diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index 41d286b59..97b9800c6 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -692,6 +692,12 @@ def after_backward(self, scheduler, inputs_grad) -> None: # pylint: disable=W06 if self._isp_communicator and self._isp_communicator.overlap: self._zero_optim.accumulate_left_grads_after_backward() + if ( + getattr(gpc.config.parallel["pipeline"], "mode", "1F1B").upper() in ["ZBV", "ZBH1"] + and not self._zero_optim.skip_grad_reduce + ): + self._zero_optim.reduce_left_grads_after_backward() + def post_helper_func(self, scheduler, outputs, label) -> None: # pylint: disable=W0613 pass diff --git a/internlm/core/parallel/shard.py b/internlm/core/parallel/shard.py index 3c3f3fb3b..f4ac37bcb 100644 --- a/internlm/core/parallel/shard.py +++ b/internlm/core/parallel/shard.py @@ -186,10 +186,16 @@ def partition_uniform(num_items: int, pipeline_parallel_size: int, num_chunks: i if chunk_size == 0: raise ValueError("Some nodes in Pipeline have no requests") - for p in range(pipeline_parallel_size): - st = base_idx - base_idx += chunk_size + (p >= left) - parts[p].append((st, base_idx)) + if getattr(gpc.config.parallel["pipeline"], "mode", "1F1B").upper() == "ZBV" and idx == 1: + for p in range(pipeline_parallel_size - 1, -1, -1): + st = base_idx + base_idx += chunk_size + ((pipeline_parallel_size - p - 1) >= left) + parts[p].append((st, base_idx)) + else: + for p in range(pipeline_parallel_size): + st = base_idx + base_idx += chunk_size + (p >= left) + parts[p].append((st, base_idx)) indexes = [] for _parts in parts: diff --git a/internlm/core/scheduler/__init__.py b/internlm/core/scheduler/__init__.py index 2ddd85299..042fa5960 100644 --- a/internlm/core/scheduler/__init__.py +++ b/internlm/core/scheduler/__init__.py @@ -1,9 +1,9 @@ from .base_scheduler import BaseScheduler from .no_pipeline_scheduler import NonPipelineScheduler -from .pipeline_scheduler import ( - InterleavedPipelineScheduler, - PipelineScheduler, +from .pipeline_scheduler_1f1b import InterleavedPipelineScheduler, PipelineScheduler +from .pipeline_scheduler_zb import ( ZeroBubblePipelineScheduler, + ZeroBubblePipelineVShapeScheduler, ) __all__ = [ @@ -12,4 +12,5 @@ "InterleavedPipelineScheduler", "PipelineScheduler", "ZeroBubblePipelineScheduler", + "ZeroBubblePipelineVShapeScheduler", ] diff --git a/internlm/core/scheduler/comm/__init__.py b/internlm/core/scheduler/comm/__init__.py index a42b9ea16..0037c097a 100644 --- a/internlm/core/scheduler/comm/__init__.py +++ b/internlm/core/scheduler/comm/__init__.py @@ -1,13 +1,12 @@ from .p2p import ( AsynCommunicator, + fused_send_recv_tensor, recv_backward, recv_forward, send_backward, - send_backward_and_recv_next_backward_async, send_backward_recv_backward, send_backward_recv_forward, send_forward, - send_forward_and_recv_next_forward_async, send_forward_backward_recv_forward_backward, send_forward_recv_backward, send_forward_recv_forward, @@ -26,7 +25,6 @@ "recv_forward", "send_obj_meta", "recv_obj_meta", - "send_backward_and_recv_next_backward_async", - "send_forward_and_recv_next_forward_async", "AsynCommunicator", + "fused_send_recv_tensor", ] diff --git a/internlm/core/scheduler/comm/p2p.py b/internlm/core/scheduler/comm/p2p.py index 56dcd553a..54fb587c0 100644 --- a/internlm/core/scheduler/comm/p2p.py +++ b/internlm/core/scheduler/comm/p2p.py @@ -167,13 +167,12 @@ def _communicate( if object_send_next is not None: filling_ops_queue(object_send_next, dist.isend, next_rank, ops) - if len(ops) > 0: reqs = dist.batch_isend_irecv(ops) for req in reqs: req.wait() - # To protect against race condition when using batch_isend_irecv(). - internlm_accelerator.synchronize() + # To protect against race condition when using batch_isend_irecv(). + internlm_accelerator.synchronize() if recv_prev and recv_prev_split: if isinstance(tensor_recv_prev, torch.Tensor): @@ -196,6 +195,121 @@ def _communicate( return tensor_recv_prev, tensor_recv_next +def _communicate_async( + object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None, + object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None, + recv_prev: bool = False, + recv_next: bool = False, + recv_prev_shape: Union[torch.Size, List[torch.Size]] = None, + recv_next_shape: Union[torch.Size, List[torch.Size]] = None, + prev_rank: int = None, + next_rank: int = None, + dtype: torch.dtype = None, + scatter_gather_tensors: bool = False, +): + """ + Adapted from megatron.p2p_communication. + Communicate tensors between stages. Used as helper method in other + communication methods that are used in pipeline schedule. + Takes the following arguments: + object_send_next (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to next rank + (no tensor sent if set to None). + object_send_prev (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to prev rank + (no tensor sent if set to None). + recv_prev (bool): boolean for whether tensor should be received from + previous rank. + recv_next (bool): boolean for whether tensor should be received from + next rank. + recv_prev_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received + from the previous stage, defualts to None. + recv_next_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received + from the next stage, defualts to None. + prev_rank (int): the rank of the previous pipeline stage, defualts to None, + next_rank (int): the rank of the next pipeline stage, defualts to None, + dtype (torch.dtype): data type of intermediate buffers, defaults to None + scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False + + Returns: + Tuple[Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]]: returns tensor_recv_prev, tensor_recv_next + """ + + # Create placeholder tensors for receive in forward and backward directions + # if needed. + tensor_recv_prev = None + tensor_recv_next = None + + if recv_prev: + assert recv_prev_shape is not None + tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes( + recv_prev_shape, dtype, scatter_gather_tensors + ) + + if recv_next: + assert recv_next_shape is not None + tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes( + recv_next_shape, dtype, scatter_gather_tensors + ) + + if object_send_prev is not None or recv_prev: + if prev_rank is None: + prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) + + if object_send_next is not None or recv_next: + if next_rank is None: + next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) + + if object_send_prev is not None: + object_send_prev = process_object_to_send(object_send_prev, scatter_gather_tensors) + + if object_send_next is not None: + object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors) + + ops = [] + if object_send_prev is not None: + filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops) + + if tensor_recv_prev is not None: + filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops) + + if tensor_recv_next is not None: + filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops) + + if object_send_next is not None: + filling_ops_queue(object_send_next, dist.isend, next_rank, ops) + + if len(ops) > 0: + reqs = dist.batch_isend_irecv(ops) + + # return and do other things + yield + + if len(ops) > 0: + for req in reqs: # pylint: disable=E0601 + req.wait() + # To protect against race condition when using batch_isend_irecv(). + internlm_accelerator.synchronize() + + if recv_prev and recv_prev_split: + if isinstance(tensor_recv_prev, torch.Tensor): + tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_() + else: + for index in range(len(tensor_recv_prev)): + tensor_recv_prev[index] = ( + gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_() + ) + + if recv_next and recv_next_split: + if isinstance(tensor_recv_next, torch.Tensor): + tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_() + else: + for index in range(len(tensor_recv_next)): + tensor_recv_next[index] = ( + gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_() + ) + + yield tensor_recv_prev, tensor_recv_next + + def recv_forward( input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False ) -> Union[torch.Tensor, List[torch.Tensor]]: @@ -426,126 +540,52 @@ def send_forward_backward_recv_forward_backward( return input_tensor, output_tensor_grad -def send_forward_and_recv_next_forward_async( - output_tensor, +def fused_send_recv_tensor( + object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None, + object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None, recv_prev_shape: Union[torch.Size, List[torch.Size]] = None, - dtype: torch.dtype = None, - scatter_gather_tensors=False, -): - """send forward output to next rank and recv forward input from prev rank""" - - reqs = [] - tensor_recv_prev = None - - # prepare send opreations - if output_tensor is not None: - next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) - - output_tensor = process_object_to_send(output_tensor, scatter_gather_tensors) - - if isinstance(output_tensor, torch.Tensor): - reqs.append(dist.P2POp(dist.isend, output_tensor, next_rank)) - else: - for tensor_to_comm in output_tensor: - reqs.append(dist.P2POp(dist.isend, tensor_to_comm, next_rank)) - - # prepare receive opreations - if recv_prev_shape is not None: - prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) - # create receive buffer - tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes( - recv_prev_shape, dtype, scatter_gather_tensors - ) - # generate async receive opterations - if isinstance(tensor_recv_prev, torch.Tensor): - reqs.append(dist.P2POp(dist.irecv, tensor_recv_prev, prev_rank)) - else: - for tensor_to_comm in tensor_recv_prev: - reqs.append(dist.P2POp(dist.irecv, tensor_to_comm, prev_rank)) - - if len(reqs) > 0: - reqs = dist.batch_isend_irecv(reqs) - - # return and do other things - yield - - # check communication completed - for req in reqs: - req.wait() - # To protect against race condition when using batch_isend_irecv() - internlm_accelerator.synchronize() - - # Process received data - if recv_prev_shape is not None and recv_prev_split: - if isinstance(tensor_recv_prev, torch.Tensor): - tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_() - else: - for index in range(len(tensor_recv_prev)): - tensor_recv_prev[index] = ( - gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_() - ) - - yield tensor_recv_prev - - -def send_backward_and_recv_next_backward_async( - input_tensor, recv_next_shape: Union[torch.Size, List[torch.Size]] = None, + prev_rank: int = None, + next_rank: int = None, dtype: torch.dtype = None, - scatter_gather_tensors=False, -): - reqs = [] - tensor_recv_next = None - - # prepare send opreations - if input_tensor is not None: - prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) - - input_tensor = process_object_to_send(input_tensor, scatter_gather_tensors) - - if isinstance(input_tensor, torch.Tensor): - reqs.append(dist.P2POp(dist.isend, input_tensor, prev_rank)) - else: - for tensor_to_comm in input_tensor: - reqs.append(dist.P2POp(dist.isend, tensor_to_comm, prev_rank)) - - # prepare receive opreations - if recv_next_shape is not None: - next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) - # create receive buffer - tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes( - recv_next_shape, dtype, scatter_gather_tensors - ) - # generate async receive opreations - if isinstance(tensor_recv_next, torch.Tensor): - reqs.append(dist.P2POp(dist.irecv, tensor_recv_next, next_rank)) - else: - for tensor_to_comm in tensor_recv_next: - reqs.append(dist.P2POp(dist.irecv, tensor_to_comm, next_rank)) - - if len(reqs) > 0: - reqs = dist.batch_isend_irecv(reqs) - - # return and do other things - yield - - # check communication completed - for req in reqs: - req.wait() - # To protect against race condition when using batch_isend_irecv() - internlm_accelerator.synchronize() + scatter_gather_tensors: bool = False, +) -> Union[torch.Tensor, List[torch.Tensor]]: + """Fused communication operation. send and recv tensor from next rank or prev rank. - # Process received data - if recv_next_shape is not None and recv_next_split: - if isinstance(tensor_recv_next, torch.Tensor): - tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_() - else: - for index in range(len(tensor_recv_next)): - tensor_recv_next[index] = ( - gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_() - ) + Args: + object_send_next (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to next rank + (no tensor sent if set to None). + object_send_prev (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to prev rank + (no tensor sent if set to None). + recv_prev (bool): boolean for whether tensor should be received from + previous rank. + recv_next (bool): boolean for whether tensor should be received from + next rank. + recv_prev_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received + from the previous stage, defualts to None. + recv_next_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received + from the next stage, defualts to None. + prev_rank (int): the rank of the previous pipeline stage, defualts to None, + next_rank (int): the rank of the next pipeline stage, defualts to None, + dtype (torch.dtype): data type of intermediate buffers, defaults to None + scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False - yield tensor_recv_next + Returns: + Tuple[Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]]: returns tensor_recv_prev, tensor_recv_next + """ + tensor_recv_prev, tensor_recv_next = _communicate( + object_send_next=object_send_next, + object_send_prev=object_send_prev, + recv_prev=recv_prev_shape is not None, + recv_next=recv_next_shape is not None, + recv_prev_shape=recv_prev_shape, + recv_next_shape=recv_next_shape, + prev_rank=prev_rank, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) + return tensor_recv_prev, tensor_recv_next class AsynCommunicator: @@ -553,22 +593,28 @@ class AsynCommunicator: def __init__( self, - tensor_to_send: Union[torch.Tensor, List[torch.Tensor]], - recv_shape: Union[torch.Size, List[torch.Size]], + object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None, + object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None, + recv_prev_shape: Union[torch.Size, List[torch.Size]] = None, + recv_next_shape: Union[torch.Size, List[torch.Size]] = None, + prev_rank: int = None, + next_rank: int = None, dtype: torch.dtype = None, - scatter_gather_tensors=False, - forward: bool = True, + scatter_gather_tensors: bool = False, ) -> None: - self._need_receive = recv_shape is not None - - if forward: - self._coroutine = send_forward_and_recv_next_forward_async( - tensor_to_send, recv_shape, dtype, scatter_gather_tensors - ) - else: - self._coroutine = send_backward_and_recv_next_backward_async( - tensor_to_send, recv_shape, dtype, scatter_gather_tensors - ) + self._need_receive = recv_prev_shape is not None or recv_next_shape is not None + self._coroutine = _communicate_async( + object_send_prev=object_send_prev, + object_send_next=object_send_next, + recv_prev=recv_prev_shape is not None, + recv_next=recv_next_shape is not None, + recv_prev_shape=recv_prev_shape, + recv_next_shape=recv_next_shape, + prev_rank=prev_rank, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) @property def need_receive(self) -> bool: diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler_1f1b.py similarity index 82% rename from internlm/core/scheduler/pipeline_scheduler.py rename to internlm/core/scheduler/pipeline_scheduler_1f1b.py index a2e089143..4864c77f1 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler_1f1b.py @@ -3,7 +3,6 @@ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine -import queue from contextlib import contextmanager from typing import Callable, List, Optional, Tuple, Union @@ -22,7 +21,6 @@ move_to_device, ) from internlm.utils.logger import get_logger -from internlm.utils.parallel import is_using_isp from internlm.utils.timeout import llm_timeout from .base_scheduler import BaseScheduler @@ -110,53 +108,6 @@ def switch_optimizer_grad_sync_skip_mode(optimizer, skip: bool = True): optimizer.skip_grad_reduce = prev_mode -class WeightGradStore: - """ - When using zero bubble pp, WeightGradStore is used to store the args and func for computating weight grad. - """ - - cache = [] - weight_grad_queue = queue.Queue() - - @classmethod - def size(cls): - return cls.weight_grad_queue.qsize() - - @classmethod - def put(cls, weight, bias, input_tensor, grad_output, has_d_bias, grad_compute_func, *args): - assert not gpc.is_first_rank(ParallelMode.PIPELINE), "pp rank 0 should not arrive here" - # Store the weight gradient computation of linear layers. - cls.cache.append((weight, bias, input_tensor, grad_output, has_d_bias, grad_compute_func, *args)) - - @classmethod - def flush(cls): - if gpc.is_first_rank(ParallelMode.PIPELINE): - return - # Collect all stored computations during backward as a W for each micro batch. - cls.weight_grad_queue.put(cls.cache) - cls.cache = [] - - @classmethod - def pop(cls): - if gpc.is_first_rank(ParallelMode.PIPELINE): - return - assert cls.weight_grad_queue.qsize() > 0 - stored_w_grad_computation = cls.weight_grad_queue.get() - # Run computation for a single W. - for weight, bias, input_tensor, grad_output, has_d_bias, grad_compute_func, *args in stored_w_grad_computation: - grad_weight, grad_bias = grad_compute_func(input_tensor, grad_output, has_d_bias) - if is_using_isp(): - isp_grad_hook = args[0] - grad_weight, _ = isp_grad_hook(grad_weight, async_op=False, is_bias=False) - if grad_bias is not None: - grad_bias, _ = isp_grad_hook(grad_bias, async_op=False, is_bias=True) - - # Gradient Accumulation - weight.grad = weight.grad + grad_weight if weight.grad is not None else grad_weight - if has_d_bias: - bias.grad = bias.grad + grad_bias if bias.grad is not None else grad_bias - - class PipelineScheduler(BaseScheduler): """ A helper schedule class for pipeline parallelism running environment. @@ -230,7 +181,6 @@ def _call_engine(engine, data): # pylint: disable=W0237 return engine(*data) elif isinstance(data, dict): stage_output = data.pop("stage_output", None) - if stage_output is None: return engine(**data) elif isinstance(stage_output, torch.Tensor): @@ -280,6 +230,7 @@ def load_micro_batch(self): def _get_data_label_for_current_step(self, stage_output, micro_batch_data): if isinstance(micro_batch_data, (tuple, list)): + assert not self._config.parallel["pipeline"].get("mode", "1F1B") == "ZBV" if gpc.is_first_rank(ParallelMode.PIPELINE): # for the first stage, we use the data from the # dataloader output by default @@ -289,6 +240,7 @@ def _get_data_label_for_current_step(self, stage_output, micro_batch_data): # by the previous as the model input data = stage_output _, label = micro_batch_data + # normally this way elif isinstance(micro_batch_data, dict): label = micro_batch_data.pop("label", None) data = {"stage_output": stage_output, **micro_batch_data} @@ -753,260 +705,6 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo return output, label, accum_loss -class ZeroBubblePipelineScheduler(PipelineScheduler): - """ - A helper schedule class for pipeline parallelism running environment. - It uses non-interleaved 1F1B strategy. Other properties are similar as - :class:`NonPipelineSchedule`. - - Args: - num_microbatches (int): The number of microbatches. - dtype (torch.dtype): Type of data. torch.float by default. - data_process_func (Callable, optional): - The post processing function which receives a micro batch of data, and it will be executed - in `load_micro_batch`. - tensor_shape (torch.Size, optional): Specified shape in pipeline communication. - scatter_gather_tensors (bool, optional): - If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. - scheduler_hooks (Optional[List[SchedulerHook]], optional): List of scheduler hooks. - """ - - def __init__( - self, - num_microbatches: int, - dtype: torch.dtype = torch.float, - data_process_func: Callable = None, - tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, - scatter_gather_tensors: bool = False, - scheduler_hooks: Optional[List[SchedulerHook]] = None, - ): - super().__init__( - num_microbatches, - dtype=dtype, - data_process_func=data_process_func, - tensor_shape=tensor_shape, - scatter_gather_tensors=scatter_gather_tensors, - scheduler_hooks=scheduler_hooks, - ) - - def _forward_backward_step(self, engine, return_loss=True, return_output_label=True): - """ - This function schedules the forward and backward computation of microbatches in the pipeline in a 1F1B manner. - It consists of three stages: warmup, 1F1B, and cooldown. - - 1. Warmup Stage: - The warmup stage performs num_warmup forward microsteps. The calculation of num_warmup is the pipeline length - minus the rank of the current pipeline minus 1. For each microstep, it receives data as input from the previous - stage, performs the forward computation, and then sends the result to the next stage. - - 2. 1F1B Stage: - The 1F1B stage consists of pairs of forward and backward microsteps. It performs num_1f1b_micropairs iterations, - where num_1f1b_micropairs is calculated as the total number of microbatches minus the number of microbatches in - the warmup stage. In each iteration, it first performs a forward computation, sends the result to the next - stage, receives input for the backward computation, performs the backward computation, and finally sends the - result to the previous stage to receive input for the next forward computation. - - 3. Cooldown Stage: - The cooldown stage performs the same number of iterations as the warmup stage. In each iteration, it receives - input for the backward computation, performs the backward computation, and finally sends the result to the - previous stage. - - There are two special cases to consider: - 1. The first stage of the pipeline does not need to receive forward input or send backward output. The last - stage does not need to send forward output or receive backward input. - 2. Pay attention to the communication between stages and use additional communication to bridge the gap. - - Args: - engine (Engine): The engine used for computation. - return_loss (bool, optional): Whether to return the accumulated loss. - return_output_label (bool, optional): Whether to return outputs and labels. - - Returns: - Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None], Union[torch.Tensor, None]]: - The output, label, and accumulated loss. - """ - - num_warmup_microsteps = ( - gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1 - ) - num_warmup_microsteps = min(num_warmup_microsteps, self.num_microbatches) - num_1f1b_micropairs = self.num_microbatches - num_warmup_microsteps - - # Input, output tensors only need to be saved when doing backward passes - input_objs = [] - output_objs = [] - moe_losses = [] - return_tensors = [] - accum_loss = ( - torch.zeros(1, device=get_current_device()) - if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True) - else None - ) - accum_moe_loss = torch.zeros(1, device=get_current_device()) - - # Used for tensor meta information communication - forward_recv_shapes = self.tensor_shape - backward_recv_shapes = None - need_forward_meta = self.tensor_shape is None - - f_times = 0 - # Run warmup forward passes. - for i in range(num_warmup_microsteps): - # Receive the input from the previous stage - if not gpc.is_first_rank(ParallelMode.PIPELINE): - if forward_recv_shapes is None: - forward_recv_shapes = comm.recv_obj_meta() - input_obj = comm.recv_forward( - forward_recv_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors, - ) - else: - input_obj = None - - # Perform forward computation - output_obj, moe_loss = self._forward_step( - engine, - input_obj, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss, - accum_moe_loss=accum_moe_loss, - ) - f_times += 1 - - if not gpc.is_last_rank(ParallelMode.PIPELINE): - if isinstance(output_obj, torch.Tensor): - backward_recv_shapes = output_obj.shape - else: - backward_recv_shapes = [out_tensor.shape for out_tensor in output_obj] - - if need_forward_meta: - comm.send_obj_meta(output_obj) - need_forward_meta = False # send only once. - - # Send the output of forward computation of this pipeline stage to the next pipeline stage as input for - # forward computation - if not gpc.is_last_rank(ParallelMode.PIPELINE): - assert output_obj.dtype == self.dtype - comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) - - input_objs.append(input_obj) - output_objs.append(output_obj) - moe_losses.append(moe_loss) - # Before running 1F1B, need to receive first forward tensor. - # If all microbatches are run in warmup / cooldown phase, then no need to - # receive this tensor here. - if num_1f1b_micropairs > 0: - if not gpc.is_first_rank(ParallelMode.PIPELINE): - if forward_recv_shapes is None: - forward_recv_shapes = comm.recv_obj_meta() - input_obj = comm.recv_forward( - forward_recv_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors, - ) - else: - input_obj = None - - # Run 1F1B in steady state. - for i in range(num_1f1b_micropairs): - # Perform forward computation - output_obj, moe_loss = self._forward_step( - engine, - input_obj, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss, - accum_moe_loss=accum_moe_loss, - ) - f_times += 1 - - if gpc.is_last_rank(ParallelMode.PIPELINE): - output_obj_grad = None - else: - assert output_obj.dtype == self.dtype - output_obj_grad = comm.send_forward_recv_backward( - output_obj, - backward_recv_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors, - ) - - # Add input_obj and output_obj to end of list. - input_objs.append(input_obj) - output_objs.append(output_obj) - moe_losses.append(moe_loss) - - # Pop output_obj and output_obj from the start of the list for - # the backward pass. - input_obj = input_objs.pop(0) - output_obj = output_objs.pop(0) - moe_loss = moe_losses.pop(0) - - input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad, moe_loss) - - if i == (num_1f1b_micropairs - 1): - input_obj = None - if not gpc.is_first_rank(ParallelMode.PIPELINE): - comm.send_backward( - input_obj_grad, - scatter_gather_tensors=self.scatter_gather_tensors, - ) - else: - if gpc.is_first_rank(ParallelMode.PIPELINE): - input_obj = None - else: - input_obj = comm.send_backward_recv_forward( - input_obj_grad, - forward_recv_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors, - ) - - WeightGradStore.flush() - if i >= gpc.get_local_rank(ParallelMode.PIPELINE): - WeightGradStore.pop() - - # Run cooldown backward passes. - for i in range(num_warmup_microsteps): - input_obj = input_objs.pop(0) - output_obj = output_objs.pop(0) - moe_loss = moe_losses.pop(0) - - if not gpc.is_last_rank(ParallelMode.PIPELINE): - output_obj_grad = comm.recv_backward( - backward_recv_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors, - ) - else: - output_obj_grad = None - - input_obj_grad = self._backward_step( - engine, num_1f1b_micropairs + i, input_obj, output_obj, output_obj_grad, moe_loss - ) - - if not gpc.is_first_rank(ParallelMode.PIPELINE): - comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors) - - WeightGradStore.flush() - WeightGradStore.pop() - - while WeightGradStore.size() > 0: - WeightGradStore.pop() - - output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) - - if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: - dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) - - if accum_loss is not None: - accum_loss += accum_moe_loss - - return output, label, accum_loss, accum_moe_loss - - class InterleavedPipelineScheduler(PipelineScheduler): """ Interleaved Pipeline Scheduler. @@ -1119,12 +817,13 @@ def load_micro_batch(self, model_chunk_id): ) if self.data_process_func: micro_batch_data, micro_batch_label = self.data_process_func(micro_batch_data, micro_batch_label) - micro_batch_data["label"] = micro_batch_label self.microbatch_offset[model_chunk_id] += self.bsz_stride - return move_to_device(micro_batch_data) - def _forward_step(self, engine, chunk_id): + result = move_to_device(micro_batch_data) + return result + + def _forward_step(self, engine, chunk_id, input_obj=None): """Forward step for passed-in model. If it is the first stage, the input tensor is obtained from data_iterator, otherwise the passed-in input_obj is used. Returns output tensor. This is a helper function and can be ignored by users. @@ -1140,8 +839,12 @@ def _forward_step(self, engine, chunk_id): if gpc.is_pipeline_first_stage() and len(self._input_objs[chunk_id]) == len(self._output_objs[chunk_id]): self._input_objs[chunk_id].append(None) - input_obj = self._input_objs[chunk_id][-1] + if input_obj is None: + input_obj = self._input_objs[chunk_id][-1] + + if not gpc.is_pipeline_first_stage(): + assert input_obj is not None, f"{gpc.get_global_rank()} input is None" micro_batch_data = self.load_micro_batch(chunk_id) data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data) @@ -1185,6 +888,8 @@ def _forward_step(self, engine, chunk_id): self._output_objs[chunk_id].append(output_obj) self._moe_losses[chunk_id].append(moe_loss) + assert output_obj is not None, f"{gpc.get_global_rank()} chunk{chunk_id} output is None" + return output_obj def _backward_step(self, engine, chunk_id, step_id): @@ -1397,7 +1102,7 @@ def _run_1f1b_loop_with_overlap( # 2. Check if the backward input is ready. if backward_async_communicator is not None: - output_obj_grad = backward_async_communicator.wait_and_receive() + _, output_obj_grad = backward_async_communicator.wait_and_receive() if backward_async_communicator.need_receive: self._output_obj_grads[backward_chunk_id].append(output_obj_grad) @@ -1419,11 +1124,10 @@ def _run_1f1b_loop_with_overlap( assert output_obj is None or output_obj.dtype == self.dtype forward_async_communicator = comm.AsynCommunicator( - output_obj, - input_obj_shape, - self.dtype, - self.scatter_gather_tensors, - forward=True, + object_send_next=output_obj, + recv_prev_shape=input_obj_shape, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, ) forward_async_communicator.start() @@ -1431,7 +1135,7 @@ def _run_1f1b_loop_with_overlap( input_obj_grad = self._backward_step(engine, backward_chunk_id, backward_microstep_id) - input_obj = forward_async_communicator.wait_and_receive() + input_obj, _ = forward_async_communicator.wait_and_receive() if forward_async_communicator.need_receive: self._input_objs[next_forward_chunk_id].append(input_obj) @@ -1448,11 +1152,10 @@ def _run_1f1b_loop_with_overlap( output_obj_shape = self._output_obj_shapes[next_backward_chunk_id] backward_async_communicator = comm.AsynCommunicator( - input_obj_grad, - output_obj_shape, - self.dtype, - self.scatter_gather_tensors, - forward=False, + object_send_prev=input_obj_grad, + recv_next_shape=output_obj_shape, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, ) backward_async_communicator.start() @@ -1468,7 +1171,7 @@ def _run_1f1b_loop_with_overlap( else: self._output_obj_grads[self._num_chunks - 1].append(None) else: - output_obj_grad = backward_async_communicator.wait_and_receive() + _, output_obj_grad = backward_async_communicator.wait_and_receive() if backward_async_communicator.need_receive: backward_chunk_id = self._get_chunk_by_microbatch(num_1f1b_micropairs, backward=True) self._output_obj_grads[backward_chunk_id].append(output_obj_grad) diff --git a/internlm/core/scheduler/pipeline_scheduler_zb.py b/internlm/core/scheduler/pipeline_scheduler_zb.py new file mode 100644 index 000000000..75cf18448 --- /dev/null +++ b/internlm/core/scheduler/pipeline_scheduler_zb.py @@ -0,0 +1,1062 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import queue +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +from torch.optim.optimizer import Optimizer + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.core.engine import Engine +from internlm.core.scheduler import comm +from internlm.utils.common import SchedulerHook, get_current_device +from internlm.utils.logger import get_logger +from internlm.utils.parallel import is_using_isp + +from .pipeline_scheduler_1f1b import ( + InterleavedPipelineScheduler, + PipelineScheduler, + pack_return_tensors, +) + +logger = get_logger(__file__) + + +class WeightGradStore: + """ + When using zero bubble pp, WeightGradStore is used to store the args and func for computating weight grad. + """ + + _cache = [] + _weight_grad_queue = queue.Queue() + _hooks = {} + pp_mode = None + optim = None + temp = [] + + @classmethod + def set_pp_mode(cls, mode): + cls.pp_mode = mode + + @classmethod + def set_optim(cls, optim): + cls.optim = optim + + @classmethod + def size(cls): + return cls._weight_grad_queue.qsize() + + @classmethod + def put(cls, weight, bias, input_tensor, grad_output, has_d_bias, grad_compute_func, *args): + if cls.pp_mode == "ZBH1": + assert not gpc.is_first_rank(ParallelMode.PIPELINE), "pp rank 0 should not arrive here" + # Store the weight gradient computation of linear layers. + cls._cache.append((weight, bias, input_tensor, grad_output, has_d_bias, grad_compute_func, *args)) + + @classmethod + def flush(cls): + if cls.pp_mode == "ZBH1" and gpc.is_first_rank(ParallelMode.PIPELINE): + return + # Collect all stored computations during backward as a W for each micro batch. + cls._weight_grad_queue.put(cls._cache) + cls._cache = [] + + @classmethod + def pop(cls): + if cls.pp_mode == "ZBH1" and gpc.is_first_rank(ParallelMode.PIPELINE): + return + assert cls._weight_grad_queue.qsize() > 0 + stored_w_grad_computation = cls._weight_grad_queue.get() + # Run computation for a single W. + for weight, bias, input_tensor, grad_output, has_d_bias, grad_compute_func, *args in stored_w_grad_computation: + assert weight.requires_grad + grad_weight, grad_bias = grad_compute_func(input_tensor, grad_output, has_d_bias) + + if is_using_isp(): + isp_grad_hook = args[0] + module = args[1] + grad_weight, handle_weight = isp_grad_hook(grad_weight, async_op=True, is_bias=False, module=module) + handle_weight.wait() + if grad_bias is not None: + grad_bias, handle_bias = isp_grad_hook(grad_bias, async_op=True, is_bias=True, module=module) + handle_bias.wait() + + # Gradient Accumulation + weight.grad = weight.grad.data + grad_weight if weight.grad is not None else grad_weight + if has_d_bias: + bias.grad = bias.grad.data + grad_bias if bias.grad is not None else grad_bias + + # overlap hook + if weight in cls._hooks: + for hook in cls._hooks[weight]: + hook() + if has_d_bias: + for hook in cls._hooks[bias]: + hook() + + @classmethod + def register_hook(cls, param, hooks): + cls._hooks[param] = hooks + + +class ZeroBubblePipelineScheduler(PipelineScheduler): + """ + A helper schedule class for pipeline parallelism running environment. + It uses non-interleaved 1F1B strategy. Other properties are similar as + :class:`NonPipelineSchedule`. + + Args: + num_microbatches (int): The number of microbatches. + dtype (torch.dtype): Type of data. torch.float by default. + data_process_func (Callable, optional): + The post processing function which receives a micro batch of data, and it will be executed + in `load_micro_batch`. + tensor_shape (torch.Size, optional): Specified shape in pipeline communication. + scatter_gather_tensors (bool, optional): + If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. + scheduler_hooks (Optional[List[SchedulerHook]], optional): List of scheduler hooks. + """ + + def __init__( + self, + num_microbatches: int, + dtype: torch.dtype = torch.float, + data_process_func: Callable = None, + tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, + scatter_gather_tensors: bool = False, + scheduler_hooks: Optional[List[SchedulerHook]] = None, + optimizer: Optimizer = None, + ): + super().__init__( + num_microbatches, + dtype=dtype, + data_process_func=data_process_func, + tensor_shape=tensor_shape, + scatter_gather_tensors=scatter_gather_tensors, + scheduler_hooks=scheduler_hooks, + ) + WeightGradStore.set_pp_mode("ZBH1") + WeightGradStore.set_optim(optimizer) + + def _forward_backward_step(self, engine, return_loss=True, return_output_label=True): + """ + This function schedules the forward and backward computation of microbatches in the pipeline in a 1F1B manner. + It consists of three stages: warmup, 1F1B, and cooldown. + + 1. Warmup Stage: + The warmup stage performs num_warmup forward microsteps. The calculation of num_warmup is the pipeline length + minus the rank of the current pipeline minus 1. For each microstep, it receives data as input from the previous + stage, performs the forward computation, and then sends the result to the next stage. + + 2. 1F1B Stage: + The 1F1B stage consists of pairs of forward and backward microsteps. It performs num_1f1b_micropairs iterations, + where num_1f1b_micropairs is calculated as the total number of microbatches minus the number of microbatches in + the warmup stage. In each iteration, it first performs a forward computation, sends the result to the next + stage, receives input for the backward computation, performs the backward computation, and finally sends the + result to the previous stage to receive input for the next forward computation. + + 3. Cooldown Stage: + The cooldown stage performs the same number of iterations as the warmup stage. In each iteration, it receives + input for the backward computation, performs the backward computation, and finally sends the result to the + previous stage. + + There are two special cases to consider: + 1. The first stage of the pipeline does not need to receive forward input or send backward output. The last + stage does not need to send forward output or receive backward input. + 2. Pay attention to the communication between stages and use additional communication to bridge the gap. + + Args: + engine (Engine): The engine used for computation. + return_loss (bool, optional): Whether to return the accumulated loss. + return_output_label (bool, optional): Whether to return outputs and labels. + + Returns: + Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None], Union[torch.Tensor, None]]: + The output, label, and accumulated loss. + """ + + num_warmup_microsteps = ( + gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1 + ) + num_warmup_microsteps = min(num_warmup_microsteps, self.num_microbatches) + num_1f1b_micropairs = self.num_microbatches - num_warmup_microsteps + + # Input, output tensors only need to be saved when doing backward passes + input_objs = [] + output_objs = [] + moe_losses = [] + return_tensors = [] + accum_loss = ( + torch.zeros(1, device=get_current_device()) + if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True) + else None + ) + accum_moe_loss = torch.zeros(1, device=get_current_device()) + + # Used for tensor meta information communication + forward_recv_shapes = self.tensor_shape + backward_recv_shapes = None + need_forward_meta = self.tensor_shape is None + + f_times = 0 + # Run warmup forward passes. + for i in range(num_warmup_microsteps): + # Receive the input from the previous stage + if not gpc.is_first_rank(ParallelMode.PIPELINE): + if forward_recv_shapes is None: + forward_recv_shapes = comm.recv_obj_meta() + input_obj = comm.recv_forward( + forward_recv_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) + else: + input_obj = None + + # Perform forward computation + output_obj, moe_loss = self._forward_step( + engine, + input_obj, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss, + accum_moe_loss=accum_moe_loss, + ) + f_times += 1 + + if not gpc.is_last_rank(ParallelMode.PIPELINE): + if isinstance(output_obj, torch.Tensor): + backward_recv_shapes = output_obj.shape + else: + backward_recv_shapes = [out_tensor.shape for out_tensor in output_obj] + + if need_forward_meta: + comm.send_obj_meta(output_obj) + need_forward_meta = False # send only once. + + # Send the output of forward computation of this pipeline stage to the next pipeline stage as input for + # forward computation + if not gpc.is_last_rank(ParallelMode.PIPELINE): + assert output_obj.dtype == self.dtype + comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) + + input_objs.append(input_obj) + output_objs.append(output_obj) + moe_losses.append(moe_loss) + # Before running 1F1B, need to receive first forward tensor. + # If all microbatches are run in warmup / cooldown phase, then no need to + # receive this tensor here. + if num_1f1b_micropairs > 0: + if not gpc.is_first_rank(ParallelMode.PIPELINE): + if forward_recv_shapes is None: + forward_recv_shapes = comm.recv_obj_meta() + input_obj = comm.recv_forward( + forward_recv_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) + else: + input_obj = None + + # Run 1F1B in steady state. + for i in range(num_1f1b_micropairs): + # Perform forward computation + output_obj, moe_loss = self._forward_step( + engine, + input_obj, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss, + accum_moe_loss=accum_moe_loss, + ) + f_times += 1 + + if gpc.is_last_rank(ParallelMode.PIPELINE): + output_obj_grad = None + else: + assert output_obj.dtype == self.dtype + output_obj_grad = comm.send_forward_recv_backward( + output_obj, + backward_recv_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) + + # Add input_obj and output_obj to end of list. + input_objs.append(input_obj) + output_objs.append(output_obj) + moe_losses.append(moe_loss) + + # Pop output_obj and output_obj from the start of the list for + # the backward pass. + input_obj = input_objs.pop(0) + output_obj = output_objs.pop(0) + moe_loss = moe_losses.pop(0) + + input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad, moe_loss) + + if i == (num_1f1b_micropairs - 1): + input_obj = None + if not gpc.is_first_rank(ParallelMode.PIPELINE): + comm.send_backward( + input_obj_grad, + scatter_gather_tensors=self.scatter_gather_tensors, + ) + else: + if gpc.is_first_rank(ParallelMode.PIPELINE): + input_obj = None + else: + input_obj = comm.send_backward_recv_forward( + input_obj_grad, + forward_recv_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) + + WeightGradStore.flush() + if i >= gpc.get_local_rank(ParallelMode.PIPELINE): + WeightGradStore.pop() + + # Run cooldown backward passes. + for i in range(num_warmup_microsteps): + input_obj = input_objs.pop(0) + output_obj = output_objs.pop(0) + moe_loss = moe_losses.pop(0) + + if not gpc.is_last_rank(ParallelMode.PIPELINE): + output_obj_grad = comm.recv_backward( + backward_recv_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) + else: + output_obj_grad = None + + input_obj_grad = self._backward_step( + engine, num_1f1b_micropairs + i, input_obj, output_obj, output_obj_grad, moe_loss + ) + + if not gpc.is_first_rank(ParallelMode.PIPELINE): + comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors) + + WeightGradStore.flush() + WeightGradStore.pop() + + while WeightGradStore.size() > 0: + WeightGradStore.pop() + + output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) + + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: + dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) + + if accum_loss is not None: + accum_loss += accum_moe_loss + + return output, label, accum_loss, accum_moe_loss + + +class ZeroBubblePipelineVShapeScheduler(InterleavedPipelineScheduler): + """ + ZB-V Scheduler. + + Args: + num_microbatches (int): The number of microbatches. + num_chunks (int): The number of model chunks. + dtype (torch.dtype, optional): The data type of the tensors. Default is torch.float. + data_process_func (Callable, optional): + The preprocessing function which receives a batch of data, and it will be executed in `load_batch`. + tensor_shape (torch.Size, optional): Specified shape in pipeline communication. + scatter_gather_tensors (bool, optional): + If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. + scheduler_hooks (List[SchedulerHook], optional): List of scheduler hooks. Default is None. + optimizer (Optimizer): The optimizer to do param update. + """ + + def __init__( + self, + num_microbatches: int, + num_chunks: int, + dtype: torch.dtype = torch.float, + data_process_func: Callable = None, + tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, + scatter_gather_tensors: bool = False, + scheduler_hooks: Optional[List[SchedulerHook]] = None, + optimizer: Optimizer = None, + ): + """A helper schedule class for pipeline parallelism running environment. + It uses ZB-V strategy. Other properties are similar as + :class:`NonPipelineSchedule`. + + Args: + num_microbatches (int): The number of microbatches. + num_chunks (int): The number of model chunks. + dtype (torch.dtype, optional): The data type of the tensors. Default is torch.float. + data_process_func (Callable, optional): + The preprocessing function which receives a batch of data, and it will be executed in `load_batch`. + tensor_shape (torch.Size, optional): Specified shape in pipeline communication. + scatter_gather_tensors (bool, optional): + If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. + scheduler_hooks (List[SchedulerHook], optional): List of scheduler hooks. Default is None. + """ + + assert ( + isinstance(num_chunks, int) and num_chunks == 2 + ), f"expect num_chunks to be an integer and equal to 2 for ZBV, but got {num_chunks}." + + assert num_microbatches >= 2 * gpc.get_world_size( + ParallelMode.PIPELINE + ), "For ZBV, num_microbatches must be greater than or equal to twice pp size." + + assert gpc.v_shape + + super().__init__( + num_microbatches, + num_chunks=num_chunks, + dtype=dtype, + data_process_func=data_process_func, + tensor_shape=tensor_shape, + scatter_gather_tensors=scatter_gather_tensors, + scheduler_hooks=scheduler_hooks, + ) + + del self._run_1f1b_loop + + WeightGradStore.set_pp_mode("ZBV") + WeightGradStore.set_optim(optimizer) + + self._special_chunk0_forward = True + self._chunk1_need_recv_prev_chunk1_grad = True + self._backward_step_num = [0, 0] + self._num_microbatches = num_microbatches + + def _clear_state(self) -> None: + super()._clear_state() + self._special_chunk0_forward = True + self._chunk1_need_recv_prev_chunk1_grad = True + self._backward_step_num = [0, 0] + + def _backward_step(self, engine, input_obj, output_obj, output_obj_grad, skip_grad_sync=True, moe_loss=None): + """ + Backward step through the passed-in output tensor. If it is the last stage, the + output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor. + Returns the gradients with respect to the input tensor (None if first stage). + This is a helper function and can be ignored by users. + + Args: + engine (colossalai.engine.Engine): Colossalai engine for training and inference. + input_obj (Union[torch.Tensor, List[torch.Tensor]]): Input tensor for this stage. + output_obj (Union[torch.Tensor, List[torch.Tensor]]): Output tensor for this stage. + output_obj_grad (Union[torch.Tensor, List[torch.Tensor]]): Gradient of output tensor for this stage. + skip_grad_sync (bool): Whether skip grad sync or not. + + Returns: + Union[torch.Tensor, List[torch.Tensor]]: Gradient of input tensor. + """ + + # Retain the grad on the input_obj. + if input_obj is not None: + assert input_obj.requires_grad + if isinstance(input_obj, torch.Tensor): + input_obj.retain_grad() + else: + for in_tensor in input_obj: + if in_tensor is not None: + in_tensor.retain_grad() + + # Only the last microbatch does syncing grad. + engine.optimizer.skip_grad_reduce = skip_grad_sync + self._call_hooks("before_backward", output_obj, output_obj_grad) + # with switch_optimizer_grad_sync_skip_mode(engine.optimizer, skip_grad_sync): + if moe_loss is None or moe_loss.item() == 0.0: + if output_obj_grad is None: + engine.backward(output_obj) + else: + engine.backward_by_grad(output_obj, output_obj_grad) + else: + if output_obj_grad is None: + engine.backward(output_obj + moe_loss) + else: + # scale the latent loss + moe_loss = moe_loss * engine.optimizer.loss_scale + # we perform chain rule here by projecting the grad to the direction of + # [output_obj_grad, 1], Because moe_loss have no relation with subsequent + # layer, we set it to None (will be ragarded as 1). + engine.backward_by_grad([output_obj, moe_loss], [output_obj_grad, None]) + + # Collect the grad of the input_obj. + input_obj_grad = None + if input_obj is not None: + assert input_obj.grad is not None + if isinstance(input_obj, torch.Tensor): + input_obj_grad = input_obj.grad + else: + input_obj_grad = [] + for in_tensor in input_obj: + input_obj_grad.append(in_tensor.grad) + + return input_obj_grad + + def _schedule_backward(self, engine, chunk_id): + """ + Backward step for passed-in model. If it is the last stage, the input tensor + is obtained from the previous forward step, otherwise the passed-in input_obj is used. + Returns input tensor gradient. This is a helper function and can be ignored by users. + + Args: + engine (colossalai.engine.Engine): Colossalai engine for training and inference. + chunk_id (int): The id of model chunks. + step_id (int): The current step id. + + Returns: + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: input tensor gradient. + """ + gpc.set_virtual_pipeline_parallel_rank(chunk_id) + + self._backward_step_num[chunk_id] += 1 + if self._backward_step_num[chunk_id] == self._num_microbatches: + skip_grad_sync = False + else: + skip_grad_sync = True + + if gpc.is_pipeline_last_stage() and len(self._output_obj_grads[chunk_id]) == 0: + self._output_obj_grads[chunk_id].append(None) + + input_obj = self._input_objs[chunk_id].pop(0) + output_obj = self._output_objs[chunk_id].pop(0) + output_obj_grad = self._output_obj_grads[chunk_id].pop(0) + moe_loss = self._moe_losses[chunk_id].pop(0) + + if not gpc.is_pipeline_last_stage(): + assert output_obj_grad is not None + if not gpc.is_pipeline_first_stage(): + assert input_obj is not None + + input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad, skip_grad_sync, moe_loss) + + WeightGradStore.flush() + + return input_obj_grad + + def _schedule_1f1b_F(self, engine, chunk_id): + output_obj = self._forward_step(engine, chunk_id) + + object_send_next = None + object_send_prev = None + recv_next_shape = None + recv_prev_shape = None + + if chunk_id == 1: + if not gpc.is_first_rank(ParallelMode.PIPELINE): + object_send_prev = output_obj + if self._chunk1_need_recv_prev_chunk1_grad: + recv_prev_shape = self._output_obj_shapes[chunk_id] + else: + self._chunk1_need_recv_prev_chunk1_grad = False + if gpc.is_last_rank(ParallelMode.PIPELINE): + # For last rank, chunk0 output does not need to be sent but is directly used for chunk1; + input_obj = output_obj.clone().detach() + input_obj.requires_grad_() + self._input_objs[1].append(input_obj) + else: + object_send_next = output_obj + recv_next_shape = self._output_obj_shapes[chunk_id] + + # chunk1 send output prev, recv output_grad prev + # chunk0 send output next, recv output_grad next + tensor_recv_prev, tensor_recv_next = comm.fused_send_recv_tensor( + object_send_next=object_send_next, + object_send_prev=object_send_prev, + recv_next_shape=recv_next_shape, + recv_prev_shape=recv_prev_shape, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) + + if chunk_id == 1 and not self._chunk1_need_recv_prev_chunk1_grad: + assert tensor_recv_prev is None + + if tensor_recv_prev is not None: + self._output_obj_grads[1].append(tensor_recv_prev) + + if tensor_recv_next is not None: + self._output_obj_grads[0].append(tensor_recv_next) + + def _schedule_1f1b_B_W(self, engine, chunk_id, next_unit_chunk_id, need_recv_chunk0_output=True): + + # 1B + input_obj_grad = self._schedule_backward(engine, chunk_id) + + object_send_next = None + object_send_prev = None + recv_next_shape = None + recv_prev_shape = [] + chunk0_B_need_recv_prev_chunk0_output = need_recv_chunk0_output + + if chunk_id == 1: + if gpc.is_last_rank(ParallelMode.PIPELINE): + # For last rank, chunk1 input_grad does not need to be sent but is directly used for chunk0. + self._output_obj_grads[0].append(input_obj_grad) + else: + object_send_next = input_obj_grad + + if next_unit_chunk_id == 1: + if gpc.is_last_rank(ParallelMode.PIPELINE): + assert False, "The last pp rank can never have two consecutive unit1 of the same chunk." + recv_next_shape = self._input_obj_shapes[next_unit_chunk_id] + else: + assert next_unit_chunk_id != 0, "There will never be two consecutive chunk0 unit1." + + if not gpc.is_first_rank(ParallelMode.PIPELINE): + object_send_prev = input_obj_grad + # pre receive chunk1 grad + recv_prev_shape.append(self._output_obj_shapes[1]) + # pre receive chunk0 input + if chunk0_B_need_recv_prev_chunk0_output: + recv_prev_shape.append(self._input_obj_shapes[0]) + + if not gpc.is_last_rank(ParallelMode.PIPELINE): + recv_next_shape = self._input_obj_shapes[next_unit_chunk_id] + + if len(recv_prev_shape) == 0: + recv_prev_shape = None + + # chunk1 send input_grad next, chunk0 send input_grad prev + # if chunk_id == 1 and next_unit_chunk_id == 1, recv chunk1 input next + # if chunk_id == 0 and next_unit_chunk_id == 1, pre-recv chunk1 grad recv; + # pre-recv chunk0 input prev and recv chunk1 input next + async_communicator = comm.AsynCommunicator( + object_send_prev=object_send_prev, + object_send_next=object_send_next, + recv_prev_shape=recv_prev_shape, + recv_next_shape=recv_next_shape, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) + async_communicator.start() + + # 1W + WeightGradStore.pop() + self._call_hooks("after_backward", input_obj_grad) + + tensor_recv_prev, tensor_recv_next = async_communicator.wait_and_receive() + + # for the special case, input_obj has already been received and appended at the end of warmup. + if next_unit_chunk_id == 0 and self._special_chunk0_forward: + self._special_chunk0_forward = False + else: + if chunk_id == 0: + # For chunk0, it's necessary to pre-fetch the output_grad of the next chunk1 + # to prevent the sender from being blocked due to the absence of a receiving op. + # Except for the stage1 last chunk0 or stage2, the chunk0 BW also needs to pre-fetch + # the input of the next chunk0 unit to prevent the sender from being blocked. + + if gpc.is_first_rank(ParallelMode.PIPELINE): + # first_rank only receive chunk1 input from next rank + self._input_objs[1].append(tensor_recv_next) + elif gpc.is_last_rank(ParallelMode.PIPELINE): + # For last rank, chunk1 input does not need to be received + self._output_obj_grads[1].append(tensor_recv_prev[0]) + if chunk0_B_need_recv_prev_chunk0_output: + self._input_objs[0].append(tensor_recv_prev[1]) + else: + self._output_obj_grads[1].append(tensor_recv_prev[0]) + if chunk0_B_need_recv_prev_chunk0_output: + self._input_objs[0].append(tensor_recv_prev[1]) + self._input_objs[1].append(tensor_recv_next) + else: + if next_unit_chunk_id == 1: + self._input_objs[1].append(tensor_recv_next) + + def _1f1b_unit_1(self, engine, chunk_id, next_unit_chunk_id, need_recv_chunk0_output): + """ + unit1 consists of: 1F + 1B + 1W, all are chunk0 or chunk1 + """ + # 1F + self._schedule_1f1b_F(engine, chunk_id) + + # 1B + 1W + self._schedule_1f1b_B_W(engine, chunk_id, next_unit_chunk_id, need_recv_chunk0_output) + + def _1f1b_unit_2(self, engine, chunk_id): + """ + unit2 consists of: chunk1 (1F + 1B + 1W) + chunk0 (1B + 1W) + """ + assert not gpc.is_last_rank(ParallelMode.PIPELINE) + + # 1F (chunk1) + self._schedule_1f1b_F(engine, chunk_id) + + # 1B + 1W (chunk1) + input_obj_grad = self._schedule_backward(engine, chunk_id) + + # chunk1 send input_grad next, chunk0 recv output_grad next + async_communicator = comm.AsynCommunicator( + object_send_next=input_obj_grad, + recv_next_shape=self._output_obj_shapes[1 - chunk_id], + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) + async_communicator.start() + + WeightGradStore.pop() + self._call_hooks("after_backward", input_obj_grad) + + _, output_obj_grad = async_communicator.wait_and_receive() + self._output_obj_grads[1 - chunk_id].append(output_obj_grad) + + # 1B + 1W (chunk0) + self._schedule_1f1b_B_W(engine, 1 - chunk_id, chunk_id, need_recv_chunk0_output=False) + + def _schedule_warmup_F(self, engine, chunk_id, input_obj=None, forward_only=False): + output_obj = self._forward_step(engine, chunk_id, input_obj) + + if forward_only: + # when forward-only, no need to save tensors for a backward pass + self._input_objs[chunk_id].pop() + self._output_objs[chunk_id].pop() + self._moe_losses[chunk_id].pop() + + if not gpc.is_pipeline_last_stage(): + if isinstance(output_obj, torch.Tensor): + self._output_obj_shapes[chunk_id] = output_obj.shape + else: + self._output_obj_shapes[chunk_id] = [out_tensor.shape for out_tensor in output_obj] + + assert self._output_obj_shapes[chunk_id] == self._input_obj_shapes[chunk_id] + + if self._send_tensor_shape_flags[chunk_id]: + comm.send_obj_meta(output_obj) + self._send_tensor_shape_flags[chunk_id] = False # send only once for each chunk. + + if not gpc.is_pipeline_first_stage() and self._input_obj_shapes[chunk_id] is None: + self._input_obj_shapes[chunk_id] = comm.recv_obj_meta() + + return output_obj + + def _run_warmup_loop( + self, + engine: Engine, + num_warmup_microsteps: int, + forward_only: bool = False, + ) -> None: + """ + Run the warm-up loop and prepare data for the steady stage. + + Args: + engine (Engine): The engine to run the warm-up loop. + num_warmup_microsteps (int): The number of warm-up microsteps. + forward_only (bool, optional): Whether to only perform forward pass. Default is False. + """ + + # For each rank, the warmup stage will be divided into two sub-phases for scheduling. + num_warmup_microsteps_phase_1 = min(self.num_microbatches, (self._pp_size - self._pp_rank) * 2 - 1) + num_warmup_microsteps_phase_2 = num_warmup_microsteps - num_warmup_microsteps_phase_1 + + if gpc.is_first_rank(ParallelMode.PIPELINE): + assert num_warmup_microsteps_phase_2 == 0 + if gpc.is_last_rank(ParallelMode.PIPELINE): + assert num_warmup_microsteps_phase_1 == 1 + + # get first forward input + chunk_id = 0 + if not gpc.is_pipeline_first_stage(): + if self._input_obj_shapes[chunk_id] is None: + self._input_obj_shapes[chunk_id] = comm.recv_obj_meta() + self._input_objs[chunk_id].append( + comm.recv_forward( + self._input_obj_shapes[chunk_id], + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) + ) + else: + self._input_objs[chunk_id].append(None) + + # Phase1 will only do chunk0 forward + for micro_step in range(num_warmup_microsteps_phase_1): + # forward + output_obj = self._schedule_warmup_F(engine, chunk_id, forward_only=forward_only) + + object_send_next = None + recv_prev_shape = None + recv_next_shape = None + + # For stage1, the last chunk0 unit needs to do recv op to prevent the sender from being blocked. + if not gpc.is_first_rank(ParallelMode.PIPELINE): + recv_prev_shape = self._input_obj_shapes[0] + + # For last rank, chunk0 output does not need to be sent but is directly used for chunk1. + if not gpc.is_last_rank(ParallelMode.PIPELINE): + object_send_next = output_obj + else: + input_obj = output_obj.clone().detach() + input_obj.requires_grad_() + self._input_objs[1].append(input_obj) + + if micro_step == num_warmup_microsteps_phase_1 - 1: + if not gpc.is_last_rank(ParallelMode.PIPELINE): + recv_next_shape = self._input_obj_shapes[1] + + tensor_recv_prev, tensor_recv_next = comm.fused_send_recv_tensor( + object_send_next=object_send_next, + recv_prev_shape=recv_prev_shape, + recv_next_shape=recv_next_shape, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) + + self._input_objs[0].append(tensor_recv_prev) + + if micro_step == num_warmup_microsteps_phase_1 - 1: + if not gpc.is_last_rank(ParallelMode.PIPELINE): + self._input_objs[1].append(tensor_recv_next) + + # Phase2 will execute chunk1 and chunk0 forward alternately + for micro_step in range(num_warmup_microsteps_phase_2): + chunk_id = 1 - chunk_id + next_chunk_id = 1 - chunk_id + + if chunk_id == 0: + input_obj = self._input_objs[chunk_id][-2] + else: + input_obj = self._input_objs[chunk_id][-1] + + output_obj = self._schedule_warmup_F(engine, chunk_id, input_obj=input_obj, forward_only=forward_only) + + object_send_next = None + object_send_prev = None + recv_next_shape = None + recv_prev_shape = None + + if chunk_id == 1: + assert micro_step < num_warmup_microsteps_phase_2 - 1 + object_send_prev = output_obj + recv_prev_shape = self._input_obj_shapes[next_chunk_id] + else: + if not gpc.is_last_rank(ParallelMode.PIPELINE): + object_send_next = output_obj + recv_next_shape = self._input_obj_shapes[next_chunk_id] + + # chunk1 send output prev, chunk0 recv input prev + # chunk0 send output next, chunk1 recv input next + tensor_recv_prev, tensor_recv_next = comm.fused_send_recv_tensor( + object_send_next=object_send_next, + object_send_prev=object_send_prev, + recv_next_shape=recv_next_shape, + recv_prev_shape=recv_prev_shape, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) + + # For last rank, chunk0 output does not need to be sent but is directly used for chunk1 + if chunk_id == 0 and gpc.is_last_rank(ParallelMode.PIPELINE): + input_obj = output_obj.clone().detach() + input_obj.requires_grad_() + else: + input_obj = tensor_recv_prev if tensor_recv_prev is not None else tensor_recv_next + + self._input_objs[next_chunk_id].append(input_obj) + + def _run_steady_loop( + self, + engine: Engine, + num_1f1b_units: int, + ) -> None: + """ + 1F1B unit schedule: + stage1: (pp_size + 1 + pp_rank + 2 * (micro_num - 2 * pp_size)) * unit1 + stage2: (pp_size - 1 - pp_rank) * unit2 + stage3: 1 * special chunk1 unit1 + + Args: + engine (Engine): The engine to use for computation. + num_1f1b_units (int): The number of 1F1B units. + """ + # unit schedule + num_units_stage1 = 2 * self.num_microbatches - 3 * self._pp_size + 1 + self._pp_rank + num_units_stage2 = self._pp_size - 1 - self._pp_rank + assert num_units_stage1 + num_units_stage2 + 1 == num_1f1b_units + + # chunk schedule: stage1 + stage2 + stage1 + # stage1: chunk1 + # stage2: chunk0 and chunk1 alternately + stage1_length = self._pp_size - self._pp_rank + stage2_length = 2 * self._pp_rank + 1 + 2 * (self.num_microbatches - 2 * self._pp_size) + stage2_list = list(range(stage1_length, stage1_length + stage2_length)) + chunk0_units = [stage2_list[i] for i in range(len(stage2_list)) if i % 2 == 0] + + # unit stage1 + for unit_step in range(num_units_stage1): + if unit_step in chunk0_units: + chunk_id = 0 + else: + chunk_id = 1 + + if unit_step + 1 in chunk0_units: + next_unit_chunk_id = 0 + else: + next_unit_chunk_id = 1 + + # import pdb; pdb.set_trace() + if unit_step == num_units_stage1 - 1: + chunk0_B_need_recv_prev_chunk0_output = False + else: + chunk0_B_need_recv_prev_chunk0_output = True + + self._1f1b_unit_1( + engine, chunk_id, next_unit_chunk_id, need_recv_chunk0_output=chunk0_B_need_recv_prev_chunk0_output + ) + + # unit stage2 + for unit_step in range(num_units_stage2): + assert unit_step + num_units_stage1 not in chunk0_units + self._1f1b_unit_2(engine, 1) + + # unit stage3 + assert num_1f1b_units - 1 not in chunk0_units + self._schedule_1f1b_F(engine, 1) + origin_skip = engine.optimizer.skip_grad_reduce + input_obj_grad = self._schedule_backward(engine, 1) + if gpc.is_last_rank(ParallelMode.PIPELINE): + # For last rank, chunk1 input_grad does not need to be sent but is directly used for chunk0. + self._output_obj_grads[0].append(input_obj_grad) + tensor_to_send = None + recv_shape = None + else: + tensor_to_send = input_obj_grad + recv_shape = self._output_obj_shapes[0] + + # chunk1 send input_grad next, chunk0 recv output_grad next + async_communicator = comm.AsynCommunicator( + object_send_next=tensor_to_send, + recv_next_shape=recv_shape, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) + async_communicator.start() + + WeightGradStore.pop() + self._call_hooks("after_backward", input_obj_grad) + engine.optimizer.skip_grad_reduce = origin_skip + + _, output_obj_grad = async_communicator.wait_and_receive() + if not gpc.is_last_rank(ParallelMode.PIPELINE): + self._output_obj_grads[0].append(output_obj_grad) + + def _run_cooldown_loop(self, engine): + """ + Cooldown unit schedule: + Unit: 1B + 1W + Schedule unit chunk0 and unit chunk1 alternatively + Each pp rank has pp_size chunk0, but only pp_rank chunk1 + """ + chunk0_length = self._pp_size + chunk1_length = self._pp_rank + num_cooldown_units = chunk0_length + chunk1_length + total_list = list(range(chunk1_length * 2)) + chunk1_units = [total_list[i] for i in range(chunk1_length * 2) if i % 2 != 0] + + cool_down = [0, 0] + + for unit_step in range(num_cooldown_units): + if unit_step in chunk1_units: + chunk_id = 1 + else: + chunk_id = 0 + + cool_down[chunk_id] += 1 + + if unit_step + 1 in chunk1_units: + next_unit_chunk_id = 1 + else: + next_unit_chunk_id = 0 + + origin_skip = engine.optimizer.skip_grad_reduce + input_obj_grad = self._schedule_backward(engine, chunk_id) + + object_send_next = None + object_send_prev = None + recv_next_shape = None + recv_prev_shape = None + + if chunk_id == 1: + assert not gpc.is_first_rank(ParallelMode.PIPELINE) + if gpc.is_last_rank(ParallelMode.PIPELINE): + # For last rank, chunk1 input_grad does not need to be sent but is directly used for chunk0. + self._output_obj_grads[0].append(input_obj_grad) + else: + object_send_next = input_obj_grad + # next unit should be chunk0 + recv_next_shape = self._output_obj_shapes[0] + else: + if not gpc.is_first_rank(ParallelMode.PIPELINE): + object_send_prev = input_obj_grad + + if unit_step != num_cooldown_units - 1: + if next_unit_chunk_id == 1: + assert not gpc.is_first_rank(ParallelMode.PIPELINE) + recv_prev_shape = self._output_obj_shapes[next_unit_chunk_id] + else: + assert not gpc.is_last_rank(ParallelMode.PIPELINE) + recv_next_shape = self._output_obj_shapes[next_unit_chunk_id] + + # chunk1 send input_grad next, chunk0 send input_grad prev + # if next_unit_chunk_id == 1, recv output_grad prev + # if next_unit_chunk_id == 0, recv output_grad next + async_communicator = comm.AsynCommunicator( + object_send_prev=object_send_prev, + object_send_next=object_send_next, + recv_prev_shape=recv_prev_shape, + recv_next_shape=recv_next_shape, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) + async_communicator.start() + + # 1W + WeightGradStore.pop() + self._call_hooks("after_backward", input_obj_grad) + engine.optimizer.skip_grad_reduce = origin_skip + + tensor_recv_prev, tensor_recv_next = async_communicator.wait_and_receive() + output_obj_grad = tensor_recv_prev if tensor_recv_prev is not None else tensor_recv_next + + if output_obj_grad is not None: + self._output_obj_grads[next_unit_chunk_id].append(output_obj_grad) + + def _forward_only_step(self, engine: Engine): + num_warmup_steps = self.num_microbatches * self._num_chunks + + self._run_warmup_loop( + engine, + num_warmup_steps, + forward_only=True, + ) + + def _forward_backward_step(self, engine: Engine): + assert self.num_microbatches > self._pp_size + + # Compute number of warmup microbatches. + num_warmup_steps = self._pp_size * 2 - 1 + + # Compute number of 1F1B unit. + num_1f1b_units = 2 * self.num_microbatches - num_warmup_steps + + # 1. Warmup + self._run_warmup_loop( + engine, + num_warmup_steps, + ) + + # 2. 1F1B + self._run_steady_loop( + engine, + num_1f1b_units, + ) + + # 3. cooldown + self._run_cooldown_loop(engine) diff --git a/internlm/core/trainer.py b/internlm/core/trainer.py index b76601408..3b01d3afd 100644 --- a/internlm/core/trainer.py +++ b/internlm/core/trainer.py @@ -9,12 +9,7 @@ from typing import Iterable, Optional from internlm.core.engine import Engine -from internlm.core.scheduler import ( - BaseScheduler, - InterleavedPipelineScheduler, - NonPipelineScheduler, - PipelineScheduler, -) +from internlm.core.scheduler import BaseScheduler, NonPipelineScheduler class TrainState: @@ -181,7 +176,7 @@ def schedule(self): @property def uses_pipeline(self): """Returns whether the pipeline parallel is used or not.""" - return isinstance(self._schedule, (PipelineScheduler, InterleavedPipelineScheduler)) + return not isinstance(self._schedule, (NonPipelineScheduler, BaseScheduler)) def train(self): """Sets the model to training mode.""" diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index b3e48f9d2..d0ef284d4 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -27,6 +27,7 @@ inject_model, load_new_batch, record_current_batch_training_metrics, + set_param_unique_tracking_name, ) from internlm.utils.common import ( BatchSkipper, @@ -99,6 +100,9 @@ def __init__( # load config_lines config_lines = self._read_config(kwargs["config"]) + # set tracking name for parameters + set_param_unique_tracking_name(model) + # inject model for amp and parallel training model = inject_model(model) diff --git a/internlm/eval/evaluation.py b/internlm/eval/evaluation.py index 50d17c01a..862057a3d 100644 --- a/internlm/eval/evaluation.py +++ b/internlm/eval/evaluation.py @@ -8,7 +8,7 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.parallel.shard import split_data_for_sequence_parallel -from internlm.core.scheduler.pipeline_scheduler import get_tensor_shape +from internlm.core.scheduler.pipeline_scheduler_1f1b import get_tensor_shape from internlm.model.metrics import AccPerplex, SchedulerMetricHook from internlm.utils.common import get_current_device from internlm.utils.parallel import is_using_isp diff --git a/internlm/initialize/initialize_trainer.py b/internlm/initialize/initialize_trainer.py index 45bd3f4d3..48487c5fb 100644 --- a/internlm/initialize/initialize_trainer.py +++ b/internlm/initialize/initialize_trainer.py @@ -21,8 +21,9 @@ NonPipelineScheduler, PipelineScheduler, ZeroBubblePipelineScheduler, + ZeroBubblePipelineVShapeScheduler, ) -from internlm.core.scheduler.pipeline_scheduler import get_tensor_shape +from internlm.core.scheduler.pipeline_scheduler_1f1b import get_tensor_shape from internlm.core.trainer import Trainer from internlm.data.utils import packed_data_normalizer, unpack_data from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer @@ -94,11 +95,16 @@ def _data_preparation_func(_data, _label): return _data, _label + pp_mode = getattr(gpc.config.parallel["pipeline"], "mode", "1F1B").upper() + if gpc.is_using_parallel_mode(ParallelMode.PIPELINE): gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num tensor_shape = get_tensor_shape() use_interleaved = ( - hasattr(gpc.config, "model") and hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1 + hasattr(gpc.config, "model") + and hasattr(gpc.config.model, "num_chunks") + and gpc.config.model.num_chunks > 1 + and pp_mode == "1F1B" ) scatter_gather = gpc.is_initialized(ParallelMode.TENSOR) if use_interleaved: @@ -116,7 +122,7 @@ def _data_preparation_func(_data, _label): scheduler_hooks=scheduler_hooks, communication_overlap=communication_overlap, ) - elif gpc.config.parallel["pipeline"].get("zero_bubble", False): + elif pp_mode == "ZBH1": scheduler = ZeroBubblePipelineScheduler( data_process_func=_data_preparation_func, num_microbatches=gpc.config.NUM_MICRO_BATCHES, @@ -124,6 +130,18 @@ def _data_preparation_func(_data, _label): tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather, scheduler_hooks=scheduler_hooks, + optimizer=optimizer, + ) + elif pp_mode == "ZBV": + scheduler = ZeroBubblePipelineVShapeScheduler( + num_microbatches=gpc.config.NUM_MICRO_BATCHES, + num_chunks=gpc.config.model.num_chunks, + dtype=gpc.config.model["dtype"], + data_process_func=_data_preparation_func, + tensor_shape=tensor_shape, + scatter_gather_tensors=scatter_gather, + scheduler_hooks=scheduler_hooks, + optimizer=optimizer, ) else: scheduler = PipelineScheduler( diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index c85568ec4..fc63b8a23 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -85,7 +85,10 @@ def args_sanity_check(): gpc.config.parallel._add_item("zero1", dict(size=zero1_size, fsdp=False)) if "pipeline" not in gpc.config.parallel: - gpc.config.parallel._add_item("pipeline", dict(size=1, interleaved_overlap=False, zero_bubble=False)) + gpc.config.parallel._add_item("pipeline", dict(size=1, interleaved_overlap=False, mode="1F1B")) + + if isinstance(gpc.config.parallel.pipeline, dict) and "mode" not in gpc.config.parallel.pipeline: + gpc.config.parallel.pipeline._add_item("mode", "1F1B") if "tensor" not in gpc.config.parallel: gpc.config.parallel._add_item("tensor", dict(size=1, mode=TensorParallelMode.mtp.name)) @@ -104,6 +107,16 @@ def args_sanity_check(): else: pp = gpc.config.parallel.pipeline.size + if isinstance(gpc.config.parallel.pipeline, dict): + gpc.config.parallel.pipeline["mode"] = gpc.config.parallel.pipeline["mode"].upper() + assert gpc.config.parallel.pipeline["mode"] in [ + "1F1B", + "ZBH1", + "ZBV", + ], f"unsupported pp mode {gpc.config.parallel.pipeline['mode']}" + if gpc.config.parallel.pipeline["mode"] == "ZBV": + gpc.v_shape = True + # check fsdp config if "fsdp" not in gpc.config.parallel.zero1: gpc.config.parallel.zero1._add_item("fsdp", False) @@ -442,6 +455,11 @@ def args_sanity_check(): gpc.config.parallel["pipeline"].get("interleaved_overlap", False) is True ), "only support interleaved pipeline scheduler with overlap" + if gpc.config.parallel["pipeline"]["mode"] == "ZBV": + gpc.config.model.num_chunks = 2 + if gpc.is_rank_for_log(): + logger.info("Using zero_bubble_v, num_chunks is set to 2.") + # monitoring default config monitor_default_config = { "alert_address": None, # compatible with old alert config @@ -480,7 +498,7 @@ def args_sanity_check(): elif optim_ckpt.use_split_tensor_optim and "all_gather_size" not in optim_ckpt: optim_ckpt._add_item("all_gather_size", 512 * 1024 * 1024) - if gpc.config.parallel["pipeline"].get("zero_bubble", False): + if gpc.config.parallel["pipeline"]["mode"] == "ZBH1": assert ( not optim_ckpt.overlap_sync_grad ), "When using zero_bubble pipeline parallelism, overlap_sync_grad must be false" diff --git a/internlm/model/modules/embedding.py b/internlm/model/modules/embedding.py index 365ee46a0..93fcd6b23 100644 --- a/internlm/model/modules/embedding.py +++ b/internlm/model/modules/embedding.py @@ -67,6 +67,8 @@ def __init__( torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), dtype=dtype) ) + setattr(self.weight, "is_embedding_param", True) + def forward(self, input_: Tensor) -> Tensor: if self.vocab_parallel and not is_using_isp(): # Build the mask. diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 2426ab8a6..64d79796a 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -98,6 +98,7 @@ def backward(ctx, grad_output, *args): (grad_input,) = args grad_input = grad_input.contiguous() + # print(f"ctx rank: {gpc.get_global_rank()}, {len(ctx.saved_tensors)}", flush=True) x, weight, bias = ctx.saved_tensors # parallel strategy-specific communication callback 1-2. @@ -133,12 +134,16 @@ def backward(ctx, grad_output, *args): handle_x.wait() x = x.reshape(batch_dim, x.shape[-1]) - if ( - gpc.is_using_parallel_mode(ParallelMode.PIPELINE) - and gpc.config.parallel["pipeline"].get("zero_bubble", False) - and not gpc.is_first_rank(ParallelMode.PIPELINE) + if gpc.is_using_parallel_mode(ParallelMode.PIPELINE) and ( + ( + gpc.config.parallel["pipeline"].get("mode", "1F1B") == "ZBH1" + and not gpc.is_first_rank(ParallelMode.PIPELINE) + ) + or gpc.config.parallel["pipeline"].get("mode", "1F1B") == "ZBV" ): - from internlm.core.scheduler.pipeline_scheduler import WeightGradStore + from internlm.core.scheduler.pipeline_scheduler_zb import ( + WeightGradStore, + ) WeightGradStore.put(weight, bias, x, grad_output, ctx.needs_input_grad[2], linear_backward_op) grad_weight, grad_bias = None, None @@ -236,10 +241,12 @@ def backward(ctx, grad_output, *args): total_weight = communicator.weight_hook(weight, module=module) - is_using_ZB = ( - gpc.is_using_parallel_mode(ParallelMode.PIPELINE) - and gpc.config.parallel["pipeline"].get("zero_bubble", False) - and not gpc.is_first_rank(ParallelMode.PIPELINE) + is_using_ZB = gpc.is_using_parallel_mode(ParallelMode.PIPELINE) and ( + ( + gpc.config.parallel["pipeline"].get("mode", "1F1B") == "ZBH1" + and not gpc.is_first_rank(ParallelMode.PIPELINE) + ) + or gpc.config.parallel["pipeline"].get("mode", "1F1B") == "ZBV" ) # compute weight grad @@ -247,10 +254,19 @@ def backward(ctx, grad_output, *args): assert ctx.compute_weight_gradient x = x.reshape(batch_dim, x.shape[-1]) if is_using_ZB: - from internlm.core.scheduler.pipeline_scheduler import WeightGradStore + from internlm.core.scheduler.pipeline_scheduler_zb import ( + WeightGradStore, + ) WeightGradStore.put( - weight, bias, x, grad_output, ctx.needs_input_grad[2], linear_backward_op, communicator.grad_hook + weight, + bias, + x, + grad_output, + ctx.needs_input_grad[2], + linear_backward_op, + communicator.grad_hook, + module, ) grad_weight, grad_bias = None, None else: diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 36f1c0531..49f3fbcf0 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -298,6 +298,9 @@ def _is_moe_group(self, param_group): # TODO check expert dp is correct when enable moe and overlap both def _attach_reduction_hook(self): + from internlm.core.scheduler.pipeline_scheduler_zb import WeightGradStore + + is_using_ZB = gpc.config.parallel["pipeline"].get("mode", "1F1B") != "1F1B" # we iterate over the fp16 params # on each param, we register a hook to its AccumulateGrad object for group_id in range(self.num_param_groups): @@ -307,6 +310,8 @@ def _attach_reduction_hook(self): if not param.requires_grad: continue + hooks = [] + reduce_rank = None def _define_and_attach(param, reduce_rank=None): @@ -347,6 +352,7 @@ def reduction_layernorm_func(): # NOT IMPORTANT BUT GOOD TO KNOW: # args here is not grad, but allow_unreacable and accumulate_grad def reduce_grad_hook(*args): # pylint: disable=W0613 + # assert self.skip_grad_reduce if self.skip_grad_reduce is False: reduction_func() @@ -389,18 +395,27 @@ def extra_layernorm_reduce_grad_hook(*args): # pylint: disable=W0613 and gpc.config.parallel.weight.overlap ) ): - if hasattr(param, "evo_tensor"): - param.register_post_accumulate_grad_hook(accum_grad_hook) + if is_using_ZB and not hasattr(param, "is_embedding_param"): + hooks.append(accum_grad_hook) # pylint: disable=W0640 else: - accum_grad_obj.register_hook(accum_grad_hook) + if hasattr(param, "evo_tensor"): + param.register_post_accumulate_grad_hook(accum_grad_hook) + else: + accum_grad_obj.register_hook(accum_grad_hook) if self._overlap_sync_grad: - if hasattr(param, "evo_tensor"): - param.register_post_accumulate_grad_hook(reduce_grad_hook) + if is_using_ZB and not hasattr(param, "is_embedding_param"): + hooks.append(reduce_grad_hook) # pylint: disable=W0640 else: - accum_grad_obj.register_hook(reduce_grad_hook) + if hasattr(param, "evo_tensor"): + param.register_post_accumulate_grad_hook(reduce_grad_hook) + else: + accum_grad_obj.register_hook(reduce_grad_hook) _define_and_attach(param, reduce_rank) + if len(hooks) > 0: + assert is_using_ZB + WeightGradStore.register_hook(param, hooks) def accumulate_left_grads_after_backward(self): if self._isp_communicator is None: @@ -409,6 +424,10 @@ def accumulate_left_grads_after_backward(self): for group_id in range(self.num_param_groups): self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id]) + def reduce_left_grads_after_backward(self): + for group_id in range(self.num_param_groups): + self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None) + def belongs_to_current_rank(self, param) -> bool: """ Check whether a parameter is supposed to be updated by the process of the current rank @@ -478,6 +497,7 @@ def _store_and_try_reduce_grads_by_bucket(self, param, reduce_rank=None): raise RuntimeError(msg) # the param must have grad for reduction + assert param.requires_grad assert param.grad is not None, f"Parameter of size ({param.size()}) has None grad, cannot be reduced" current_bucket.add_num_elements_in_bucket(param_size, reduce_rank) @@ -718,8 +738,7 @@ def step(self, closure=None): self._store_and_try_reduce_grads_by_bucket(param) # we need to reduce the gradients left in the communication bucket - for group_id in range(self.num_param_groups): - self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None) + self.reduce_left_grads_after_backward() if internlm_accelerator.get_accelerator_backend() in [ AcceleratorType.NPU, diff --git a/internlm/solver/optimizer/hybrid_zero_optim_v2.py b/internlm/solver/optimizer/hybrid_zero_optim_v2.py index 9fafaab99..36e5f073f 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim_v2.py +++ b/internlm/solver/optimizer/hybrid_zero_optim_v2.py @@ -851,6 +851,9 @@ def reload_zero_fp32_buff(self): ################ def _attach_reduction_hook(self): + from internlm.core.scheduler.pipeline_scheduler_zb import WeightGradStore + + is_using_ZB = gpc.config.parallel["pipeline"].get("mode", "1F1B") != "1F1B" # we iterate over the fp16 params # on each param, we register a hook to its AccumulateGrad object for group_id in range(self.num_param_groups): @@ -860,6 +863,8 @@ def _attach_reduction_hook(self): if not param.requires_grad: continue + hooks = [] + reduce_rank = None def _define_and_attach(param, reduce_rank=None): @@ -910,11 +915,19 @@ def extra_layernorm_reduce_grad_hook(*args): # pylint: disable=W0613 and self._isp_communicator.overlap and gpc.config.parallel.weight.size > 1 ): - param.register_post_accumulate_grad_hook(accum_grad_hook) + if is_using_ZB and not hasattr(param, "is_embedding_param"): + hooks.append(accum_grad_hook) # pylint: disable=W0640 + else: + param.register_post_accumulate_grad_hook(accum_grad_hook) if self._overlap_sync_grad: - param.register_post_accumulate_grad_hook( - partial(grad_handler, group_id) - ) # pylint: disable=W0640 + if is_using_ZB and not hasattr(param, "is_embedding_param"): + hooks.append(partial(grad_handler, group_id)) # pylint: disable=W0640 + else: + param.register_post_accumulate_grad_hook( + partial(grad_handler, group_id) + ) # pylint: disable=W0640 _define_and_attach(param, reduce_rank) + if len(hooks) > 0: + WeightGradStore.register_hook(param, hooks) diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 5547a9fb2..1d832d6bf 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -114,6 +114,42 @@ internlm_accelerator = get_accelerator() +def set_param_unique_tracking_name(model): + for chunk_id, chunk in enumerate(unwrap_naive_amp(model)): + # Important: only works for llama-class models + childrens = chunk.named_children() + for _, children in childrens: + if isinstance(children, nn.ModuleList): + for idx, block in enumerate(children): + for name, child in block.named_modules(): + if isinstance(child, (ParallelLinearWithCommExt)): + full_name = f"{chunk_id}.{idx}.{name}" + setattr( + child.weight, + "tracking_name", + f"{full_name}.weight", + ) + if child.bias is not None: + setattr( + child.bias, + "tracking_name", + f"{full_name}.bias", + ) + else: + if isinstance(children, Embedding1D): + setattr( + children.weight, + "tracking_name", + f"{chunk_id}_embedding.weight", + ) + else: + setattr( + children.weight, + "tracking_name", + f"{chunk_id}_head.weight", + ) + + def set_fp32_attr_for_model(model: Union[nn.Module, nn.ModuleList]): if not isinstance(model, nn.ModuleList): model = [model] diff --git a/tests/test_core/test_pipeline.py b/tests/test_core/test_pipeline.py index 41fc62faa..180fe4b71 100644 --- a/tests/test_core/test_pipeline.py +++ b/tests/test_core/test_pipeline.py @@ -153,6 +153,10 @@ def exam_pipeline_parallel(args): first_output = output_list[0] for i in range(1, 10): assert torch.equal(first_output, output_list[i]) + print( + f"idx {i} pass: micro_num={micro_num}, num_chunks={num_chunks}, overlap={interleaved_overlap}", + flush=True, + ) # check output torch_output = torch_model(input_ids=torch_xs) # pylint: disable=E1102 @@ -167,7 +171,7 @@ def exam_pipeline_parallel(args): loose_close(torch_loss, loss[0], dtype=dtype) -@pytest.mark.parametrize("micro_num", [4, 8, 16]) +@pytest.mark.parametrize("micro_num", [8, 16]) @pytest.mark.parametrize("num_chunks", [1, 2, 4]) @pytest.mark.parametrize("interleaved_overlap", [True, False]) def test_pipeline_parallel(micro_num, num_chunks, interleaved_overlap): diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index 3f67b0871..4094c5822 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -61,7 +61,7 @@ def train( load_ckpt: bool = False, model_type: str = "INTERNLM", optimizer_ver: str = "v1", - zero_bubble: bool = False, + pp_mode: str = "1F1B", ): # initialize distributed environment config = Config.from_file(CONFIG_FILE_PATH) @@ -97,14 +97,13 @@ def train( # update parallel config config.parallel.tensor = dict(size=tp_size, mode=tp_mode) - if zero_bubble: + if pp_mode == "ZBH1": config.hybrid_zero_optimizer.overlap_sync_grad = False - config.parallel.pipeline = dict(size=pp_size, zero_bubble=True) - else: - config.parallel.pipeline = dict(size=pp_size) + + config.parallel.pipeline = dict(size=pp_size, mode=pp_mode) config.parallel.weight = dict(size=wp_size, overlap=True) if interleaved is True: - config.parallel.pipeline = dict(size=pp_size, interleaved_overlap=True) + config.parallel.pipeline = dict(size=pp_size, interleaved_overlap=True, mode=pp_mode) config.model.num_chunks = num_chunks if "use_packed_dataset" not in config.data: @@ -379,7 +378,7 @@ def test_training_loss_with_dp4_pp2(): @pytest.mark.training_8GPU_4DP2PP_ZB def test_training_loss_with_dp4_pp2_zero_bubble(): # model training - train(dp_size=4, pp_size=2, zero_bubble=True) + train(dp_size=4, pp_size=2, pp_mode="ZBH1") # print loss value print(f"cur_loss_list: {cur_loss_list}", flush=True)