diff --git a/.github/workflows/lint_check.yaml b/.github/workflows/lint_check.yaml index fe86bd05a..19f170462 100644 --- a/.github/workflows/lint_check.yaml +++ b/.github/workflows/lint_check.yaml @@ -19,24 +19,20 @@ jobs: pip install flake8==v3.8.4 FLAKE_DISABLE_LIST="F403,F405,W504,W503,E203" flake8 --max-line-length=120 --ignore=$FLAKE_DISABLE_LIST --exclude=./internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py ./internlm/* - flake8 --max-line-length=120 --ignore=$FLAKE_DISABLE_LIST ./train.py - name: lint-isort run: | pip install isort==5.12.0 isort --check --profile=black ./internlm/* - isort --check --profile=black ./train.py - name: lint-black run: | pip install black==22.8.0 BLACK_EXCLUDE_SETTINGS='\.venv/|\.local/|\.cache/|\.git/' black --line-length=120 --check --exclude $BLACK_EXCLUDE_SETTINGS ./internlm/* - black --line-length=120 --check --exclude $BLACK_EXCLUDE_SETTINGS ./train.py - name: lint-pylint run: | pip install pylint==v2.17.2 PYLINT_DISABLE_LIST="C0114,C0415,W0212,W0235,W0238,W0621,C0103,R1735,C2801,E0402,C0412,W0719,R1728,W1514,W0718,W0105,W0707,C0209,W0703,W1203" pylint --rcfile .pylintrc --disable=$PYLINT_DISABLE_LIST --ignore=./internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py ./internlm/* - pylint --rcfile .pylintrc --disable=$PYLINT_DISABLE_LIST ./train.py diff --git a/ci_scripts/train/generate_config.py b/ci_scripts/train/generate_config.py index 096334d06..d1a34940a 100644 --- a/ci_scripts/train/generate_config.py +++ b/ci_scripts/train/generate_config.py @@ -5,7 +5,7 @@ import os from ci_scripts.common import com_func -from internlm.core.context import Config +from internlm.core.context.parallel_context import Config def generate_new_config(config_py_file, test_config_json, case_name): diff --git a/configs/1.8B_MoE16_sft.py b/configs/1.8B_MoE16_sft.py index f85302778..eca10b045 100644 --- a/configs/1.8B_MoE16_sft.py +++ b/configs/1.8B_MoE16_sft.py @@ -170,7 +170,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], @@ -197,7 +196,7 @@ 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. """ parallel = dict( - zero1=dict(size=-1, fsdp=False), + zero1=dict(size=-1), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), weight=dict(size=1, overlap=True), diff --git a/configs/57B_qwen2_MoE.py b/configs/57B_qwen2_MoE.py index abfb0a5b8..27f63cc1d 100644 --- a/configs/57B_qwen2_MoE.py +++ b/configs/57B_qwen2_MoE.py @@ -175,7 +175,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], @@ -202,7 +201,7 @@ 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. """ parallel = dict( - zero1=dict(size=-1, fsdp=False), + zero1=dict(size=-1), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), weight=dict(size=1, overlap=True), diff --git a/configs/7B_MoE4_sft.py b/configs/7B_MoE4_sft.py index 8d8acc406..74ebbcbb6 100644 --- a/configs/7B_MoE4_sft.py +++ b/configs/7B_MoE4_sft.py @@ -182,7 +182,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], @@ -217,7 +216,7 @@ 4. forward_overlap_per: str, all gather prefetch granularity, per 'module' or per 'layer', defaults to 'layer'. """ parallel = dict( - zero1=dict(size=-1, fsdp=False), + zero1=dict(size=-1), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), weight=dict(size=1, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"), diff --git a/configs/7B_baichuan2.py b/configs/7B_baichuan2.py index eaa26a867..9957d6819 100644 --- a/configs/7B_baichuan2.py +++ b/configs/7B_baichuan2.py @@ -165,7 +165,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/configs/7B_gemma.py b/configs/7B_gemma.py index aff448232..643bcbdbf 100644 --- a/configs/7B_gemma.py +++ b/configs/7B_gemma.py @@ -172,7 +172,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/configs/7B_internlm2.py b/configs/7B_internlm2.py index 7a670171c..3c7bb9f4f 100644 --- a/configs/7B_internlm2.py +++ b/configs/7B_internlm2.py @@ -174,7 +174,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/configs/7B_isp_sft.py b/configs/7B_isp_sft.py index 95049036d..e7dd47b04 100644 --- a/configs/7B_isp_sft.py +++ b/configs/7B_isp_sft.py @@ -187,7 +187,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/configs/7B_llama2.py b/configs/7B_llama2.py index b0a173c8d..7783abaf7 100644 --- a/configs/7B_llama2.py +++ b/configs/7B_llama2.py @@ -164,7 +164,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/configs/7B_qwen2.py b/configs/7B_qwen2.py index 09b536ccc..3622e12f1 100644 --- a/configs/7B_qwen2.py +++ b/configs/7B_qwen2.py @@ -172,7 +172,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 4799b5f35..43690c5e9 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -155,7 +155,7 @@ dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" norm_type="rmsnorm", layer_norm_epsilon=1e-5, - use_flash_attn=True, + use_flash_attn=False, # Whether the odd and even columns of the query and key in the model are normally interleaved. # If it's True, the model's odd and even columns are normally ordered; if it's False, # it means that the model has prematurely concatenated all odd columns and even columns in front @@ -174,7 +174,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/configs/8x22B_mixtral.py b/configs/8x22B_mixtral.py index debd423b0..f1f1b6e60 100644 --- a/configs/8x22B_mixtral.py +++ b/configs/8x22B_mixtral.py @@ -176,7 +176,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], @@ -203,7 +202,7 @@ 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. """ parallel = dict( - zero1=dict(size=-1, fsdp=False), + zero1=dict(size=-1), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), weight=dict(size=1, overlap=True), diff --git a/configs/8x7B_mixtral.py b/configs/8x7B_mixtral.py index 322342ea6..6db43f9c6 100644 --- a/configs/8x7B_mixtral.py +++ b/configs/8x7B_mixtral.py @@ -176,7 +176,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], @@ -203,7 +202,7 @@ 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. """ parallel = dict( - zero1=dict(size=-1, fsdp=False), + zero1=dict(size=-1), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), weight=dict(size=1, overlap=True), diff --git a/configs/_base_/models/internlm2_1B.py b/configs/_base_/models/internlm2_1B.py index 5d050da92..cc3f186ad 100644 --- a/configs/_base_/models/internlm2_1B.py +++ b/configs/_base_/models/internlm2_1B.py @@ -51,7 +51,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/configs/_base_/models/internlm2_20B.py b/configs/_base_/models/internlm2_20B.py index 3b297e51f..dc461c0da 100644 --- a/configs/_base_/models/internlm2_20B.py +++ b/configs/_base_/models/internlm2_20B.py @@ -48,7 +48,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/configs/_base_/models/internlm2_7B.py b/configs/_base_/models/internlm2_7B.py index 37b99294b..cbdb03cb1 100644 --- a/configs/_base_/models/internlm2_7B.py +++ b/configs/_base_/models/internlm2_7B.py @@ -48,7 +48,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/configs/_base_/models/internlm_20B.py b/configs/_base_/models/internlm_20B.py index b7f7d8a59..26f4ff7f8 100644 --- a/configs/_base_/models/internlm_20B.py +++ b/configs/_base_/models/internlm_20B.py @@ -43,7 +43,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/configs/_base_/models/internlm_7B.py b/configs/_base_/models/internlm_7B.py index e666c02ee..8dde6e4e4 100644 --- a/configs/_base_/models/internlm_7B.py +++ b/configs/_base_/models/internlm_7B.py @@ -43,7 +43,6 @@ * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], diff --git a/doc/code-docs/source/initialize.rst b/doc/code-docs/source/initialize.rst index bcfe67d1a..721eec006 100644 --- a/doc/code-docs/source/initialize.rst +++ b/doc/code-docs/source/initialize.rst @@ -43,7 +43,7 @@ InternEvo 使用 `argparse `_ 模型初始化 ------------------------- -.. autofunction:: internlm.train.initialize_model +.. autofunction:: internlm.train.initialize_model_and_parallel_communicator InternEvo 在配置文件中使用字段 ``model_type`` 和 ``model`` 来控制模型初始化过程。示例模型初始化配置定义如下: diff --git a/doc/code-docs/source/training.rst b/doc/code-docs/source/training.rst index a0b4c2288..f43bfe4af 100644 --- a/doc/code-docs/source/training.rst +++ b/doc/code-docs/source/training.rst @@ -27,7 +27,7 @@ - 初始化模型 .. code-block:: python - model = initialize_model() + model = initialize_model_and_parallel_communicator() 详细介绍请参考: `模型初始化 `_ diff --git a/doc/en/train_performance.md b/doc/en/train_performance.md index d6b572f7b..ea998f06e 100644 --- a/doc/en/train_performance.md +++ b/doc/en/train_performance.md @@ -121,7 +121,7 @@ model = dict( ) parallel = dict( - zero1=dict(size=8, fsdp=False), + zero1=dict(size=8), tensor=1, pipeline=dict(size=1, interleaved_overlap=True), sequence_parallel=False, diff --git a/doc/train_performance.md b/doc/train_performance.md index 98364a753..891fa2f33 100644 --- a/doc/train_performance.md +++ b/doc/train_performance.md @@ -117,7 +117,7 @@ model = dict( ) parallel = dict( - zero1=dict(size=8, fsdp=False), + zero1=dict(size=8), tensor=1, pipeline=dict(size=1, interleaved_overlap=True), sequence_parallel=False, diff --git a/doc/usage.md b/doc/usage.md index 67ae1edf5..7c28d6d3e 100644 --- a/doc/usage.md +++ b/doc/usage.md @@ -268,7 +268,6 @@ zero1 parallel (dict): * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. - 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. tensor parallel (dict): 1. size: int, the size of tensor parallel. 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], @@ -432,7 +431,6 @@ parallel = dict( - 当`zero1 <= 0`,则 zero1 进程组的大小等于数据并行进程组的大小,因此优化器状态参数将在数据并行范围内分配 - 当`zero1 == 1`,则不使用 zero1 ,所有数据并行组保留完整的优化器状态参数 - 当`zero1 > 1`且`zero1 <= data_parallel_world_size`,则 zero1 进程组是数据并行进程组的子集 - 2. fsdp: 布尔值,启用/禁用torch的完全分片数据并行,默认为False。 - tensor(字典): 1. size: 整数,张量并行的大小。 2. mode: 字符串,张量并行模式,应该是 ['mtp', 'msp', 'fsp', 'isp'] 中的一个, diff --git a/generate.py b/generate.py index 4ae760299..707d47763 100644 --- a/generate.py +++ b/generate.py @@ -16,12 +16,12 @@ from internlm.accelerator import get_accelerator from internlm.apis.inference import SequenceGenerator -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.data import build_generation_loader_with_data_type from internlm.initialize import initialize_distributed_env from internlm.monitor import initialize_monitor_manager from internlm.monitor.monitor import monitor_manager as mm -from internlm.train import initialize_model, initialize_parallel_communicator +from internlm.train import initialize_model_and_parallel_communicator from internlm.utils.common import ( enable_pytorch_expandable_segments, launch_time, @@ -106,8 +106,7 @@ def main(): raise e # initialize model - model = initialize_model() - _ = initialize_parallel_communicator(model) + model, _ = initialize_model_and_parallel_communicator() model = model.model state_dict = merge_pp_within_tp(generation_config.ckpt_folder, del_model_prefix=True) diff --git a/internlm/__init__.py b/internlm/__init__.py index dc34a3167..e69de29bb 100644 --- a/internlm/__init__.py +++ b/internlm/__init__.py @@ -1,9 +0,0 @@ -from .initialize.initialize_trainer import initialize_trainer -from .initialize.launch import get_default_parser, launch_from_slurm, launch_from_torch - -__all__ = [ - "get_default_parser", - "initialize_trainer", - "launch_from_slurm", - "launch_from_torch", -] diff --git a/internlm/apis/inference.py b/internlm/apis/inference.py index a45cf27dd..05e85a042 100644 --- a/internlm/apis/inference.py +++ b/internlm/apis/inference.py @@ -7,7 +7,7 @@ from internlm.apis import InferenceParams, process_parallel_output from internlm.core.context import ParallelMode # noqa: E402 -from internlm.core.context import global_context as gpc # noqa: E402 +from internlm.core.context.parallel_context import global_context as gpc # noqa: E402 from internlm.core.trainer import Trainer __all__ = ["SequenceGenerator"] diff --git a/internlm/apis/inference_utils.py b/internlm/apis/inference_utils.py index 423e7aafe..9a8cffa06 100644 --- a/internlm/apis/inference_utils.py +++ b/internlm/apis/inference_utils.py @@ -1,7 +1,7 @@ import torch from internlm.core.context import ParallelMode # noqa: E402 -from internlm.core.context import global_context as gpc # noqa: E402 +from internlm.core.context.parallel_context import global_context as gpc # noqa: E402 from internlm.core.parallel.comm.utils import _gather as gather diff --git a/internlm/checkpoint/checkpoint_manager.py b/internlm/checkpoint/checkpoint_manager.py index 2f7f5d4ed..83644d24f 100644 --- a/internlm/checkpoint/checkpoint_manager.py +++ b/internlm/checkpoint/checkpoint_manager.py @@ -9,7 +9,7 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.trainer import TrainState from internlm.initialize.launch import get_config_value from internlm.initialize.legacy.launch import ( @@ -23,6 +23,7 @@ from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer +from internlm.utils.parallel import is_using_fsdp, is_using_hf from internlm.utils.storage_manager import ( get_storage_manager, init_storage_manager, @@ -271,7 +272,7 @@ def __init__( self.storage_manager = get_storage_manager() self.snapshot_counter = -1 - if hasattr(model, "model"): + if hasattr(model, "model") and not is_using_fsdp(): model = model.model self.model = model @@ -575,6 +576,8 @@ def try_resume_training(self, train_state: TrainState, current_time=""): f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)}," f"dp={gpc.get_local_rank(ParallelMode.DATA)}===========" ) + elif is_using_fsdp() and is_using_hf() and not self.auto_resume: + pass else: load_path = self.load_ckpt_info["path"] load_content = self.load_ckpt_info["content"] diff --git a/internlm/checkpoint/components.py b/internlm/checkpoint/components.py index eee92c9c5..70fc4bdff 100644 --- a/internlm/checkpoint/components.py +++ b/internlm/checkpoint/components.py @@ -4,25 +4,34 @@ from collections import defaultdict import torch -from torch.distributed._shard.api import load_with_process_group from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.trainer import TrainState from internlm.model.moe import MoE from internlm.solver.optimizer import HybridZeroOptimizer, HybridZeroOptimizer_v2 from internlm.utils.common import get_current_device +from internlm.utils.lazy import LazyObject from internlm.utils.logger import get_logger -from internlm.utils.parallel import is_using_isp +from internlm.utils.parallel import is_using_fsdp, is_using_hf, is_using_isp from internlm.utils.storage_manager import get_fns, llm_load, llm_save -from .utils import ( - get_model_topology, - get_non_moe_state_dict, - get_shard_state_dict, - load_shard_state_dict, -) +from .utils import get_model_topology, get_non_moe_state_dict + +try: + import torch.distributed.checkpoint as dcp + from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + set_model_state_dict, + ) + + DCP_SUPPORTED = True +except (ImportError, ModuleNotFoundError): + DCP_SUPPORTED = False + +RESUME_HF_FORMAT = True logger = get_logger(__file__) internlm_accelerator = get_accelerator() @@ -99,6 +108,34 @@ def try_save_moe_checkpoint(folder, model, expert_mp_rank, pp_rank): moe_layer_id += 1 +def load_fsdp_model_checkpoint(folder, model): + if DCP_SUPPORTED: + assert folder.startswith("local:"), "Currently we only support DCP load and save locally." + local_folder = folder[6:] + + if is_using_hf() and RESUME_HF_FORMAT: + hf = gpc.config.hf + mod = LazyObject(hf.mod, hf.mod_cls) + mod = mod.build() + state_dict = mod.from_pretrained( + pretrained_model_name_or_path=os.path.join(local_folder, "hf"), use_safetensors=True + ).state_dict() + state_dict = {f"model.{key}": state_dict[key].clone().detach() for key in state_dict} + set_model_state_dict( + model=model, model_state_dict=state_dict, options=StateDictOptions(full_state_dict=True) + ) + else: + state_dict = get_model_state_dict(model=model) + state_dict = {key: state_dict[key].clone().detach() for key in state_dict} + dcp.load(state_dict=state_dict, checkpoint_id=local_folder) + set_model_state_dict(model=model, model_state_dict=state_dict) + + del state_dict + internlm_accelerator.empty_cache() + else: + raise RuntimeError("DCP is not supported in this version of PyTorch.") + + def load_model_checkpoint(folder, model): """ There should be weights with names similar to the following under the folder. @@ -109,43 +146,31 @@ def load_model_checkpoint(folder, model): - folder - model_wp{wp_rank}_pp{pp_rank}.pt - If fsdp is activated, the saved weight is named: - - folder - - model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank}.pt - If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading. """ + if is_using_fsdp(): + return load_fsdp_model_checkpoint(folder, model) + tp_size = gpc.get_world_size(ParallelMode.TENSOR) wp_size = gpc.get_world_size(ParallelMode.WEIGHT) pp_size = gpc.get_world_size(ParallelMode.PIPELINE) - dp_size = gpc.get_world_size(ParallelMode.DATA) tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT) pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - dp_rank = gpc.get_local_rank(ParallelMode.DATA) fns = get_fns(folder) - # avoid ckpt misuse between FSDP and no-FSDP _start_with = "model_w" if is_using_isp() else "model_t" - test_fn = list([f for f in fns if f.startswith(_start_with) and not f.endswith(".md5")]).pop() - assert ("_dp" in test_fn and gpc.config.parallel.zero1.fsdp) or ( - "_dp" not in test_fn and not gpc.config.parallel.zero1.fsdp - ), "FSDP model wants to load no-FSDP ckpts or reverse" - max_pp, max_wp, max_tp, max_zo = 0, 0, 0, 0 + max_pp, max_wp, max_tp = 0, 0, 0 for fn in fns: if fn.startswith(_start_with) and not fn.endswith(".md5"): segements = os.path.splitext(fn)[0].split("_") if is_using_isp(): max_pp = max(max_pp, int(segements[-1][2:])) max_wp = max(max_wp, int(segements[-2][2:])) - elif gpc.config.parallel.zero1.fsdp: - max_zo = max(max_zo, int(segements[-1][2:])) - max_pp = max(max_pp, int(segements[-2][2:])) - max_tp = max(max_tp, int(segements[-3][2:])) else: max_pp = max(max_pp, int(segements[-1][2:])) max_tp = max(max_tp, int(segements[-2][2:])) @@ -160,23 +185,13 @@ def load_model_checkpoint(folder, model): assert ( tp_size == max_tp + 1 ), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism" - if gpc.config.parallel.zero1.fsdp: - assert ( - dp_size == max_zo + 1 - ), f"The weights are save for {max_zo+1} FSDP shards , while current has {dp_size} FSDP shards" - if is_using_isp(): should_load_name = f"model_wp{wp_rank}_pp{pp_rank}.pt" - elif gpc.config.parallel.zero1.fsdp: - should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_dp{dp_rank}.pt" else: should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt" fp = os.path.join(folder, should_load_name) - # for FSDP shards loading, we need to set process group - with load_with_process_group(gpc.get_group(ParallelMode.ZERO1)): - states = llm_load(fp, map_location=get_current_device()) - + states = llm_load(fp, map_location=get_current_device()) """ # need convert the gate parameters to float32 (to fit deepspeed style mechanism), it may cause round-off in # gate.weight. The conversion will also be done when doing forward. so we can just comment it out. this make @@ -193,10 +208,7 @@ def load_model_checkpoint(folder, model): expert_tp_rank = 0 if gpc.config.parallel.expert.no_tp else tp_rank try_load_moe_checkpoint(folder, model, states, expert_tp_rank, pp_rank) - if gpc.config.parallel.zero1.fsdp: - missing_k, unexpected_keys = load_shard_state_dict(model, states, strict=False) - else: - missing_k, unexpected_keys = model.load_state_dict(states, strict=False) + missing_k, unexpected_keys = model.load_state_dict(states, strict=False) if len(missing_k) != 0: logger.warning(f"Warning: missing keys {missing_k}") if len(unexpected_keys) != 0: @@ -207,6 +219,40 @@ def load_model_checkpoint(folder, model): internlm_accelerator.empty_cache() +def save_fsdp_model_checkpoint(folder, model): + def remove_model_prefix(state_dict): + new_state_dict = {} + for key in state_dict.keys(): + new_key = key.replace("model.", "", 1) + new_state_dict[new_key] = state_dict[key].clone().detach() + return new_state_dict + + if DCP_SUPPORTED: + assert folder.startswith("local:"), "Currently we only support DCP load and save locally." + local_folder = folder[6:] + + if is_using_hf() and RESUME_HF_FORMAT: + state_dict = remove_model_prefix( + get_model_state_dict(model, options=StateDictOptions(full_state_dict=True, cpu_offload=True)) + ) + if state_dict: + hf = gpc.config.hf + cfg = LazyObject(hf.cfg, hf.cfg_cls) + cfg = cfg.build() + mod = LazyObject(hf.mod, hf.mod_cls) + mod = mod.build() + with torch.device("meta"): + mod_to_save = mod(cfg(**hf.cfg_extra_kwargs)) + mod_to_save.load_state_dict(state_dict, strict=True, assign=True) + mod_to_save.save_pretrained(save_directory=os.path.join(local_folder, "hf"), safe_serialization=True) + else: + dcp.save(get_model_state_dict(model=model), checkpoint_id=local_folder) + + torch.distributed.barrier() + else: + raise RuntimeError("DCP is not supported in this version of PyTorch.") + + def save_model_checkpoint(folder, model): """ Save the model according to the relationship between tp and dp. The principle is that the data of each tp @@ -218,10 +264,6 @@ def save_model_checkpoint(folder, model): - folder - model_wp{wp_rank}_pp{pp_rank}.pt - If fsdp is activated, the saved weight is named: - - folder - - model_tp{tp_rank}_pp{pp_rank}_zo{zo_rank}.pt - If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading. Args: @@ -229,10 +271,10 @@ def save_model_checkpoint(folder, model): model: The model to be saved """ - if gpc.config.parallel.zero1.fsdp: - states = get_shard_state_dict(model) - else: - states = model.state_dict() + if is_using_fsdp(): + return save_fsdp_model_checkpoint(folder, model) + + states = model.state_dict() # get non-expert parameters states = get_non_moe_state_dict(states) @@ -268,21 +310,15 @@ def save_model_checkpoint(folder, model): else: # for tensor parallel mode with mtp/msp/fsp for i in range(tp_size): - if gpc.config.parallel.zero1.fsdp: - for j in range(dp_size): - should_save_rank_pair.add((i, j)) - else: - should_save_rank_pair.add((i, i % dp_size)) + should_save_rank_pair.add((i, i % dp_size)) if (tp_rank, dp_rank) in should_save_rank_pair: - f_dp = f"_dp{dp_rank}" if gpc.config.parallel.zero1.fsdp else "" - fn = f"model_tp{tp_rank}_pp{pp_rank}{f_dp}.pt" + fn = f"model_tp{tp_rank}_pp{pp_rank}.pt" fp = os.path.join(folder, fn) llm_save(fp, saved_obj=states) - if not gpc.config.parallel.zero1.fsdp or dp_rank == tp_rank % dp_size: - topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json" - topo_fp = os.path.join(folder, topo_fn) - llm_save(topo_fp, saved_obj=topo) + topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json" + topo_fp = os.path.join(folder, topo_fn) + llm_save(topo_fp, saved_obj=topo) # try to save expert parameter to separate files if model have moe layer expert_dp_size = gpc.get_world_size(ParallelMode.EXPERT_DATA) @@ -310,19 +346,25 @@ def load_optimizer_checkpoint(folder, optim): fns = get_fns(folder) max_tp, max_wp, max_pp, max_zero = 0, 0, 0, 0 + max_fsdp = 0 for fn in fns: if fn.startswith("optimizer_") and not fn.endswith(".md5"): - if is_using_isp(): - _, wp, pp, zero = os.path.splitext(fn)[0].split("_") - max_zero = max(max_zero, int(zero[2:])) - max_wp = max(max_wp, int(wp[2:])) - max_pp = max(max_pp, int(pp[2:])) + if isinstance(optim, (HybridZeroOptimizer, HybridZeroOptimizer_v2)): + if is_using_isp(): + _, wp, pp, zero = os.path.splitext(fn)[0].split("_") + max_zero = max(max_zero, int(zero[2:])) + max_wp = max(max_wp, int(wp[2:])) + max_pp = max(max_pp, int(pp[2:])) + else: + _, tp, pp, zero = os.path.splitext(fn)[0].split("_") + max_zero = max(max_zero, int(zero[2:])) + max_tp = max(max_tp, int(tp[2:])) + max_pp = max(max_pp, int(pp[2:])) else: - _, tp, pp, zero = os.path.splitext(fn)[0].split("_") - max_zero = max(max_zero, int(zero[2:])) - max_tp = max(max_tp, int(tp[2:])) - max_pp = max(max_pp, int(pp[2:])) + _, fsdp = os.path.splitext(fn)[0].split("_") + max_fsdp = max(max_fsdp, int(fsdp[4:])) + fsdp_size = gpc.get_world_size(ParallelMode.GLOBAL) zero_size = gpc.get_world_size(ParallelMode.ZERO1) tp_size = gpc.get_world_size(ParallelMode.TENSOR) wp_size = gpc.get_world_size(ParallelMode.WEIGHT) @@ -343,14 +385,24 @@ def load_optimizer_checkpoint(folder, optim): wp_size == max_wp + 1 ), f"The optimizer states are save for {max_wp+1} parallelism, while current has {wp_size} weight parallelism" + if not isinstance(optim, (HybridZeroOptimizer, HybridZeroOptimizer_v2)): + assert ( + fsdp_size == max_fsdp + 1 + ), f"The optimizer states are save for {max_fsdp+1} parallelism, while current has {fsdp_size} fsdp parallelism" + + fsdp_rank = gpc.get_local_rank(ParallelMode.GLOBAL) zero_rank = gpc.get_local_rank(ParallelMode.ZERO1) tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT) pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - if is_using_isp(): - fp = f"optimizer_wp{wp_rank}_pp{pp_rank}_zo{zero_rank}.pt" + + if isinstance(optim, (HybridZeroOptimizer, HybridZeroOptimizer_v2)): + if is_using_isp(): + fp = f"optimizer_wp{wp_rank}_pp{pp_rank}_zo{zero_rank}.pt" + else: + fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt" else: - fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt" + fp = f"optimizer_fsdp{fsdp_rank}.pt" states = llm_load(os.path.join(folder, fp), map_location=get_current_device()) @@ -392,6 +444,7 @@ def save_optimizer_checkpoint(optim, state_path): """ # TODO sanity check for optimizer type + fsdp_rank = gpc.get_local_rank(ParallelMode.GLOBAL) zero_rank = gpc.get_local_rank(ParallelMode.ZERO1) tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT) @@ -416,6 +469,7 @@ def save_optimizer_checkpoint(optim, state_path): fp_meta = os.path.join(state_path, optim.rank_unique_id) llm_save(fp_meta, params_per_rank_id_dict) else: + fp = f"optimizer_fsdp{fsdp_rank}.pt" llm_save(os.path.join(state_path, fp), states) diff --git a/internlm/checkpoint/utils.py b/internlm/checkpoint/utils.py index a63ddb948..6b2c61b88 100644 --- a/internlm/checkpoint/utils.py +++ b/internlm/checkpoint/utils.py @@ -1,31 +1,21 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import itertools + +import numpy as np +import torch from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import StateDictType -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.parallel.shard import split_data_for_sequence_parallel +from internlm.data.utils import packed_data_normalizer, unpack_data from internlm.utils.logger import get_logger +from internlm.utils.parallel import is_using_isp logger = get_logger(__file__) -def get_shard_state_dict(shard_model): - """ - Only used for FSDP module saving. - It's a warper of model.state_dict() and with the context of 'FSDP.state_dict_type', the sharded parameter - (saved as model.flat_param_xx in sharded FSDP module) will be gathered at every gpu. - 'offload_to_cpu' means that the model states are to be offloaded to cpu chunk by chunk, avoiding OOM in gpu - - """ - - # FSDP model can only save with sharded shape SHARDED_STATE_DICT when set use_orig_params=True - with FSDP.state_dict_type(shard_model, StateDictType.SHARDED_STATE_DICT): - shard_states = shard_model.state_dict() - - return shard_states - - def get_non_moe_state_dict(full_state_dict): """ Get the state dict of the non-moe layers @@ -37,18 +27,6 @@ def get_non_moe_state_dict(full_state_dict): return full_state_dict -def load_shard_state_dict(shard_model, shard_state, **kwargs): - """ - Only used for FSDP module loading. - - """ - - with FSDP.state_dict_type(shard_model, StateDictType.SHARDED_STATE_DICT): - missing_k, unexpected_keys = shard_model.load_state_dict(shard_state, kwargs) - - return (missing_k, unexpected_keys) - - def get_model_topology(model): """ Returns: @@ -75,3 +53,67 @@ def process_load_info(load_info): logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}") return load_content_str, load_ckpt_folder, load_content + + +def init_fsdp_v1(model: FSDP, device: torch.device) -> FSDP: + """ + Initialize Fully Sharded Data Parallel (FSDP) for the model. + This function is needed to properly initialize FSDP when resuming from a checkpoint. + It runs a forward pass with dummy inputs to ensure FSDP is fully initialized. + + References: + https://github.com/pytorch/pytorch/issues/113496 + https://github.com/huggingface/transformers/pull/34032 + https://github.com/huggingface/transformers/issues/31892 + + Args: + model: The model to initialize with FSDP. + device: The device to run the model on. + + Returns: + The initialized FSDP model. + """ + model.train() + with torch.no_grad(): + # generate dummy packed sequence + seq_len = gpc.config.data.seq_len * gpc.config.data.micro_bsz + input_ids = [1] * seq_len + label = input_ids[1:] + [-100] + cu_seqlens = list(range(0, seq_len + gpc.config.data.seq_len, gpc.config.data.seq_len)) + + input_ids = torch.tensor(input_ids, device=device).unsqueeze(0) + label = torch.tensor(label, device=device).unsqueeze(0) + indexes = torch.tensor( + list(itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])])), + device=device, + ).unsqueeze(0) + cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32).unsqueeze(0) + + data = { + "input_ids": input_ids, + "cu_seqlens": cu_seqlens, + "indexes": indexes, + "max_seqlen": seq_len, + } + + data_fns = [] + + # default data process function + if gpc.config.data.use_packed_dataset: + data_fns.append(packed_data_normalizer) + else: + data_fns.append(unpack_data) + + # support sequence parallel for isp + if is_using_isp(): + data_fns.append(split_data_for_sequence_parallel) + + # generate dummy_input + _data, _label = data, label + for fn in data_fns: + _data, _label = fn(_data, _label) + dummy_input = _data + + # run a forward pass with dummy_input to initialize FSDP + _ = model(**dummy_input) + return model diff --git a/internlm/core/__init__.py b/internlm/core/__init__.py index d6b704899..998693984 100644 --- a/internlm/core/__init__.py +++ b/internlm/core/__init__.py @@ -1,9 +1,7 @@ -from .engine import Engine from .naive_amp import NaiveAMPModel from .trainer import Trainer __all__ = [ "NaiveAMPModel", - "Engine", "Trainer", ] diff --git a/internlm/core/context/__init__.py b/internlm/core/context/__init__.py index ae5f6a25f..ba47967f9 100644 --- a/internlm/core/context/__init__.py +++ b/internlm/core/context/__init__.py @@ -5,9 +5,7 @@ IS_TENSOR_ZERO_PARALLEL, IS_WEIGHT_EXPERT_DATA_PARALLEL, IS_WEIGHT_ZERO_PARALLEL, - Config, ParallelContext, - global_context, ) from .process_group_initializer import ( Initializer_Data, @@ -15,7 +13,6 @@ Initializer_Pipeline, Initializer_Tensor, Initializer_Zero1, - Initializer_Zero3_dp, ParallelMode, ProcessGroupInitializer, ) @@ -31,14 +28,12 @@ ) __all__ = [ - "Config", "IS_REPLICA_EXPERT_DATA_PARALLEL", "IS_TENSOR_ZERO_PARALLEL", "IS_REPLICA_ZERO_PARALLEL", "IS_WEIGHT_EXPERT_DATA_PARALLEL", "IS_WEIGHT_ZERO_PARALLEL", "IS_TENSOR_EXPERT_DATA_PARALLEL", - "global_context", "ParallelContext", "ParallelMode", "Initializer_Tensor", @@ -46,7 +41,6 @@ "Initializer_Data", "Initializer_Zero1", "Initializer_Nettest", - "Initializer_Zero3_dp", "ProcessGroupInitializer", "seed", "set_mode", diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index f4751f59a..1bd7ede38 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -3,12 +3,12 @@ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context +from importlib.machinery import SourceFileLoader import inspect +from pathlib import Path import random import socket import sys -from importlib.machinery import SourceFileLoader -from pathlib import Path from typing import Union import numpy as np @@ -16,8 +16,9 @@ import torch.distributed as dist from internlm.accelerator import get_accelerator -from internlm.utils.common import SingletonMeta +from internlm.core.context.parallel_context import Config from internlm.utils.logger import get_logger +from internlm.utils.common import SingletonMeta from internlm.utils.timeout import LLM_NCCL_TIMEOUT from internlm.utils.utils import TensorParallelMode @@ -483,16 +484,6 @@ def check_sanity(self): assert self.zero1_parallel_size > 0 - # check for fsdp: - # if zo_size < dp_size, ckpts saving will introduce redundent storage for model weights - # because pytorch "ShardTensor" need to ensure current global rank equals to saved shard's global rank - # pytorch vision: 1.13.1+cu117 - if self.data_parallel_size > self.zero1_parallel_size and self.config.parallel.zero1.get("fsdp", False): - logger.warning( - f"zo size: {self.zero1_parallel_size} < dp size: {self.data_parallel_size}, " - "will introduce redundancy when saving fsdp model ckpts, recommend setting them to same value" - ) - def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str): if key in config: ele = config[key] @@ -518,7 +509,7 @@ def init_parallel_groups(self): if parallel_config is not None: # set default value for parallel size if "zero1" not in parallel_config: - parallel_config._add_item("zero1", dict(size=-1, fsdp=False)) + parallel_config._add_item("zero1", dict(size=-1)) if "pipeline" not in parallel_config: parallel_config._add_item("pipeline", dict(size=1, interleaved_overlap=False)) if "tensor" not in parallel_config: @@ -657,9 +648,7 @@ def init_parallel_groups(self): # process groups for parallelism. enable_moe = self.config.model.get("num_experts", 1) > 1 tp_mode = "mtp" if isinstance(parallel_config.tensor, int) else parallel_config.tensor.get("mode", "mtp") - is_fsdp = False if isinstance(parallel_config.zero1, int) else parallel_config.zero1.get("fsdp", False) - parallel_strategy = "fsdp" if is_fsdp else tp_mode - group_configs = generate_parallel_group_configs(parallel_strategy, parallel_sizes, enable_moe) + group_configs = generate_parallel_group_configs(tp_mode, parallel_sizes, enable_moe) group_results = create_parallel_process_groups(world_size, rank, group_configs, with_cpu_group=False) # process group for network test. diff --git a/internlm/core/context/process_group_initializer.py b/internlm/core/context/process_group_initializer.py index 1e8057383..014bdbfdd 100644 --- a/internlm/core/context/process_group_initializer.py +++ b/internlm/core/context/process_group_initializer.py @@ -42,11 +42,6 @@ class ParallelMode(Enum): # runntime network test NETTEST = "nettest" - # zero3-dp parallel - # if fsdp is activated and size of fsdp-parallel-size is less than dp-parallel-size - # then manual communication only happens between inter-fsdp-modules, while intra-modules reduction is done by fsdp - ZERO3_DP = "zero3_dp" - # expert parallel EXPERT = "expert" @@ -274,7 +269,6 @@ def create_single_process_group( ISP_SP_GROUP_ORDER = [ParallelMode.TENSOR, ParallelMode.DATA, ParallelMode.PIPELINE] ISP_WP_GROUP_ORDER = [ParallelMode.WEIGHT, ParallelMode.WEIGHT_DATA, ParallelMode.PIPELINE] ISP_MOE_GROUP_ORDER = [ParallelMode.EXPERT_WEIGHT, ParallelMode.EXPERT, ParallelMode.EXPERT_DATA, ParallelMode.PIPELINE] -FSDP_ORDER = [ParallelMode.DATA] # TODO: should we support moe for fsdp? SUBGROUP_SPEC = { "mtp": { @@ -283,9 +277,6 @@ def create_single_process_group( "isp": { ParallelMode.WEIGHT_DATA: [ParallelMode.ZERO1], }, # TODO: WEIGHT_ZERO1 - "fsdp": { - ParallelMode.DATA: [ParallelMode.ZERO3_DP, ParallelMode.ZERO1], - }, } @@ -321,8 +312,6 @@ def _recurse_generater(order: List[ParallelMode]): group_configs.append(("isp-wp", _recurse_generater(ISP_WP_GROUP_ORDER))) if enable_moe: group_configs.append(("isp-moe", _recurse_generater(ISP_MOE_GROUP_ORDER))) - elif parallel_strategy == "fsdp": - group_configs.append(("fsdp", _recurse_generater(FSDP_ORDER))) else: # 3d parallel: mtp, msp, fsp group_configs.append(("3d", _recurse_generater(MTP_GROUP_ORDER))) if enable_moe: @@ -1118,65 +1107,6 @@ def init_dist_group(self, use_cpu: bool = False): return groups -class Initializer_Zero3_dp(ProcessGroupInitializer): - """A ProcessGroupInitializer for data parallelism. - - Args: - rank (int): The rank of current process. - world_size (int): Size of whole communication world. - data_parallel_size (int): Size of data parallel. - pipeline_parallel_size (int): Size of pipeline parallel. - tensor_parallel_size (int): Size of tensor parallel. - zero1_parallel_size (int): Size of zero1 parallel. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - assert self.data_parallel_size % self.zero1_parallel_size == 0 - - # the only difference between this initializer and DP_initializer - # when FSDP is enabled, only corresponding pairs are in the same actual DP group due to parameter sharding - # eg: when zero=4 and dp=8 - # no fsdp: rank [0-7] share same model paramters, and [0-3], [4-7] are two separate zero group - # fsdp: params of (0, 4), (1, 5), (2, 6), (3, 7) are the same actually - - self.data_parallel_size //= self.zero1_parallel_size - self.rank_num_per_dp_group = self.world_size // self.data_parallel_size - - assert self.world_size % self.data_parallel_size == 0 - - def init_dist_group(self, use_cpu: bool = False): - """Initialize data parallel groups, and assign local_ranks and groups to each gpu. - - Returns: - Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): - A Data parallelism's information tuple. - """ - local_rank = None - ranks_in_group = None - process_group = None - cpu_group = None - group_world_size = None - mode = ParallelMode.ZERO3_DP - - for i in range(self.rank_num_per_dp_group): - ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)] - group = dist.new_group(ranks) - if use_cpu: - group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group - else: - group_cpu = None - - if self.rank in ranks: - local_rank = ranks.index(self.rank) - group_world_size = len(ranks) - process_group = group - cpu_group = group_cpu - ranks_in_group = ranks - - return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode - - class Initializer_Weight(ProcessGroupInitializer): """A ProcessGroupInitializer for model weight parallelism. diff --git a/internlm/core/gradient_handler.py b/internlm/core/gradient_handler.py index c866be5b3..f267b65e5 100644 --- a/internlm/core/gradient_handler.py +++ b/internlm/core/gradient_handler.py @@ -7,7 +7,7 @@ import torch.distributed as dist from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.utils.common import get_current_device diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index 23a92980c..eb3949f6e 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -13,7 +13,7 @@ from torch import nn from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.naive_amp import unwrap_naive_amp from internlm.core.parallel.comm.utils import ( DUMMY_HANDLE_CONST, @@ -25,7 +25,6 @@ expandKVPacked, reduce_scatter_raw, ) -from internlm.model.modules.embedding import Embedding1D from internlm.model.modules.linear import ParallelLinearWithCommExt from internlm.model.modules.utils import is_moe_param from internlm.utils.common import SchedulerHook, UniqueChainMap, get_current_device @@ -179,14 +178,20 @@ class EmbeddingWeightParallelCommunicator: """ def __init__(self, parallel_mode: ParallelMode) -> None: + from internlm.model.modules.embedding import Embedding1D + + self._embedding1d_class = Embedding1D + self.parallel_mode = parallel_mode self.gather_dim = 0 self._cur_micro_step = 0 self._num_micro_step = gpc.config.data.micro_num - def register_module_hook(self, module: Embedding1D) -> None: - assert isinstance(module, Embedding1D), "Embbeding weight parallel communicator is only support Embedding1D" + def register_module_hook(self, module: nn.Module) -> None: + assert isinstance( + module, self._embedding1d_class + ), "Embbeding weight parallel communicator is only support Embedding1D" module.weight.evo_tensor = None self.gather_dim = 0 if module.vocab_parallel else 1 diff --git a/internlm/core/parallel/comm/utils.py b/internlm/core/parallel/comm/utils.py index 5cd8cb79d..238d09c9c 100644 --- a/internlm/core/parallel/comm/utils.py +++ b/internlm/core/parallel/comm/utils.py @@ -8,7 +8,7 @@ from torch import Tensor from torch.distributed import ProcessGroup -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc class AsyncCommHandle(ABC): diff --git a/internlm/core/parallel/comm/zero.py b/internlm/core/parallel/comm/zero.py index 58929290f..31786cdad 100644 --- a/internlm/core/parallel/comm/zero.py +++ b/internlm/core/parallel/comm/zero.py @@ -9,7 +9,7 @@ from torch import nn from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.naive_amp import unwrap_naive_amp from internlm.core.parallel.comm.isp import ISPCommunicatorWrapper from internlm.model.modules.embedding import Embedding1D diff --git a/internlm/core/parallel/shard.py b/internlm/core/parallel/shard.py index 4dc6a1f5b..1c387117c 100644 --- a/internlm/core/parallel/shard.py +++ b/internlm/core/parallel/shard.py @@ -8,9 +8,10 @@ from torch import nn from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.parallel.comm.utils import _gather, _split from internlm.utils.logger import get_logger +from internlm.utils.parallel import is_using_hf from internlm.utils.utils import TensorParallelMode logger = get_logger(__file__) @@ -33,7 +34,7 @@ def _split_data_for_sequence_parallel(data, label): data["indexes"] = _split(data["indexes"], ParallelMode.TENSOR, dim=_indexes_seq_dim) # NOTICE: For compatibility where the shape of position_ids is [batch, seqlen, ...] - if "inject_info" in gpc.config.model and gpc.config.model.inject_info.get("data_helper", False): + if ("inject_info" in gpc.config.model and gpc.config.model.inject_info.get("data_helper", False)) or is_using_hf(): _position_ids_seq_dim = 1 data["position_ids"] = _split(data["position_ids"], ParallelMode.TENSOR, dim=_position_ids_seq_dim) diff --git a/internlm/core/scheduler/comm/p2p.py b/internlm/core/scheduler/comm/p2p.py index 54fb587c0..194adeb6e 100644 --- a/internlm/core/scheduler/comm/p2p.py +++ b/internlm/core/scheduler/comm/p2p.py @@ -12,7 +12,7 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.utils.common import get_current_device from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks diff --git a/internlm/core/scheduler/comm/utils.py b/internlm/core/scheduler/comm/utils.py index d9e6f7e85..a340bee2e 100644 --- a/internlm/core/scheduler/comm/utils.py +++ b/internlm/core/scheduler/comm/utils.py @@ -6,7 +6,7 @@ import torch.distributed as dist from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.utils.common import get_current_device TensorShape = Union[torch.Size, List[int], Tuple[int]] diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 84b94dbfa..ca6fe7f22 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -9,7 +9,7 @@ import torch.distributed as dist from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.engine import Engine from internlm.utils.common import ( SchedulerHook, diff --git a/internlm/core/scheduler/pipeline_scheduler_1f1b.py b/internlm/core/scheduler/pipeline_scheduler_1f1b.py index 289bc37d3..585102a81 100644 --- a/internlm/core/scheduler/pipeline_scheduler_1f1b.py +++ b/internlm/core/scheduler/pipeline_scheduler_1f1b.py @@ -10,7 +10,7 @@ import torch.distributed as dist from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.engine import Engine from internlm.core.naive_amp import NaiveAMPModel from internlm.core.scheduler import comm diff --git a/internlm/core/scheduler/pipeline_scheduler_zb.py b/internlm/core/scheduler/pipeline_scheduler_zb.py index 75cf18448..a400a53e9 100644 --- a/internlm/core/scheduler/pipeline_scheduler_zb.py +++ b/internlm/core/scheduler/pipeline_scheduler_zb.py @@ -9,7 +9,7 @@ from torch.optim.optimizer import Optimizer from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.engine import Engine from internlm.core.scheduler import comm from internlm.utils.common import SchedulerHook, get_current_device diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index 71c30d00d..f8c4cb6f2 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -9,7 +9,7 @@ from torch.utils.data import DataLoader from internlm.checkpoint.checkpoint_manager import CheckpointManager -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.context.process_group_initializer import ParallelMode from internlm.core.parallel.comm import initialize_offload_manager from internlm.core.trainer import Trainer @@ -24,11 +24,9 @@ get_scheduler_hooks, initialize_llm_profile, initialize_optimizer, - initialize_parallel_communicator, inject_model, load_new_batch, record_current_batch_training_metrics, - set_param_unique_tracking_name, ) from internlm.utils.common import ( BatchSkipper, @@ -101,11 +99,8 @@ def __init__( # load config_lines config_lines = self._read_config(kwargs["config"]) - # set tracking name for parameters - set_param_unique_tracking_name(model) - - # inject model for amp and parallel training - model = inject_model(model) + # inject model for amp, parallel setting, parameter syncing and others + model, isp_communicator = inject_model(model) # check cuda env check_cuda_env() @@ -116,9 +111,6 @@ def __init__( # initialize loss function criterion = self._initialize_criterion() - # initialize isp communicator - isp_communicator = initialize_parallel_communicator(model) - # initialize cpu offload manager for selective checkpoint initialize_offload_manager(gpc.config.get("selective_checkpoint_offload", False)) diff --git a/internlm/data/build_dataloader.py b/internlm/data/build_dataloader.py index e99bbfc70..cd0636622 100644 --- a/internlm/data/build_dataloader.py +++ b/internlm/data/build_dataloader.py @@ -8,7 +8,7 @@ from internlm.accelerator.abstract_accelerator import get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.data.megatron.collaters import megatron_collate_fn from internlm.data.megatron.dataset import build_megatron_dataset from internlm.data.mocked.batch_sampler import MockedSequentialBatchSampler diff --git a/internlm/data/megatron/dataset.py b/internlm/data/megatron/dataset.py index 88f4697bc..daa8caab2 100644 --- a/internlm/data/megatron/dataset.py +++ b/internlm/data/megatron/dataset.py @@ -12,7 +12,7 @@ import torch from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc dtypes = { 1: np.uint8, diff --git a/internlm/data/mocked/dataset.py b/internlm/data/mocked/dataset.py index 88020a78a..2b646eefe 100644 --- a/internlm/data/mocked/dataset.py +++ b/internlm/data/mocked/dataset.py @@ -7,7 +7,7 @@ from torch.utils.data import Dataset from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc def merge_tensors(fn_pattern: str) -> torch.Tensor: diff --git a/internlm/data/streaming/batch_sampler.py b/internlm/data/streaming/batch_sampler.py index 11f9bb8b0..67db5d1fc 100644 --- a/internlm/data/streaming/batch_sampler.py +++ b/internlm/data/streaming/batch_sampler.py @@ -5,7 +5,7 @@ from typing import Optional from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.utils.logger import get_logger logger = get_logger(__file__) diff --git a/internlm/data/streaming/dataset.py b/internlm/data/streaming/dataset.py index 8b0755edf..ab40d6ce3 100644 --- a/internlm/data/streaming/dataset.py +++ b/internlm/data/streaming/dataset.py @@ -10,7 +10,7 @@ from torch.utils.data import Dataset from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from transformers import AutoTokenizer diff --git a/internlm/data/streaming/utils.py b/internlm/data/streaming/utils.py index adf63124e..f50f389c3 100644 --- a/internlm/data/streaming/utils.py +++ b/internlm/data/streaming/utils.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc # simple auto_resume for streaming dataloader diff --git a/internlm/data/tokenized/batch_sampler.py b/internlm/data/tokenized/batch_sampler.py index 282dad390..b7ffb88ee 100644 --- a/internlm/data/tokenized/batch_sampler.py +++ b/internlm/data/tokenized/batch_sampler.py @@ -10,7 +10,7 @@ from torch.utils.data import DataLoader, Dataset, Sampler from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.utils.logger import get_logger logger = get_logger(__file__) diff --git a/internlm/data/tokenized/packed_dataset.py b/internlm/data/tokenized/packed_dataset.py index 2225ee61c..b5e437ce5 100644 --- a/internlm/data/tokenized/packed_dataset.py +++ b/internlm/data/tokenized/packed_dataset.py @@ -15,7 +15,7 @@ from tqdm import tqdm from internlm.accelerator import get_accelerator -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.data.tokenized.single_dataset import JsonlDataset from internlm.data.utils import get_dataset_type_id, get_dataset_type_ids_map from internlm.utils.logger import get_logger diff --git a/internlm/data/tokenized/single_dataset.py b/internlm/data/tokenized/single_dataset.py index 2527dc0a9..4e5abcaa8 100644 --- a/internlm/data/tokenized/single_dataset.py +++ b/internlm/data/tokenized/single_dataset.py @@ -16,7 +16,7 @@ import torch from internlm.accelerator import get_accelerator -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.utils.logger import get_logger logger = get_logger(__file__) diff --git a/internlm/data/train_state.py b/internlm/data/train_state.py index 2c678b05b..34ff2132b 100644 --- a/internlm/data/train_state.py +++ b/internlm/data/train_state.py @@ -1,5 +1,5 @@ # Copyright (c) InternLM. All rights reserved. -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.trainer import TrainState from internlm.utils.utils import DataType diff --git a/internlm/data/utils.py b/internlm/data/utils.py index 119f00f61..e4d9ca243 100644 --- a/internlm/data/utils.py +++ b/internlm/data/utils.py @@ -5,8 +5,9 @@ import torch -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.context.process_group_initializer import ParallelMode +from internlm.utils.parallel import is_using_hf def get_dataset_type_ids_map(path): @@ -64,7 +65,7 @@ def unpack_data(data, label): data["indexes"] = data["indexes"][0] # If model has inject_info and data_helper is enabled, we provide position_ids - if "inject_info" in gpc.config.model and gpc.config.model.inject_info.get("data_helper", False): + if ("inject_info" in gpc.config.model and gpc.config.model.inject_info.get("data_helper", False)) or is_using_hf(): data.pop("max_seqlen") data["position_ids"] = data.pop("indexes").unsqueeze(0) # [batch, seqlen] @@ -81,7 +82,7 @@ def packed_data_normalizer(data, label): data["max_seqlen"] = (data["cu_seqlens"][1:] - data["cu_seqlens"][:-1]).max().item() # If model has inject_info and data_helper is enabled, we provide position_ids, cu_seqlens, max_seqlen - if "inject_info" in gpc.config.model and gpc.config.model.inject_info.get("data_helper", False): + if ("inject_info" in gpc.config.model and gpc.config.model.inject_info.get("data_helper", False)) or is_using_hf(): gpc.config.data[f"cu_seqlens_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] = data.pop("cu_seqlens") gpc.config.data[f"max_seqlen_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] = data.pop("max_seqlen") data["position_ids"] = data.pop("indexes").unsqueeze(0) # [batch, seqlen] diff --git a/internlm/env.py b/internlm/env.py new file mode 100644 index 000000000..99bde449c --- /dev/null +++ b/internlm/env.py @@ -0,0 +1 @@ +VERSION = "0.5.3" diff --git a/internlm/eval/evaluation.py b/internlm/eval/evaluation.py index 862057a3d..18ca5b69b 100644 --- a/internlm/eval/evaluation.py +++ b/internlm/eval/evaluation.py @@ -6,7 +6,7 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.parallel.shard import split_data_for_sequence_parallel from internlm.core.scheduler.pipeline_scheduler_1f1b import get_tensor_shape from internlm.model.metrics import AccPerplex, SchedulerMetricHook diff --git a/internlm/initialize/__init__.py b/internlm/initialize/__init__.py index 14fe06bbb..87486f0bb 100644 --- a/internlm/initialize/__init__.py +++ b/internlm/initialize/__init__.py @@ -1,6 +1,5 @@ from .initialize_trainer import initialize_trainer from .launch import ( - get_default_parser, initialize_distributed_env, launch_from_slurm, launch_from_torch, @@ -8,7 +7,6 @@ ) __all__ = [ - "get_default_parser", "initialize_trainer", "launch_from_slurm", "launch_from_torch", diff --git a/internlm/initialize/initialize_trainer.py b/internlm/initialize/initialize_trainer.py index 48487c5fb..19f5a8656 100644 --- a/internlm/initialize/initialize_trainer.py +++ b/internlm/initialize/initialize_trainer.py @@ -12,7 +12,7 @@ from torch.utils.data import DataLoader from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.engine import Engine from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler from internlm.core.parallel.shard import split_data_for_sequence_parallel diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index e1cb2f0d2..4247d99f0 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -2,7 +2,6 @@ # -*- encoding: utf-8 -*- # Copyright (c) InternLM. All rights reserved. -import argparse import os from pathlib import Path from typing import Dict, Union @@ -10,12 +9,14 @@ import torch from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.core.context import Config -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import Config +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.context.process_group_initializer import ParallelMode from internlm.utils.common import get_master_node from internlm.utils.gputest import warmup_process_group +from internlm.utils.lazy import LazyObject from internlm.utils.logger import get_logger +from internlm.utils.parallel import is_using_hf from internlm.utils.timeout import llm_timeout from internlm.utils.utils import DataType, ModelType, TensorParallelMode @@ -33,35 +34,26 @@ internlm_accelerator = get_accelerator() -def get_default_parser(): - """Reads user command line and uses an argument parser to parse the input arguments. - Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed. - - Returns: - Parser: Returns the parser with the default arguments, the user may add customized arguments into this parser. - """ - parser = argparse.ArgumentParser() - parser.add_argument("--config", type=str, help="path to the config file") - parser.add_argument( - "--launcher", - type=str, - default="slurm", - choices=["slurm", "torch"], - help="launcher for launching distributed environment", - ) - parser.add_argument("--host", type=str, help="the master address for distributed training") - parser.add_argument("--port", type=int, default=8888, help="the master port for distributed training") - parser.add_argument("--world_size", type=int, help="world size for distributed training") - parser.add_argument("--rank", type=int, help="rank for the default process group") - parser.add_argument("--local_rank", type=int, help="local rank on the node") - parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication") - parser.add_argument("--seed", type=int, default=1024) - parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.") - parser.add_argument("--enable_ali_topology", default=False, action="store_true", help="enable ali switch topology.") - parser.add_argument( - "--disable_volc_topology", default=False, action="store_true", help="disable volc switch topology." - ) - return parser +def inject_hf_config_before_launch(hf: dict): + # get HuggingFace model config + cfg = LazyObject(hf.cfg, hf.cfg_cls) + cfg = cfg.build() + model_config = cfg(**hf.cfg_extra_kwargs) + # inject HuggingFace model config into InternTrain as much as we know + if hasattr(model_config, "vocab_size"): + gpc.config.model.vocab_size = gpc.config.VOCAB_SIZE = model_config.vocab_size + if hasattr(model_config, "num_hidden_layers"): + gpc.config.model.num_layers = gpc.config.NUM_LAYER = model_config.num_hidden_layers + if hasattr(model_config, "num_attention_heads"): + gpc.config.model.num_attention_heads = gpc.config.NUM_ATTENTION_HEAD = model_config.num_attention_heads + if hasattr(model_config, "num_key_value_heads"): + gpc.config.model.num_kv_attention_heads = gpc.config.NUM_KV_ATTENTION_HEAD = model_config.num_key_value_heads + if hasattr(model_config, "hidden_size"): + gpc.config.model.hidden_size = gpc.config.HIDDEN_SIZE = model_config.hidden_size + if hasattr(model_config, "intermediate_size"): + gpc.config.model.mlp_ratio = gpc.config.MLP_RATIO = model_config.intermediate_size / model_config.hidden_size + if hasattr(model_config, "num_experts"): + gpc.config.model.num_experts = model_config.num_experts def args_sanity_check(): @@ -76,6 +68,11 @@ def args_sanity_check(): if "model_type" not in gpc.config: gpc.config._add_item("model_type", ModelType.INTERNLM.name) + # inject HuggingFace model config into IntrainTrain + if is_using_hf(): + inject_hf_config_before_launch(gpc.config.hf) + gpc.config.model_type = "hf" + if gpc.config.model_type == "InternLM3_M": # TODO: need check for isp overlap num_layers = gpc.config.model.num_self_decoder_layers + gpc.config.model.num_cross_decoder_layers @@ -88,11 +85,11 @@ def args_sanity_check(): # procssing the parallel config in gpc if "zero1" not in gpc.config.parallel: - gpc.config.parallel._add_item("zero1", dict(size=-1, fsdp=False)) + gpc.config.parallel._add_item("zero1", dict(size=-1)) if isinstance(gpc.config.parallel.zero1, int): zero1_size = gpc.config.parallel.zero1 - gpc.config.parallel._add_item("zero1", dict(size=zero1_size, fsdp=False)) + gpc.config.parallel._add_item("zero1", dict(size=zero1_size)) if "pipeline" not in gpc.config.parallel: gpc.config.parallel._add_item("pipeline", dict(size=1, interleaved_overlap=False, mode="1F1B")) @@ -131,19 +128,6 @@ def args_sanity_check(): if gpc.config.parallel.pipeline["mode"] == "ZBV": gpc.v_shape = True - # check fsdp config - if "fsdp" not in gpc.config.parallel.zero1: - gpc.config.parallel.zero1._add_item("fsdp", False) - - assert not ( - gpc.config.parallel.zero1.fsdp and pp > 1 - ), "FSDP is not supportted when pipeline size > 1, please set pipeline size to 1 or disabled FSDP" - - if gpc.config.parallel.zero1.fsdp: - assert ( - torch.__version__ >= "2.0.1" - ), f"requires torch>=2.0.1 when using fsdp but current version is {torch.__version__}" - # processing the data config in gpc data = gpc.config.data @@ -401,11 +385,6 @@ def args_sanity_check(): gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode=TensorParallelMode.mtp.name) if gpc.config.parallel["tensor"].get("mode", None) is None: gpc.config.parallel["tensor"]["mode"] = TensorParallelMode.mtp.name - if gpc.config.parallel["tensor"]["mode"] == TensorParallelMode.isp.name: - assert not gpc.config.parallel.zero1.fsdp, "FSDP does not support isp" - assert ( - torch.__version__ >= "2.1.0" - ), f"requires torch>=2.1.0 when using isp but current version is {torch.__version__}" assert ( gpc.config.model.vocab_size % gpc.config.parallel.weight.size == 0 @@ -563,7 +542,6 @@ def args_sanity_check(): # moe not support overlap and zero1.5 for now if gpc.config.model.get("num_experts", 1) > 1: - assert not gpc.config.parallel.zero1.fsdp, "FSDP does not support num_experts > 1" assert ( not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param ), "not support overlap and moe at the same time" diff --git a/internlm/model/builder.py b/internlm/model/builder.py index d6c3b20f1..94b4f8e0c 100644 --- a/internlm/model/builder.py +++ b/internlm/model/builder.py @@ -1,19 +1,34 @@ from typing import List, Union +import torch from torch import nn from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.parallel.shard import pipeline_parallel_sharding_wrapper from internlm.model.base_model import BaseModel +from internlm.model.modules.linear import ( + ParallelLinearWithCommExt, + ScaleColumnParallelLinear, +) from internlm.model.registry import model_initializer from internlm.utils.common import get_current_device +from internlm.utils.lazy import LazyObject from internlm.utils.logger import get_logger +from internlm.utils.parallel import is_using_fsdp, is_using_hf, is_using_isp logger = get_logger(__file__) -def create_model(model_type) -> Union[nn.Module, List[nn.Module]]: +def create_model() -> Union[nn.Module, List[nn.Module]]: + if is_using_hf(): + model = create_model_hf(hf=gpc.config.hf) + else: + model = create_model_builtin(model_type=gpc.config.model_type) + return model + + +def create_model_builtin(model_type) -> Union[nn.Module, List[nn.Module]]: kwargs = dict(gpc.config.model) @@ -44,3 +59,71 @@ def create_model(model_type) -> Union[nn.Module, List[nn.Module]]: logger.warning(f"To load/save huggingface ckpt, built-in model should inherited from {BaseModel.__name__}") return model + + +def create_model_hf(hf: dict) -> nn.Module: + cfg = LazyObject(hf.cfg, hf.cfg_cls) + cfg = cfg.build() + mod = LazyObject(hf.mod, hf.mod_cls) + mod = mod.build() + + assert is_using_fsdp(), "Curently HF models can only train with FSDP." + + fsdp_init_method = gpc.config.parallel.fsdp.get("init_method", "cuda") + if fsdp_init_method == "meta": + with torch.device("meta"): + model = mod(cfg(**hf.cfg_extra_kwargs)) + elif fsdp_init_method == "cuda": + # TODO: does HuggingFace models support directly initialized on cuda? + model = mod(cfg(**hf.cfg_extra_kwargs)).to(get_current_device()) + elif fsdp_init_method == "cpu": + model = mod(cfg(**hf.cfg_extra_kwargs)) + else: + raise ValueError(f"Unsupported fsdp init_method: {fsdp_init_method}") + + def traverse(module): + for name, child in module.named_children(): + if ( + isinstance(child, nn.Linear) + and not isinstance(child, ParallelLinearWithCommExt) + and child.weight.shape == (gpc.config.VOCAB_SIZE, gpc.config.HIDDEN_SIZE) + ): + child_new = ScaleColumnParallelLinear( + in_features=child.in_features, + out_features=child.out_features, + bias=child.bias is not None, + device=child.weight.device, + dtype=child.weight.dtype, + ) + setattr(module, name, child_new) + else: + traverse(child) + + # Do hack: lm_head or output layer should be replaced with ScaleColumnParallelLinear, + # to get ISP fwd gather / bwd split work normally. + if is_using_isp(): + # traverse model might be slower than replacement module by name directly + if getattr(model, "lm_head", None) is not None: + lm_head = model.lm_head + lm_head_new = ScaleColumnParallelLinear( + in_features=lm_head.in_features, + out_features=lm_head.out_features, + bias=lm_head.bias is not None, + device=lm_head.weight.device, + dtype=lm_head.weight.dtype, + ) + setattr(model, "lm_head", lm_head_new) + elif getattr(model, "output", None) is not None: + output = model.output + output_new = ScaleColumnParallelLinear( + in_features=output.in_features, + out_features=output.out_features, + bias=output.bias is not None, + device=output.weight.device, + dtype=output.weight.dtype, + ) + setattr(model, "output", output_new) + else: + traverse(model) + + return model diff --git a/internlm/model/metrics.py b/internlm/model/metrics.py index a7f6c9668..f01eab831 100644 --- a/internlm/model/metrics.py +++ b/internlm/model/metrics.py @@ -3,7 +3,7 @@ import torch from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.model.ops.cross_entropy import new_cross_entropy from internlm.utils.common import SchedulerHook, get_current_device from internlm.utils.logger import get_logger diff --git a/internlm/model/modules/embedding.py b/internlm/model/modules/embedding.py index 93fcd6b23..96599d4b4 100644 --- a/internlm/model/modules/embedding.py +++ b/internlm/model/modules/embedding.py @@ -9,7 +9,7 @@ from torch import Tensor, nn from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.model.ops.rotary_emb import apply_rotary_emb from internlm.utils.parallel import is_using_isp diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 6f1902689..8efa75df4 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -13,7 +13,7 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.parallel.shard import ( get_head_parallel_mode, get_parallel_strategies_split_mode, diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index 42418a212..7843eddc1 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -10,7 +10,7 @@ from torch import nn from torch.nn import functional as F -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.model.modules.embedding import new_rotary_embedding from internlm.model.modules.linear import new_linear from internlm.model.modules.utils import update_kv_cache diff --git a/internlm/model/moe/base_layer.py b/internlm/model/moe/base_layer.py index fa02f1457..cfa41b412 100644 --- a/internlm/model/moe/base_layer.py +++ b/internlm/model/moe/base_layer.py @@ -4,7 +4,7 @@ from torch import Tensor from torch.nn import Module, ModuleList -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.model.moe.experts import Experts from internlm.utils.common import get_current_device diff --git a/internlm/model/moe/dropless_layer.py b/internlm/model/moe/dropless_layer.py index f5881dfbe..ca211be2f 100644 --- a/internlm/model/moe/dropless_layer.py +++ b/internlm/model/moe/dropless_layer.py @@ -14,7 +14,7 @@ from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.model.modules.mlp import new_feed_forward from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger diff --git a/internlm/model/moe/gshard_layer.py b/internlm/model/moe/gshard_layer.py index 3aba8d1a3..416929dfa 100644 --- a/internlm/model/moe/gshard_layer.py +++ b/internlm/model/moe/gshard_layer.py @@ -14,7 +14,7 @@ from torch.nn import Module from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.model.modules.mlp import new_feed_forward from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer diff --git a/internlm/model/moe/megablocks/megablock_dmoe.py b/internlm/model/moe/megablocks/megablock_dmoe.py index 46e1a81cd..47da5260a 100644 --- a/internlm/model/moe/megablocks/megablock_dmoe.py +++ b/internlm/model/moe/megablocks/megablock_dmoe.py @@ -4,7 +4,7 @@ import torch from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.model.moe.base_layer import BaseMoELayer from internlm.model.moe.megablocks.megablock_moe import MegaBlockMoE from internlm.model.moe.megablocks.mlp import MegaBlockGroupedFeedForward diff --git a/internlm/model/moe/megablocks/megablock_moe.py b/internlm/model/moe/megablocks/megablock_moe.py index 82fa3062c..996d7e0ac 100644 --- a/internlm/model/moe/megablocks/megablock_moe.py +++ b/internlm/model/moe/megablocks/megablock_moe.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.model.moe.base_layer import BaseMoELayer from internlm.model.moe.megablocks.mlp import MegaBlockFeedForward from internlm.model.moe.utils import all_to_all diff --git a/internlm/model/moe/megablocks/mlp.py b/internlm/model/moe/megablocks/mlp.py index 374793d6c..27d9b6f4a 100644 --- a/internlm/model/moe/megablocks/mlp.py +++ b/internlm/model/moe/megablocks/mlp.py @@ -2,7 +2,7 @@ from torch import nn from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.model.modules.utils import Silu from internlm.model.moe.megablocks.utils import ( act_fn, diff --git a/internlm/model/moe/moe.py b/internlm/model/moe/moe.py index 0bd35e5b2..ad8c471a3 100644 --- a/internlm/model/moe/moe.py +++ b/internlm/model/moe/moe.py @@ -2,7 +2,7 @@ import torch.nn.functional as F from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.naive_amp import set_fp32_attr_to_module from internlm.model.modules.mlp import new_feed_forward from internlm.model.moe.dropless_layer import DroplessMoELayer diff --git a/internlm/model/moe/utils.py b/internlm/model/moe/utils.py index b47459094..0a882b5d7 100644 --- a/internlm/model/moe/utils.py +++ b/internlm/model/moe/utils.py @@ -4,7 +4,7 @@ from torch import Tensor from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.utils.common import get_current_device diff --git a/internlm/model/ops/_flash_attn.py b/internlm/model/ops/_flash_attn.py index 87aac2eb8..fb152aa91 100644 --- a/internlm/model/ops/_flash_attn.py +++ b/internlm/model/ops/_flash_attn.py @@ -2,7 +2,7 @@ import torch from internlm.accelerator import get_accelerator -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.parallel.comm import get_offload_manager try: diff --git a/internlm/model/ops/attention.py b/internlm/model/ops/attention.py index 3aec51f55..41fa09f45 100644 --- a/internlm/model/ops/attention.py +++ b/internlm/model/ops/attention.py @@ -16,7 +16,7 @@ from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.parallel.comm.isp import ( auto_wrap_distributed_attention, auto_wrap_func_distributed_attention, diff --git a/internlm/model/ops/cross_entropy.py b/internlm/model/ops/cross_entropy.py index 99bf1e047..6120996e7 100644 --- a/internlm/model/ops/cross_entropy.py +++ b/internlm/model/ops/cross_entropy.py @@ -13,7 +13,7 @@ from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.model.ops.cross_entropy_ops import ( CrossEntropyApexVocabParallel, CrossEntropyLossApex, diff --git a/internlm/model/ops/cross_entropy_ops/sequence_parallel_loss.py b/internlm/model/ops/cross_entropy_ops/sequence_parallel_loss.py index 2072944f8..49200198f 100644 --- a/internlm/model/ops/cross_entropy_ops/sequence_parallel_loss.py +++ b/internlm/model/ops/cross_entropy_ops/sequence_parallel_loss.py @@ -2,7 +2,7 @@ from torch import nn from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc # Adapted from https://github.com/microsoft/Megatron-DeepSpeed/blob/main/megatron/core/ \ diff --git a/internlm/model/ops/fused_rmsnorm.py b/internlm/model/ops/fused_rmsnorm.py new file mode 100644 index 000000000..bbf9e5d97 --- /dev/null +++ b/internlm/model/ops/fused_rmsnorm.py @@ -0,0 +1,247 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# flake8: noqa=E731 +# pylint: disable=C3001 + +import math +from functools import partial + +import torch +import triton +import triton.language as tl +from torch.distributed._tensor import Partial, Replicate, Shard +from torch.distributed._tensor.experimental import local_map + +# FusedRMSNorm in Triton + +# Credit +# Tri Dao's Triton LayerNorm: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py +# Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N"], +) +@triton.jit +def _rms_norm_fwd_kernel( + X, + stride_x, + Y, + stride_y, + W, + Rstd, + eps, + M, # num rows # pylint: disable=W0613 + N, # num cols + block_N: tl.constexpr, +): + row = tl.program_id(0) + cols = tl.arange(0, block_N) + + # Load input data and weights + mask = cols < N + x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) + w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) + + # Compute mean and variance + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + + # Store the reciprocal standard deviation + tl.store(Rstd + row, rstd) + + # Normalize and apply linear transformation + x_hat = x * rstd + y = x_hat * w + + # Write output + tl.store(Y + row * stride_y + cols, y, mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N"], +) +@triton.jit +def _rms_norm_bwd_kernel_sm( + X, + stride_x, + W, + DY, + stride_dy, + DX, + stride_dx, + Rstd, + DW, + eps, # pylint: disable=W0613 + M, # num rows + N, # num cols + rows_per_program, + block_N: tl.constexpr, +): + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + cols = tl.arange(0, block_N) + mask = cols < N + + # Load weights + w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) + + # Accumulate gradients for weights + dw = tl.zeros((block_N,), dtype=tl.float32) + + row_end = min(row_start + rows_per_program, M) + for row in range(row_start, row_end): + # Load input, output gradient, and reciprocal standard deviation + x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32) + rstd = tl.load(Rstd + row) + + # Compute normalized input and gradients + x_hat = x * rstd + wdy = w * dy + dw += dy * x_hat + c1 = tl.sum(x_hat * wdy, axis=0) / N + dx = (wdy - x_hat * c1) * rstd + + # Store input gradient + tl.store(DX + row * stride_dx + cols, dx, mask=mask) + + # Store weight gradients + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + + +class TritonFusedRMSNorm(torch.autograd.Function): + """ + Triton based Fused RMSNorm + """ + + @partial( + local_map, + out_placements=[Shard(1)], + in_placements=(None, [Shard(1)], [Replicate()], None), + ) + @staticmethod + def forward(ctx, x, weight, eps): + x_shape_start = x.shape + + # Flatten input + x = x.view(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if weight.stride(-1) != 1: + weight = weight.contiguous() + + M, N = x.shape + y = torch.empty_like(x) + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + + max_size = 65536 // x.element_size() + block_N = min(max_size, triton.next_power_of_2(N)) + + if N > block_N: + raise ValueError(f"N {N} must be <= {block_N=}") + + grid = lambda meta: (M,) + _rms_norm_fwd_kernel[grid]( + x, + x.stride(0), + y, + y.stride(0), + weight, + rstd, + eps, + M, + N, + block_N, + ) + + ctx.eps = eps + ctx.save_for_backward(x, weight, rstd) + ctx.x_shape_start = x_shape_start + + y = y.reshape(x_shape_start) + return y + + @partial( + local_map, + out_placements=([Shard(1)], [Partial()], None), + in_placements=(None, [Shard(1)]), + ) + @staticmethod + def backward(ctx, dy): + x, weight, rstd = ctx.saved_tensors + eps = ctx.eps + x_shape_start = ctx.x_shape_start + + # Flatten input and output gradients + dy = dy.view(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + + M, N = dy.shape + dx = torch.empty_like(x) + + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + + max_size = 65536 // x.element_size() + block_N = min(max_size, triton.next_power_of_2(N)) + rows_per_sm = math.ceil(M / sm_count) + + if N > block_N: + raise ValueError(f"N {N} must be <= {block_N=}") + + grid = lambda meta: (sm_count,) + _rms_norm_bwd_kernel_sm[grid]( + x, + x.stride(0), + weight, + dy, + dy.stride(0), + dx, + dx.stride(0), + rstd, + _dw, + eps, + M, + N, + rows_per_sm, + block_N, + ) + dw = _dw.sum(0).to(weight.dtype) + dx = dx.view(x_shape_start) + return dx, dw, None + + +# expose fusedRMSNorm as a function +def fused_rms_norm_fn( + x, + weight, + eps=1e-6, +): + return TritonFusedRMSNorm.apply( + x, + weight, + eps, + ) diff --git a/internlm/model/ops/linear.py b/internlm/model/ops/linear.py index fa4c93c6f..732a6fe19 100644 --- a/internlm/model/ops/linear.py +++ b/internlm/model/ops/linear.py @@ -12,7 +12,7 @@ from torch.nn.functional import linear as _torch_linear_forward_op from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc try: from fused_dense_lib import linear_bias_wgrad as _flash_linear_backward_op diff --git a/internlm/model/ops/rotary_emb.py b/internlm/model/ops/rotary_emb.py index 1f058c9d9..38122ff15 100644 --- a/internlm/model/ops/rotary_emb.py +++ b/internlm/model/ops/rotary_emb.py @@ -13,7 +13,7 @@ from torch import Tensor from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc try: from rotary_emb import apply_rotary as _flash_apply_rotary_func diff --git a/internlm/monitor/monitor.py b/internlm/monitor/monitor.py index fc33de62a..f9d94651b 100644 --- a/internlm/monitor/monitor.py +++ b/internlm/monitor/monitor.py @@ -11,7 +11,7 @@ from threading import Thread from internlm.accelerator.abstract_accelerator import get_accelerator -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.monitor.alert import send_feishu_msg_with_webhook from internlm.utils.common import SingletonMeta diff --git a/internlm/solver/activation_checkpoint.py b/internlm/solver/activation_checkpoint.py index 5aedd9a3d..2b5c9e4ed 100644 --- a/internlm/solver/activation_checkpoint.py +++ b/internlm/solver/activation_checkpoint.py @@ -4,6 +4,9 @@ import weakref import torch +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper as ptd_checkpoint_wrapper, +) from torch.utils.checkpoint import check_backward_validity, detach_variable from internlm.accelerator import get_accelerator @@ -273,3 +276,13 @@ def inner_unpack(packed): arg = arg.to(device="cpu") return output + + +def apply_ac_to_transformer_block(module: torch.nn.Module, checkpoint): + ac_freq = round(1 / checkpoint) + ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) + ptd_checkpoint_wrapper._count += 1 + if ptd_checkpoint_wrapper._count % ac_freq == 0: + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + else: + return module diff --git a/internlm/solver/optimizer/compatible_adamw.py b/internlm/solver/optimizer/compatible_adamw.py index 055e04914..9b010e1f0 100644 --- a/internlm/solver/optimizer/compatible_adamw.py +++ b/internlm/solver/optimizer/compatible_adamw.py @@ -3,7 +3,7 @@ import torch from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.utils.logger import get_logger logger = get_logger(__file__) diff --git a/internlm/solver/optimizer/fsdp_optimizer.py b/internlm/solver/optimizer/fsdp_optimizer.py index 5676608fa..4bdb7aef0 100644 --- a/internlm/solver/optimizer/fsdp_optimizer.py +++ b/internlm/solver/optimizer/fsdp_optimizer.py @@ -1,26 +1,84 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import math +from typing import Iterable +import torch +import torch.distributed as dist from torch.optim import Optimizer -from internlm.core.context import Config, ParallelMode -from internlm.core.context import global_context as gpc +from internlm.accelerator import get_accelerator +from internlm.core.context.parallel_context import Config +from internlm.core.context.parallel_context import ParallelMode +from internlm.core.context.parallel_context import global_context as gpc +from internlm.solver.optimizer.base_optimizer import BaseOptimizer from internlm.solver.optimizer.utils import ( DynamicGradScaler, - reduce_tensor, + get_norm, release_param_grad, ) +from internlm.utils.common import get_tensor_norm, move_norm_to_cuda from internlm.utils.logger import get_logger -from .base_optimizer import BaseOptimizer -from .utils import compute_norm +try: + from torch.distributed.tensor import DTensor + + DTENSOR_SUPPORTED = True +except (ModuleNotFoundError, ImportError): + DTENSOR_SUPPORTED = False logger = get_logger(__file__) +inf = math.inf + +internlm_accelerator = get_accelerator() + + +def compute_norm( + gradients: Iterable[torch.Tensor], + parameters: Iterable[torch.Tensor], +) -> float: + """Get L2 norm + Arguments: + gradients (Iterable[Tensor]): The gradient value. + parameters (Iterable[Tensor]): The parameter each gradient corresponds to. + + Returns: + Total norm of the parameters, need total_norm**(1/norm) before using. + """ + + enable_cuda_kernels = gradients[0].device.type != "cpu" + + # Calculate norm. + tensor_parallel_grads = [g.data.float() for g, _ in zip(gradients, parameters)] + tensor_parallel_norm = get_norm(tensor_parallel_grads, float(2), enable_cuda_kernels) + # If norm is type of float, then we convert them into torch.Tensor. + total_norm = get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels) + # If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors + if not enable_cuda_kernels: + total_norm = move_norm_to_cuda(total_norm) + + if DTENSOR_SUPPORTED and isinstance(total_norm, DTensor): + total_norm = total_norm.full_tensor() + + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.GLOBAL)) + + if torch.is_tensor(total_norm): + total_norm = total_norm.item() + + # Scale. + if total_norm == float("inf") or total_norm == -float("inf"): + total_norm = -1 + + if math.isnan(total_norm): + total_norm = -2 + + return total_norm + class FSDPadaptOptimizer(BaseOptimizer): """ - optimizer for Pytorch FSDP if 'parallel.zero1.fsdp' is True in config file + optimizer for Pytorch FSDP if 'parallel.fsdp' is not None in config file reserve some necessary components of hybird-optim: grad_scaler; grad_clip and unscale; @@ -44,6 +102,7 @@ def __init__( growth_interval=grad_scal_cfg.fp16.growth_interval, hysteresis=grad_scal_cfg.hysteresis, max_scale=grad_scal_cfg.max_scale, + dtype=gpc.config.model.dtype, ) # clip gradient @@ -93,16 +152,6 @@ def zero_grad(self): param.grad = None def step(self): - # in case that fsdp-zero3 size is not equal to dp size - # FSDP module will only reduce gradient within FSDP process group - # so manually reduce grad is essential between two parallel FSDP process group - for group_idx in range(len(self.param_groups)): - params = self._fp16_param_groups[group_idx] - for param in params: - if param.requires_grad and param.grad is not None: - handle = reduce_tensor(tensor=param.grad, parallel_mode=ParallelMode.ZERO3_DP) - handle.wait() - # compute norm found_inf = False norm_groups = {} @@ -207,13 +256,6 @@ def load_state_dict(self, states): self.grad_scaler.load_state_dict(grad_scaler) optim_states = states["base_optim_states"] - if gpc.config.get("only_load_lr", False): - if gpc.is_rank_for_log(): - logger.info("Only load lr in param_groups, skip loading weights in optimizer...") - for pg1, pg2 in zip(self.optim.param_groups, optim_states["param_groups"]): - pg1["lr"] = pg2["lr"] - return - self.optim.load_state_dict(optim_states) # load fp32 optimizer weight diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 8d3ce3add..ba8d947a3 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -11,8 +11,7 @@ from torch.optim import Optimizer from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.core.context import Config, ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import Config from internlm.core.context.parallel_context import ( IS_REPLICA_EXPERT_DATA_PARALLEL, IS_REPLICA_ZERO_PARALLEL, @@ -20,7 +19,9 @@ IS_TENSOR_ZERO_PARALLEL, IS_WEIGHT_EXPERT_DATA_PARALLEL, IS_WEIGHT_ZERO_PARALLEL, + ParallelMode, ) +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.parallel.comm.isp import ISPCommunicatorWrapper from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler from internlm.model.modules.utils import is_gate_param, is_moe_param diff --git a/internlm/solver/optimizer/hybrid_zero_optim_v2.py b/internlm/solver/optimizer/hybrid_zero_optim_v2.py index 36e5f073f..42c3ab177 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim_v2.py +++ b/internlm/solver/optimizer/hybrid_zero_optim_v2.py @@ -7,14 +7,15 @@ import torch.distributed as dist from torch.optim import Optimizer -from internlm.core.context import Config, ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import Config from internlm.core.context.parallel_context import ( IS_REPLICA_ZERO_PARALLEL, IS_TENSOR_EXPERT_DATA_PARALLEL, IS_TENSOR_ZERO_PARALLEL, IS_WEIGHT_ZERO_PARALLEL, + ParallelMode, ) +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( diff --git a/internlm/solver/optimizer/store.py b/internlm/solver/optimizer/store.py index c785e41a9..fd52e0468 100644 --- a/internlm/solver/optimizer/store.py +++ b/internlm/solver/optimizer/store.py @@ -8,7 +8,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc class BaseStore: diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index a0180a596..a9946ad9e 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -11,7 +11,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.utils.common import get_current_device, get_tensor_norm, move_norm_to_cuda from internlm.utils.logger import get_logger from internlm.utils.parallel import ( diff --git a/train.py b/internlm/train.py old mode 100755 new mode 100644 similarity index 91% rename from train.py rename to internlm/train.py index 437774b1d..8d46e2c31 --- a/train.py +++ b/internlm/train.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.trainer_builder import TrainerBuilder from internlm.data import ( build_train_loader_with_data_type, @@ -16,7 +16,7 @@ @internevo_monitor(feishu_alert=True, clean_run=True) def main(args): # initialize model - model = create_model(model_type=gpc.config.model_type) + model = create_model() # initialize train dataloader train_dl, dataset_types = build_train_loader_with_data_type() diff --git a/internlm/train/__init__.py b/internlm/train/__init__.py index 2ad60df09..f3c680da4 100644 --- a/internlm/train/__init__.py +++ b/internlm/train/__init__.py @@ -1,7 +1,7 @@ from .pipeline import ( get_scheduler_hooks, initialize_llm_profile, - initialize_model, + initialize_model_and_parallel_communicator, initialize_optimizer, initialize_parallel_communicator, load_new_batch, @@ -12,7 +12,7 @@ __all__ = [ "initialize_llm_profile", - "initialize_model", + "initialize_model_and_parallel_communicator", "initialize_parallel_communicator", "initialize_optimizer", "load_new_batch", diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 784a5305a..31a726a6f 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -1,15 +1,25 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import collections +import functools +import itertools import math import time -from typing import Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union import torch from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + BackwardPrefetch, + ShardingStrategy, +) +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.utils.data import DataLoader from internlm.accelerator import AcceleratorType, get_accelerator +from internlm.checkpoint.utils import init_fsdp_v1 from internlm.core.context import ( IS_REPLICA_EXPERT_DATA_PARALLEL, IS_REPLICA_ZERO_PARALLEL, @@ -19,7 +29,7 @@ IS_WEIGHT_ZERO_PARALLEL, ParallelMode, ) -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.context.random import set_mode from internlm.core.naive_amp import ( NaiveAMPModel, @@ -78,6 +88,7 @@ from internlm.solver.schedulers.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.train.utils import create_param_groups, map_param_block, timeout_input from internlm.utils.common import DummyProfile, SchedulerHook, get_current_device +from internlm.utils.lazy import LazyObject from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.parallel import ( @@ -85,6 +96,8 @@ is_replica_zero_parallel_parameter, is_tensor_expert_data_parallel_parameter, is_tensor_zero_parallel_parameter, + is_using_fsdp, + is_using_hf, is_using_isp, is_weight_expert_data_parallel_parameter, is_weight_zero_parallel_parameter, @@ -99,6 +112,25 @@ except (ImportError, ModuleNotFoundError): pass +try: + from torch.distributed._composable.fsdp import fully_shard + + FSDP2_SUPPORTED = True +except (ImportError, ModuleNotFoundError): + FSDP2_SUPPORTED = False + + +try: + from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + set_model_state_dict, + ) + + DCP_SUPPORTED = True +except (ImportError, ModuleNotFoundError): + DCP_SUPPORTED = False + + IS_INJECTED = "is_injected" LINEAR2NEWLINEAR_NAME_MAPPING = dict( @@ -220,44 +252,49 @@ def _check_module(name, module): setattr(param, IS_REPLICA_ZERO_PARALLEL, True) for _chunk in unwrap_naive_amp(model): - # special case for pure dp mode - if ( - isinstance(gpc.config.parallel["tensor"], dict) - and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.mtp.name - and gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) - ): - _check_module_func = _check_module_pure_dp - else: - _check_module_func = _check_module - # set param parallel attribute - for name, module in _chunk.named_modules(): - _check_module_func(name, module) - - for name, param in _chunk.named_parameters(): - assert ( - is_replica_zero_parallel_parameter(param) - or is_tensor_zero_parallel_parameter(param) - or is_weight_zero_parallel_parameter(param) - or is_tensor_expert_data_parallel_parameter(param) - or is_weight_expert_data_parallel_parameter(param) - or is_replica_expert_data_parallel_parameter(param) - ), f"parameter with name: {name} has no parallel attribution." - - -@llm_timeout(func_name="initialize_model") -def initialize_model(pre_process_func: Optional[Callable] = None, post_process_func: Optional[Callable] = None): + if not is_using_fsdp(): + # special case for pure dp mode + if ( + isinstance(gpc.config.parallel["tensor"], dict) + and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) + == TensorParallelMode.mtp.name + and gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) + ): + _check_module_func = _check_module_pure_dp + else: + _check_module_func = _check_module + # set param parallel attribute + for name, module in _chunk.named_modules(): + _check_module_func(name, module) + + for name, param in _chunk.named_parameters(): + assert ( + is_replica_zero_parallel_parameter(param) + or is_tensor_zero_parallel_parameter(param) + or is_weight_zero_parallel_parameter(param) + or is_tensor_expert_data_parallel_parameter(param) + or is_weight_expert_data_parallel_parameter(param) + or is_replica_expert_data_parallel_parameter(param) + ), f"parameter with name: {name} has no parallel attribution." + + +@llm_timeout(func_name="initialize_model_and_parallel_communicator") +def initialize_model_and_parallel_communicator( + pre_process_func: Optional[Callable] = None, post_process_func: Optional[Callable] = None +): """ Initialize model with Automatic Mixed Precision. Returns: torch.nn.Module: The neural network model to be trained or evaluated. + An isp communicator for managing comp/comm overlap. """ if pre_process_func: pre_process_output = pre_process_func() register_model_initializer() - model = create_model(model_type=gpc.config.model_type) + model = create_model() if post_process_func: post_process_func(pre_process_output) @@ -276,11 +313,18 @@ def inject_model(model): Returns: torch.nn.Module: The injected neural network model to be trained or evaluated. + An isp communicator for managing comp/comm overlap. """ if hasattr(model, IS_INJECTED) and getattr(model, IS_INJECTED): return model - inject_model_helper(model, inject_info=gpc.config.model.get("inject_info", None)) + # For non-HF cases, set tracking name for parameters + if not is_using_hf(): + set_param_unique_tracking_name(model) + + # For non-fsdp cases, set model inject helper + if not is_using_fsdp(): + inject_model_helper(model, inject_info=gpc.config.model.get("inject_info", None)) # should be set before NaiveAMPModel set_fp32_attr_for_model(model) @@ -310,7 +354,8 @@ def inject_model(model): # This sync is very important, cause the model weights kept in optimizer are copied # from the origin parameters in the memory, so we should make sure the dp sync # does not influence the model weights in optimizer be different with the origin parameters. - sync_model_param(model) + if not is_using_fsdp() or gpc.config.parallel.fsdp.get("init_method", "cuda") == "cuda": + sync_model_param(model) # This function is needed to make sure parameters that are not splitted by tensor parallelism are # the same across tensor parallelism. @@ -321,10 +366,15 @@ def inject_model(model): random_mode = ParallelMode.WEIGHT_DATA if is_using_isp() else ParallelMode.DATA set_mode(random_mode) + # initialize isp communicator + isp_communicator = initialize_parallel_communicator(model) + + model = wrap_FSDP_model(model) + # set is_injected flag setattr(model, "IS_INJECTED", True) - return model + return model, isp_communicator _T = TypeVar("_T") @@ -360,7 +410,7 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): get_current_device(), gpc.config.model.checkpoint, ), - gpc.config.parallel.weight.overlap, + gpc.config.parallel.weight.overlap and not is_using_fsdp(), gpc.get_group(ParallelMode.WEIGHT), is_moe=False, selective_ckpt_offload=gpc.config.get("selective_checkpoint_offload", False), @@ -495,8 +545,9 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): _embedding_communicator = EmbeddingSequenceParallelCommunicator(ParallelMode.TENSOR) # register communitorc for embedding layer. - for embedding in _submodule_filter(model, Embedding1D): - _embedding_communicator.register_module_hook(embedding) + if not is_using_fsdp(): + for embedding in _submodule_filter(model, Embedding1D): + _embedding_communicator.register_module_hook(embedding) # register communictor for head layer. ScaleColumnParallelLinear.register_cls_communicator(_head_communicator) @@ -554,7 +605,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicato else: param_bcast_sync_handler = None - if not gpc.config.parallel.zero1.fsdp: + if not is_using_fsdp(): if ( "use_split_tensor_optim" not in gpc.config.hybrid_zero_optimizer or not gpc.config.hybrid_zero_optimizer.use_split_tensor_optim @@ -975,6 +1026,122 @@ def inject_config(model: nn.Module) -> None: gpc.config.model.num_kv_attention_heads = gpc.config.NUM_KV_ATTENTION_HEAD = llm_cfg.num_key_value_heads +def _get_modules_to_materialize( + root_module: nn.Module, + ignored_modules: Set[nn.Module], +) -> List[nn.Module]: + # Run BFS to collect the modules to materialize via `reset_parameters()`, + # stopping at any module with FSDP already applied or at ignored modules. + modules_to_materialize: List[nn.Module] = [] + queue = collections.deque([root_module]) + visited_modules: Set[nn.Module] = {root_module} + while queue: + module = queue.popleft() + modules_to_materialize.append(module) + for child_module in module.children(): + if child_module not in visited_modules and child_module not in ignored_modules: + visited_modules.add(child_module) + queue.append(child_module) + return modules_to_materialize + + +def _materialize_meta_module( + root_module: nn.Module, + ignored_modules: Set[nn.Module], + device_id: Optional[torch.device], +) -> None: + # Run default meta device initialization + modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules) + module = None + try: + # Assume that each module's `reset_parameters()` only initializes its + # own parameters and not those of its children + with torch.no_grad(): + for module in modules_to_materialize: + # As a contract to the user, only call `reset_parameters()` if + # the module has directly managed parameters/buffers + module_state_iter = itertools.chain(module.parameters(recurse=False), module.buffers(recurse=False)) + has_module_states = len(list(module_state_iter)) > 0 + if has_module_states: + module.to_empty(device=device_id, recurse=False) + module.reset_parameters() # type: ignore[operator] + except BaseException as e: + logger.warning( + "Unable to call `reset_parameters()` for module on meta " + f"device with error {str(e)}. Please ensure that your module of" + f"type {type(module)} implements a `reset_parameters()` method." # type: ignore[possibly-undefined] + ) + raise e + + +def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): + if is_using_fsdp(): + assert isinstance(model, nn.Module), "Currently FSDP does not support pipeline parallel." + wrap_cls = tuple( + LazyObject(warp_cls["mod"], warp_cls["mod_cls"]).build() for warp_cls in gpc.config.get("fsdp_wrap_cls", []) + ) + fsdp_mode = gpc.config.parallel.fsdp.get("mode", "v1") + fsdp_init_method = gpc.config.parallel.fsdp.get("init_method", "cuda") + + if fsdp_mode == "v1": + model = FSDP( + module=model, + process_group=gpc.get_group(ParallelMode.GLOBAL), + sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO2: SHARD_GRAD_OP, ZeRO3: FULL_SHARD + auto_wrap_policy=functools.partial(transformer_auto_wrap_policy, transformer_layer_cls=set(wrap_cls)), + sync_module_states=fsdp_init_method != "cuda", # sync model paramters + forward_prefetch=True, + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + limit_all_gathers=True, + use_orig_params=True, + device_id=None if fsdp_init_method == "cuda" else get_current_device(), # needed for sync_module_states + ) + # For FSDP v1, to get ckpt resuming work normally, we do dummy forward. + # This hack is needed due to FSDP v1 lazy initialization in model construction. + # FYI: https://github.com/pytorch/pytorch/issues/113496 + model = init_fsdp_v1(model, get_current_device()) + elif FSDP2_SUPPORTED and fsdp_mode == "v2": + fsdp_kwargs = { + "reshard_after_forward": True, # ZeRO2: False, ZeRO3: True + } + for module in model.modules(): + if isinstance(module, wrap_cls): + fully_shard(module, **fsdp_kwargs) + fully_shard(model, **fsdp_kwargs) + if fsdp_init_method == "meta": + _materialize_meta_module(model, set(), get_current_device()) + elif fsdp_init_method == "cpu": + model.to(get_current_device()) + else: + raise ValueError(f"Unsupported FSDP mode: {fsdp_mode}") + + if is_using_hf() and not gpc.config.ckpt.get("auto_resume", False): + load_ckpt_info = gpc.config.ckpt.load_ckpt_info + load_ckpt_path = load_ckpt_info.get("path", None) + load_ckpt_content = load_ckpt_info.get("content", []) + if load_ckpt_path: + assert load_ckpt_content == ( + "model", + ), "If auto_resume=False and checkpoint path is given, only model can be loaded" + if DCP_SUPPORTED: + hf = gpc.config.hf + mod = LazyObject(hf.mod, hf.mod_cls) + mod = mod.build() + state_dict = mod.from_pretrained( + pretrained_model_name_or_path=load_ckpt_path, use_safetensors=True + ).state_dict() + state_dict = {f"model.{key}": state_dict[key].clone().detach() for key in state_dict} + set_model_state_dict( + model=model, model_state_dict=state_dict, options=StateDictOptions(full_state_dict=True) + ) + del state_dict + internlm_accelerator.empty_cache() + else: + raise RuntimeError("DCP is not supported in this version of PyTorch.") + + return model + + def inject_model_helper(model: Union[nn.Module, nn.ModuleList], inject_info: Optional[Dict] = None) -> None: """ Inject model helper functions. diff --git a/internlm/utils/common.py b/internlm/utils/common.py index 56ebcfbe6..15c984509 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -1,21 +1,21 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import argparse import bisect import inspect import os import random -import threading from abc import ABC, abstractmethod from collections import ChainMap from contextlib import contextmanager from datetime import datetime +import threading from typing import Union import numpy as np import torch -import internlm from internlm.accelerator import AcceleratorType, get_accelerator from internlm.utils.logger import get_logger @@ -24,8 +24,39 @@ internlm_accelerator = get_accelerator() +def get_default_parser(): + """Reads user command line and uses an argument parser to parse the input arguments. + Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed. + + Returns: + Parser: Returns the parser with the default arguments, the user may add customized arguments into this parser. + """ + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, help="path to the config file") + parser.add_argument( + "--launcher", + type=str, + default="slurm", + choices=["slurm", "torch"], + help="launcher for launching distributed environment", + ) + parser.add_argument("--host", type=str, help="the master address for distributed training") + parser.add_argument("--port", type=int, default=8888, help="the master port for distributed training") + parser.add_argument("--world_size", type=int, help="world size for distributed training") + parser.add_argument("--rank", type=int, help="rank for the default process group") + parser.add_argument("--local_rank", type=int, help="local rank on the node") + parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication") + parser.add_argument("--seed", type=int, default=1024) + parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.") + parser.add_argument("--enable_ali_topology", default=False, action="store_true", help="enable ali switch topology.") + parser.add_argument( + "--disable_volc_topology", default=False, action="store_true", help="disable volc switch topology." + ) + return parser + + def parse_args(): - parser = internlm.get_default_parser() + parser = get_default_parser() args = parser.parse_args() return args diff --git a/internlm/utils/gputest.py b/internlm/utils/gputest.py index 0e75cc48b..5ce206b30 100644 --- a/internlm/utils/gputest.py +++ b/internlm/utils/gputest.py @@ -19,7 +19,7 @@ GPUtil, psutil = None, None from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc logger = get_logger(__file__) internlm_accelerator = get_accelerator() @@ -118,8 +118,6 @@ def warmup_process_group(): dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.ZERO1)) if gpc.is_initialized(ParallelMode.MODEL): dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.MODEL)) - if gpc.is_initialized(ParallelMode.ZERO3_DP): - dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.ZERO3_DP)) if gpc.is_initialized(ParallelMode.EXPERT_DATA): dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.EXPERT_DATA)) if gpc.is_initialized(ParallelMode.EXPERT): diff --git a/internlm/utils/lazy.py b/internlm/utils/lazy.py new file mode 100644 index 000000000..e67c63aa2 --- /dev/null +++ b/internlm/utils/lazy.py @@ -0,0 +1,265 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import abc +import importlib +from typing import Any, Optional, Type, Union + + +def is_seq_of(seq: Any, expected_type: Union[Type, tuple], seq_type: Type = None) -> bool: + """Check whether it is a sequence of some type. + + Args: + seq (Sequence): The sequence to be checked. + expected_type (type or tuple): Expected type of sequence items. + seq_type (type, optional): Expected sequence type. Defaults to None. + + Returns: + bool: Return True if ``seq`` is valid else False. + + Examples: + >>> from mmengine.utils import is_seq_of + >>> seq = ['a', 'b', 'c'] + >>> is_seq_of(seq, str) + True + >>> is_seq_of(seq, int) + False + """ + if seq_type is None: + exp_seq_type = abc.Sequence + else: + assert isinstance(seq_type, type) + exp_seq_type = seq_type + if not isinstance(seq, exp_seq_type): + return False + for item in seq: + if not isinstance(item, expected_type): + return False + return True + + +class LazyObject: + """LazyObject is used to lazily initialize the imported module during + parsing the configuration file. + + During parsing process, the syntax like: + + Examples: + >>> import torch.nn as nn + >>> from mmdet.models import RetinaNet + >>> import mmcls.models + >>> import mmcls.datasets + >>> import mmcls + + Will be parsed as: + + Examples: + >>> # import torch.nn as nn + >>> nn = lazyObject('torch.nn') + >>> # from mmdet.models import RetinaNet + >>> RetinaNet = lazyObject('mmdet.models', 'RetinaNet') + >>> # import mmcls.models; import mmcls.datasets; import mmcls + >>> mmcls = lazyObject(['mmcls', 'mmcls.datasets', 'mmcls.models']) + + ``LazyObject`` records all module information and will be further + referenced by the configuration file. + + Args: + module (str or list or tuple): The module name to be imported. + imported (str, optional): The imported module name. Defaults to None. + location (str, optional): The filename and line number of the imported + module statement happened. + """ + + def __init__(self, module: Union[str, list, tuple], imported: Optional[str] = None, location: Optional[str] = None): + if not isinstance(module, str) and not is_seq_of(module, str): + raise TypeError( + "module should be `str`, `list`, or `tuple`" + f"but got {type(module)}, this might be " + "a bug of MMEngine, please report it to " + "https://github.com/open-mmlab/mmengine/issues" + ) + self._module: Union[str, list, tuple] = module + + if not isinstance(imported, str) and imported is not None: + raise TypeError( + "imported should be `str` or None, but got " + f"{type(imported)}, this might be " + "a bug of MMEngine, please report it to " + "https://github.com/open-mmlab/mmengine/issues" + ) + self._imported = imported + self.location = location + + def build(self) -> Any: + """Return imported object. + + Returns: + Any: Imported object + """ + if isinstance(self._module, str): + try: + module = importlib.import_module(self._module) + except Exception as e: + raise type(e)(f"Failed to import {self._module} " f"in {self.location} for {e}") + + if self._imported is not None: + if hasattr(module, self._imported): + module = getattr(module, self._imported) + else: + raise ImportError(f"Failed to import {self._imported} " f"from {self._module} in {self.location}") + + return module + else: + # import xxx.xxx + # import xxx.yyy + # import xxx.zzz + # return imported xxx + try: + for module in self._module: + importlib.import_module(module) # type: ignore + module_name = self._module[0].split(".")[0] + return importlib.import_module(module_name) + except Exception as e: + raise type(e)(f"Failed to import {self.module} " f"in {self.location} for {e}") + + @property + def module(self): + if isinstance(self._module, str): + return self._module + return self._module[0].split(".")[0] + + def __call__(self, *args, **kwargs): + raise RuntimeError() + + def __deepcopy__(self, memo): + return LazyObject(self._module, self._imported, self.location) + + def __getattr__(self, name): + # Cannot locate the line number of the getting attribute. + # Therefore only record the filename. + if self.location is not None: + location = self.location.split(", line")[0] + else: + location = self.location + return LazyAttr(name, self, location) + + def __str__(self) -> str: + if self._imported is not None: + return self._imported + return self.module + + __repr__ = __str__ + + # `pickle.dump` will try to get the `__getstate__` and `__setstate__` + # methods of the dumped object. If these two methods are not defined, + # LazyObject will return a `__getstate__` LazyObject` or `__setstate__` + # LazyObject. + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + self.__dict__ = state + + +class LazyAttr: + """The attribute of the LazyObject. + + When parsing the configuration file, the imported syntax will be + parsed as the assignment ``LazyObject``. During the subsequent parsing + process, users may reference the attributes of the LazyObject. + To ensure that these attributes also contain information needed to + reconstruct the attribute itself, LazyAttr was introduced. + + Examples: + >>> models = LazyObject(['mmdet.models']) + >>> model = dict(type=models.RetinaNet) + >>> print(type(model['type'])) # + >>> print(model['type'].build()) # + """ # noqa: E501 + + def __init__(self, name: str, source: Union["LazyObject", "LazyAttr"], location=None): + self.name = name + self.source: Union[LazyAttr, LazyObject] = source + + if isinstance(self.source, LazyObject): + if isinstance(self.source._module, str): + if self.source._imported is None: + # source code: + # from xxx.yyy import zzz + # equivalent code: + # zzz = LazyObject('xxx.yyy', 'zzz') + # The source code of get attribute: + # eee = zzz.eee + # Then, `eee._module` should be "xxx.yyy.zzz" + self._module = self.source._module + else: + # source code: + # import xxx.yyy as zzz + # equivalent code: + # zzz = LazyObject('xxx.yyy') + # The source code of get attribute: + # eee = zzz.eee + # Then, `eee._module` should be "xxx.yyy" + self._module = f"{self.source._module}.{self.source}" + else: + # The source code of LazyObject should be + # 1. import xxx.yyy + # 2. import xxx.zzz + # Equivalent to + # xxx = LazyObject(['xxx.yyy', 'xxx.zzz']) + + # The source code of LazyAttr should be + # eee = xxx.eee + # Then, eee._module = xxx + self._module = str(self.source) + elif isinstance(self.source, LazyAttr): + # 1. import xxx + # 2. zzz = xxx.yyy.zzz + + # Equivalent to: + # xxx = LazyObject('xxx') + # zzz = xxx.yyy.zzz + # zzz._module = xxx.yyy._module + zzz.name + self._module = f"{self.source._module}.{self.source.name}" + self.location = location + + @property + def module(self): + return self._module + + def __call__(self, *args, **kwargs: Any) -> Any: + raise RuntimeError() + + def __getattr__(self, name: str) -> "LazyAttr": + return LazyAttr(name, self) + + def __deepcopy__(self, memo): + return LazyAttr(self.name, self.source) + + def build(self) -> Any: + """Return the attribute of the imported object. + + Returns: + Any: attribute of the imported object. + """ + obj = self.source.build() + try: + return getattr(obj, self.name) + except AttributeError: + raise ImportError(f"Failed to import {self.module}.{self.name} in " f"{self.location}") + except ImportError as e: + raise e + + def __str__(self) -> str: + return self.name + + __repr__ = __str__ + + # `pickle.dump` will try to get the `__getstate__` and `__setstate__` + # methods of the dumped object. If these two methods are not defined, + # LazyAttr will return a `__getstate__` LazyAttr` or `__setstate__` + # LazyAttr. + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + self.__dict__ = state diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index 665353070..852c2d54c 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -12,10 +12,22 @@ IS_WEIGHT_ZERO_PARALLEL, ParallelMode, ) -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.utils.utils import TensorParallelMode +def is_using_hf(): + return "hf" in gpc.config + + +def is_using_fsdp(): + return ( + "fsdp" in gpc.config.parallel + and isinstance(gpc.config.parallel["fsdp"], dict) + and gpc.config.parallel["fsdp"].get("enable", False) + ) + + def is_using_sequence_parallel(): return ( isinstance(gpc.config.parallel["tensor"], dict) diff --git a/internlm/utils/singleton.py b/internlm/utils/singleton.py new file mode 100644 index 000000000..86a8aa5d8 --- /dev/null +++ b/internlm/utils/singleton.py @@ -0,0 +1,27 @@ +import threading + + +class SingletonMeta(type): + """ + Thread-safe Singleton Meta with double-checked locking. + Reference: https://en.wikipedia.org/wiki/Double-checked_locking + """ + + _instances = {} + _lock = threading.Lock() + + def __call__(cls, *args, **kwargs): + # First check (without locking) for performance reasons + if cls not in cls._instances: + # Acquire a lock before proceeding to the second check + with cls._lock: + # Second check with lock held to ensure thread safety + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance + else: + assert ( + len(args) == 0 and len(kwargs) == 0 + ), f"{cls.__name__} is a singleton class and an instance has been created." + + return cls._instances[cls] diff --git a/internlm/utils/timeout.py b/internlm/utils/timeout.py index 55b354c4d..f7debc2f9 100644 --- a/internlm/utils/timeout.py +++ b/internlm/utils/timeout.py @@ -41,7 +41,7 @@ def __exit__(self, error_type, value, traceback): timeout_threshold_dict = { "initialize_distributed_env": 240, "nopp_forward_backward_step": 360, - "initialize_model": 60, + "initialize_model_and_parallel_communicator": 60, "initialize_optimizer": 60, "optim_step": 60, "build_train_loader_with_data_type": 600, @@ -63,7 +63,7 @@ def __exit__(self, error_type, value, traceback): def try_get_gpc_rank(): try: - from internlm.core.context import global_context as gpc + from internlm.core.context.parallel_context import global_context as gpc rank = gpc.get_global_rank() except: # noqa: E722 # pylint: disable=bare-except diff --git a/internlm/utils/writer.py b/internlm/utils/writer.py index 7abb8ddde..c19c2175a 100644 --- a/internlm/utils/writer.py +++ b/internlm/utils/writer.py @@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter from internlm.accelerator import get_accelerator -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc internlm_accelerator = get_accelerator() diff --git a/requirements/runtime.txt b/requirements.txt similarity index 51% rename from requirements/runtime.txt rename to requirements.txt index a545f766c..dd734ec8c 100644 --- a/requirements/runtime.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -transformers +transformers<4.47.0 sentencepiece datasets numpy @@ -6,15 +6,11 @@ tqdm einops psutil packaging -pre-commit ninja gputil -pytest boto3 botocore -torch-scatter pyecharts py-libnuma pynvml -tensorboard --f https://data.pyg.org/whl/torch-2.1.0+cu118.html +tensorboard \ No newline at end of file diff --git a/requirements/torch.txt b/requirements/torch.txt deleted file mode 100644 index c9a04b3d8..000000000 --- a/requirements/torch.txt +++ /dev/null @@ -1,4 +0,0 @@ ---extra-index-url https://download.pytorch.org/whl/cu118 -torch==2.1.0+cu118 -torchvision==0.16.0+cu118 -torchaudio==2.1.0+cu118 diff --git a/setup.py b/setup.py index f37599543..1b2f2358f 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,7 @@ import os import re -import sys -import subprocess +from typing import List from setuptools import setup, find_packages -from setuptools.command.install import install pwd = os.path.dirname(__file__) @@ -12,32 +10,24 @@ def readme(): content = f.read() return content -def get_version(): - with open(os.path.join(pwd, 'version.txt'), 'r') as f: - content = f.read() - return content - -def has_nvcc(): - try: - subprocess.run(['nvcc', '--version'], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - return True - except (subprocess.CalledProcessError, FileNotFoundError): - return False +def get_version() -> str: + with open(os.path.join("internlm", "env.py"), encoding="utf-8") as f: + file_content = f.read() + pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION") + (version,) = re.findall(pattern, file_content) + return version -def fetch_requirements(path): - with open(path, 'r') as fd: - return [r.strip() for r in fd.readlines() if 'torch-scatter' not in r and not r.startswith('-f ')] +def get_requires() -> List[str]: + with open("requirements.txt", encoding="utf-8") as f: + file_content = f.read() + lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")] + return lines -if has_nvcc(): - install_requires = [ - fetch_requirements('requirements/runtime.txt'), - 'rotary_emb', - 'xentropy', - ] -else: - install_requires = [ - fetch_requirements('requirements/runtime.txt'), - ] +extra_require = { + "torch": ["torch>=1.13.1"], + "torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3"], + "test": ["pre-commit", "pylint", "pytest"], +} setup( name='InternEvo', @@ -46,7 +36,8 @@ def fetch_requirements(path): long_description=readme(), long_description_content_type='text/markdown', packages=find_packages(), - install_requires=install_requires, + install_requires=get_requires(), + extras_require=extra_require, classifiers=[ 'Programming Language :: Python :: 3.10', 'Intended Audience :: Developers', diff --git a/tests/common_fixture.py b/tests/common_fixture.py index e5a8b9aa1..3362099f5 100644 --- a/tests/common_fixture.py +++ b/tests/common_fixture.py @@ -5,9 +5,9 @@ import numpy as np import torch -import internlm +from internlm.initialize.launch import launch_from_torch from internlm.accelerator import get_accelerator -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.context.parallel_context import Config from internlm.data.utils import unpack_type_ids from internlm.initialize.launch import args_sanity_check @@ -120,7 +120,7 @@ def build_environment(rank, world_size, free_port, config): os.environ["MASTER_PORT"] = str(free_port) internlm_accelerator.empty_cache() # launcher="torch" - internlm.launch_from_torch(config=config, seed=1024) + launch_from_torch(config=config, seed=1024) args_sanity_check() diff --git a/tests/test_core/test_pipeline.py b/tests/test_core/test_pipeline.py index 180fe4b71..28cb3959e 100644 --- a/tests/test_core/test_pipeline.py +++ b/tests/test_core/test_pipeline.py @@ -4,7 +4,7 @@ import torch from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.context.parallel_context import Config from internlm.solver.optimizer.compatible_adamw import new_compatible_adamw from internlm.utils.common import get_current_device @@ -21,7 +21,7 @@ dict( gradient_handler=[dict(type="PipelineSharedModuleGradientHandler")], parallel=dict( - zero1=dict(size=1, fsdp=False), + zero1=dict(size=1), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=8, interleaved_overlap=True), weight=dict(size=1, overlap=True), diff --git a/tests/test_core/utils.py b/tests/test_core/utils.py index 5ccaccaf3..6cce98314 100644 --- a/tests/test_core/utils.py +++ b/tests/test_core/utils.py @@ -5,10 +5,10 @@ from torch import nn from torch.testing import assert_close -import internlm +from internlm.initialize.launch import launch_from_torch from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.engine import Engine from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler from internlm.core.parallel.shard import partition_uniform @@ -156,7 +156,7 @@ def build_environment(rank, world_size, config): os.environ["MASTER_PORT"] = "33333" internlm_accelerator.empty_cache() # launcher="torch" - internlm.launch_from_torch(config=config, seed=1024) + launch_from_torch(config=config, seed=1024) def loose_close(a, b, dtype: torch.dtype = torch.float32): diff --git a/tests/test_data/test_batch_sampler.py b/tests/test_data/test_batch_sampler.py index 6beeb7a7f..e1dea8fa0 100644 --- a/tests/test_data/test_batch_sampler.py +++ b/tests/test_data/test_batch_sampler.py @@ -5,7 +5,7 @@ import torch from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc # from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import Config @@ -152,7 +152,7 @@ def test_warmup(use_flash_atten_case, group_case, micro_bsz_case): config = Config( dict( parallel=dict( - zero1=dict(size=1, fsdp=False), + zero1=dict(size=1), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), weight=dict(size=1, overlap=True), diff --git a/tests/test_infer/test_generate.py b/tests/test_infer/test_generate.py index bbb804a32..14741b494 100644 --- a/tests/test_infer/test_generate.py +++ b/tests/test_infer/test_generate.py @@ -6,7 +6,7 @@ from internlm.apis.inference import SequenceGenerator, batch_tokenize from internlm.initialize import initialize_distributed_env # noqa: E402 -from internlm.train import initialize_model, initialize_parallel_communicator +from internlm.train import initialize_model_and_parallel_communicator def set_seed(seed: int = 1024): @@ -30,7 +30,7 @@ def load_and_generate(path, model_type="INTERNLM2", tokenizer_path=""): model_type=model_type, model=model_config, parallel=dict( - zero1=dict(size=1, fsdp=False), + zero1=dict(size=1), pipeline=dict(size=1, interleaved_overlap=True), tensor=dict(size=1, mode="mtp"), sequence_parallel=0, @@ -50,8 +50,7 @@ def convert_to_str(output_ids): all_output_str.append(cur_sent) return all_output_str - model = initialize_model() - _ = initialize_parallel_communicator(model) + model, _ = initialize_model_and_parallel_communicator() # Directly get the origin model without NativeAMP wrapper. model = model.model diff --git a/tests/test_infer/test_trainer_generate.py b/tests/test_infer/test_trainer_generate.py index 537a40777..c0061e79c 100644 --- a/tests/test_infer/test_trainer_generate.py +++ b/tests/test_infer/test_trainer_generate.py @@ -3,27 +3,25 @@ import pytest from sentencepiece import SentencePieceProcessor -import internlm # noqa: E402 +from internlm.initialize.initialize_trainer import initialize_trainer # noqa: E402 from internlm.apis.inference import SequenceGenerator, batch_tokenize from internlm.checkpoint import CheckpointManager # noqa: E402 -from internlm.core.context import global_context as gpc # noqa: E402 +from internlm.core.context.parallel_context import global_context as gpc # noqa: E402 from internlm.core.trainer import TrainState, Trainer # noqa: E402 from internlm.data import build_train_loader_with_data_type # noqa: E402 from internlm.initialize import initialize_distributed_env # noqa: E402 from internlm.model.losses import InternLoss # noqa: E402 from internlm.train import ( # noqa: E402 get_scheduler_hooks, - initialize_model, + initialize_model_and_parallel_communicator, initialize_optimizer, - initialize_parallel_communicator, ) def setup_generator(config, tokenizer): initialize_distributed_env(config=config) - model = initialize_model() - isp_communicator = initialize_parallel_communicator(model) + model, isp_communicator = initialize_model_and_parallel_communicator() criterion = InternLoss() @@ -47,7 +45,7 @@ def setup_generator(config, tokenizer): ckpt_manager.try_resume_training(train_state) # initialize trainer - engine, scheduler = internlm.initialize_trainer( + engine, scheduler = initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, diff --git a/tests/test_model/test_model_internlm.py b/tests/test_model/test_model_internlm.py index 3ce6f530e..310d2f1d3 100644 --- a/tests/test_model/test_model_internlm.py +++ b/tests/test_model/test_model_internlm.py @@ -6,7 +6,7 @@ import torch from torch import nn -import internlm +from internlm.initialize.launch import launch_from_torch from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import Config @@ -33,7 +33,7 @@ config = Config( dict( parallel=dict( - zero1=dict(size=1, fsdp=False), + zero1=dict(size=1), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), weight=dict(size=1, overlap=True), @@ -84,7 +84,7 @@ def build_environment(rank, world_size, free_port): os.environ["MASTER_PORT"] = free_port internlm_accelerator.empty_cache() # launcher="torch" - internlm.launch_from_torch(config=config, seed=1024) + launch_from_torch(config=config, seed=1024) def seed_all(seed, cuda_deterministic=False): diff --git a/tests/test_model/test_npu_ops/test_flash_attention.py b/tests/test_model/test_npu_ops/test_flash_attention.py index a2a8b91b8..81166bf9f 100644 --- a/tests/test_model/test_npu_ops/test_flash_attention.py +++ b/tests/test_model/test_npu_ops/test_flash_attention.py @@ -12,8 +12,8 @@ from torch import nn from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.core.context import Config -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import Config +from internlm.core.context.parallel_context import global_context as gpc from internlm.model.ops.attention import SelfAttention from internlm.model.ops.utils import pack_output_after_attn, unpack_qkv_before_attn from internlm.utils.common import get_current_device, set_random_seed diff --git a/tests/test_solver/test_optimizer.py b/tests/test_solver/test_optimizer.py index ca470ffc9..11d0ebc7e 100644 --- a/tests/test_solver/test_optimizer.py +++ b/tests/test_solver/test_optimizer.py @@ -9,7 +9,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close -import internlm +from internlm.initialize.launch import launch_from_torch from internlm.accelerator import get_accelerator from internlm.core.context.parallel_context import Config, ParallelMode from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler @@ -96,7 +96,7 @@ def build_environment(rank, world_size): os.environ["MASTER_PORT"] = "12345" internlm_accelerator.empty_cache() # launcher="torch" - internlm.launch_from_torch(config=config, seed=1024) + launch_from_torch(config=config, seed=1024) def loose_close(a, b, dtype: torch.dtype = torch.float32): diff --git a/tests/test_training/test_forward_output_no_fa.py b/tests/test_training/test_forward_output_no_fa.py index d01b876c8..715a362a9 100644 --- a/tests/test_training/test_forward_output_no_fa.py +++ b/tests/test_training/test_forward_output_no_fa.py @@ -7,10 +7,11 @@ import pytest import torch -import internlm +from internlm.initialize import launch_from_torch +from internlm.initialize import initialize_trainer from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.context.parallel_context import Config from internlm.core.trainer import Trainer from internlm.data import build_train_loader_with_data_type @@ -18,9 +19,8 @@ from internlm.model.losses import InternLoss from internlm.model.metrics import AccPerplex, SchedulerMetricHook from internlm.train import ( - initialize_model, + initialize_model_and_parallel_communicator, initialize_optimizer, - initialize_parallel_communicator, ) from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger @@ -134,7 +134,7 @@ def build_environment(rank, world_size, free_port, config): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(free_port) internlm_accelerator.empty_cache() - internlm.launch_from_torch(config=config, seed=1024) + launch_from_torch(config=config, seed=1024) args_sanity_check() @@ -171,8 +171,7 @@ def train_check_output(args): seed_all(1024) # initialize model - model = initialize_model() - _ = initialize_parallel_communicator(model) + model, _ = initialize_model_and_parallel_communicator() # initialize loss function criterion = InternLoss(parallel_output=False, label_smoothing=gpc.config.loss.label_smoothing) @@ -200,7 +199,7 @@ def train_check_output(args): ), ] - engine, scheduler = internlm.initialize_trainer( + engine, scheduler = initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, diff --git a/tests/test_training/test_load_ckpt_loss.py b/tests/test_training/test_load_ckpt_loss.py index ddbb24a08..2a37e6186 100644 --- a/tests/test_training/test_load_ckpt_loss.py +++ b/tests/test_training/test_load_ckpt_loss.py @@ -14,7 +14,8 @@ import torch # noqa: E402 #pylint: disable=wrong-import-position import torch.distributed as dist # noqa: E402 #pylint: disable=wrong-import-position -import internlm # noqa: E402 #pylint: disable=wrong-import-position +from internlm.initialize.initialize_trainer import initialize_trainer # noqa: E402 #pylint: disable=wrong-import-position +from internlm.initialize.launch import launch_from_torch from internlm.checkpoint import ( # noqa: E402 #pylint: disable=wrong-import-position CheckpointManager, ) @@ -45,9 +46,8 @@ SchedulerMetricHook, ) from internlm.train import ( # noqa: E402 #pylint: disable=wrong-import-position - initialize_model, + initialize_model_and_parallel_communicator, initialize_optimizer, - initialize_parallel_communicator, load_new_batch, ) from internlm.utils.common import ( # noqa: E402 #pylint: disable=wrong-import-position @@ -67,7 +67,7 @@ dict( VOCAB_SIZE=103168, parallel=dict( - zero1=dict(size=-1, fsdp=False), + zero1=dict(size=-1), pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, tensor=dict(size=1, mode="mtp"), @@ -175,7 +175,7 @@ def build_environment(rank, world_size, free_port, config): os.environ["MASTER_PORT"] = str(free_port) internlm_accelerator.empty_cache() # launcher="torch" - internlm.launch_from_torch(config=config, seed=1024) + launch_from_torch(config=config, seed=1024) args_sanity_check() @@ -220,8 +220,7 @@ def train_model(args): current_time = objs[0] # initialize model - model = initialize_model() - _ = initialize_parallel_communicator(model) + model, _ = initialize_model_and_parallel_communicator() # initialize loss function criterion = InternLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) @@ -267,7 +266,7 @@ def train_model(args): ), ] - engine, scheduler = internlm.initialize_trainer( + engine, scheduler = initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index 967398e17..1f61544fd 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -5,20 +5,20 @@ import torch import torch.distributed as dist -import internlm +from internlm.initialize.initialize_trainer import initialize_trainer from internlm.accelerator import AcceleratorType, get_accelerator from internlm.checkpoint import CheckpointManager -from internlm.core.context import Config, ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import Config +from internlm.core.context.parallel_context import ParallelMode +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.trainer import Trainer, TrainState from internlm.data import build_train_loader_with_data_type from internlm.initialize import initialize_distributed_env from internlm.model.losses import InternLoss from internlm.train import ( get_scheduler_hooks, - initialize_model, + initialize_model_and_parallel_communicator, initialize_optimizer, - initialize_parallel_communicator, load_new_batch, ) from internlm.utils.common import BatchSkipper, launch_time @@ -167,11 +167,8 @@ def train( dist.broadcast_object_list(objs, src=0) current_time = objs[0] - # initialize model - model = initialize_model() - - # initialize isp communicator - isp_communicator = initialize_parallel_communicator(model) + # initialize model and isp_communicator + model, isp_communicator = initialize_model_and_parallel_communicator() # initialize loss function criterion = InternLoss(parallel_output=gpc.config.model.parallel_output, label_smoothing=label_smoothing) @@ -204,7 +201,7 @@ def train( metric = None # initialize trainer - engine, scheduler = internlm.initialize_trainer( + engine, scheduler = initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, diff --git a/tests/test_training/test_no_fa_train_temp.py b/tests/test_training/test_no_fa_train_temp.py index 5f0782b4b..1da0e30ca 100644 --- a/tests/test_training/test_no_fa_train_temp.py +++ b/tests/test_training/test_no_fa_train_temp.py @@ -2,19 +2,18 @@ import pytest -import internlm +from internlm.initialize.initialize_trainer import initialize_trainer from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.trainer import Trainer from internlm.data import build_train_loader_with_data_type from internlm.model.losses import InternLoss from internlm.model.metrics import AccPerplex from internlm.train import ( get_scheduler_hooks, - initialize_model, + initialize_model_and_parallel_communicator, initialize_optimizer, - initialize_parallel_communicator, ) from internlm.utils.logger import get_logger from tests.common_fixture import ( @@ -51,11 +50,8 @@ def train_check(args): # set seed seed_all(1024) - # initialize model - model = initialize_model() - - # initialize isp communicator - isp_communicator = initialize_parallel_communicator(model) + # initialize model and isp communicator + model, isp_communicator = initialize_model_and_parallel_communicator() # initialize loss function criterion = InternLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) @@ -71,7 +67,7 @@ def train_check(args): dataset_types=dataset_types, ) - engine, scheduler = internlm.initialize_trainer( + engine, scheduler = initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, diff --git a/tests/test_training/test_norm_weight.py b/tests/test_training/test_norm_weight.py index 990b334a6..f104eb152 100644 --- a/tests/test_training/test_norm_weight.py +++ b/tests/test_training/test_norm_weight.py @@ -5,19 +5,18 @@ import pytest import torch -import internlm +from internlm.initialize.initialize_trainer import initialize_trainer from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.trainer import Trainer from internlm.data import build_train_loader_with_data_type from internlm.model.losses import InternLoss from internlm.model.metrics import AccPerplex from internlm.train import ( get_scheduler_hooks, - initialize_model, + initialize_model_and_parallel_communicator, initialize_optimizer, - initialize_parallel_communicator, ) from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger @@ -71,11 +70,8 @@ def train_check_norm_weight(args): # set seed seed_all(1024) - # initialize model - model = initialize_model() - - # initialize isp communicator - isp_communicator = initialize_parallel_communicator(model) + # initialize model and isp communicator + model, isp_communicator = initialize_model_and_parallel_communicator() # initialize loss function criterion = InternLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) @@ -91,7 +87,7 @@ def train_check_norm_weight(args): dataset_types=dataset_types, ) - engine, scheduler = internlm.initialize_trainer( + engine, scheduler = initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, diff --git a/tests/test_training/test_swap_nb_loss_and_gradnorm.py b/tests/test_training/test_swap_nb_loss_and_gradnorm.py index 13c01b1c5..5b6eeca30 100644 --- a/tests/test_training/test_swap_nb_loss_and_gradnorm.py +++ b/tests/test_training/test_swap_nb_loss_and_gradnorm.py @@ -9,10 +9,11 @@ import torch.distributed as dist from tqdm import tqdm -import internlm +from internlm.initialize.launch import launch_from_torch +from internlm.initialize.initialize_trainer import initialize_trainer from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.context.parallel_context import Config from internlm.core.trainer import Trainer from internlm.data import ( @@ -24,9 +25,8 @@ from internlm.model.losses import InternLoss from internlm.model.metrics import AccPerplex, SchedulerMetricHook from internlm.train import ( - initialize_model, + initialize_model_and_parallel_communicator, initialize_optimizer, - initialize_parallel_communicator, ) from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger @@ -138,7 +138,7 @@ def build_environment(rank, world_size, config): os.environ["MASTER_PORT"] = "33333" internlm_accelerator.empty_cache() # launcher="torch" - internlm.launch_from_torch(config=config, seed=1024) + launch_from_torch(config=config, seed=1024) args_sanity_check() @@ -271,8 +271,7 @@ def exam_loss(args): seed_all(1024) # initialize model - model = initialize_model() - _ = initialize_parallel_communicator(model) + model, _ = initialize_model_and_parallel_communicator() # initialize loss function criterion = InternLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) @@ -304,7 +303,7 @@ def exam_loss(args): ), ] - engine, scheduler = internlm.initialize_trainer( + engine, scheduler = initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py index c7da6f85c..1d0c83367 100644 --- a/tests/test_training/train_CI.py +++ b/tests/test_training/train_CI.py @@ -16,10 +16,10 @@ project_root = os.path.abspath(os.path.join(script_dir, "../../")) sys.path.append(project_root) -import internlm # noqa: E402 +from internlm.initialize.initialize_trainer import initialize_trainer # noqa: E402 from internlm.checkpoint import CheckpointManager # noqa: E402 from internlm.core.context import ParallelMode # noqa: E402 -from internlm.core.context import global_context as gpc # noqa: E402 +from internlm.core.context.parallel_context import global_context as gpc # noqa: E402 from internlm.core.trainer import Trainer, TrainState # noqa: E402 from internlm.data import ( # noqa: E402 build_train_loader_with_data_type, @@ -36,9 +36,8 @@ from internlm.monitor.monitor import monitor_manager as mm # noqa: E402 from internlm.train import ( # noqa: E402 initialize_llm_profile, - initialize_model, + initialize_model_and_parallel_communicator, initialize_optimizer, - initialize_parallel_communicator, record_current_batch_training_metrics, ) from internlm.utils.common import ( # noqa: E402 @@ -116,8 +115,7 @@ def main(args): current_time = objs[0] # initialize model - model = initialize_model() - _ = initialize_parallel_communicator(model) + model , _ = initialize_model_and_parallel_communicator() with open(args.config, "r") as f: config_lines = f.readlines() @@ -182,7 +180,7 @@ def main(args): ), ] - engine, scheduler = internlm.initialize_trainer( + engine, scheduler = initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, diff --git a/tests/test_utils/common_fixture.py b/tests/test_utils/common_fixture.py index f4b34ddee..8c51a161a 100644 --- a/tests/test_utils/common_fixture.py +++ b/tests/test_utils/common_fixture.py @@ -5,7 +5,7 @@ import pytest import torch -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.core.context.parallel_context import Config from internlm.core.naive_amp import NaiveAMPModel from internlm.model.builder import create_model @@ -46,7 +46,7 @@ init_config = Config( dict( parallel=dict( - zero1=dict(size=1, fsdp=False), + zero1=dict(size=1), tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), weight=dict(size=1, overlap=True), @@ -91,7 +91,7 @@ def init_naive_model(): register_model_initializer() - model = create_model(model_type=gpc.config.model_type) + model = create_model() model = NaiveAMPModel( model=model, output_to_fp32=False, diff --git a/tests/test_utils/test_model_checkpoint.py b/tests/test_utils/test_model_checkpoint.py index 5fe8b3c49..5cfdf10c7 100644 --- a/tests/test_utils/test_model_checkpoint.py +++ b/tests/test_utils/test_model_checkpoint.py @@ -201,7 +201,7 @@ def return_latest_save_path(save_ckpt_folder, total_step, snapshot_freq, ckpt_fr @pytest.mark.parametrize("step_info", step_info_list) @pytest.mark.parametrize("ckpt_config", ckpt_config_list) def test_ckpt_mm(step_info, ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-import - from internlm.core.context import global_context as gpc + from internlm.core.context.parallel_context import global_context as gpc from internlm.checkpoint.checkpoint_manager import CheckpointLoadMask ckpt_config = Config(ckpt_config) @@ -297,7 +297,7 @@ def test_ckpt_mm(step_info, ckpt_config, init_dist_and_model): # noqa # pylint: def query_quit_file(rank, world_size=2): - from internlm.core.context import global_context as gpc + from internlm.core.context.parallel_context import global_context as gpc from internlm.initialize import initialize_distributed_env from internlm.checkpoint.checkpoint_manager import CheckpointSaveType diff --git a/tools/load_internlm2_model.py b/tools/load_internlm2_model.py index 70900cad5..a470e4c0e 100644 --- a/tools/load_internlm2_model.py +++ b/tools/load_internlm2_model.py @@ -9,9 +9,9 @@ from internlm.apis.inference import SequenceGenerator from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import global_context as gpc from internlm.initialize.launch import initialize_distributed_env -from internlm.train import initialize_model, initialize_parallel_communicator +from internlm.train import initialize_model_and_parallel_communicator from internlm.utils.storage_manager import get_fns, init_storage_manager, llm_load from tools.interface import GenerationConfig @@ -185,7 +185,7 @@ def initialize_internlm_model( model_type=model_type, model=model_config, parallel=dict( - zero1=dict(size=1, fsdp=False), + zero1=dict(size=1), pipeline=dict(size=1, interleaved_overlap=True), tensor=dict(size=get_tp_world_size(), mode="mtp"), sequence_parallel=0, @@ -197,8 +197,7 @@ def initialize_internlm_model( args_check=False, ) # Directly get the origin model without NativeAMP wrapper. - model = initialize_model() - _ = initialize_parallel_communicator(model) + model, _ = initialize_model_and_parallel_communicator() model = model.model state_dict = merge_pp_within_tp(ckpt_dir, del_model_prefix=del_model_prefix) diff --git a/tools/moe_group_ckpt_converter.py b/tools/moe_group_ckpt_converter.py index d3fefb7c7..e07d6a273 100644 --- a/tools/moe_group_ckpt_converter.py +++ b/tools/moe_group_ckpt_converter.py @@ -8,7 +8,6 @@ from tqdm import tqdm sys.path.append(".") -import internlm # noqa: E402,F401 # pylint: disable=W0611,C0413 moe_str_prefix = None weight_key_suffix = ".weight" diff --git a/version.txt b/version.txt deleted file mode 100644 index be14282b7..000000000 --- a/version.txt +++ /dev/null @@ -1 +0,0 @@ -0.5.3