diff --git a/.github/workflows/demo_in_readme.yaml b/.github/workflows/demo_in_readme.yaml
index a764a39f6..5a1fa3a85 100644
--- a/.github/workflows/demo_in_readme.yaml
+++ b/.github/workflows/demo_in_readme.yaml
@@ -63,6 +63,7 @@ jobs:
export GITHUB_WORKSPACE=$GITHUB_WORKSPACE
export SLURM_PARTITION=$SLURM_PARTITION
source activate ${evo_env_torch21_flash2}
+ export PYTHONPATH=$PWD:$PYTHONPATH
sh ./ci_scripts/train/slurm_train.sh ${GITHUB_RUN_ID}-${GITHUB_JOB}
EOF
@@ -97,6 +98,7 @@ jobs:
export GITHUB_WORKSPACE=$GITHUB_WORKSPACE
export SLURM_PARTITION=$SLURM_PARTITION
source activate ${evo_env_torch21_flash2}
+ export PYTHONPATH=$PWD:$PYTHONPATH
sh ./ci_scripts/train/torchrun.sh ${GITHUB_RUN_ID}-${GITHUB_JOB}
rm -rf $GITHUB_WORKSPACE/llm_ckpts
EOF
diff --git a/.github/workflows/lint_check.yaml b/.github/workflows/lint_check.yaml
index fe86bd05a..1d881cd2b 100644
--- a/.github/workflows/lint_check.yaml
+++ b/.github/workflows/lint_check.yaml
@@ -18,25 +18,21 @@ jobs:
run: |
pip install flake8==v3.8.4
FLAKE_DISABLE_LIST="F403,F405,W504,W503,E203"
- flake8 --max-line-length=120 --ignore=$FLAKE_DISABLE_LIST --exclude=./internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py ./internlm/*
- flake8 --max-line-length=120 --ignore=$FLAKE_DISABLE_LIST ./train.py
+ flake8 --max-line-length=120 --ignore=$FLAKE_DISABLE_LIST --exclude=./internlm/model/model_ops/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py ./internlm/*
- name: lint-isort
run: |
pip install isort==5.12.0
isort --check --profile=black ./internlm/*
- isort --check --profile=black ./train.py
- name: lint-black
run: |
pip install black==22.8.0
BLACK_EXCLUDE_SETTINGS='\.venv/|\.local/|\.cache/|\.git/'
black --line-length=120 --check --exclude $BLACK_EXCLUDE_SETTINGS ./internlm/*
- black --line-length=120 --check --exclude $BLACK_EXCLUDE_SETTINGS ./train.py
- name: lint-pylint
run: |
pip install pylint==v2.17.2
PYLINT_DISABLE_LIST="C0114,C0415,W0212,W0235,W0238,W0621,C0103,R1735,C2801,E0402,C0412,W0719,R1728,W1514,W0718,W0105,W0707,C0209,W0703,W1203"
- pylint --rcfile .pylintrc --disable=$PYLINT_DISABLE_LIST --ignore=./internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py ./internlm/*
- pylint --rcfile .pylintrc --disable=$PYLINT_DISABLE_LIST ./train.py
+ pylint --rcfile .pylintrc --disable=$PYLINT_DISABLE_LIST --ignore=./internlm/model/model_ops/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py ./internlm/*
diff --git a/README-ja-JP.md b/README-ja-JP.md
index bb4c9c201..18db395f3 100644
--- a/README-ja-JP.md
+++ b/README-ja-JP.md
@@ -99,7 +99,7 @@ data = dict(
Slurm環境で2ノード16カードを使用する場合、コマンドは以下の通りです:
```bash
-$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ./configs/7B_sft.py
+$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python -m internlm.launcher.launch --config ./configs/7B_sft.py
```
torchを使用し、1ノード8カードで実行する場合、コマンドは以下の通りです:
@@ -166,8 +166,8 @@ $ torchrun --nnodes=1 --nproc_per_node=8 train.py --config ./configs/7B_sft.py -
diff --git a/README-zh-Hans.md b/README-zh-Hans.md
index 98a9caab0..6a5503077 100644
--- a/README-zh-Hans.md
+++ b/README-zh-Hans.md
@@ -99,7 +99,7 @@ data = dict(
slurm环境,双机16卡,启动训练命令如下:
```bash
-$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ./configs/7B_sft.py
+$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python -m internlm.launcher.launch --config ./configs/7B_sft.py
```
torch环境,单机8卡,启动训练命令如下:
@@ -166,8 +166,8 @@ $ torchrun --nnodes=1 --nproc_per_node=8 train.py --config ./configs/7B_sft.py -
|
diff --git a/README.md b/README.md
index 8a9b96612..90c700bcd 100644
--- a/README.md
+++ b/README.md
@@ -99,7 +99,7 @@ Training can be started on slurm or torch distributed environment.
On slurm, using 2 nodes and 16 cards, the command is as follows:
```bash
-$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ./configs/7B_sft.py
+$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python -m internlm.launcher.launch --config ./configs/7B_sft.py
```
On torch, using 1 node and 8 cards, the command is as follows:
@@ -166,8 +166,8 @@ Please refer to the [System Architecture document](./doc/en/structure.md) for ar
|
diff --git a/ci_scripts/model/convert_to_hf.sh b/ci_scripts/model/convert_to_hf.sh
index 3bf381c74..c0280be5d 100644
--- a/ci_scripts/model/convert_to_hf.sh
+++ b/ci_scripts/model/convert_to_hf.sh
@@ -25,7 +25,7 @@ if [[ -d ${CKPTS_OUTPUT} ]]; then
fi
fi
-python ./transformers/convert2hf_internlm.py --src ${CKPTS_INPUT} --tgt ${CKPTS_OUTPUT} --tokenizer ./tools/tokenizer_internlm.model
+python ./huggingface_models/convert2hf_internlm.py --src ${CKPTS_INPUT} --tgt ${CKPTS_OUTPUT} --tokenizer ./tools/tokenizer_internlm.model
[[ $? -ne 0 ]] && { echo "test convert2hf_internlm.py failed."; exit_code=$(($exit_code + 1)); }
#assert exists model
diff --git a/ci_scripts/train/ci_7B_sft.py b/ci_scripts/train/ci_7B_sft.py
index fea45e124..591faf36c 100644
--- a/ci_scripts/train/ci_7B_sft.py
+++ b/ci_scripts/train/ci_7B_sft.py
@@ -101,14 +101,12 @@
model = dict(
checkpoint=False,
num_attention_heads=NUM_ATTENTION_HEAD,
- embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
diff --git a/ci_scripts/train/generate_config.py b/ci_scripts/train/generate_config.py
index 096334d06..a2a0aaf0d 100644
--- a/ci_scripts/train/generate_config.py
+++ b/ci_scripts/train/generate_config.py
@@ -5,7 +5,7 @@
import os
from ci_scripts.common import com_func
-from internlm.core.context import Config
+from internlm.utils.config import Config
def generate_new_config(config_py_file, test_config_json, case_name):
diff --git a/ci_scripts/train/load_ckpt.sh b/ci_scripts/train/load_ckpt.sh
index 3b447bcf1..b9119cc69 100644
--- a/ci_scripts/train/load_ckpt.sh
+++ b/ci_scripts/train/load_ckpt.sh
@@ -22,7 +22,7 @@ if [[ ! -f ${file} ]]; then
exit_code=$(($exit_code + 1))
fi
-srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --exclusive --job-name=$2 -n 8 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ${file}
+srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --exclusive --job-name=$2 -n 8 --ntasks-per-node=8 --gpus-per-task=1 python internlm/launcher/launch.py --config ${file}
[[ $? -ne 0 ]] && { echo "test slurm training failed."; exit_code=$(($exit_code + 1)); }
diff --git a/ci_scripts/train/slurm_train.sh b/ci_scripts/train/slurm_train.sh
index ca5e840b9..5da69e19f 100644
--- a/ci_scripts/train/slurm_train.sh
+++ b/ci_scripts/train/slurm_train.sh
@@ -22,7 +22,7 @@ if [[ -d ${CKPTS20_PATH} ]]; then
fi
fi
-srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --exclusive --job-name=$1 -n 8 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ./ci_scripts/train/ci_7B_sft.py
+srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --exclusive --job-name=$1 -n 8 --ntasks-per-node=8 --gpus-per-task=1 python internlm/launcher/launch.py --config ./ci_scripts/train/ci_7B_sft.py
[[ $? -ne 0 ]] && { echo "test slurm training failed."; exit_code=$(($exit_code + 1)); }
num=$(num_files "${CKPTS20_OUTPUT}")
diff --git a/ci_scripts/train/torchrun.sh b/ci_scripts/train/torchrun.sh
index 27c815725..4f6ba33bf 100644
--- a/ci_scripts/train/torchrun.sh
+++ b/ci_scripts/train/torchrun.sh
@@ -22,7 +22,7 @@ if [[ -d ${CKPTS20_PATH} ]]; then
fi
fi
-srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --exclusive --job-name=$1 -N 1 torchrun --nnodes=1 --nproc_per_node=8 --master_port=29501 train.py --config ./ci_scripts/train/ci_7B_sft.py --launcher torch
+srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --exclusive --job-name=$1 -N 1 torchrun --nnodes=1 --nproc_per_node=8 --master_port=29501 internlm/launcher/launch.py --config ./ci_scripts/train/ci_7B_sft.py --launcher torch
[[ $? -ne 0 ]] && { echo "test torch training failed."; exit_code=$(($exit_code + 1)); }
num=$(num_files "${CKPTS_OUTPUT}")
diff --git a/configs/1.8B_MoE16_sft.py b/configs/1.8B_MoE16_sft.py
index eca10b045..a8a58dc6f 100644
--- a/configs/1.8B_MoE16_sft.py
+++ b/configs/1.8B_MoE16_sft.py
@@ -136,14 +136,12 @@
model = dict(
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
num_attention_heads=NUM_ATTENTION_HEAD,
- embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=False,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
diff --git a/configs/57B_qwen2_MoE.py b/configs/57B_qwen2_MoE.py
deleted file mode 100644
index 27f63cc1d..000000000
--- a/configs/57B_qwen2_MoE.py
+++ /dev/null
@@ -1,226 +0,0 @@
-JOB_NAME = "57b_qwen2_moe"
-model_type = "QWEN2MOE"
-DO_ALERT = False
-
-SEQ_LEN = 4096
-HIDDEN_SIZE = 3584
-NUM_ATTENTION_HEAD = 28
-NUM_KV_ATTENTION_HEAD = 4
-MLP_RATIO = 5 / 7
-NUM_LAYER = 28
-VOCAB_SIZE = 151936
-
-MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
-# Ckpt folder format:
-# fs: 'local:/mnt/nfs/XXX'
-SAVE_CKPT_FOLDER = "local:llm_ckpts"
-LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
-
-# boto3 Ckpt folder format:
-# import os
-# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
-# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
-# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
-CHECKPOINT_EVERY = 50
-ckpt = dict(
- enable_save_ckpt=False, # enable ckpt save.
- save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
- # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"),
- load_ckpt_folder="local:llm_ckpts/",
- # 'load_ckpt_info' setting guide:
- # 1. the 'path' indicate ckpt path,
- # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
- # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined
- # load function such as "llama"
- load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"),
- # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
- # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
- # with an automatic restart mechanism upon training reboot.
- # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint
- # path specified in `load_ckpt_info` by default.
- # If you want to initialize your model weights from another model, you must set `auto_resume` to False.
- # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None.
- auto_resume=True,
- checkpoint_every=CHECKPOINT_EVERY,
- async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
- async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
- oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
-)
-
-TRAIN_FOLDER = None # "/path/to/dataset"
-VALID_FOLDER = None # "/path/to/dataset"
-data = dict(
- seq_len=SEQ_LEN,
- # micro_num means the number of micro_batch contained in one gradient update
- micro_num=4,
- # packed_length = micro_bsz * SEQ_LEN
- micro_bsz=2,
- # defaults to the value of micro_num
- valid_micro_num=4,
- # defaults to 0, means disable evaluate
- valid_every=50,
- pack_sample_into_one=False,
- total_steps=50000,
- skip_batches="",
- # rampup_batch_size (str): A string with three space-separated integers representing the
- # starting batch size, the increment, and the number of steps between
- # each increment. For example, "192 24 8" means that the batch size (micro_num)
- # starts at 192 and increases by 24 every 8 steps. Defaults to None.
- # (IMPORTANT): The interval step size is 'micro_bsz'.
- rampup_batch_size="",
- # Datasets with less than 50 rows will be discarded
- min_length=50,
- train_folder=TRAIN_FOLDER,
- valid_folder=VALID_FOLDER,
- empty_cache_and_diag_interval=200,
- diag_outlier_ratio=1.1,
-)
-
-grad_scaler = dict(
- fp16=dict(
- # the initial loss scale, defaults to 2**16
- initial_scale=2**16,
- # the minimum loss scale, defaults to None
- min_scale=1,
- # the number of steps to increase loss scale when no overflow occurs
- growth_interval=1000,
- ),
- # the multiplication factor for increasing loss scale, defaults to 2
- growth_factor=2,
- # the multiplication factor for decreasing loss scale, defaults to 0.5
- backoff_factor=0.5,
- # the maximum loss scale, defaults to None
- max_scale=2**24,
- # the number of overflows before decreasing loss scale, defaults to 2
- hysteresis=2,
-)
-
-hybrid_zero_optimizer = dict(
- # Enable low_level_optimzer overlap_communication
- overlap_sync_grad=False,
- overlap_sync_param=False,
- # bucket size for nccl communication params
- reduce_bucket_size=512 * 1024 * 1024,
- # grad clipping
- clip_grad_norm=1.0,
-)
-
-loss = dict(
- label_smoothing=0,
- moe_loss_coeff=0.001,
-)
-
-adam = dict(
- lr=1e-4,
- adam_beta1=0.9,
- adam_beta2=0.95,
- adam_beta2_c=0,
- adam_eps=1e-8,
- weight_decay=0.01,
-)
-
-lr_scheduler = dict(
- total_steps=data["total_steps"],
- init_steps=0, # optimizer_warmup_step
- warmup_ratio=0.01,
- eta_min=1e-5,
- last_epoch=-1,
-)
-
-beta2_scheduler = dict(
- init_beta2=adam["adam_beta2"],
- c=adam["adam_beta2_c"],
- cur_iter=-1,
-)
-
-use_fp32_norm = False
-model = dict(
- checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
- num_attention_heads=NUM_ATTENTION_HEAD,
- num_kv_attention_heads=NUM_KV_ATTENTION_HEAD,
- max_position_embeddings=131072,
- embed_split_hidden=True,
- vocab_size=VOCAB_SIZE,
- embed_grad_scale=1,
- parallel_output=True,
- hidden_size=HIDDEN_SIZE,
- num_layers=NUM_LAYER,
- mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
- dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
- norm_type="rmsnorm",
- layer_norm_epsilon=1e-6,
- use_flash_attn=True,
- # Whether the odd and even columns of the query and key in the model are normally interleaved.
- # If it's True, the model's odd and even columns are normally ordered; if it's False,
- # it means that the model has prematurely concatenated all odd columns and even columns in front
- # and back, in order to improve the RoPE's computational efficiency.
- # Example:
- # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...]
- # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...]
- qk_interleaved=False,
- use_sliding_window=False,
- rope_base=1000000,
- num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
- moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-D", "Dropless"
- num_experts=64,
- num_shared_experts=8,
- top_k=8,
-)
-"""
-zero1 parallel (dict):
- 1. size: int
- * if size <= 0, the size of the zero process group is equal to the size of the dp process group,
- so parameters will be divided within the range of dp.
- * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
- * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
- For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
-tensor parallel (dict):
- 1. size: int, the size of tensor parallel.
- 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],
- defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel.
- msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size.
- fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size.
- isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel.
-pipeline parallel (dict):
- 1. size: int, the size of pipeline parallel.
- 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
- defaults to False.
-weight parallel (dict):
- 1. size: int, the size of weight parallel.
- 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
-expert parallel (dict):
- 1. size: int
- * if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size
- to be the number of experts to make sure each device has one expert.
- * if size == 1, all experts are placed in each device, running as dp-only.
- * if size > 1, all experts are placed in k devices and each device has n/k experts, where n is the total
- number of experts and k = size.
-expert weight parallel (dict):
- 1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size.
- 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
-"""
-parallel = dict(
- zero1=dict(size=-1),
- tensor=dict(size=1, mode="mtp"),
- pipeline=dict(size=1, interleaved_overlap=True),
- weight=dict(size=1, overlap=True),
- expert=dict(size=-1, no_tp=False),
- expert_weight=dict(size=1, overlap=True),
-)
-
-cudnn_deterministic = False
-cudnn_benchmark = False
-
-monitor = dict(
- # feishu alert configs
- alert=dict(
- enable_feishu_alert=DO_ALERT,
- feishu_alert_address=None, # feishu webhook to send alert message
- light_monitor_address=None, # light_monitor address to send heartbeat
- alert_file_path=f"llm_alter/{JOB_NAME}_alert.log",
- ),
- tensorboard=dict(
- queue_max_length=10,
- ),
-)
\ No newline at end of file
diff --git a/configs/7B_MoE4_sft.py b/configs/7B_MoE4_sft.py
index 74ebbcbb6..4b494d9f5 100644
--- a/configs/7B_MoE4_sft.py
+++ b/configs/7B_MoE4_sft.py
@@ -149,14 +149,12 @@
model = dict(
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
num_attention_heads=NUM_ATTENTION_HEAD,
- embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
diff --git a/configs/7B_gemma.py b/configs/7B_gemma.py
deleted file mode 100644
index 643bcbdbf..000000000
--- a/configs/7B_gemma.py
+++ /dev/null
@@ -1,230 +0,0 @@
-JOB_NAME = "7b_gemma_train"
-model_type = "GEMMA"
-DO_ALERT = False
-
-VOCAB_SIZE = 256000
-SEQ_LEN = 2048
-HIDDEN_SIZE = 3072
-NUM_ATTENTION_HEAD = 16
-NUM_KV_ATTENTION_HEAD = 16
-HEAD_DIM = 256
-MLP_RATIO = 8
-NUM_LAYER = 28
-
-
-MODEL_ONLY_FOLDER = "local:llm_ckpts_gemma/xxxx"
-# Ckpt folder format:
-# fs: 'local:/mnt/nfs/XXX'
-SAVE_CKPT_FOLDER = "local:llm_ckpts_gemma"
-
-# boto3 Ckpt folder format:
-# import os
-# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
-# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
-CHECKPOINT_EVERY = 50
-ckpt = dict(
- enable_save_ckpt=False, # enable ckpt save.
- enable_internevo2hf_ckpt=False, # enable ckpt save for huggingface format.
- save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
- # 'load_ckpt_info' setting guide:
- # 1. the 'path' indicate ckpt path,
- # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
- # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined
- # load function such as "llama"
- load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="hf"),
- # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
- # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
- # with an automatic restart mechanism upon training reboot.
- # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint
- # path specified in `load_ckpt_info` by default.
- # If you want to initialize your model weights from another model, you must set `auto_resume` to False.
- # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None.
- auto_resume=False,
- checkpoint_every=CHECKPOINT_EVERY,
- async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
- async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
- oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
-)
-
-TRAIN_FOLDER = None
-VALID_FOLDER = None # "/path/to/dataset"
-data = dict(
- seq_len=SEQ_LEN,
- # micro_num means the number of micro_batch contained in one gradient update
- micro_num=4,
- # packed_length = micro_bsz * SEQ_LEN
- micro_bsz=1,
- # defaults to the value of micro_num
- valid_micro_num=4,
- # defaults to 0, means disable evaluate
- valid_every=0,
- pack_sample_into_one=False,
- total_steps=20,
- skip_batches="",
- # rampup_batch_size (str): A string with three space-separated integers representing the
- # starting batch size, the increment, and the number of steps between
- # each increment. For example, "192 24 8" means that the batch size (micro_num)
- # starts at 192 and increases by 24 every 8 steps. Defaults to None.
- # (IMPORTANT): The interval step size is 'micro_bsz'.
- rampup_batch_size="",
- # Datasets with less than 50 rows will be discarded
- min_length=50,
- train_folder=TRAIN_FOLDER,
- valid_folder=VALID_FOLDER,
- empty_cache_and_diag_interval=200,
- diag_outlier_ratio=1.1,
-)
-
-grad_scaler = dict(
- fp16=dict(
- # the initial loss scale, defaults to 2**16
- initial_scale=2**16,
- # the minimum loss scale, defaults to None
- min_scale=1,
- # the number of steps to increase loss scale when no overflow occurs
- growth_interval=1000,
- ),
- # the multiplication factor for increasing loss scale, defaults to 2
- growth_factor=2,
- # the multiplication factor for decreasing loss scale, defaults to 0.5
- backoff_factor=0.5,
- # the maximum loss scale, defaults to None
- max_scale=2**24,
- # the number of overflows before decreasing loss scale, defaults to 2
- hysteresis=2,
-)
-
-hybrid_zero_optimizer = dict(
- # Enable low_level_optimzer overlap_communication
- overlap_sync_grad=True,
- overlap_sync_param=False,
- # bucket size for nccl communication params
- reduce_bucket_size=512 * 1024 * 1024,
- # grad clipping
- clip_grad_norm=1.0,
-)
-
-loss = dict(
- label_smoothing=0,
-)
-
-adam = dict(
- lr=1e-4,
- adam_beta1=0.9,
- adam_beta2=0.95,
- adam_beta2_c=0,
- adam_eps=1e-8,
- weight_decay=0.01,
-)
-
-lr_scheduler = dict(
- total_steps=data["total_steps"],
- init_steps=0, # optimizer_warmup_step
- warmup_ratio=0.01,
- eta_min=1e-5,
- last_epoch=-1,
-)
-
-beta2_scheduler = dict(
- init_beta2=adam["adam_beta2"],
- c=adam["adam_beta2_c"],
- cur_iter=-1,
-)
-
-use_fp32_norm = False
-model = dict(
- checkpoint=False,
- num_chunks=1,
- num_attention_heads=NUM_ATTENTION_HEAD,
- num_kv_attention_heads=NUM_KV_ATTENTION_HEAD,
- max_position_embeddings=8192,
- embed_split_hidden=True,
- vocab_size=VOCAB_SIZE,
- embed_grad_scale=1,
- parallel_output=True,
- hidden_size=HIDDEN_SIZE,
- num_layers=NUM_LAYER,
- no_bias=True,
- mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
- dtype="torch.bfloat16",
- add_unit_offset=True,
- norm_type="rmsnorm",
- layer_norm_epsilon=1e-6,
- head_dim=HEAD_DIM,
- use_flash_attn=True,
- # Whether the odd and even columns of the query and key in the model are normally interleaved.
- # If it's True, the model's odd and even columns are normally ordered; if it's False,
- # it means that the model has prematurely concatenated all odd columns and even columns in front
- # and back, in order to improve the RoPE's computational efficiency.
- # Example:
- # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...]
- # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...]
- qk_interleaved=False,
- use_swiglu=False,
-)
-
-"""
-zero1 parallel (dict):
- 1. size: int
- * if size <= 0, the size of the zero process group is equal to the size of the dp process group,
- so parameters will be divided within the range of dp.
- * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
- * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
- For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
-tensor parallel (dict):
- 1. size: int, the size of tensor parallel.
- 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],
- defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel.
- msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size.
- fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size.
- isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel.
-pipeline parallel (dict):
- 1. size: int, the size of pipeline parallel.
- 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
- defaults to False.
-weight parallel (dict):
- 1. size: int, the size of weight parallel.
- 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
-"""
-parallel = dict(
- zero1=dict(size=-1),
- tensor=dict(size=1, mode="mtp"),
- pipeline=dict(size=1, interleaved_overlap=True),
- weight=dict(size=1, overlap=True),
-)
-
-cudnn_deterministic = False
-cudnn_benchmark = False
-
-monitor = dict(
- # feishu alert configs
- alert=dict(
- enable_feishu_alert=DO_ALERT,
- feishu_alert_address=None, # feishu webhook to send alert message
- light_monitor_address=None, # light_monitor address to send heartbeat
- alert_file_path=f"llm_alter/{JOB_NAME}_alert.log",
- ),
- tensorboard=dict(
- queue_max_length=10,
- ),
-)
-
-# metric_dtype can be "fp32" or other string
-# only when set to "fp32" will use fp32 to calc in metrics
-# metric_dtype = "fp32"
-
-generation = dict(
- ckpt_folder="/path/to/saved/ckpt",
- output_folder="/path/to/save/generation",
- batch_size=1,
- eos_id=[2, 0],
- bos_id=1,
- max_length=100,
- do_sample=True,
- temperature=1.0,
- top_k=50,
- top_p=1.0,
- repetition_penalty=1,
- length_penalty=1.0,
-)
diff --git a/configs/7B_internlm2.py b/configs/7B_internlm2.py
index 3c7bb9f4f..2126d7470 100644
--- a/configs/7B_internlm2.py
+++ b/configs/7B_internlm2.py
@@ -142,7 +142,6 @@
checkpoint=False,
num_chunks=1,
num_attention_heads=NUM_ATTENTION_HEAD,
- embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
@@ -150,7 +149,6 @@
num_layers=NUM_LAYER,
no_bias=True,
mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
diff --git a/configs/7B_baichuan2.py b/configs/7B_internlm2_hf.py
similarity index 70%
rename from configs/7B_baichuan2.py
rename to configs/7B_internlm2_hf.py
index 9957d6819..4e6ed9042 100644
--- a/configs/7B_baichuan2.py
+++ b/configs/7B_internlm2_hf.py
@@ -1,35 +1,34 @@
-JOB_NAME = "7b_baichuan2_train"
-model_type = "BAICHUAN2"
+JOB_NAME = "7b_internlm2_train"
DO_ALERT = False
-VOCAB_SIZE = 125696
+
+VOCAB_SIZE = 92544
SEQ_LEN = 2048
-HIDDEN_SIZE = 4096
-NUM_ATTENTION_HEAD = 32
-MLP_RATIO = 8 / 3
-NUM_LAYER = 32
-MODEL_ONLY_FOLDER = "local:llm_ckpts_baichuan2/xxxx"
+MODEL_ONLY_FOLDER = None
# Ckpt folder format:
# fs: 'local:/mnt/nfs/XXX'
-SAVE_CKPT_FOLDER = "local:llm_ckpts_baichuan2"
+SAVE_CKPT_FOLDER = "local:llm_ckpts"
+LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
# boto3 Ckpt folder format:
# import os
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
+# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
CHECKPOINT_EVERY = 50
ckpt = dict(
enable_save_ckpt=False, # enable ckpt save.
- enable_internevo2hf_ckpt=False, # enable ckpt save for huggingface format.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
+ # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"),
+ load_ckpt_folder="local:llm_ckpts/",
# 'load_ckpt_info' setting guide:
# 1. the 'path' indicate ckpt path,
# 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined
# load function such as "llama"
- load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="hf"),
+ load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"),
# 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
# training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
# with an automatic restart mechanism upon training reboot.
@@ -37,7 +36,7 @@
# path specified in `load_ckpt_info` by default.
# If you want to initialize your model weights from another model, you must set `auto_resume` to False.
# If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None.
- auto_resume=False,
+ auto_resume=True,
checkpoint_every=CHECKPOINT_EVERY,
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
@@ -57,7 +56,7 @@
# defaults to 0, means disable evaluate
valid_every=0,
pack_sample_into_one=False,
- total_steps=20,
+ total_steps=20000,
skip_batches="",
# rampup_batch_size (str): A string with three space-separated integers representing the
# starting batch size, the increment, and the number of steps between
@@ -102,9 +101,21 @@
clip_grad_norm=1.0,
)
-loss = dict(
- label_smoothing=0,
-)
+
+# loss config (dict):
+# 1. label_smoothing
+# 2. op_type: cross_entropy operator type, we support five types for loss computing,
+# including ["torch_naive", "apex_naive", "py_naive", "flash_vocab_parallel", "py_vocab_parallel"]
+# default is "py_vocab_parallel".
+# "torch_naive": cross_entropy imported from torch, i.e. torch.nn.CrossEntropyLoss
+# "apex_naive": cross_entropy from apex
+# "py_naive": self-implemented cross_entropy
+# "flash_vocab_parallel": vocab parallel cross_entropy imported from flash_attn
+# "py_vocab_parallel": self-implemented vocab parallel cross_entropy
+
+# * op_types that ends with "naive" only support parallel_output=False;
+# * if in no-GPU env, only "torch_naive" and "py_vocab_parallel" are supported.
+loss = dict(label_smoothing=0, op_type="py_vocab_parallel")
adam = dict(
lr=1e-4,
@@ -130,33 +141,51 @@
)
use_fp32_norm = False
+
model = dict(
- checkpoint=False,
- num_chunks=1,
- num_attention_heads=NUM_ATTENTION_HEAD,
- embed_split_hidden=True,
- vocab_size=VOCAB_SIZE,
- embed_grad_scale=1,
- parallel_output=True,
- hidden_size=HIDDEN_SIZE,
- num_layers=NUM_LAYER,
- no_bias=True,
- mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
dtype="torch.bfloat16",
- norm_type="rmsnorm",
- layer_norm_epsilon=1e-6,
- use_flash_attn=True,
- # Whether the odd and even columns of the query and key in the model are normally interleaved.
- # If it's True, the model's odd and even columns are normally ordered; if it's False,
- # it means that the model has prematurely concatenated all odd columns and even columns in front
- # and back, in order to improve the RoPE's computational efficiency.
- # Example:
- # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...]
- # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...]
- qk_interleaved=False,
+ checkpoint=0,
+ parallel_output=True,
)
+hf = dict(
+ cfg="huggingface_models.internlm2_model.configuration_internlm2",
+ cfg_cls="InternLM2Config",
+ cfg_extra_kwargs=dict(
+ vocab_size=VOCAB_SIZE,
+ hidden_size=4096,
+ intermediate_size=14336,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=8,
+ hidden_act="silu",
+ max_position_embeddings=4096,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=False,
+ pad_token_id=None, # We actually didn't use pad_token_id in this framework
+ # bos_token_id=1,
+ # eos_token_id=2,
+ # pretraining_tp=1,
+ tie_word_embeddings=False,
+ bias=False,
+ rope_theta=1000000,
+ rope_scaling=None,
+ attn_implementation="flash_attention_2",
+ dtype=model["dtype"],
+ return_dict=False,
+ ),
+ mod="huggingface_models.internlm2_model.modeling_internlm2",
+ mod_cls="InternLM2ForCausalLM",
+)
+
+fsdp_wrap_cls = [
+ dict(
+ mod=hf["mod"],
+ mod_cls="InternLM2DecoderLayer",
+ ),
+]
+
"""
zero1 parallel (dict):
1. size: int
@@ -176,14 +205,16 @@
1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
defaults to False.
+ 3. mode: str, the pipeline parallel mode, should be in ['1f1b', 'zbh1', 'zbv']. The defalut is 1f1b.
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
"""
parallel = dict(
- zero1=dict(size=-1),
+ fsdp=dict(enable=True, mode="v1", init_method="meta"),
+ zero1=dict(size=1),
tensor=dict(size=1, mode="mtp"),
- pipeline=dict(size=1, interleaved_overlap=True),
+ pipeline=dict(size=1, interleaved_overlap=True, mode="1f1b"),
weight=dict(size=1, overlap=True),
)
@@ -221,3 +252,11 @@
repetition_penalty=1,
length_penalty=1.0,
)
+
+
+# fp8 = dict(
+# margin=0,
+# fp8_format="HYBRID",
+# amax_history_len=1024,
+# amax_compute_algo="max",
+# )
diff --git a/configs/7B_isp_sft.py b/configs/7B_isp_sft.py
index f269ab4e2..746dbafcc 100644
--- a/configs/7B_isp_sft.py
+++ b/configs/7B_isp_sft.py
@@ -163,7 +163,6 @@
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
num_attention_heads=NUM_ATTENTION_HEAD,
num_kv_attention_heads=NUM_KV_ATTENTION_HEAD,
- embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
@@ -171,7 +170,6 @@
num_layers=NUM_LAYER,
no_bias=True,
mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
diff --git a/configs/7B_llama2.py b/configs/7B_llama2.py
index 7783abaf7..0161d78e0 100644
--- a/configs/7B_llama2.py
+++ b/configs/7B_llama2.py
@@ -130,7 +130,6 @@
checkpoint=False,
num_chunks=1,
num_attention_heads=NUM_ATTENTION_HEAD,
- embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
@@ -138,7 +137,6 @@
num_layers=NUM_LAYER,
no_bias=True,
mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
@@ -152,7 +150,7 @@
# qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...]
# qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...]
qk_interleaved=False,
- mlp_layer_fusion=True,
+ mlp_layer_fusion=False,
enable_qkv_fusion=True,
)
diff --git a/configs/7B_qwen2.py b/configs/7B_qwen2.py
deleted file mode 100644
index 3622e12f1..000000000
--- a/configs/7B_qwen2.py
+++ /dev/null
@@ -1,230 +0,0 @@
-JOB_NAME = "7b_qwen2_train"
-model_type = "QWEN2"
-DO_ALERT = False
-
-VOCAB_SIZE = 152064
-SEQ_LEN = 2048
-HIDDEN_SIZE = 3584
-NUM_ATTENTION_HEAD = 28
-NUM_KV_ATTENTION_HEAD = 4
-MLP_RATIO = 5.25
-NUM_LAYER = 28
-
-
-MODEL_ONLY_FOLDER = "local:llm_ckpts_qwen2/xxxx/"
-# Ckpt folder format:
-# fs: 'local:/mnt/nfs/XXX'
-SAVE_CKPT_FOLDER = "local:llm_ckpts_qwen2"
-
-# boto3 Ckpt folder format:
-# import os
-# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
-# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
-CHECKPOINT_EVERY = 50
-ckpt = dict(
- enable_save_ckpt=False, # enable ckpt save.
- enable_internevo2hf_ckpt=False, # enable ckpt save for huggingface format.
- save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
- # 'load_ckpt_info' setting guide:
- # 1. the 'path' indicate ckpt path,
- # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
- # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined
- # load function such as "llama"
- load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="hf"),
- # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
- # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
- # with an automatic restart mechanism upon training reboot.
- # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint
- # path specified in `load_ckpt_info` by default.
- # If you want to initialize your model weights from another model, you must set `auto_resume` to False.
- # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None.
- auto_resume=False,
- checkpoint_every=CHECKPOINT_EVERY,
- async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
- async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
- oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
-)
-
-TRAIN_FOLDER = None
-VALID_FOLDER = None # "/path/to/dataset"
-data = dict(
- seq_len=SEQ_LEN,
- # micro_num means the number of micro_batch contained in one gradient update
- micro_num=4,
- # packed_length = micro_bsz * SEQ_LEN
- micro_bsz=1,
- # defaults to the value of micro_num
- valid_micro_num=4,
- # defaults to 0, means disable evaluate
- valid_every=0,
- pack_sample_into_one=False,
- total_steps=20,
- skip_batches="",
- # rampup_batch_size (str): A string with three space-separated integers representing the
- # starting batch size, the increment, and the number of steps between
- # each increment. For example, "192 24 8" means that the batch size (micro_num)
- # starts at 192 and increases by 24 every 8 steps. Defaults to None.
- # (IMPORTANT): The interval step size is 'micro_bsz'.
- rampup_batch_size="",
- # Datasets with less than 50 rows will be discarded
- min_length=50,
- train_folder=TRAIN_FOLDER,
- valid_folder=VALID_FOLDER,
- empty_cache_and_diag_interval=200,
- diag_outlier_ratio=1.1,
-)
-
-grad_scaler = dict(
- fp16=dict(
- # the initial loss scale, defaults to 2**16
- initial_scale=2**16,
- # the minimum loss scale, defaults to None
- min_scale=1,
- # the number of steps to increase loss scale when no overflow occurs
- growth_interval=1000,
- ),
- # the multiplication factor for increasing loss scale, defaults to 2
- growth_factor=2,
- # the multiplication factor for decreasing loss scale, defaults to 0.5
- backoff_factor=0.5,
- # the maximum loss scale, defaults to None
- max_scale=2**24,
- # the number of overflows before decreasing loss scale, defaults to 2
- hysteresis=2,
-)
-
-hybrid_zero_optimizer = dict(
- # Enable low_level_optimzer overlap_communication
- overlap_sync_grad=True,
- overlap_sync_param=False,
- # bucket size for nccl communication params
- reduce_bucket_size=512 * 1024 * 1024,
- # grad clipping
- clip_grad_norm=1.0,
-)
-
-loss = dict(
- label_smoothing=0,
-)
-
-adam = dict(
- lr=1e-4,
- adam_beta1=0.9,
- adam_beta2=0.95,
- adam_beta2_c=0,
- adam_eps=1e-8,
- weight_decay=0.01,
-)
-
-lr_scheduler = dict(
- total_steps=data["total_steps"],
- init_steps=0, # optimizer_warmup_step
- warmup_ratio=0.01,
- eta_min=1e-5,
- last_epoch=-1,
-)
-
-beta2_scheduler = dict(
- init_beta2=adam["adam_beta2"],
- c=adam["adam_beta2_c"],
- cur_iter=-1,
-)
-
-use_fp32_norm = False
-model = dict(
- checkpoint=False,
- num_chunks=1,
- num_attention_heads=NUM_ATTENTION_HEAD,
- num_kv_attention_heads=NUM_KV_ATTENTION_HEAD,
- embed_split_hidden=True,
- vocab_size=VOCAB_SIZE,
- embed_grad_scale=1,
- parallel_output=True,
- hidden_size=HIDDEN_SIZE,
- num_layers=NUM_LAYER,
- qkv_bias=True,
- o_bias=False,
- mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
- dtype="torch.bfloat16",
- norm_type="rmsnorm",
- layer_norm_epsilon=1e-6,
- use_flash_attn=True,
- # Whether the odd and even columns of the query and key in the model are normally interleaved.
- # If it's True, the model's odd and even columns are normally ordered; if it's False,
- # it means that the model has prematurely concatenated all odd columns and even columns in front
- # and back, in order to improve the RoPE's computational efficiency.
- # Example:
- # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...]
- # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...]
- qk_interleaved=False,
- rope_base=1000000,
- use_sliding_window=False,
- sliding_window=32768,
- max_window_layers=28,
-)
-
-"""
-zero1 parallel (dict):
- 1. size: int
- * if size <= 0, the size of the zero process group is equal to the size of the dp process group,
- so parameters will be divided within the range of dp.
- * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
- * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
- For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
-tensor parallel (dict):
- 1. size: int, the size of tensor parallel.
- 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],
- defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel.
- msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size.
- fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size.
- isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel.
-pipeline parallel (dict):
- 1. size: int, the size of pipeline parallel.
- 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
- defaults to False.
-weight parallel (dict):
- 1. size: int, the size of weight parallel.
- 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
-"""
-parallel = dict(
- zero1=dict(size=-1),
- tensor=dict(size=1, mode="mtp"),
- pipeline=dict(size=1, interleaved_overlap=True),
- weight=dict(size=1, overlap=True),
-)
-
-cudnn_deterministic = False
-cudnn_benchmark = False
-
-monitor = dict(
- # feishu alert configs
- alert=dict(
- enable_feishu_alert=DO_ALERT,
- feishu_alert_address=None, # feishu webhook to send alert message
- light_monitor_address=None, # light_monitor address to send heartbeat
- alert_file_path=f"llm_alter/{JOB_NAME}_alert.log",
- ),
- tensorboard=dict(
- queue_max_length=10,
- ),
-)
-
-# metric_dtype can be "fp32" or other string
-# only when set to "fp32" will use fp32 to calc in metrics
-# metric_dtype = "fp32"
-
-generation = dict(
- ckpt_folder="/path/to/saved/ckpt",
- output_folder="/path/to/save/generation",
- batch_size=1,
- eos_id=[2, 0],
- bos_id=1,
- max_length=100,
- do_sample=True,
- temperature=1.0,
- top_k=50,
- top_p=1.0,
- repetition_penalty=1,
- length_penalty=1.0,
-)
diff --git a/configs/7B_sft.py b/configs/7B_sft.py
index 27847a5e8..0a19f137c 100644
--- a/configs/7B_sft.py
+++ b/configs/7B_sft.py
@@ -22,16 +22,7 @@
CHECKPOINT_EVERY = 50
ckpt = dict(
enable_save_ckpt=False, # enable ckpt save.
- enable_internevo2hf_ckpt=False, # enable ckpt save for huggingface format.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
- # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"),
- load_ckpt_folder="local:llm_ckpts/",
- # 'load_ckpt_info' setting guide:
- # 1. the 'path' indicate ckpt path,
- # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
- # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined
- # load function such as "llama"
- load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"),
# 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
# training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
# with an automatic restart mechanism upon training reboot.
@@ -39,7 +30,7 @@
# path specified in `load_ckpt_info` by default.
# If you want to initialize your model weights from another model, you must set `auto_resume` to False.
# If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None.
- auto_resume=True,
+ auto_resume=False,
checkpoint_every=CHECKPOINT_EVERY,
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
@@ -144,14 +135,12 @@
model = dict(
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
num_attention_heads=NUM_ATTENTION_HEAD,
- embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
diff --git a/configs/8B_internlm3.py b/configs/8B_internlm3.py
index acb04d446..9f5840c05 100644
--- a/configs/8B_internlm3.py
+++ b/configs/8B_internlm3.py
@@ -153,7 +153,6 @@
num_chunks=1,
num_attention_heads=NUM_ATTENTION_HEAD,
num_kv_attention_heads=NUM_KV_ATTENTION_HEAD,
- embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
@@ -161,7 +160,6 @@
num_layers=NUM_LAYER,
no_bias=True,
mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
diff --git a/configs/8x22B_mixtral.py b/configs/8x22B_mixtral.py
deleted file mode 100644
index f1f1b6e60..000000000
--- a/configs/8x22B_mixtral.py
+++ /dev/null
@@ -1,227 +0,0 @@
-JOB_NAME = "22b_moe_mixtral"
-model_type = "MIXTRALMOE"
-DO_ALERT = False
-
-SEQ_LEN = 4096
-HIDDEN_SIZE = 6144
-NUM_ATTENTION_HEAD = 48
-NUM_KV_ATTENTION_HEAD = 8
-MLP_RATIO = 8 / 3
-NUM_LAYER = 56
-VOCAB_SIZE = 32000
-
-MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
-# Ckpt folder format:
-# fs: 'local:/mnt/nfs/XXX'
-SAVE_CKPT_FOLDER = "local:llm_ckpts"
-LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
-
-# boto3 Ckpt folder format:
-# import os
-# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
-# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
-# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
-CHECKPOINT_EVERY = 50
-ckpt = dict(
- enable_save_ckpt=False, # enable ckpt save.
- save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
- # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"),
- load_ckpt_folder="local:llm_ckpts/",
- # 'load_ckpt_info' setting guide:
- # 1. the 'path' indicate ckpt path,
- # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
- # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined
- # load function such as "llama"
- load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"),
- # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
- # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
- # with an automatic restart mechanism upon training reboot.
- # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint
- # path specified in `load_ckpt_info` by default.
- # If you want to initialize your model weights from another model, you must set `auto_resume` to False.
- # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None.
- auto_resume=True,
- checkpoint_every=CHECKPOINT_EVERY,
- async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
- async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
- oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
-)
-
-TRAIN_FOLDER = None # "/path/to/dataset"
-VALID_FOLDER = None # "/path/to/dataset"
-data = dict(
- seq_len=SEQ_LEN,
- # micro_num means the number of micro_batch contained in one gradient update
- micro_num=4,
- # packed_length = micro_bsz * SEQ_LEN
- micro_bsz=2,
- # defaults to the value of micro_num
- valid_micro_num=4,
- # defaults to 0, means disable evaluate
- valid_every=50,
- pack_sample_into_one=False,
- total_steps=50000,
- skip_batches="",
- # rampup_batch_size (str): A string with three space-separated integers representing the
- # starting batch size, the increment, and the number of steps between
- # each increment. For example, "192 24 8" means that the batch size (micro_num)
- # starts at 192 and increases by 24 every 8 steps. Defaults to None.
- # (IMPORTANT): The interval step size is 'micro_bsz'.
- rampup_batch_size="",
- # Datasets with less than 50 rows will be discarded
- min_length=50,
- train_folder=TRAIN_FOLDER,
- valid_folder=VALID_FOLDER,
- empty_cache_and_diag_interval=200,
- diag_outlier_ratio=1.1,
-)
-
-grad_scaler = dict(
- fp16=dict(
- # the initial loss scale, defaults to 2**16
- initial_scale=2**16,
- # the minimum loss scale, defaults to None
- min_scale=1,
- # the number of steps to increase loss scale when no overflow occurs
- growth_interval=1000,
- ),
- # the multiplication factor for increasing loss scale, defaults to 2
- growth_factor=2,
- # the multiplication factor for decreasing loss scale, defaults to 0.5
- backoff_factor=0.5,
- # the maximum loss scale, defaults to None
- max_scale=2**24,
- # the number of overflows before decreasing loss scale, defaults to 2
- hysteresis=2,
-)
-
-hybrid_zero_optimizer = dict(
- # Enable low_level_optimzer overlap_communication
- overlap_sync_grad=False,
- overlap_sync_param=False,
- # bucket size for nccl communication params
- reduce_bucket_size=512 * 1024 * 1024,
- # grad clipping
- clip_grad_norm=1.0,
-)
-
-loss = dict(
- label_smoothing=0,
- moe_loss_coeff=0.001,
-)
-
-adam = dict(
- lr=1e-4,
- adam_beta1=0.9,
- adam_beta2=0.95,
- adam_beta2_c=0,
- adam_eps=1e-8,
- weight_decay=0.01,
-)
-
-lr_scheduler = dict(
- total_steps=data["total_steps"],
- init_steps=0, # optimizer_warmup_step
- warmup_ratio=0.01,
- eta_min=1e-5,
- last_epoch=-1,
-)
-
-beta2_scheduler = dict(
- init_beta2=adam["adam_beta2"],
- c=adam["adam_beta2_c"],
- cur_iter=-1,
-)
-
-use_fp32_norm = False
-model = dict(
- checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
- num_attention_heads=NUM_ATTENTION_HEAD,
- num_kv_attention_heads=NUM_KV_ATTENTION_HEAD,
- max_position_embeddings=65536,
- embed_split_hidden=True,
- vocab_size=VOCAB_SIZE,
- embed_grad_scale=1,
- parallel_output=True,
- hidden_size=HIDDEN_SIZE,
- num_layers=NUM_LAYER,
- qkv_bias=False,
- o_bias=False,
- mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
- dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
- norm_type="rmsnorm",
- layer_norm_epsilon=1e-5,
- use_flash_attn=True,
- # Whether the odd and even columns of the query and key in the model are normally interleaved.
- # If it's True, the model's odd and even columns are normally ordered; if it's False,
- # it means that the model has prematurely concatenated all odd columns and even columns in front
- # and back, in order to improve the RoPE's computational efficiency.
- # Example:
- # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...]
- # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...]
- qk_interleaved=False,
- use_sliding_window=False,
- rope_base=1000000,
- num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
- moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-D", "Dropless"
- num_experts=8,
- top_k=2,
-)
-"""
-zero1 parallel (dict):
- 1. size: int
- * if size <= 0, the size of the zero process group is equal to the size of the dp process group,
- so parameters will be divided within the range of dp.
- * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
- * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
- For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
-tensor parallel (dict):
- 1. size: int, the size of tensor parallel.
- 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],
- defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel.
- msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size.
- fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size.
- isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel.
-pipeline parallel (dict):
- 1. size: int, the size of pipeline parallel.
- 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
- defaults to False.
-weight parallel (dict):
- 1. size: int, the size of weight parallel.
- 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
-expert parallel (dict):
- 1. size: int
- * if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size
- to be the number of experts to make sure each device has one expert.
- * if size == 1, all experts are placed in each device, running as dp-only.
- * if size > 1, all experts are placed in k devices and each device has n/k experts, where n is the total
- number of experts and k = size.
-expert weight parallel (dict):
- 1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size.
- 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
-"""
-parallel = dict(
- zero1=dict(size=-1),
- tensor=dict(size=1, mode="mtp"),
- pipeline=dict(size=1, interleaved_overlap=True),
- weight=dict(size=1, overlap=True),
- expert=dict(size=-1, no_tp=False),
- expert_weight=dict(size=1, overlap=True),
-)
-
-cudnn_deterministic = False
-cudnn_benchmark = False
-
-monitor = dict(
- # feishu alert configs
- alert=dict(
- enable_feishu_alert=DO_ALERT,
- feishu_alert_address=None, # feishu webhook to send alert message
- light_monitor_address=None, # light_monitor address to send heartbeat
- alert_file_path=f"llm_alter/{JOB_NAME}_alert.log",
- ),
- tensorboard=dict(
- queue_max_length=10,
- ),
-)
diff --git a/configs/8x7B_mixtral.py b/configs/8x7B_mixtral.py
deleted file mode 100644
index 6db43f9c6..000000000
--- a/configs/8x7B_mixtral.py
+++ /dev/null
@@ -1,227 +0,0 @@
-JOB_NAME = "7b_moe_mixtral"
-model_type = "MIXTRALMOE"
-DO_ALERT = False
-
-SEQ_LEN = 4096
-HIDDEN_SIZE = 4096
-NUM_ATTENTION_HEAD = 32
-NUM_KV_ATTENTION_HEAD = 8
-MLP_RATIO = 3.5
-NUM_LAYER = 32
-VOCAB_SIZE = 32000
-
-MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
-# Ckpt folder format:
-# fs: 'local:/mnt/nfs/XXX'
-SAVE_CKPT_FOLDER = "local:llm_ckpts"
-LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
-
-# boto3 Ckpt folder format:
-# import os
-# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
-# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
-# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
-CHECKPOINT_EVERY = 50
-ckpt = dict(
- enable_save_ckpt=False, # enable ckpt save.
- save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
- # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"),
- load_ckpt_folder="local:llm_ckpts/",
- # 'load_ckpt_info' setting guide:
- # 1. the 'path' indicate ckpt path,
- # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
- # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined
- # load function such as "llama"
- load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"),
- # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
- # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
- # with an automatic restart mechanism upon training reboot.
- # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint
- # path specified in `load_ckpt_info` by default.
- # If you want to initialize your model weights from another model, you must set `auto_resume` to False.
- # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None.
- auto_resume=True,
- checkpoint_every=CHECKPOINT_EVERY,
- async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
- async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
- oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
-)
-
-TRAIN_FOLDER = None # "/path/to/dataset"
-VALID_FOLDER = None # "/path/to/dataset"
-data = dict(
- seq_len=SEQ_LEN,
- # micro_num means the number of micro_batch contained in one gradient update
- micro_num=4,
- # packed_length = micro_bsz * SEQ_LEN
- micro_bsz=2,
- # defaults to the value of micro_num
- valid_micro_num=4,
- # defaults to 0, means disable evaluate
- valid_every=50,
- pack_sample_into_one=False,
- total_steps=50000,
- skip_batches="",
- # rampup_batch_size (str): A string with three space-separated integers representing the
- # starting batch size, the increment, and the number of steps between
- # each increment. For example, "192 24 8" means that the batch size (micro_num)
- # starts at 192 and increases by 24 every 8 steps. Defaults to None.
- # (IMPORTANT): The interval step size is 'micro_bsz'.
- rampup_batch_size="",
- # Datasets with less than 50 rows will be discarded
- min_length=50,
- train_folder=TRAIN_FOLDER,
- valid_folder=VALID_FOLDER,
- empty_cache_and_diag_interval=200,
- diag_outlier_ratio=1.1,
-)
-
-grad_scaler = dict(
- fp16=dict(
- # the initial loss scale, defaults to 2**16
- initial_scale=2**16,
- # the minimum loss scale, defaults to None
- min_scale=1,
- # the number of steps to increase loss scale when no overflow occurs
- growth_interval=1000,
- ),
- # the multiplication factor for increasing loss scale, defaults to 2
- growth_factor=2,
- # the multiplication factor for decreasing loss scale, defaults to 0.5
- backoff_factor=0.5,
- # the maximum loss scale, defaults to None
- max_scale=2**24,
- # the number of overflows before decreasing loss scale, defaults to 2
- hysteresis=2,
-)
-
-hybrid_zero_optimizer = dict(
- # Enable low_level_optimzer overlap_communication
- overlap_sync_grad=False,
- overlap_sync_param=False,
- # bucket size for nccl communication params
- reduce_bucket_size=512 * 1024 * 1024,
- # grad clipping
- clip_grad_norm=1.0,
-)
-
-loss = dict(
- label_smoothing=0,
- moe_loss_coeff=0.02,
-)
-
-adam = dict(
- lr=1e-4,
- adam_beta1=0.9,
- adam_beta2=0.95,
- adam_beta2_c=0,
- adam_eps=1e-8,
- weight_decay=0.01,
-)
-
-lr_scheduler = dict(
- total_steps=data["total_steps"],
- init_steps=0, # optimizer_warmup_step
- warmup_ratio=0.01,
- eta_min=1e-5,
- last_epoch=-1,
-)
-
-beta2_scheduler = dict(
- init_beta2=adam["adam_beta2"],
- c=adam["adam_beta2_c"],
- cur_iter=-1,
-)
-
-use_fp32_norm = False
-model = dict(
- checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
- num_attention_heads=NUM_ATTENTION_HEAD,
- num_kv_attention_heads=NUM_KV_ATTENTION_HEAD,
- max_position_embeddings=32768,
- embed_split_hidden=True,
- vocab_size=VOCAB_SIZE,
- embed_grad_scale=1,
- parallel_output=True,
- hidden_size=HIDDEN_SIZE,
- num_layers=NUM_LAYER,
- qkv_bias=False,
- o_bias=False,
- mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
- dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
- norm_type="rmsnorm",
- layer_norm_epsilon=1e-5,
- use_flash_attn=True,
- # Whether the odd and even columns of the query and key in the model are normally interleaved.
- # If it's True, the model's odd and even columns are normally ordered; if it's False,
- # it means that the model has prematurely concatenated all odd columns and even columns in front
- # and back, in order to improve the RoPE's computational efficiency.
- # Example:
- # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...]
- # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...]
- qk_interleaved=False,
- use_sliding_window=False,
- rope_base=1000000,
- num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
- moe_type="GShard", # Support: "GShard", "MegaBlock", "MegaBlock-D", "Dropless"
- num_experts=8,
- top_k=2,
-)
-"""
-zero1 parallel (dict):
- 1. size: int
- * if size <= 0, the size of the zero process group is equal to the size of the dp process group,
- so parameters will be divided within the range of dp.
- * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
- * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
- For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
-tensor parallel (dict):
- 1. size: int, the size of tensor parallel.
- 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],
- defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel.
- msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size.
- fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size.
- isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel.
-pipeline parallel (dict):
- 1. size: int, the size of pipeline parallel.
- 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
- defaults to False.
-weight parallel (dict):
- 1. size: int, the size of weight parallel.
- 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
-expert parallel (dict):
- 1. size: int
- * if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size
- to be the number of experts to make sure each device has one expert.
- * if size == 1, all experts are placed in each device, running as dp-only.
- * if size > 1, all experts are placed in k devices and each device has n/k experts, where n is the total
- number of experts and k = size.
-expert weight parallel (dict):
- 1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size.
- 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
-"""
-parallel = dict(
- zero1=dict(size=-1),
- tensor=dict(size=1, mode="mtp"),
- pipeline=dict(size=1, interleaved_overlap=True),
- weight=dict(size=1, overlap=True),
- expert=dict(size=-1, no_tp=False),
- expert_weight=dict(size=1, overlap=True),
-)
-
-cudnn_deterministic = False
-cudnn_benchmark = False
-
-monitor = dict(
- # feishu alert configs
- alert=dict(
- enable_feishu_alert=DO_ALERT,
- feishu_alert_address=None, # feishu webhook to send alert message
- light_monitor_address=None, # light_monitor address to send heartbeat
- alert_file_path=f"llm_alter/{JOB_NAME}_alert.log",
- ),
- tensorboard=dict(
- queue_max_length=10,
- ),
-)
diff --git a/configs/_base_/models/internlm2_1B.py b/configs/_base_/models/internlm2_1B.py
index cc3f186ad..f4cfef8aa 100644
--- a/configs/_base_/models/internlm2_1B.py
+++ b/configs/_base_/models/internlm2_1B.py
@@ -14,7 +14,6 @@
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
checkpoint=0.2, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
- embed_split_hidden=True,
num_layers=NUM_LAYER,
hidden_size=HIDDEN_SIZE,
vocab_size=VOCAB_SIZE,
@@ -26,7 +25,6 @@
multiple_of=MULTIPLE_OF,
norm_type="rmsnorm",
qk_interleaved=False,
- apply_post_layer_norm=False,
no_bias=True,
layer_norm_epsilon=1e-5,
rope_base=1000000,
diff --git a/configs/_base_/models/internlm2_20B.py b/configs/_base_/models/internlm2_20B.py
index dc461c0da..f0fea954e 100644
--- a/configs/_base_/models/internlm2_20B.py
+++ b/configs/_base_/models/internlm2_20B.py
@@ -13,7 +13,6 @@
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
checkpoint=1.0, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
- embed_split_hidden=True,
num_layers=NUM_LAYER,
hidden_size=HIDDEN_SIZE,
vocab_size=VOCAB_SIZE,
@@ -24,7 +23,6 @@
mlp_ratio=MLP_RATIO,
norm_type="rmsnorm",
qk_interleaved=False,
- apply_post_layer_norm=False,
no_bias=True,
layer_norm_epsilon=1e-5,
rope_base=1000000,
diff --git a/configs/_base_/models/internlm2_7B.py b/configs/_base_/models/internlm2_7B.py
index cbdb03cb1..06b27693b 100644
--- a/configs/_base_/models/internlm2_7B.py
+++ b/configs/_base_/models/internlm2_7B.py
@@ -13,7 +13,6 @@
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
checkpoint=0.2, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
- embed_split_hidden=True,
num_layers=NUM_LAYER,
hidden_size=HIDDEN_SIZE,
vocab_size=VOCAB_SIZE,
@@ -24,7 +23,6 @@
mlp_ratio=MLP_RATIO,
norm_type="rmsnorm",
qk_interleaved=True,
- apply_post_layer_norm=False,
no_bias=True,
layer_norm_epsilon=1e-5,
rope_base=1000000,
diff --git a/configs/_base_/models/internlm_20B.py b/configs/_base_/models/internlm_20B.py
index 26f4ff7f8..2f7ff0c8c 100644
--- a/configs/_base_/models/internlm_20B.py
+++ b/configs/_base_/models/internlm_20B.py
@@ -12,7 +12,6 @@
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
- embed_split_hidden=True,
num_layers=NUM_LAYER,
hidden_size=HIDDEN_SIZE,
vocab_size=VOCAB_SIZE,
@@ -21,7 +20,6 @@
num_attention_heads=NUM_ATTENTION_HEAD,
mlp_ratio=MLP_RATIO,
norm_type="rmsnorm",
- apply_post_layer_norm=False,
layer_norm_epsilon=1e-5,
)
diff --git a/configs/_base_/models/internlm_7B.py b/configs/_base_/models/internlm_7B.py
index 8dde6e4e4..4b63c7ded 100644
--- a/configs/_base_/models/internlm_7B.py
+++ b/configs/_base_/models/internlm_7B.py
@@ -12,7 +12,6 @@
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
- embed_split_hidden=True,
num_layers=NUM_LAYER,
hidden_size=HIDDEN_SIZE,
vocab_size=VOCAB_SIZE,
@@ -21,7 +20,6 @@
num_attention_heads=NUM_ATTENTION_HEAD,
mlp_ratio=MLP_RATIO,
norm_type="rmsnorm",
- apply_post_layer_norm=False,
layer_norm_epsilon=1e-5,
)
diff --git a/configs/demo_llava.py b/configs/demo_llava.py
deleted file mode 100644
index e138e886a..000000000
--- a/configs/demo_llava.py
+++ /dev/null
@@ -1,191 +0,0 @@
-JOB_NAME = "llava_train"
-model_type = "LLAVA"
-DO_ALERT = False
-
-VOCAB_SIZE = 32000
-SEQ_LEN = 2048
-HIDDEN_SIZE = 4096
-NUM_ATTENTION_HEAD = 32
-NUM_KV_ATTENTION_HEAD = 8
-MLP_RATIO = 3.5
-NUM_LAYER = 32
-
-
-MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
-# Ckpt folder format:
-# fs: 'local:/mnt/nfs/XXX'
-SAVE_CKPT_FOLDER = "local:llm_ckpts"
-LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
-
-# boto3 Ckpt folder format:
-# import os
-# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
-# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
-# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
-CHECKPOINT_EVERY = 50
-ckpt = dict(
- enable_save_ckpt=False, # enable ckpt save.
- save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
- # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
- # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
- # with an automatic restart mechanism upon training reboot.
- # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint
- # path specified in `load_ckpt_info` by default.
- # If you want to initialize your model weights from another model, you must set `auto_resume` to False.
- # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None.
- auto_resume=False,
- checkpoint_every=CHECKPOINT_EVERY,
- async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
- async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
- oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
-)
-
-TRAIN_FOLDER = None
-VALID_FOLDER = None # "/path/to/dataset"
-data = dict(
- is_multimodal=True,
- seq_len=SEQ_LEN,
- # micro_num means the number of micro_batch contained in one gradient update
- micro_num=4,
- packed_length=SEQ_LEN,
- micro_bsz=1,
- # defaults to the value of micro_num
- valid_micro_num=4,
- # defaults to 0, means disable evaluate
- valid_every=0,
- pack_sample_into_one=False,
- total_steps=200,
- skip_batches="",
- # rampup_batch_size (str): A string with three space-separated integers representing the
- # starting batch size, the increment, and the number of steps between
- # each increment. For example, "192 24 8" means that the batch size (micro_num)
- # starts at 192 and increases by 24 every 8 steps. Defaults to None.
- # (IMPORTANT): The interval step size is 'micro_bsz'.
- rampup_batch_size="",
- # Datasets with less than 50 rows will be discarded
- min_length=50,
- train_folder=TRAIN_FOLDER,
- valid_folder=VALID_FOLDER,
- empty_cache_and_diag_interval=200,
- diag_outlier_ratio=1.1,
- image_size=336,
- patch_size=14,
-)
-
-grad_scaler = dict(
- fp16=dict(
- # the initial loss scale, defaults to 2**16
- initial_scale=2**16,
- # the minimum loss scale, defaults to None
- min_scale=1,
- # the number of steps to increase loss scale when no overflow occurs
- growth_interval=1000,
- ),
- # the multiplication factor for increasing loss scale, defaults to 2
- growth_factor=2,
- # the multiplication factor for decreasing loss scale, defaults to 0.5
- backoff_factor=0.5,
- # the maximum loss scale, defaults to None
- max_scale=2**24,
- # the number of overflows before decreasing loss scale, defaults to 2
- hysteresis=2,
-)
-
-hybrid_zero_optimizer = dict(
- # Enable low_level_optimzer overlap_communication
- overlap_sync_grad=True,
- overlap_sync_param=False,
- # bucket size for nccl communication params
- reduce_bucket_size=512 * 1024 * 1024,
- # grad clipping
- clip_grad_norm=1.0,
-)
-
-loss = dict(
- label_smoothing=0,
-)
-
-adam = dict(
- lr=1e-4,
- adam_beta1=0.9,
- adam_beta2=0.95,
- adam_beta2_c=0,
- adam_eps=1e-8,
- weight_decay=0.01,
-)
-
-lr_scheduler = dict(
- total_steps=data["total_steps"],
- init_steps=0, # optimizer_warmup_step
- warmup_ratio=0.01,
- eta_min=1e-5,
- last_epoch=-1,
-)
-
-beta2_scheduler = dict(
- init_beta2=adam["adam_beta2"],
- c=adam["adam_beta2_c"],
- cur_iter=-1,
-)
-
-use_fp32_norm = False
-model = dict(
- checkpoint=False,
- num_chunks=1,
- num_attention_heads=NUM_ATTENTION_HEAD,
- embed_split_hidden=True,
- vocab_size=VOCAB_SIZE,
- embed_grad_scale=1,
- parallel_output=True,
- hidden_size=HIDDEN_SIZE,
- num_layers=NUM_LAYER,
- no_bias=True,
- mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
- dtype="torch.bfloat16",
- norm_type="rmsnorm",
- layer_norm_epsilon=1e-5,
- num_kv_attention_heads=NUM_KV_ATTENTION_HEAD,
- use_flash_attn=True,
- image_token_id=200000,
- vit_cfg=dict(
- mm_projector_type="mlp2x_gelu",
- mm_use_im_patch_token=True,
- mm_use_im_start_end=True,
- mm_vision_select_feature="patch",
- mm_vision_select_layer=-2,
- mm_vision_tower="openai/clip-vit-large-patch14-336",
- ),
- vision_proj_cfg=dict(
- mm_projector_type="mlp2x_gelu",
- mm_hidden_size=1024, # vit hidden_size
- hidden_size=HIDDEN_SIZE, # llm hidden_size
- ),
-)
-
-parallel = dict(
- zero1=dict(size=-1),
- tensor=dict(size=1, mode="mtp"),
- pipeline=dict(size=1, interleaved_overlap=True),
- weight=dict(size=1, overlap=True),
-)
-
-cudnn_deterministic = False
-cudnn_benchmark = False
-
-monitor = dict(
- # feishu alert configs
- alert=dict(
- enable_feishu_alert=DO_ALERT,
- feishu_alert_address=None, # feishu webhook to send alert message
- light_monitor_address=None, # light_monitor address to send heartbeat
- alert_file_path=f"llm_alter/{JOB_NAME}_alert.log",
- ),
- tensorboard=dict(
- queue_max_length=10,
- ),
-)
-
-# metric_dtype can be "fp32" or other string
-# only when set to "fp32" will use fp32 to calc in metrics
-# metric_dtype = "fp32"
diff --git a/doc/code-docs/locales/en/LC_MESSAGES/training.po b/doc/code-docs/locales/en/LC_MESSAGES/training.po
index 25b4a4927..fc59d8c13 100644
--- a/doc/code-docs/locales/en/LC_MESSAGES/training.po
+++ b/doc/code-docs/locales/en/LC_MESSAGES/training.po
@@ -68,10 +68,10 @@ msgstr "Initialize Distributed Training Environment"
#: ../../source/training.rst:23
msgid ""
-"调用 ``initialize_distributed_env`` 函数,支持通过 slurm 或 torch "
+"调用 ``init_distributed`` 函数,支持通过 slurm 或 torch "
"方式启动训练脚本,并传入配置文件、端口号、进程随机种子等信息。函数详细说明如下:"
msgstr ""
-"Call the initialize_distributed_env function, which supports launching "
+"Call the init_distributed function, which supports launching "
"the training script through Slurm or Torch, and pass in information such "
"as the configuration file, port number, and process random seed. Detailed"
" description of the function is as follows:"
diff --git a/doc/code-docs/source/example/20B_demo.rst b/doc/code-docs/source/example/20B_demo.rst
index da7f1d2df..232d810b2 100644
--- a/doc/code-docs/source/example/20B_demo.rst
+++ b/doc/code-docs/source/example/20B_demo.rst
@@ -123,14 +123,12 @@
model = dict(
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
num_attention_heads=NUM_ATTENTION_HEAD,
- embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
@@ -167,7 +165,7 @@
.. code-block:: bash
- srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ./configs/20B_sft.py
+ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python -m internlm.launcher.launch --config ./configs/20B_sft.py
训练结果
----------------
diff --git a/doc/code-docs/source/example/7B_demo.rst b/doc/code-docs/source/example/7B_demo.rst
index 78154175e..92f9b0307 100644
--- a/doc/code-docs/source/example/7B_demo.rst
+++ b/doc/code-docs/source/example/7B_demo.rst
@@ -123,14 +123,12 @@
model = dict(
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
num_attention_heads=NUM_ATTENTION_HEAD,
- embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
@@ -165,7 +163,7 @@
.. code-block:: bash
- srun -p internllm -N 1 -n 8 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ./configs/7B_sft.py
+ srun -p internllm -N 1 -n 8 --ntasks-per-node=8 --gpus-per-task=1 python -m internlm.launcher.launch --config ./configs/7B_sft.py
训练结果
----------------
diff --git a/doc/code-docs/source/initialize.rst b/doc/code-docs/source/initialize.rst
index 721eec006..4c938c30b 100644
--- a/doc/code-docs/source/initialize.rst
+++ b/doc/code-docs/source/initialize.rst
@@ -43,7 +43,7 @@ InternEvo 使用 `argparse `_
模型初始化
-------------------------
-.. autofunction:: internlm.train.initialize_model_and_parallel_communicator
+.. autofunction:: internlm.initialize.initialize_model.initialize_model_and_parallel_communicator
InternEvo 在配置文件中使用字段 ``model_type`` 和 ``model`` 来控制模型初始化过程。示例模型初始化配置定义如下:
@@ -58,14 +58,12 @@ InternEvo 在配置文件中使用字段 ``model_type`` 和 ``model`` 来控制
model = dict(
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
num_attention_heads=NUM_ATTENTION_HEAD,
- embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
diff --git a/doc/code-docs/source/mixed_precision.rst b/doc/code-docs/source/mixed_precision.rst
index bbada7f77..774c620f7 100644
--- a/doc/code-docs/source/mixed_precision.rst
+++ b/doc/code-docs/source/mixed_precision.rst
@@ -63,14 +63,12 @@ InternEvo支持使用TF32训练模型,允许用户在config文件中将 ``dtyp
model = dict(
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
num_attention_heads=NUM_ATTENTION_HEAD,
- embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
dtype="torch.tf32", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
diff --git a/doc/code-docs/source/training.rst b/doc/code-docs/source/training.rst
index f43bfe4af..22b0ed2ba 100644
--- a/doc/code-docs/source/training.rst
+++ b/doc/code-docs/source/training.rst
@@ -18,11 +18,11 @@
- 初始化分布式训练环境
.. code-block:: python
- initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
+ init_distributed(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
-调用 ``initialize_distributed_env`` 函数,支持通过 slurm 或 torch 方式启动训练脚本,并传入配置文件、端口号、进程随机种子等信息。函数详细说明如下:
+调用 ``init_distributed`` 函数,支持通过 slurm 或 torch 方式启动训练脚本,并传入配置文件、端口号、进程随机种子等信息。函数详细说明如下:
-.. autofunction:: internlm.initialize.initialize_distributed_env
+.. autofunction:: internlm.initialize.init_distributed
- 初始化模型
.. code-block:: python
diff --git a/doc/en/usage.md b/doc/en/usage.md
index 8e1670c2f..17cf88e98 100644
--- a/doc/en/usage.md
+++ b/doc/en/usage.md
@@ -229,14 +229,12 @@ beta2_scheduler = dict(
model = dict(
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
num_attention_heads=NUM_ATTENTION_HEAD,
- embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
@@ -359,14 +357,12 @@ MLP_RATIO = 8 / 3
model = dict(
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
num_attention_heads=NUM_ATTENTION_HEAD,
- embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
@@ -407,7 +403,7 @@ After completing the data preparation and relevant training configurations menti
If you want to start distributed training on slurm with 16 GPUs across multiple nodes, use the following command:
```bash
-$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ./configs/7B_sft.py
+$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python -m internlm.launcher.launch --config ./configs/7B_sft.py
```
If you want to start distributed training on torch with 8 GPUs on a single node, use the following command:
@@ -455,14 +451,12 @@ MLP_RATIO = 8 / 3
model = dict(
checkpoint=False, # 进行重计算的模型层数比例,可选值为 True/False/[0-1]
num_attention_heads=NUM_ATTENTION_HEAD,
- embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
diff --git a/doc/usage.md b/doc/usage.md
index 7c28d6d3e..b28144cca 100644
--- a/doc/usage.md
+++ b/doc/usage.md
@@ -238,14 +238,12 @@ use_fp32_norm = False
model = dict(
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
num_attention_heads=NUM_ATTENTION_HEAD,
- embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
@@ -398,14 +396,12 @@ MLP_RATIO = 8 / 3
model = dict(
checkpoint=False, # 进行重计算的模型层数比例,可选值为 True/False/[0-1]
num_attention_heads=NUM_ATTENTION_HEAD,
- embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
@@ -453,7 +449,7 @@ parallel = dict(
若在 slurm 上启动分布式运行环境,多节点 16 卡的运行命令如下所示:
```bash
-$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python train.py --config ./configs/7B_sft.py
+$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python -m internlm.launcher.launch --config ./configs/7B_sft.py
```
若在 torch 上启动分布式运行环境,单节点 8 卡的运行命令如下所示:
diff --git a/generate.py b/generate.py
index 48efa8b3f..69d4f1c51 100644
--- a/generate.py
+++ b/generate.py
@@ -18,10 +18,12 @@
from internlm.apis.inference import SequenceGenerator
from internlm.core.context import global_context as gpc
from internlm.data import build_generation_loader_with_data_type
-from internlm.initialize import initialize_distributed_env
+from internlm.initialize import initialize_launcher
+from internlm.initialize.initialize_model import (
+ initialize_model_and_parallel_communicator,
+)
from internlm.monitor import initialize_monitor_manager
-from internlm.monitor.monitor import monitor_manager as mm
-from internlm.train import initialize_model_and_parallel_communicator
+from internlm.monitor import monitor_manager as mm
from internlm.utils.common import (
enable_pytorch_expandable_segments,
launch_time,
@@ -219,7 +221,7 @@ def main():
hostname = socket.gethostname()
# initialize distributed environment
- initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
+ initialize_launcher(config=args.config, launcher=args.launcher, distributed_port=args.port, seed=args.seed)
assert hasattr(gpc, "config") and gpc.config is not None
assert "generation" in gpc.config, f"Please set `generation` config in `{args.config}` file"
assert (
diff --git a/transformers/README-zh-Hans.md b/huggingface_models/README-zh-Hans.md
similarity index 100%
rename from transformers/README-zh-Hans.md
rename to huggingface_models/README-zh-Hans.md
diff --git a/transformers/README.md b/huggingface_models/README.md
similarity index 100%
rename from transformers/README.md
rename to huggingface_models/README.md
diff --git a/transformers/convert2hf_internlm.py b/huggingface_models/convert2hf_internlm.py
similarity index 100%
rename from transformers/convert2hf_internlm.py
rename to huggingface_models/convert2hf_internlm.py
diff --git a/transformers/convert2hf_internlm2.py b/huggingface_models/convert2hf_internlm2.py
similarity index 100%
rename from transformers/convert2hf_internlm2.py
rename to huggingface_models/convert2hf_internlm2.py
diff --git a/transformers/convert2hf_internlm_moe.py b/huggingface_models/convert2hf_internlm_moe.py
similarity index 100%
rename from transformers/convert2hf_internlm_moe.py
rename to huggingface_models/convert2hf_internlm_moe.py
diff --git a/transformers/internlm2_model/__init__.py b/huggingface_models/internlm2_model/__init__.py
similarity index 100%
rename from transformers/internlm2_model/__init__.py
rename to huggingface_models/internlm2_model/__init__.py
diff --git a/transformers/internlm2_model/configuration_internlm2.py b/huggingface_models/internlm2_model/configuration_internlm2.py
similarity index 100%
rename from transformers/internlm2_model/configuration_internlm2.py
rename to huggingface_models/internlm2_model/configuration_internlm2.py
diff --git a/transformers/internlm2_model/modeling_internlm2.py b/huggingface_models/internlm2_model/modeling_internlm2.py
similarity index 95%
rename from transformers/internlm2_model/modeling_internlm2.py
rename to huggingface_models/internlm2_model/modeling_internlm2.py
index f026e5f9d..18eaaa1c4 100644
--- a/transformers/internlm2_model/modeling_internlm2.py
+++ b/huggingface_models/internlm2_model/modeling_internlm2.py
@@ -40,6 +40,15 @@
replace_return_docstrings,
)
+from internlm.core.context import ParallelMode
+from internlm.core.context import global_context as gpc
+from internlm.model.model_ops.ops.attention import (
+ isp_flash_attn_func,
+ isp_flash_attn_varlen_func,
+)
+from internlm.model.model_ops.ops.fused_rmsnorm import fused_rms_norm_fn
+from internlm.solver.activation_checkpoint import apply_ac_to_transformer_block
+
try:
from transformers.generation.streamers import BaseStreamer
except: # noqa # pylint: disable=bare-except
@@ -53,17 +62,24 @@
flash_attn_func, flash_attn_varlen_func = None, None
pad_input, index_first_axis, unpad_input = None, None, None
+
+
def _import_flash_attn():
global flash_attn_func, flash_attn_varlen_func
global pad_input, index_first_axis, unpad_input
try:
- from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
- from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input
+ from flash_attn import flash_attn_func as _flash_attn_func
+ from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
+ from flash_attn.bert_padding import index_first_axis as _index_first_axis
+ from flash_attn.bert_padding import pad_input as _pad_input
+ from flash_attn.bert_padding import unpad_input as _unpad_input
+
flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
except ImportError:
raise ImportError("flash_attn is not installed.")
+
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
@@ -121,11 +137,15 @@ def __init__(self, hidden_size, eps=1e-6):
self.variance_epsilon = eps
def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
+ # input_dtype = hidden_states.dtype
+ # hidden_states = hidden_states.to(torch.float32)
+ # variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ # hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ # return self.weight * hidden_states.to(input_dtype)
+ return fused_rms_norm_fn(hidden_states, self.weight, self.variance_epsilon)
+
+ def reset_parameters(self):
+ torch.nn.init.ones_(self.weight)
# Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2
@@ -164,6 +184,13 @@ def forward(self, x, seq_len=None):
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
+ def reset_parameters(self):
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(self.inv_freq.device) / self.dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self._set_cos_sin_cache(
+ seq_len=self.max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
+ )
+
# Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->InternLM2
class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
@@ -443,6 +470,12 @@ def forward(
bsz, q_len, _ = hidden_states.size()
+ use_packed_dataset = gpc.config.data.get("use_packed_dataset", False)
+ if use_packed_dataset:
+ assert bsz == 1, "hidden_states should be packed into bsz=1 when use_packed_dataset=True"
+ cu_seqlens = gpc.config.data[f"cu_seqlens_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"]
+ max_seqlen = gpc.config.data[f"max_seqlen_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"]
+
qkv_states = self.wqkv(hidden_states)
qkv_states = rearrange(
@@ -480,9 +513,31 @@ def forward(
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
- attn_output = self._flash_attention_forward(
- query_states, key_states, value_states, attention_mask, q_len
- )
+ # attn_output = self._flash_attention_forward(
+ # query_states, key_states, value_states, attention_mask, q_len
+ # )
+ if use_packed_dataset:
+ attn_output = isp_flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens,
+ cu_seqlens,
+ max_seqlen,
+ max_seqlen,
+ causal=False,
+ softmax_scale=None,
+ attention_dropout=0.0,
+ )
+ else:
+ attn_output = isp_flash_attn_func(
+ query_states,
+ key_states,
+ value_states,
+ causal=False,
+ softmax_scale=None,
+ attention_dropout=0.0,
+ )
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.wo(attn_output)
@@ -584,6 +639,7 @@ def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, quer
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
+
INTERNLM2_ATTENTION_CLASSES = {
"eager": InternLM2Attention,
"flash_attention_2": InternLM2FlashAttention2,
@@ -794,6 +850,11 @@ def __init__(self, config: InternLM2Config):
self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
+ for layer_id, transformer_block in self.layers.named_children():
+ checkpoint = gpc.config.model.checkpoint
+ if checkpoint > 0:
+ transformer_block = apply_ac_to_transformer_block(transformer_block, checkpoint)
+ self.layers.register_module(layer_id, transformer_block)
self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
diff --git a/transformers/internlm2_model/tokenization_internlm2.py b/huggingface_models/internlm2_model/tokenization_internlm2.py
similarity index 100%
rename from transformers/internlm2_model/tokenization_internlm2.py
rename to huggingface_models/internlm2_model/tokenization_internlm2.py
diff --git a/transformers/internlm2_model/tokenization_internlm2_fast.py b/huggingface_models/internlm2_model/tokenization_internlm2_fast.py
similarity index 100%
rename from transformers/internlm2_model/tokenization_internlm2_fast.py
rename to huggingface_models/internlm2_model/tokenization_internlm2_fast.py
diff --git a/transformers/internlm_model/__init__.py b/huggingface_models/internlm_model/__init__.py
similarity index 100%
rename from transformers/internlm_model/__init__.py
rename to huggingface_models/internlm_model/__init__.py
diff --git a/transformers/internlm_model/configuration_internlm.py b/huggingface_models/internlm_model/configuration_internlm.py
similarity index 100%
rename from transformers/internlm_model/configuration_internlm.py
rename to huggingface_models/internlm_model/configuration_internlm.py
diff --git a/transformers/internlm_model/modeling_internlm.py b/huggingface_models/internlm_model/modeling_internlm.py
similarity index 100%
rename from transformers/internlm_model/modeling_internlm.py
rename to huggingface_models/internlm_model/modeling_internlm.py
diff --git a/transformers/internlm_model/tokenization_internlm.py b/huggingface_models/internlm_model/tokenization_internlm.py
similarity index 100%
rename from transformers/internlm_model/tokenization_internlm.py
rename to huggingface_models/internlm_model/tokenization_internlm.py
diff --git a/transformers/internlm_moe_model/__init__.py b/huggingface_models/internlm_moe_model/__init__.py
similarity index 100%
rename from transformers/internlm_moe_model/__init__.py
rename to huggingface_models/internlm_moe_model/__init__.py
diff --git a/transformers/internlm_moe_model/configuration_internlm_moe.py b/huggingface_models/internlm_moe_model/configuration_internlm_moe.py
similarity index 100%
rename from transformers/internlm_moe_model/configuration_internlm_moe.py
rename to huggingface_models/internlm_moe_model/configuration_internlm_moe.py
diff --git a/transformers/internlm_moe_model/modeling_internlm_moe.py b/huggingface_models/internlm_moe_model/modeling_internlm_moe.py
similarity index 100%
rename from transformers/internlm_moe_model/modeling_internlm_moe.py
rename to huggingface_models/internlm_moe_model/modeling_internlm_moe.py
diff --git a/transformers/internlm_moe_model/tokenization_internlm.py b/huggingface_models/internlm_moe_model/tokenization_internlm.py
similarity index 100%
rename from transformers/internlm_moe_model/tokenization_internlm.py
rename to huggingface_models/internlm_moe_model/tokenization_internlm.py
diff --git a/transformers/revert_internlm.py b/huggingface_models/revert_internlm.py
similarity index 100%
rename from transformers/revert_internlm.py
rename to huggingface_models/revert_internlm.py
diff --git a/transformers/revert_internlm2.py b/huggingface_models/revert_internlm2.py
similarity index 100%
rename from transformers/revert_internlm2.py
rename to huggingface_models/revert_internlm2.py
diff --git a/internlm/__init__.py b/internlm/__init__.py
index dc34a3167..e69de29bb 100644
--- a/internlm/__init__.py
+++ b/internlm/__init__.py
@@ -1,9 +0,0 @@
-from .initialize.initialize_trainer import initialize_trainer
-from .initialize.launch import get_default_parser, launch_from_slurm, launch_from_torch
-
-__all__ = [
- "get_default_parser",
- "initialize_trainer",
- "launch_from_slurm",
- "launch_from_torch",
-]
diff --git a/internlm/accelerator/abstract_accelerator.py b/internlm/accelerator/abstract_accelerator.py
index ffedaed59..734d45da5 100644
--- a/internlm/accelerator/abstract_accelerator.py
+++ b/internlm/accelerator/abstract_accelerator.py
@@ -1,8 +1,10 @@
"""
Universal accelerator interface implementation, inspired by DeepSpeed.
"""
+import abc
import enum
import os
+from abc import ABC
class AcceleratorType(enum.Enum):
@@ -17,57 +19,72 @@ class AcceleratorType(enum.Enum):
internlm_accelerator = None
-class Accelerator:
+class Accelerator(ABC):
"""
Abstract base class for accelerator
"""
def __init__(self) -> None:
- pass
+ self._name_str = None
+ self._communication_backend_name = None
+ @abc.abstractmethod
def get_backend_name(self):
"""
Return the name of the accelerator.
"""
raise NotImplementedError
+ @abc.abstractmethod
def get_accelerator_backend(self):
"""
- Return the name of the backend.
+ Return the name of the accelerator backend.
"""
raise NotImplementedError
- # Device APIs
+ @abc.abstractmethod
+ def communication_backend_name(self):
+ """
+ Return the name of the communication backend.
+ """
+ raise NotImplementedError
+
+ @abc.abstractmethod
def device_name(self, device_index=None):
"""
Return the name of the device.
"""
raise NotImplementedError
+ @abc.abstractmethod
def set_device(self, device_index):
"""
Bind the current process to a device.
"""
raise NotImplementedError
+ @abc.abstractmethod
def get_device_id(self):
"""
Return the current device index.
"""
raise NotImplementedError
+ @abc.abstractmethod
def current_device_name(self):
"""
Return the name of the current device.
"""
raise NotImplementedError
+ @abc.abstractmethod
def device_count(self):
"""
Return the number of devices on the machine.
"""
raise NotImplementedError
+ @abc.abstractmethod
def synchronize(self, device_index=None):
"""
Synchronize the current process.
diff --git a/internlm/accelerator/cuda_accelerator.py b/internlm/accelerator/cuda_accelerator.py
index 48a471657..d5986077c 100644
--- a/internlm/accelerator/cuda_accelerator.py
+++ b/internlm/accelerator/cuda_accelerator.py
@@ -14,6 +14,7 @@ class CUDA_Accelerator(Accelerator):
"""
def __init__(self) -> None:
+ super().__init__()
self._name_str = "cuda"
self._communication_backend_name = "nccl"
self.amp = self.get_amp()
diff --git a/internlm/accelerator/dipu_accelerator.py b/internlm/accelerator/dipu_accelerator.py
index 7943b4c7f..b5383eded 100644
--- a/internlm/accelerator/dipu_accelerator.py
+++ b/internlm/accelerator/dipu_accelerator.py
@@ -14,6 +14,7 @@ class DIPU_Accelerator(Accelerator):
"""
def __init__(self) -> None:
+ super().__init__()
self._name_str = "cuda"
self._communication_backend_name = "nccl"
self.amp = self.get_amp()
diff --git a/internlm/accelerator/ditorch_accelerator.py b/internlm/accelerator/ditorch_accelerator.py
index 528b858e2..e4a2fca54 100644
--- a/internlm/accelerator/ditorch_accelerator.py
+++ b/internlm/accelerator/ditorch_accelerator.py
@@ -14,6 +14,7 @@ class DITORCH_Accelerator(Accelerator):
"""
def __init__(self) -> None:
+ super().__init__()
self._name_str = "cuda"
self._communication_backend_name = "nccl"
self.amp = self.get_amp()
diff --git a/internlm/accelerator/npu_accelerator.py b/internlm/accelerator/npu_accelerator.py
index e1bd3549d..e078014e6 100644
--- a/internlm/accelerator/npu_accelerator.py
+++ b/internlm/accelerator/npu_accelerator.py
@@ -14,6 +14,7 @@ class ASCEND_Accelerator(Accelerator):
"""
def __init__(self) -> None:
+ super().__init__()
self._name_str = "npu"
self._communication_backend_name = "hccl"
self.amp = self.get_amp()
diff --git a/internlm/apis/inference_utils.py b/internlm/apis/inference_utils.py
index 423e7aafe..931d10537 100644
--- a/internlm/apis/inference_utils.py
+++ b/internlm/apis/inference_utils.py
@@ -2,7 +2,7 @@
from internlm.core.context import ParallelMode # noqa: E402
from internlm.core.context import global_context as gpc # noqa: E402
-from internlm.core.parallel.comm.utils import _gather as gather
+from internlm.core.parallel.comm.utils import _gather
class InferenceParams:
@@ -64,6 +64,6 @@ def process_parallel_output(model_output):
# gather tp parallel output
if gpc.config.model.parallel_output and gpc.is_initialized(ParallelMode.TENSOR):
- return gather(model_output, ParallelMode.TENSOR, -1)
+ return _gather(model_output, ParallelMode.TENSOR, -1)
else:
return model_output
diff --git a/internlm/checkpoint/checkpoint_manager.py b/internlm/checkpoint/checkpoint_manager.py
index 8ba3f4e8f..8e36b7745 100644
--- a/internlm/checkpoint/checkpoint_manager.py
+++ b/internlm/checkpoint/checkpoint_manager.py
@@ -11,16 +11,14 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.trainer import TrainState
-from internlm.initialize.launch import get_config_value
-from internlm.initialize.legacy.launch import (
- auto_resume_sanity_check,
- ckpt_info_sanity_check,
+from internlm.model.model_implementations.registry import model_initializer
+from internlm.model.model_implementations.transformers.base_model import (
+ BaseTransformerModel,
)
-from internlm.model.base_model import BaseModel
-from internlm.model.registry import model_initializer
from internlm.monitor import send_alert_message
from internlm.solver.optimizer import HybridZeroOptimizer, HybridZeroOptimizer_v2
from internlm.utils.common import get_current_device
+from internlm.utils.config import get_config_value
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import is_using_fsdp, is_using_hf
@@ -289,7 +287,7 @@ def __init__(
k: partial(try_load_internlm_ckpt_func, func=v) for k, v in LOAD_FUNC_DICT.items()
}
# Register huggingface ckpt load type
- if isinstance(model, BaseModel):
+ if isinstance(model, BaseTransformerModel):
self.defalut_load_type_func.update(
{
"hf": partial(
@@ -311,14 +309,10 @@ def __init__(
f.write("0")
self.load_ckpt_info = get_config_value(ckpt_config, "load_ckpt_info", None)
- if self.load_ckpt_info is None: # (legacy): Try Compatible with old interfaces
- self.load_ckpt_info = ckpt_info_sanity_check(ckpt_config)
# Auto-reload latest checkpoint, it will overwrite the setting of 'load_ckpt_info'.
- self.auto_resume = get_config_value(ckpt_config, "auto_resume", None)
- if self.auto_resume is None: # (legacy): Try Compatible with old interfaces
- self.auto_resume = auto_resume_sanity_check(ckpt_config)
- if self.auto_resume:
+ self.auto_resume = get_config_value(ckpt_config, "auto_resume", False)
+ if self.auto_resume and self.save_ckpt_folder and self.has_available_ckpt(self.save_ckpt_folder):
self.load_ckpt_info = self.query_lastest_ckpt()
if self.stop_file_path is None and gpc.is_rank_for_log():
@@ -393,6 +387,16 @@ def quit_signal_handler(self, train_state) -> bool:
return now_break, now_save_ckpt, save_type
+ def has_available_ckpt(self, folder) -> bool:
+ """Check if there is an available ckpt in the folder."""
+ folder = folder.split(":")[-1]
+ for _, _, files in os.walk(folder, followlinks=True):
+ for fn in files:
+ fn = fn.strip("/")
+ if fn.endswith(".step"):
+ return True
+ return False
+
def is_now_to_save_ckpt(self, train_state, force=False) -> (bool, CheckpointSaveType, bool):
save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False
if force:
@@ -446,7 +450,7 @@ def try_save_checkpoint(self, train_state, force=False):
)
if (
- isinstance(self.model, BaseModel)
+ isinstance(self.model, BaseTransformerModel)
and self.enable_internevo2hf_ckpt
and save_type == CheckpointSaveType.NORMAL_CHECKPOINT
and gpc.is_rank_for_log()
@@ -578,7 +582,7 @@ def try_resume_training(self, train_state: TrainState, current_time=""):
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
)
- elif is_using_fsdp() and is_using_hf() and not self.auto_resume:
+ elif is_using_fsdp() and not self.auto_resume:
pass
else:
load_path = self.load_ckpt_info["path"]
diff --git a/internlm/checkpoint/components.py b/internlm/checkpoint/components.py
index d96bb65c5..a94237ac5 100644
--- a/internlm/checkpoint/components.py
+++ b/internlm/checkpoint/components.py
@@ -9,7 +9,7 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.trainer import TrainState
-from internlm.model.moe import MoE
+from internlm.model.model_ops.moe import MoE
from internlm.solver.optimizer import HybridZeroOptimizer, HybridZeroOptimizer_v2
from internlm.utils.common import get_current_device
from internlm.utils.lazy import LazyObject
diff --git a/internlm/checkpoint/load_funcs.py b/internlm/checkpoint/load_funcs.py
index dde4bc523..13342afcb 100644
--- a/internlm/checkpoint/load_funcs.py
+++ b/internlm/checkpoint/load_funcs.py
@@ -1,14 +1,7 @@
# Copyright (c) InternLM. All rights reserved.
-from internlm.model.modeling_internlm import InternLM1
-from internlm.model.modeling_internlm2 import InternLM2
-from internlm.model.modeling_llama import Llama2
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
-LOAD_FUNC_DICT = {
- "llama": Llama2.load_llama_pretrained_weights,
- "internlm_test": InternLM1.load_internlm_with_dynamic_parallel_size,
- "internlm2_test": InternLM2.load_internlm2_with_dynamic_parallel_size,
-}
+LOAD_FUNC_DICT = {}
diff --git a/internlm/checkpoint/utils.py b/internlm/checkpoint/utils.py
index cd8fae4bf..401bd54ec 100644
--- a/internlm/checkpoint/utils.py
+++ b/internlm/checkpoint/utils.py
@@ -1,17 +1,8 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
-import itertools
-
-import numpy as np
-import torch
-from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
-
from internlm.core.context import global_context as gpc
-from internlm.core.parallel.shard import split_data_for_sequence_parallel
-from internlm.data.utils import packed_data_normalizer, unpack_data
from internlm.utils.logger import get_logger
-from internlm.utils.parallel import is_using_isp
logger = get_logger(__file__)
@@ -53,67 +44,3 @@ def process_load_info(load_info):
logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}")
return load_content_str, load_ckpt_folder, load_content
-
-
-def init_fsdp_v1(model: FSDP, device: torch.device) -> FSDP:
- """
- Initialize Fully Sharded Data Parallel (FSDP) for the model.
- This function is needed to properly initialize FSDP when resuming from a checkpoint.
- It runs a forward pass with dummy inputs to ensure FSDP is fully initialized.
-
- References:
- https://github.com/pytorch/pytorch/issues/113496
- https://github.com/huggingface/transformers/pull/34032
- https://github.com/huggingface/transformers/issues/31892
-
- Args:
- model: The model to initialize with FSDP.
- device: The device to run the model on.
-
- Returns:
- The initialized FSDP model.
- """
- model.train()
- with torch.no_grad():
- # generate dummy packed sequence
- seq_len = gpc.config.data.seq_len * gpc.config.data.micro_bsz
- input_ids = [1] * seq_len
- label = input_ids[1:] + [-100]
- cu_seqlens = list(range(0, seq_len + gpc.config.data.seq_len, gpc.config.data.seq_len))
-
- input_ids = torch.tensor(input_ids, device=device).unsqueeze(0)
- label = torch.tensor(label, device=device).unsqueeze(0)
- indexes = torch.tensor(
- list(itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])])),
- device=device,
- ).unsqueeze(0)
- cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32).unsqueeze(0)
-
- data = {
- "input_ids": input_ids,
- "cu_seqlens": cu_seqlens,
- "indexes": indexes,
- "max_seqlen": seq_len,
- }
-
- data_fns = []
-
- # default data process function
- if gpc.config.data.use_packed_dataset:
- data_fns.append(packed_data_normalizer)
- else:
- data_fns.append(unpack_data)
-
- # support sequence parallel for isp
- if is_using_isp():
- data_fns.append(split_data_for_sequence_parallel)
-
- # generate dummy_input
- _data, _label = data, label
- for fn in data_fns:
- _data, _label = fn(_data, _label)
- dummy_input = _data
-
- # run a forward pass with dummy_input to initialize FSDP
- _ = model(**dummy_input)
- return model
diff --git a/internlm/core/context/__init__.py b/internlm/core/context/__init__.py
index b2fc95cc9..be444ea30 100644
--- a/internlm/core/context/__init__.py
+++ b/internlm/core/context/__init__.py
@@ -5,7 +5,6 @@
IS_TENSOR_ZERO_PARALLEL,
IS_WEIGHT_EXPERT_DATA_PARALLEL,
IS_WEIGHT_ZERO_PARALLEL,
- Config,
ParallelContext,
global_context,
)
@@ -19,6 +18,7 @@
ProcessGroupInitializer,
)
from .random import (
+ _SEED_MANAGER,
add_seed,
get_current_mode,
get_seeds,
@@ -30,7 +30,6 @@
)
__all__ = [
- "Config",
"IS_REPLICA_EXPERT_DATA_PARALLEL",
"IS_TENSOR_ZERO_PARALLEL",
"IS_REPLICA_ZERO_PARALLEL",
@@ -54,4 +53,5 @@
"get_current_mode",
"set_seed_states",
"sync_states",
+ "_SEED_MANAGER",
]
diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py
index 5278426ed..7e83129c8 100644
--- a/internlm/core/context/parallel_context.py
+++ b/internlm/core/context/parallel_context.py
@@ -3,12 +3,8 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context
-import inspect
import random
import socket
-import sys
-from importlib.machinery import SourceFileLoader
-from pathlib import Path
from typing import Union
import numpy as np
@@ -17,6 +13,7 @@
from internlm.accelerator import get_accelerator
from internlm.utils.common import SingletonMeta
+from internlm.utils.config import Config
from internlm.utils.logger import get_logger
from internlm.utils.timeout import LLM_NCCL_TIMEOUT
from internlm.utils.utils import TensorParallelMode
@@ -46,97 +43,6 @@
internlm_accelerator = get_accelerator()
-class Config(dict):
- """This is a wrapper class for dict objects so that values of which can be
- accessed as attributes.
-
- Args:
- config (dict): The dict object to be wrapped.
- """
-
- def __init__(self, config: dict = None): # pylint: disable=W0231
- if config is not None:
- for k, v in config.items():
- self._add_item(k, v)
-
- def __missing__(self, key):
- raise KeyError(key)
-
- def __getattr__(self, key):
- try:
- value = super().__getitem__(key)
- return value
- except KeyError:
- raise AttributeError(key)
-
- def __setattr__(self, key, value):
- super().__setitem__(key, value)
-
- def _add_item(self, key, value):
- if isinstance(value, dict):
- self.__setattr__(key, Config(value))
- else:
- self.__setattr__(key, value)
-
- def update(self, config):
- assert isinstance(config, (Config, dict)), "can only update dictionary or Config objects."
- for k, v in config.items():
- self._add_item(k, v)
- return self
-
- @staticmethod
- def from_file(filename: str):
- """Reads a python file and constructs a corresponding :class:`Config` object.
-
- Args:
- filename (str): Name of the file to construct the return object.
-
- Returns:
- :class:`Config`: A :class:`Config` object constructed with information in the file.
-
- Raises:
- AssertionError: Raises an AssertionError if the file does not exist, or the file is not .py file
- """
-
- # check config path
- if isinstance(filename, str):
- filepath = Path(filename).absolute()
- elif isinstance(filename, Path):
- filepath = filename.absolute()
-
- assert filepath.exists(), f"{filename} is not found, please check your configuration path"
-
- # check extension
- extension = filepath.suffix
- assert extension == ".py", "only .py files are supported"
-
- # import the config as module
- remove_path = False
- if filepath.parent not in sys.path:
- sys.path.insert(0, (filepath))
- remove_path = True
-
- module_name = filepath.stem
- source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath))
- module = source_file.load_module() # pylint: disable=W4902,E1120,W1505
-
- # load into config
- config = Config()
-
- for k, v in module.__dict__.items():
- if k.startswith("__") or inspect.ismodule(v) or inspect.isclass(v):
- continue
- else:
- config._add_item(k, v)
-
- # remove module
- del sys.modules[module_name]
- if remove_path:
- sys.path.pop(0)
-
- return config
-
-
class ParallelContext(metaclass=SingletonMeta):
"""This class provides interface functions for users to get the parallel context,
such as the global rank, the local rank, the world size, etc. of each device.
@@ -403,7 +309,7 @@ def init_global_dist(self, rank: int, world_size: int, backend: str, host: str,
use_cpu (bool): whether to set up cpu process group.
"""
# initialize the default process group
- init_method = f"tcp://[{host}]:{port}"
+ init_method = f"tcp://{host}:{port}"
dist.init_process_group(
rank=rank,
world_size=world_size,
diff --git a/internlm/core/engine.py b/internlm/core/engine.py
index 5989536dc..cfa3ac6a1 100644
--- a/internlm/core/engine.py
+++ b/internlm/core/engine.py
@@ -3,6 +3,7 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
+from contextlib import nullcontext
from typing import List, Optional
import torch
@@ -10,11 +11,20 @@
from torch.nn.modules.loss import _Loss
from torch.optim.lr_scheduler import _LRScheduler
+from internlm.core.context import global_context as gpc
from internlm.core.gradient_handler import BaseGradientHandler
-from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer
-from internlm.solver.schedulers.beta2_scheduler import Beta2Scheduler
+from internlm.solver.optimizer import BaseOptimizer
+from internlm.solver.schedulers import Beta2Scheduler
from internlm.utils.common import get_batch_size, move_to_device
+try:
+ import transformer_engine.pytorch as te
+ from transformer_engine.common.recipe import DelayedScaling, Format
+
+ HAS_TE = True
+except (ModuleNotFoundError, ImportError):
+ HAS_TE = False
+
class Engine:
"""
@@ -62,6 +72,7 @@ def __init__(
lr_scheduler: Optional[_LRScheduler] = None,
beta2_scheduler: Optional[Beta2Scheduler] = None,
criterion: Optional[_Loss] = None,
+ mtp_criterions: Optional[List[_Loss]] = None,
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
clip_grad_norm: float = 0.0,
):
@@ -70,6 +81,7 @@ def __init__(
self._lr_scheduler = lr_scheduler
self._beta2_scheduler = beta2_scheduler
self._criterion = criterion
+ self._mtp_criterions = mtp_criterions
self._clip_grad_norm = clip_grad_norm
# state
@@ -78,6 +90,33 @@ def __init__(
# build gradient handler
self._gradient_handlers = gradient_handlers if gradient_handlers else []
+ # FP8 GEMM
+ fp8_cfg = gpc.config.get("fp8", None)
+ self.use_fp8 = HAS_TE and fp8_cfg is not None
+ if self.use_fp8:
+ self.fp8_recipe = DelayedScaling(
+ margin=fp8_cfg.get("margin", 0), # int, default = 0. Margin for scaling factor computation
+ fp8_format=Format[
+ fp8_cfg.get("fp8_format", "HYBRID")
+ ], # {Format.E4M3, Format.HYBRID}, default = Format.HYBRID. FP8 Data format
+ amax_history_len=fp8_cfg.get(
+ "amax_history_len", 1024
+ ), # int, default = 1024. Amax history window used for scaling factor computation
+ amax_compute_algo=fp8_cfg.get(
+ "amax_compute_algo", "max"
+ ), # {'max', 'most_recent'}, default = "max". Algorithm used for choosing amax
+ )
+
+ @property
+ def mtp_criterions(self):
+ """Returns the criterion (loss function) attached to the engine."""
+ return self._mtp_criterions
+
+ @mtp_criterions.setter
+ def mtp_criterions(self, mtp_criterions):
+ """Set the criterion (loss function) attached to the engine."""
+ self._mtp_criterions = mtp_criterions
+
@property
def model(self):
"""Returns the model attached to the engine."""
@@ -166,7 +205,9 @@ def __call__(self, *args, **kwargs):
Returns:
torch.Tensor: The output of the model.
"""
- return self.model(*args, **kwargs)
+ with te.fp8_autocast(enabled=self.use_fp8, fp8_recipe=self.fp8_recipe) if self.use_fp8 else nullcontext():
+ output = self.model(*args, **kwargs)
+ return output
def load_batch(self, data_iter, to_gpu=True):
"""
diff --git a/internlm/core/fsdp.py b/internlm/core/fsdp.py
new file mode 100644
index 000000000..904bf7de4
--- /dev/null
+++ b/internlm/core/fsdp.py
@@ -0,0 +1,270 @@
+import collections
+import itertools
+from typing import List, Optional, Set, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp.fully_sharded_data_parallel import (
+ BackwardPrefetch,
+ ShardingStrategy,
+)
+from torch.distributed.fsdp.wrap import ModuleWrapPolicy
+
+from internlm.accelerator.abstract_accelerator import get_accelerator
+from internlm.core.context import ParallelMode
+from internlm.core.context import global_context as gpc
+from internlm.core.parallel.shard import split_data_for_sequence_parallel
+from internlm.data.utils import packed_data_normalizer, unpack_data
+from internlm.utils.common import get_current_device
+from internlm.utils.lazy import LazyObject
+from internlm.utils.logger import get_logger
+from internlm.utils.parallel import is_using_fsdp, is_using_hf, is_using_isp
+
+internlm_accelerator = get_accelerator()
+logger = get_logger(__file__)
+
+try:
+ from torch.distributed._composable.fsdp import fully_shard
+ from torch.distributed.tensor import DeviceMesh
+ FSDP2_SUPPORTED = True
+except (ImportError, ModuleNotFoundError):
+ FSDP2_SUPPORTED = False
+
+try:
+ import torch.distributed.checkpoint as dcp
+ from torch.distributed.checkpoint.state_dict import (
+ StateDictOptions,
+ get_model_state_dict,
+ set_model_state_dict,
+ )
+
+ DCP_SUPPORTED = True
+except (ImportError, ModuleNotFoundError):
+ DCP_SUPPORTED = False
+
+RESUME_HF_FORMAT = False
+
+
+def _get_modules_to_materialize(
+ root_module: nn.Module,
+ ignored_modules: Set[nn.Module],
+) -> List[nn.Module]:
+ # Run BFS to collect the modules to materialize via `reset_parameters()`,
+ # stopping at any module with FSDP already applied or at ignored modules.
+ modules_to_materialize: List[nn.Module] = []
+ queue = collections.deque([root_module])
+ visited_modules: Set[nn.Module] = {root_module}
+ while queue:
+ module = queue.popleft()
+ modules_to_materialize.append(module)
+ for child_module in module.children():
+ if child_module not in visited_modules and child_module not in ignored_modules:
+ visited_modules.add(child_module)
+ queue.append(child_module)
+ return modules_to_materialize
+
+
+def _materialize_meta_module(
+ root_module: nn.Module,
+ ignored_modules: Set[nn.Module],
+ device_id: Optional[torch.device],
+) -> None:
+ # Run default meta device initialization
+ modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules)
+ module = None
+ try:
+ # Assume that each module's `reset_parameters()` only initializes its
+ # own parameters and not those of its children
+ with torch.no_grad():
+ for module in modules_to_materialize:
+ # As a contract to the user, only call `reset_parameters()` if
+ # the module has directly managed parameters/buffers
+ module_state_iter = itertools.chain(module.parameters(recurse=False), module.buffers(recurse=False))
+ has_module_states = len(list(module_state_iter)) > 0
+ if has_module_states:
+ module.to_empty(device=device_id, recurse=False)
+ module.reset_parameters() # type: ignore[operator]
+ except BaseException as e:
+ logger.warning(
+ "Unable to call `reset_parameters()` for module on meta "
+ f"device with error {str(e)}. Please ensure that your module of"
+ f"type {type(module)} implements a `reset_parameters()` method." # type: ignore[possibly-undefined]
+ )
+ raise e
+
+
+def _init_fsdp_v1(model: FSDP, device: torch.device) -> FSDP:
+ """
+ Initialize Fully Sharded Data Parallel (FSDP) for the model.
+ This function is needed to properly initialize FSDP when resuming from a checkpoint.
+ It runs a forward pass with dummy inputs to ensure FSDP is fully initialized.
+
+ References:
+ https://github.com/pytorch/pytorch/issues/113496
+ https://github.com/huggingface/transformers/pull/34032
+ https://github.com/huggingface/transformers/issues/31892
+
+ Args:
+ model: The model to initialize with FSDP.
+ device: The device to run the model on.
+
+ Returns:
+ The initialized FSDP model.
+ """
+ model.train()
+ with torch.no_grad():
+ # generate dummy packed sequence
+ seq_len = gpc.config.data.seq_len * gpc.config.data.micro_bsz
+ input_ids = [1] * seq_len
+ label = input_ids[1:] + [-100]
+ cu_seqlens = list(range(0, seq_len + gpc.config.data.seq_len, gpc.config.data.seq_len))
+
+ input_ids = torch.tensor(input_ids, device=device).unsqueeze(0)
+ label = torch.tensor(label, device=device).unsqueeze(0)
+ indexes = torch.tensor(
+ list(itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])])),
+ device=device,
+ ).unsqueeze(0)
+ cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32).unsqueeze(0)
+
+ data = {
+ "input_ids": input_ids,
+ "cu_seqlens": cu_seqlens,
+ "indexes": indexes,
+ "max_seqlen": seq_len,
+ }
+
+ data_fns = []
+
+ # default data process function
+ if gpc.config.data.use_packed_dataset:
+ data_fns.append(packed_data_normalizer)
+ else:
+ data_fns.append(unpack_data)
+
+ # support sequence parallel for isp
+ if is_using_isp():
+ data_fns.append(split_data_for_sequence_parallel)
+
+ # generate dummy_input
+ _data, _label = data, label
+ for fn in data_fns:
+ _data, _label = fn(_data, _label)
+ dummy_input = _data
+
+ # run a forward pass with dummy_input to initialize FSDP
+ _ = model(**dummy_input)
+ return model
+
+
+def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
+ if is_using_fsdp():
+ assert isinstance(model, nn.Module), "Currently FSDP does not support pipeline parallel."
+ wrap_cls = tuple(
+ LazyObject(warp_cls["mod"], warp_cls["mod_cls"]).build() for warp_cls in gpc.config.get("fsdp_wrap_cls", [])
+ )
+ fsdp_mode = gpc.config.parallel.fsdp.get("mode", "v1")
+ fsdp_init_method = gpc.config.parallel.fsdp.get("init_method", "cuda")
+ if gpc.is_using_parallel_mode(ParallelMode.EXPERT):
+ assert gpc.get_world_size(ParallelMode.EXPERT_DATA) * gpc.get_world_size(ParallelMode.EXPERT) == gpc.get_world_size(ParallelMode.GLOBAL)
+
+ if fsdp_mode == "v1":
+ ignored_mod = []
+ if gpc.is_using_parallel_mode(ParallelMode.EXPERT):
+ for layer_id, layer in enumerate(model.model.model.layers if is_using_hf() else model.model.layers):
+ if layer_id >= gpc.config.model.first_k_dense_replace:
+ # Should follow this modeling pattern if EP is enabled.
+ # Change the expert module name if needed.
+ # TODO: Make this part hard-coded or config-driven?
+ layer.mlp.experts = FSDP(
+ layer.mlp.experts,
+ process_group=gpc.get_group(ParallelMode.EXPERT_DATA),
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
+ sync_module_states=fsdp_init_method != "cuda", # sync model paramters
+ forward_prefetch=True,
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
+ limit_all_gathers=True,
+ use_orig_params=True,
+ device_id=None if fsdp_init_method == "cuda" else get_current_device(), # needed for sync_module_states
+ )
+ ignored_mod.append(layer.mlp.experts)
+ model = FSDP(
+ module=model,
+ process_group=gpc.get_group(ParallelMode.GLOBAL),
+ sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO2: SHARD_GRAD_OP, ZeRO3: FULL_SHARD
+ auto_wrap_policy=ModuleWrapPolicy(wrap_cls),
+ sync_module_states=fsdp_init_method != "cuda", # sync model paramters
+ forward_prefetch=True,
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
+ limit_all_gathers=True,
+ use_orig_params=True,
+ device_id=None if fsdp_init_method == "cuda" else get_current_device(), # needed for sync_module_states
+ ignored_modules=ignored_mod,
+ )
+ # For FSDP v1, to get ckpt resuming work normally, we do dummy forward.
+ # This hack is needed due to FSDP v1 lazy initialization in model construction.
+ # FYI: https://github.com/pytorch/pytorch/issues/113496
+ model = _init_fsdp_v1(model, get_current_device())
+ elif FSDP2_SUPPORTED and fsdp_mode == "v2":
+ fsdp_kwargs = {
+ "reshard_after_forward": True, # ZeRO2: False, ZeRO3: True
+ }
+ if gpc.is_using_parallel_mode(ParallelMode.EXPERT):
+ device_mesh = DeviceMesh.from_group(
+ group=[gpc.get_group(ParallelMode.EXPERT), gpc.get_group(ParallelMode.EXPERT_DATA)],
+ device_type="cuda",
+ mesh=torch.arange(
+ gpc.get_world_size(ParallelMode.GLOBAL),
+ ).view((gpc.get_world_size(ParallelMode.EXPERT), gpc.get_world_size(ParallelMode.EXPERT_DATA))),
+ mesh_dim_names=("ep", "edp"),
+ )
+ for layer_id, layer in enumerate(model.model.model.layers if is_using_hf() else model.model.layers):
+ if layer_id >= gpc.config.model.first_k_dense_replace:
+ # Should follow this modeling pattern if EP is enabled.
+ # Change the expert module name if needed.
+ # TODO: Make this part hard-coded or config-driven?
+ fully_shard(layer.mlp.experts, mesh=device_mesh["edp"], **fsdp_kwargs)
+ for module in model.modules():
+ if isinstance(module, wrap_cls):
+ fully_shard(module, **fsdp_kwargs)
+ fully_shard(model, **fsdp_kwargs)
+ if fsdp_init_method == "meta":
+ _materialize_meta_module(model, set(), get_current_device())
+ elif fsdp_init_method == "cpu":
+ model.to(get_current_device())
+ else:
+ raise ValueError(f"Unsupported FSDP mode: {fsdp_mode}")
+
+ if not gpc.config.ckpt.get("auto_resume", False):
+ load_ckpt_info = gpc.config.ckpt.load_ckpt_info
+ load_ckpt_path = load_ckpt_info.get("path", None)
+ load_ckpt_content = load_ckpt_info.get("content", [])
+ if load_ckpt_path:
+ assert load_ckpt_content == (
+ "model",
+ ), "If auto_resume=False and checkpoint path is given, only model can be loaded"
+ if DCP_SUPPORTED:
+ if is_using_hf() and RESUME_HF_FORMAT:
+ hf = gpc.config.hf
+ mod = LazyObject(hf.mod, hf.mod_cls)
+ mod = mod.build()
+ state_dict = mod.from_pretrained(
+ pretrained_model_name_or_path=load_ckpt_path, use_safetensors=True
+ ).state_dict()
+ state_dict = {f"model.{key}": state_dict[key].clone().detach() for key in state_dict}
+ set_model_state_dict(
+ model=model, model_state_dict=state_dict, options=StateDictOptions(full_state_dict=True)
+ )
+ else:
+ state_dict = get_model_state_dict(model=model)
+ state_dict = {key: state_dict[key].clone().detach() for key in state_dict}
+ dcp.load(state_dict=state_dict, checkpoint_id=load_ckpt_path)
+ set_model_state_dict(model=model, model_state_dict=state_dict)
+ del state_dict
+ internlm_accelerator.empty_cache()
+ else:
+ raise RuntimeError("DCP is not supported in this version of PyTorch.")
+
+ return model
diff --git a/internlm/core/naive_amp.py b/internlm/core/naive_amp.py
index 7cac640da..006c1a2ce 100644
--- a/internlm/core/naive_amp.py
+++ b/internlm/core/naive_amp.py
@@ -14,7 +14,7 @@
from internlm.accelerator import AcceleratorType, get_accelerator
from internlm.core.context import ParallelMode
-from internlm.core.context.parallel_context import global_context as gpc
+from internlm.core.context import global_context as gpc
internlm_accelerator = get_accelerator()
@@ -94,6 +94,8 @@ def _convert_to_fp32(self, input_: Any):
"""Converts the input to fp32 if it is a Tensor of dtype float16."""
if isinstance(input_, Tensor) and input_.dtype in (torch.float16, torch.bfloat16):
input_ = input_.float()
+ elif isinstance(input_, (tuple, list)):
+ input_ = [self._convert_to_fp32(val) for val in input_]
return input_
def convert_to_fp32(self, out):
diff --git a/internlm/core/parallel/comm/__init__.py b/internlm/core/parallel/comm/__init__.py
index be170f286..422578ed1 100644
--- a/internlm/core/parallel/comm/__init__.py
+++ b/internlm/core/parallel/comm/__init__.py
@@ -1,3 +1,48 @@
from .attn_offload import get_offload_manager, initialize_offload_manager
+from .isp import (
+ EmbeddingWeightParallelCommunicator,
+ HeadWeightParallelCommunicator,
+ ISPCommModelConfig,
+ ISPCommunicator,
+ ISPCommunicatorSchedulerHook,
+ ISPCommunicatorWrapper,
+ WPCommunicator,
+ auto_wrap_distributed_attention,
+ auto_wrap_func_distributed_attention,
+)
+from .tensor import (
+ EmbeddingSequenceParallelCommunicator,
+ EmbeddingTensorParallelCommunicator,
+ HeadSequenceParallelCommunicator,
+ HeadTensorParallelCommunicator,
+ LinearRole,
+ MoESequenceParallelCommunicator,
+ SequenceParallelCommunicator,
+ TensorParallelCommunicator,
+ TPCommunicator,
+)
+from .zero import ParamAsyncBcastHandler
-__all__ = ["initialize_offload_manager", "get_offload_manager"]
+__all__ = [
+ "initialize_offload_manager",
+ "get_offload_manager",
+ "EmbeddingWeightParallelCommunicator",
+ "HeadWeightParallelCommunicator",
+ "ISPCommModelConfig",
+ "ISPCommunicator",
+ "ISPCommunicatorWrapper",
+ "ISPCommunicatorSchedulerHook",
+ "WPCommunicator",
+ "auto_wrap_distributed_attention",
+ "auto_wrap_func_distributed_attention",
+ "EmbeddingSequenceParallelCommunicator",
+ "EmbeddingTensorParallelCommunicator",
+ "HeadSequenceParallelCommunicator",
+ "HeadTensorParallelCommunicator",
+ "LinearRole",
+ "MoESequenceParallelCommunicator",
+ "SequenceParallelCommunicator",
+ "TensorParallelCommunicator",
+ "TPCommunicator",
+ "ParamAsyncBcastHandler",
+]
diff --git a/internlm/core/parallel/comm/attn_offload.py b/internlm/core/parallel/comm/attn_offload.py
index da23f3ae8..02f1cd15d 100644
--- a/internlm/core/parallel/comm/attn_offload.py
+++ b/internlm/core/parallel/comm/attn_offload.py
@@ -1,8 +1,10 @@
import torch
+from internlm.accelerator import AcceleratorType, get_accelerator
from internlm.utils.common import get_current_device
global_attn_offload = None
+internlm_accelerator = get_accelerator()
class AttnOffloadManager:
@@ -117,7 +119,8 @@ def preload_fa_output_with_layer(self, layer_idx):
def initialize_offload_manager(enable_cpu_offload: bool = False):
global global_attn_offload
if global_attn_offload is None:
- global_attn_offload = AttnOffloadManager(enable_cpu_offload)
+ if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU:
+ global_attn_offload = AttnOffloadManager(enable_cpu_offload)
return global_attn_offload
diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py
index 7f4c5d7e9..81424c5bd 100644
--- a/internlm/core/parallel/comm/isp.py
+++ b/internlm/core/parallel/comm/isp.py
@@ -29,9 +29,8 @@
expandKVPacked,
reduce_scatter_raw,
)
-from internlm.model.modules.embedding import Embedding1D
-from internlm.model.modules.linear import ParallelLinearWithCommExt
-from internlm.model.modules.utils import is_moe_param
+from internlm.model.model_ops.modules.linear import ParallelLinearWithCommExt
+from internlm.model.model_ops.modules.utils import is_moe_param
from internlm.utils.common import SchedulerHook, get_current_device
from internlm.utils.utils import (
CuSeqlenType,
@@ -183,14 +182,19 @@ class EmbeddingWeightParallelCommunicator:
"""
def __init__(self, parallel_mode: ParallelMode) -> None:
+ from internlm.model.model_ops.modules.embedding import Embedding1D
+
+ self.embedding1d_cls = Embedding1D
self.parallel_mode = parallel_mode
self.gather_dim = 0
self._cur_micro_step = 0
self._num_micro_step = gpc.config.data.micro_num
- def register_module_hook(self, module: Embedding1D) -> None:
- assert isinstance(module, Embedding1D), "Embbeding weight parallel communicator is only support Embedding1D"
+ def register_module_hook(self, module: nn.Module) -> None:
+ assert isinstance(
+ module, self.embedding1d_cls
+ ), "Embbeding weight parallel communicator is only support Embedding1D"
module.weight.evo_tensor = None
self.gather_dim = 0 if module.vocab_parallel else 1
@@ -1501,6 +1505,17 @@ def _q_k_v(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwa
Returns:
* output (Tensor): context output
"""
+ # if the num head of kv is not enough to be splitted by sp
+ # then we could copy the kv head
+ num_head_k = k.shape[2]
+ if self.sp_size > num_head_k:
+ assert self.sp_size % num_head_k == 0, "the num_head_k should be divided by sp size."
+ k = expandKVPacked(k, self.sp_size // num_head_k, 2)
+ num_head_v = v.shape[2]
+ if self.sp_size > num_head_v:
+ assert self.sp_size % num_head_v == 0, "the num_head_v should be divided by sp size."
+ v = expandKVPacked(v, self.sp_size // num_head_v, 2)
+
# self._scatter_gather_idx["q"] = [1, 0] # q/k/v shape: [sequence, head, head_dim]
# q shpae: [1, packlen, n_head, head_dim] or [batch, seqlen, n_head, head_dim]
# scatter in n_head and gather in seqlen(packlen)
@@ -1559,8 +1574,10 @@ def _attetion_constructor(*args, attn_impl: type, **kwargs) -> Callable:
if tp_mode != TensorParallelMode.isp.name:
return attn_impl(*args, **kwargs)
else:
- return DistributedAttention(
- local_attention=attn_impl, sequence_process_group=gpc.get_group(ParallelMode.TENSOR)
- )(*args, **kwargs)
+ if gpc.config.parallel.sequence_2D.enable is True:
+ spg = gpc.get_group(ParallelMode.HEAD)
+ else:
+ spg = gpc.get_group(ParallelMode.TENSOR)
+ return DistributedAttention(local_attention=attn_impl, sequence_process_group=spg)(*args, **kwargs)
return partial(_attetion_constructor, attn_impl=attn_impl)
diff --git a/internlm/core/parallel/comm/tensor.py b/internlm/core/parallel/comm/tensor.py
index 229f9a9b0..a9c8b1f44 100644
--- a/internlm/core/parallel/comm/tensor.py
+++ b/internlm/core/parallel/comm/tensor.py
@@ -10,7 +10,7 @@
from torch import distributed as dist
from internlm.core.context import ParallelMode
-from internlm.core.context.parallel_context import global_context as gpc
+from internlm.core.context import global_context as gpc
from internlm.core.parallel.comm.utils import (
DUMMY_HANDLE_CONST,
AsyncCommHandle,
@@ -23,8 +23,7 @@
reduce_scatter_raw,
split_forward_gather_backward,
)
-from internlm.model.modules.embedding import Embedding1D
-from internlm.model.moe.moe import MoE
+from internlm.model.model_ops.moe.moe import MoE
# input gather dim
_GATHER_DIM = 1 # shape: [batch, seqlen, dim] or [1, packlen, dim]
@@ -339,14 +338,21 @@ class EmbeddingTensorParallelCommunicator:
"""
def __init__(self, parallel_mode: ParallelMode) -> None:
+ from internlm.model.model_ops.modules.embedding import Embedding1D
+
+ self.embedding1d_class = Embedding1D
self._parallel_mode = parallel_mode
- def register_module_hook(self, module: Embedding1D) -> None:
- assert isinstance(module, Embedding1D), "Embbeding tensor parallel communicator is only support Embedding1D"
+ def register_module_hook(self, module: torch.nn.Module) -> None:
+ assert isinstance(
+ module, self.embedding1d_class
+ ), "Embbeding tensor parallel communicator is only support Embedding1D"
module.register_forward_hook(self.output_hook)
- def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tuple[Any]: # pylint: disable=W0613
+ def output_hook(
+ self, module: torch.nn.Module, args: Any, output: Tuple[Any] # pylint: disable=W0613
+ ) -> Tuple[Any]:
"""
split output after forward and allgather grad_output before backward.
"""
@@ -366,14 +372,21 @@ class EmbeddingSequenceParallelCommunicator:
"""
def __init__(self, parallel_mode: ParallelMode) -> None:
+ from internlm.model.model_ops.modules.embedding import Embedding1D
+
+ self.embedding1d_class = Embedding1D
self._parallel_mode = parallel_mode
- def register_module_hook(self, module: Embedding1D) -> None:
- assert isinstance(module, Embedding1D), "Embbeding sequence parallel communicator is only support Embedding1D"
+ def register_module_hook(self, module: torch.nn.Module) -> None:
+ assert isinstance(
+ module, self.embedding1d_class
+ ), "Embbeding sequence parallel communicator is only support Embedding1D"
module.register_forward_hook(self.output_hook)
- def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tuple[Any]: # pylint: disable=W0613
+ def output_hook(
+ self, module: torch.nn.Module, args: Any, output: Tuple[Any] # pylint: disable=W0613
+ ) -> Tuple[Any]:
"""
split output after forward and allgather grad_output before backward.
"""
diff --git a/internlm/core/parallel/comm/zero.py b/internlm/core/parallel/comm/zero.py
index 58929290f..3aa8f8a35 100644
--- a/internlm/core/parallel/comm/zero.py
+++ b/internlm/core/parallel/comm/zero.py
@@ -7,14 +7,13 @@
from torch import distributed as dist
from torch import nn
+from torch._utils import _flatten_dense_tensors
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import unwrap_naive_amp
-from internlm.core.parallel.comm.isp import ISPCommunicatorWrapper
-from internlm.model.modules.embedding import Embedding1D
-from internlm.model.modules.linear import ScaleColumnParallelLinear
-from internlm.solver.optimizer.utils import flatten
+from internlm.core.parallel.comm import ISPCommunicatorWrapper
+from internlm.model.model_ops.modules.linear import ScaleColumnParallelLinear
class ParamAsyncBcastHandler:
@@ -28,6 +27,10 @@ def __init__(
model: Union[nn.Module, nn.ModuleList],
isp_communicator: ISPCommunicatorWrapper = None,
) -> None:
+ from internlm.model.model_ops.modules.embedding import Embedding1D
+
+ self.embedding1d_cls = Embedding1D
+
self._block_to_param: Dict[nn.Module, List[nn.Parameter]] = OrderedDict()
self._param_to_rank: Dict[nn.Parameter, int] = {}
self._block_to_rank: Dict[nn.Module, int] = {}
@@ -121,7 +124,7 @@ def _pre_forward_hook(model: nn.Module, *args, **kwargs): # pylint: disable=W06
# NOTE: Although the layernorm layer does not have explicit processing,
# both ISPCommunicator and ParamAsyncBcastHandler handle transformer blocks as granularity,
# so everything is fine.
- if isp_communicator is None or isinstance(block, (Embedding1D, ScaleColumnParallelLinear)):
+ if isp_communicator is None or isinstance(block, (self.embedding1d_cls, ScaleColumnParallelLinear)):
block.register_forward_pre_hook(_pre_forward_hook)
if isp_communicator:
isp_communicator.register_prerequisite_for_forward_prefetch_hooks(_pre_forward_hook)
@@ -156,7 +159,9 @@ def _pre_forward_hook(model: nn.Module, *args, **kwargs): # pylint: disable=W06
for working_param, all_splited_param in zip(
self._block_working_params[block_name], all_splited_param_list
):
- working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].view_as(working_param))
+ working_param.data.copy_(
+ _flatten_dense_tensors(all_splited_param)[: working_param.numel()].view_as(working_param)
+ )
self._block_allgather_handles[block_name] = None
self._block_gathered_params[block_name] = []
@@ -170,7 +175,7 @@ def _pre_forward_hook(model: nn.Module, *args, **kwargs): # pylint: disable=W06
# NOTE: Although the layernorm layer does not have explicit processing,
# both ISPCommunicator and ParamAsyncBcastHandler handle transformer blocks as granularity,
# so everything is fine.
- if isp_communicator is None or isinstance(block, (Embedding1D, ScaleColumnParallelLinear)):
+ if isp_communicator is None or isinstance(block, (self.embedding1d_cls, ScaleColumnParallelLinear)):
block.register_forward_pre_hook(_pre_forward_hook)
if isp_communicator:
isp_communicator.register_prerequisite_for_forward_prefetch_hooks(_pre_forward_hook)
diff --git a/internlm/core/parallel/shard.py b/internlm/core/parallel/shard.py
index 979f5cf23..d23b2e3de 100644
--- a/internlm/core/parallel/shard.py
+++ b/internlm/core/parallel/shard.py
@@ -34,7 +34,7 @@ def _split_data_for_sequence_parallel(data, label):
data["indexes"] = _split(data["indexes"], ParallelMode.TENSOR, dim=_indexes_seq_dim)
# NOTICE: For compatibility where the shape of position_ids is [batch, seqlen, ...]
- if ("inject_info" in gpc.config.model and gpc.config.model.inject_info.get("data_helper", False)) or is_using_hf():
+ if is_using_hf():
_position_ids_seq_dim = 1
data["position_ids"] = _split(data["position_ids"], ParallelMode.TENSOR, dim=_position_ids_seq_dim)
diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py
index 7e309beb6..ff7b86e64 100644
--- a/internlm/core/scheduler/no_pipeline_scheduler.py
+++ b/internlm/core/scheduler/no_pipeline_scheduler.py
@@ -59,6 +59,44 @@ def __init__(
super().__init__(data_process_func)
+ def _call_engine_mtp_criterion(self, engine: Engine, outputs: Any, labels: Any):
+ """Calls the engine's criterion with the given outputs and labels.
+ Args:
+ engine (internlm.core.Engine): InternLM engine for training and inference.
+ outputs (Any): The outputs from the model, can be of type torch.Tensor, list, tuple, or dict.
+ labels (Any): The labels for the outputs, can be of type torch.Tensor, list, tuple, or dict.
+ """
+ assert isinstance(
+ outputs, (torch.Tensor, list, tuple, dict)
+ ), f"Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}"
+
+ mtp_losses = []
+ for i, (output, label) in enumerate(zip(outputs, labels)):
+ if isinstance(output, torch.Tensor):
+ output = (output,)
+ if isinstance(label, torch.Tensor):
+ label = (label,)
+
+ self._call_hooks("before_criterion", output, label)
+ if isinstance(output, (tuple, list)) and isinstance(label, (tuple, list)):
+ mtp_loss = engine.mtp_criterions[i](*output, *label)
+ elif isinstance(output, (tuple, list)) and isinstance(label, dict):
+ mtp_loss = engine.mtp_criterions[i](*output, **label)
+ elif isinstance(output, dict) and isinstance(label, dict):
+ mtp_loss = engine.mtp_criterions[i](**output, **label)
+ elif isinstance(output, dict) and isinstance(label, (list, tuple)):
+ raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(label)}")
+ else:
+ raise TypeError(
+ f"Expected model outputs and labels to be of type torch.Tensor ' \
+ '(which is auto-converted to tuple), list, tuple, or dict, ' \
+ 'but got {type(output)} (model outputs) and {type(label)} (labels)"
+ )
+ self._call_hooks("after_criterion", mtp_loss)
+ mtp_losses.append(mtp_loss)
+
+ return mtp_losses
+
def pre_processing(self, engine: Engine):
"""Performs actions before running the schedule.
@@ -116,8 +154,10 @@ def _train_one_batch(
with conditional_context(torch.no_grad(), enable=forward_only):
self._call_hooks("before_forward", data)
if hasattr(gpc.config.model, "num_experts"):
- # moe is used
- output, moe_losses = self._call_engine(engine, data)
+ if hasattr(gpc.config.model, "num_mtp_layers") and gpc.config.model.num_mtp_layers > 0:
+ output, moe_losses, mtp_outputs = self._call_engine(engine, data)
+ else:
+ output, moe_losses = self._call_engine(engine, data)
else:
output = self._call_engine(engine, data)
self._call_hooks("after_forward", output)
@@ -128,6 +168,26 @@ def _train_one_batch(
self._call_hooks("before_criterion", output, label)
loss = self._call_engine_criterion(engine, output, label)
self._call_hooks("after_criterion", loss)
+
+ if hasattr(gpc.config.model, "num_mtp_layers") and gpc.config.model.num_mtp_layers > 0:
+ mtp_labels = []
+ for i in range(gpc.config.model.num_mtp_layers):
+ mtp_labels.append(
+ torch.cat(
+ [
+ label[:, i + 1 :],
+ torch.full((label.size(0), i + 1), -100, dtype=label.dtype, device=label.device),
+ ],
+ dim=1,
+ )
+ )
+ mtp_losses = self._call_engine_mtp_criterion(engine, mtp_outputs, mtp_labels)
+ mtp_loss = sum(mtp_losses) * gpc.config.loss.mtp_loss_coeff
+ mtp_loss /= scale_loss
+ loss += mtp_loss
+ else:
+ mtp_loss = None
+
moe_loss = (
sum(moe_losses) * gpc.config.loss.moe_loss_coeff # pylint: disable=E0606
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1
diff --git a/internlm/core/trainer.py b/internlm/core/trainer.py
index 3b01d3afd..0579c2e8c 100644
--- a/internlm/core/trainer.py
+++ b/internlm/core/trainer.py
@@ -4,12 +4,25 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
import json
+import math
import os
+import time
from collections import deque
-from typing import Iterable, Optional
+from typing import Iterable, List, Optional
+from torch.utils.data import DataLoader
+
+from internlm.core.context import ParallelMode
+from internlm.core.context import global_context as gpc
from internlm.core.engine import Engine
+from internlm.core.parallel.comm import ISPCommunicatorSchedulerHook
from internlm.core.scheduler import BaseScheduler, NonPipelineScheduler
+from internlm.data.utils import unpack_type_ids
+from internlm.model.model_ops.metrics import SchedulerMetricHook
+from internlm.monitor import monitor_manager as mm
+from internlm.utils.common import SchedulerHook, set_env_var
+from internlm.utils.megatron_timers import megatron_timer as timer
+from internlm.utils.timeout import llm_timeout
class TrainState:
@@ -206,3 +219,220 @@ def execute_schedule(self, data_iter: Iterable, **kwargs):
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, moe_loss).
"""
return self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
+
+
+def get_scheduler_hooks(metric, zero_optim, isp_communicator_wrapper) -> List[SchedulerHook]:
+ scheduler_hooks: List[SchedulerHook] = []
+
+ if metric is not None:
+ scheduler_hooks.append(
+ SchedulerMetricHook(
+ metric=metric,
+ skip=(
+ gpc.is_using_parallel_mode(ParallelMode.PIPELINE)
+ and hasattr(gpc.config.model, "num_chunks")
+ and gpc.config.model.num_chunks > 1
+ and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
+ ),
+ ),
+ )
+
+ if isp_communicator_wrapper is not None:
+ for isp_communicator in isp_communicator_wrapper.isp_communicators:
+ if isp_communicator is not None and isp_communicator.overlap:
+ scheduler_hooks.append(ISPCommunicatorSchedulerHook(isp_communicator, zero_optim))
+
+ return scheduler_hooks
+
+
+@llm_timeout(func_name="load_new_batch")
+def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState):
+ """
+ Load and return the new batch data based on training data loader.
+
+ Args:
+ train_dl (torch.utils.data.DataLoader): Dataloader for training.
+ train_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
+ train_state (TrainState): Current training state.
+
+ Returns: A batch data and the updated train_iter.
+ """
+
+ timer("batch-gen").start()
+ try:
+ batch = next(train_iter) # structure is ({'input_ids': Tensor, 'cu_seqlens': Tensor}, Tensor)
+ if hasattr(train_state, "batch_sampler_iter"):
+ next(train_state.batch_sampler_iter)
+ except StopIteration:
+ train_iter = iter(train_dl)
+ batch = next(train_iter)
+ train_state.num_consumed_samples_in_epoch = 0
+ if hasattr(train_state, "batch_sampler"):
+ train_state.batch_sampler.batch_count = 0
+ train_state.batch_sampler.num_consumed_samples_in_epoch = 0
+ train_state.batch_sampler_iter = iter(train_state.batch_sampler)
+ next(train_state.batch_sampler_iter)
+ timer("batch-gen").stop()
+
+ if batch[0].get("type_ids", None) is not None:
+ # if use_packed_dataset is False, we need to unpack type_ids
+ if not gpc.config.data.use_packed_dataset:
+ batch[0]["type_ids"] = unpack_type_ids(batch[0]["type_ids"], batch[0]["cu_seqlens"])
+
+ return batch, train_iter
+
+
+@llm_timeout(func_name="record_current_batch_training_metrics")
+def record_current_batch_training_metrics(
+ get_tflops_func,
+ logger,
+ writer,
+ success_update,
+ batch_count,
+ batch,
+ train_state,
+ optimizer,
+ beta2_scheduler,
+ engine,
+ start_time,
+ very_begining_time,
+ loss,
+ moe_loss,
+ grad_norm,
+ metric,
+):
+ """
+ Print some training metrics of current batch.
+ """
+
+ set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time()))
+
+ timer.store_last_timers()
+ if success_update in (0, True):
+ train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
+ if gpc.is_no_pp_or_last_stage():
+ acc_perplex = metric.get_metric()
+
+ if success_update and gpc.is_rank_for_log():
+ lr = optimizer.param_groups[0]["lr"]
+ if hasattr(engine.optimizer, "grad_scaler"):
+ scaler = engine.optimizer.grad_scaler._scale.item()
+ elif hasattr(engine.optimizer.optim, "grad_scaler"):
+ scaler = engine.optimizer.optim.grad_scaler._scale.item()
+
+ num_tokens_in_batch = batch[1].nelement()
+ real_num_tokens = math.ceil(acc_perplex.pop("real_token_num") / gpc.get_world_size(ParallelMode.GLOBAL))
+ num_samples_in_batch = sum([len(b) - 1 for b in batch[0]["cu_seqlens"]])
+ max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]])
+ max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]])
+ min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]])
+ time_cost = time.time() - start_time
+ tk_per_gpu = round(
+ num_tokens_in_batch * gpc.get_world_size(ParallelMode.DATA) / gpc.get_world_size(ParallelMode.GLOBAL),
+ 4,
+ )
+ tgs_statistic = train_state.tgs_statistic
+ tgs_statistic["sum_step"] += 1
+ tgs_statistic["sum_tg"] += tk_per_gpu
+ tgs_statistic["total_time"] = time.time() - very_begining_time
+ tgs_statistic["sum_last_tg_10"] += tk_per_gpu
+ tgs_statistic["sum_last_time_10"] += time_cost
+ tgs_statistic["sum_last_tg_50"] += tk_per_gpu
+ tgs_statistic["sum_last_time_50"] += time_cost
+ tgs_statistic["SMA_tg_50"] += tk_per_gpu
+ tgs_statistic["SMA_time_50"] += time_cost
+ tgs_statistic["SMA_tg_50_list"].append(tk_per_gpu)
+ tgs_statistic["SMA_time_50_list"].append(time_cost)
+ if tgs_statistic["sum_step"] > 50:
+ tgs_statistic["SMA_tg_50"] -= tgs_statistic["SMA_tg_50_list"][0]
+ tgs_statistic["SMA_time_50"] -= tgs_statistic["SMA_time_50_list"][0]
+ tgs_statistic["SMA_tg_50_list"].popleft()
+ tgs_statistic["SMA_time_50_list"].popleft()
+
+ last_tgs_1 = round(tk_per_gpu / time_cost, 2)
+ tgs_statistic["sum_tgs"] += last_tgs_1
+
+ if tgs_statistic["sum_step"] % 10 == 0:
+ tgs_statistic["last_tgs_10"] = round(tgs_statistic["sum_last_tg_10"] / tgs_statistic["sum_last_time_10"], 2)
+ tgs_statistic["sum_last_tg_10"] = 0
+ tgs_statistic["sum_last_time_10"] = 0
+
+ if tgs_statistic["sum_step"] % 50 == 0:
+ tgs_statistic["last_tgs_50"] = round(tgs_statistic["sum_last_tg_50"] / tgs_statistic["sum_last_time_50"], 2)
+ tgs_statistic["sum_last_tg_50"] = 0
+ tgs_statistic["sum_last_time_50"] = 0
+
+ last_tgs_10 = tgs_statistic["last_tgs_10"]
+ last_tgs_50 = tgs_statistic["last_tgs_50"]
+
+ tgs_all = round(tgs_statistic["sum_tg"] / tgs_statistic["total_time"], 2)
+ tgs_avg = round(tgs_statistic["sum_tgs"] / tgs_statistic["sum_step"], 2)
+ tgs_SMA = round(tgs_statistic["SMA_tg_50"] / tgs_statistic["SMA_time_50"], 2)
+
+ tflops = get_tflops_func(time_cost)
+
+ tgs_origin = round(
+ num_tokens_in_batch
+ * gpc.get_world_size(ParallelMode.DATA)
+ / gpc.get_world_size(ParallelMode.GLOBAL)
+ / time_cost,
+ 2,
+ )
+
+ real_tgs = round(
+ real_num_tokens / time_cost,
+ 2,
+ )
+
+ infos = {
+ "tflops": tflops,
+ "step": batch_count,
+ "loss": loss - moe_loss if moe_loss is not None else loss,
+ "real_tgs": real_tgs,
+ "tgs (tokens/gpu/second)": tgs_origin,
+ "tgs/last_tgs_1": last_tgs_1,
+ "tgs/tgs_all": tgs_all,
+ "tgs/tgs_avg": tgs_avg,
+ "tgs/tgs_SMA": tgs_SMA,
+ "tgs/last_tgs_10": last_tgs_10,
+ "tgs/last_tgs_50": last_tgs_50,
+ "lr": lr,
+ "loss_scale": scaler,
+ "grad_norm": grad_norm,
+ }
+ if moe_loss is not None:
+ infos["moe_loss"] = moe_loss
+
+ infos["micro_num"] = len(batch[1])
+ infos["num_consumed_tokens"] = train_state.num_consumed_tokens
+ infos["inf_nan_skip_batches"] = train_state.inf_nan_skip_batches
+ infos["num_samples_in_batch"] = num_samples_in_batch # the number of batches which have the most samples
+ infos["largest_length"] = max_length_in_batch # the longest input
+ infos["largest_batch"] = max_samples_in_batch # the batch with the most samples
+ infos["smallest_batch"] = min_samples_in_batch
+ infos["adam_beta2"] = beta2_scheduler.get_beta2()
+
+ fwd_bwd_time = round(timer("fwd-bwd").elapsed(), 2)
+ infos["fwd_bwd_time"] = fwd_bwd_time
+ bwd_time = round(timer("bwd").elapsed(), 2)
+ infos["bwd_time"] = bwd_time
+
+ for key, value in acc_perplex.items():
+ infos[key] = value
+
+ line = ""
+ for key, value in infos.items():
+ line += f"{key}={value} "
+ if isinstance(value, dict):
+ writer.add_scalars(key=key, value=value, step=train_state.step_count)
+ else:
+ writer.add_scalar(key=key, value=value, step=train_state.step_count)
+
+ logger.info(line)
+
+ # if loss spike occurs, send alert info to feishu
+ mm.monitor_loss_spike(
+ alert_address=gpc.config.monitor.alert.feishu_alert_address,
+ step_count=batch_count,
+ cur_step_loss=loss,
+ )
diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py
index 7df72c442..53440979b 100644
--- a/internlm/core/trainer_builder.py
+++ b/internlm/core/trainer_builder.py
@@ -9,26 +9,28 @@
from torch.utils.data import DataLoader
from internlm.checkpoint.checkpoint_manager import CheckpointManager
+from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
-from internlm.core.context.process_group_initializer import ParallelMode
from internlm.core.parallel.comm import initialize_offload_manager
-from internlm.core.trainer import Trainer
-from internlm.data.streaming.utils import streaming_simple_resume
-from internlm.data.train_state import get_train_state
-from internlm.eval.evaluation import evaluate_on_val_dls
-from internlm.initialize.initialize_trainer import initialize_trainer
-from internlm.model.losses.ce_loss import InternLoss
-from internlm.model.metrics import AccPerplex
-from internlm.monitor.monitor import send_alert_message
-from internlm.train.pipeline import (
- generate_meta_data,
+from internlm.core.trainer import (
+ Trainer,
get_scheduler_hooks,
- initialize_llm_profile,
- initialize_optimizer,
- inject_model,
load_new_batch,
record_current_batch_training_metrics,
)
+from internlm.data.streaming.utils import streaming_simple_resume
+from internlm.data.train_state import get_train_state
+from internlm.eval import evaluate_on_val_dls
+from internlm.initialize import initialize_trainer
+from internlm.initialize.initialize_model import (
+ generate_meta_data,
+ initialize_model_and_parallel_communicator,
+)
+from internlm.initialize.initialize_optimizer import initialize_optimizer
+from internlm.initialize.initialize_profiler import initialize_llm_profile
+from internlm.model.model_ops.losses.ce_loss import InternLoss
+from internlm.model.model_ops.metrics import AccPerplex
+from internlm.monitor import send_alert_message
from internlm.utils.common import (
BatchSkipper,
check_cuda_env,
@@ -100,8 +102,8 @@ def __init__(
# load config_lines
config_lines = self._read_config(kwargs["config"])
- # inject model for amp, parallel setting, parameter syncing and others
- model, isp_communicator = inject_model(model)
+ # initialize model and communicators
+ model, isp_communicator = initialize_model_and_parallel_communicator(model)
# check cuda env
check_cuda_env()
@@ -112,6 +114,9 @@ def __init__(
# initialize loss function
criterion = self._initialize_criterion()
+ # initialize mtp loss function
+ mtp_criterions = self._initialize_mtp_criterion()
+
# initialize cpu offload manager for selective checkpoint
initialize_offload_manager(gpc.config.get("selective_checkpoint_offload", False))
@@ -147,6 +152,7 @@ def __init__(
model=model,
optimizer=optimizer,
criterion=criterion,
+ mtp_criterions=mtp_criterions,
lr_scheduler=lr_scheduler,
beta2_scheduler=beta2_scheduler,
scheduler_hooks=get_scheduler_hooks(self.metric, optimizer, isp_communicator),
@@ -159,6 +165,20 @@ def __init__(
super().__init__(engine, scheduler)
+ def _initialize_mtp_criterion(self) -> InternLoss:
+ if hasattr(gpc.config.model, "num_mtp_layers") and gpc.config.model.num_mtp_layers > 0:
+ mtp_criterions = []
+ for _ in range(gpc.config.model.num_mtp_layers):
+ mtp_criterion = InternLoss(
+ parallel_output=gpc.config.model.parallel_output,
+ label_smoothing=gpc.config.loss.label_smoothing,
+ op_type=gpc.config.loss.op_type,
+ )
+ mtp_criterions.append(mtp_criterion)
+ else:
+ mtp_criterions = []
+ return mtp_criterions
+
def _setup_time_and_logging(self) -> str:
current_time = launch_time()
objs = [current_time]
@@ -363,8 +383,8 @@ def _record_metrics(self, batch_count: int, batch, start_time, loss, moe_loss, s
engine=self.engine,
start_time=start_time,
very_begining_time=self.very_beginning_time,
- loss=loss,
- moe_loss=moe_loss,
+ loss=loss.item() if isinstance(loss, torch.Tensor) else loss,
+ moe_loss=moe_loss.item() if isinstance(moe_loss, torch.Tensor) else moe_loss,
grad_norm=grad_norm_groups,
metric=self.metric,
)
diff --git a/internlm/data/streaming/dataset.py b/internlm/data/streaming/dataset.py
index 8b0755edf..5d7f22445 100644
--- a/internlm/data/streaming/dataset.py
+++ b/internlm/data/streaming/dataset.py
@@ -8,10 +8,10 @@
from datasets.distributed import split_dataset_by_node
from PIL import Image
from torch.utils.data import Dataset
+from transformers import AutoTokenizer
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
-from transformers import AutoTokenizer
class StreamingDataset(Dataset):
diff --git a/internlm/data/tokenized/dummy_dataset.py b/internlm/data/tokenized/dummy_dataset.py
index dcb6c027d..f057941bc 100644
--- a/internlm/data/tokenized/dummy_dataset.py
+++ b/internlm/data/tokenized/dummy_dataset.py
@@ -4,7 +4,7 @@
import numpy as np
from torch.utils.data import Dataset
-# from internlm.core.context.parallel_context import global_context as gpc
+# from internlm.core.context import global_context as gpc
class RandomDataset(Dataset):
diff --git a/internlm/data/utils.py b/internlm/data/utils.py
index 352273c79..74e860997 100644
--- a/internlm/data/utils.py
+++ b/internlm/data/utils.py
@@ -5,8 +5,8 @@
import torch
+from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
-from internlm.core.context.process_group_initializer import ParallelMode
from internlm.utils.parallel import is_using_hf
@@ -64,8 +64,7 @@ def unpack_data(data, label):
# per batch's index should be equal, so we select first batch
data["indexes"] = data["indexes"][0]
- # If model has inject_info and data_helper is enabled, we provide position_ids
- if ("inject_info" in gpc.config.model and gpc.config.model.inject_info.get("data_helper", False)) or is_using_hf():
+ if is_using_hf():
data.pop("max_seqlen")
data["position_ids"] = data.pop("indexes").unsqueeze(0) # [batch, seqlen]
@@ -81,8 +80,7 @@ def packed_data_normalizer(data, label):
data["cu_seqlens"] = data["cu_seqlens"][0].squeeze(0)
data["max_seqlen"] = (data["cu_seqlens"][1:] - data["cu_seqlens"][:-1]).max().item()
- # If model has inject_info and data_helper is enabled, we provide position_ids, cu_seqlens, max_seqlen
- if ("inject_info" in gpc.config.model and gpc.config.model.inject_info.get("data_helper", False)) or is_using_hf():
+ if is_using_hf():
gpc.config.data[f"cu_seqlens_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] = data.pop("cu_seqlens")
gpc.config.data[f"max_seqlen_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] = data.pop("max_seqlen")
data["position_ids"] = data.pop("indexes").unsqueeze(0) # [batch, seqlen]
diff --git a/internlm/eval/__init__.py b/internlm/eval/__init__.py
index dc70e4d45..208779157 100644
--- a/internlm/eval/__init__.py
+++ b/internlm/eval/__init__.py
@@ -1,5 +1,11 @@
-from .evaluation import evaluate_on_val_dls
+from .evaluation import (
+ evaluate_on_val_dls,
+ switch_evaluation_mode,
+ switch_evaluation_pipeline_scheduler,
+)
__all__ = [
"evaluate_on_val_dls",
+ "switch_evaluation_mode",
+ "switch_evaluation_pipeline_scheduler",
]
diff --git a/internlm/eval/evaluation.py b/internlm/eval/evaluation.py
index 862057a3d..2b8dd08a3 100644
--- a/internlm/eval/evaluation.py
+++ b/internlm/eval/evaluation.py
@@ -9,7 +9,7 @@
from internlm.core.context import global_context as gpc
from internlm.core.parallel.shard import split_data_for_sequence_parallel
from internlm.core.scheduler.pipeline_scheduler_1f1b import get_tensor_shape
-from internlm.model.metrics import AccPerplex, SchedulerMetricHook
+from internlm.model.model_ops.metrics import AccPerplex, SchedulerMetricHook
from internlm.utils.common import get_current_device
from internlm.utils.parallel import is_using_isp
diff --git a/internlm/initialize/__init__.py b/internlm/initialize/__init__.py
index 14fe06bbb..c7d474957 100644
--- a/internlm/initialize/__init__.py
+++ b/internlm/initialize/__init__.py
@@ -1,17 +1,7 @@
+from .initialize_launcher import initialize_launcher
from .initialize_trainer import initialize_trainer
-from .launch import (
- get_default_parser,
- initialize_distributed_env,
- launch_from_slurm,
- launch_from_torch,
- try_bind_numa,
-)
__all__ = [
- "get_default_parser",
+ "initialize_launcher",
"initialize_trainer",
- "launch_from_slurm",
- "launch_from_torch",
- "initialize_distributed_env",
- "try_bind_numa",
]
diff --git a/internlm/initialize/constants.py b/internlm/initialize/constants.py
new file mode 100644
index 000000000..28474d075
--- /dev/null
+++ b/internlm/initialize/constants.py
@@ -0,0 +1,9 @@
+#############################################
+# Default Distributed Master Port #
+#############################################
+DEFAULT_DISTRIBUTED_PORT = 8888
+
+#############################################
+# Default Universal Random Seed #
+#############################################
+DEFAULT_RANDOM_SEED = 1024
diff --git a/internlm/initialize/initialize_communicator.py b/internlm/initialize/initialize_communicator.py
new file mode 100644
index 000000000..c28fdfb86
--- /dev/null
+++ b/internlm/initialize/initialize_communicator.py
@@ -0,0 +1,216 @@
+from typing import Iterable, Tuple, TypeVar, Union
+
+from torch import nn
+
+from internlm.core.context import ParallelMode
+from internlm.core.context import global_context as gpc
+from internlm.core.naive_amp import unwrap_naive_amp
+from internlm.core.parallel.comm import (
+ EmbeddingSequenceParallelCommunicator,
+ EmbeddingTensorParallelCommunicator,
+ EmbeddingWeightParallelCommunicator,
+ HeadSequenceParallelCommunicator,
+ HeadTensorParallelCommunicator,
+ HeadWeightParallelCommunicator,
+ ISPCommModelConfig,
+ ISPCommunicator,
+ ISPCommunicatorWrapper,
+ LinearRole,
+ MoESequenceParallelCommunicator,
+ SequenceParallelCommunicator,
+ TensorParallelCommunicator,
+)
+from internlm.model.model_ops.modules.embedding import Embedding1D
+from internlm.model.model_ops.modules.linear import (
+ ColumnParallelLinear,
+ GroupedColumnLinear,
+ GroupedRowLinear,
+ GroupedWPLinear,
+ RewardModelLinear,
+ RowParallelLinear,
+ ScaleColumnParallelLinear,
+)
+from internlm.model.model_ops.moe import Experts, MoE
+from internlm.utils.common import get_current_device
+from internlm.utils.parallel import is_using_fsdp, is_using_isp
+from internlm.utils.utils import TensorParallelMode
+
+_T = TypeVar("_T")
+
+
+def submodule_filter(model: Union[nn.Module, nn.ModuleList], target_cls: Union[_T, Tuple[_T]]) -> Iterable[_T]:
+ for _chunk in unwrap_naive_amp(model):
+ for _module in _chunk.modules():
+ if not isinstance(_module, target_cls):
+ continue
+
+ yield _module
+
+
+def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]):
+ """
+ Initialize communicator for isp tensor parallel mode.
+
+ Args:
+ model (:class:`torch.nn.Module`): Your model instance to be trained or evaluated.
+
+ Returns:
+ An isp communicator for managing comp/comm overlap.
+ """
+ isp_communicator_wrapper = None
+ _retain_out_sharded = gpc.config.model.get("parallel_output", True)
+
+ if is_using_isp():
+ isp_communicator = ISPCommunicator(
+ model,
+ ISPCommModelConfig(
+ gpc.config.model.dtype,
+ get_current_device(),
+ gpc.config.model.checkpoint,
+ ),
+ gpc.config.parallel.weight.overlap and not is_using_fsdp(),
+ gpc.get_group(ParallelMode.WEIGHT),
+ is_moe=False,
+ selective_ckpt_offload=gpc.config.get("selective_checkpoint_offload", False),
+ early_reduce_scatter_release=gpc.config.parallel.weight.early_reduce_scatter_release,
+ enable_layer_fuse_isp_comm=gpc.config.parallel.weight.get("layer_fuse_isp_comm", False),
+ )
+ # register communicator for isp column parallel linear.
+ ColumnParallelLinear.register_cls_communicator(isp_communicator)
+ # row parallel linear will not be used.
+ RowParallelLinear.register_cls_communicator(None)
+ _head_communicator = HeadWeightParallelCommunicator(
+ weight_process_group=gpc.get_group(ParallelMode.WEIGHT),
+ seq_process_group=gpc.get_group(ParallelMode.TENSOR),
+ retain_out_sharded=_retain_out_sharded,
+ )
+ _embedding_communicator = EmbeddingWeightParallelCommunicator(ParallelMode.WEIGHT)
+
+ if gpc.config.model.get("num_experts", 1) > 1:
+ # register communicator for moe isp column parallel linear.
+ # NOTE: this wil overwrite registed communicator
+ moe_isp_communicator = ISPCommunicator(
+ model,
+ ISPCommModelConfig(
+ gpc.config.model.dtype,
+ get_current_device(),
+ gpc.config.model.checkpoint,
+ ),
+ gpc.config.parallel.expert_weight.overlap,
+ gpc.get_group(ParallelMode.EXPERT_WEIGHT),
+ is_moe=True,
+ early_reduce_scatter_release=gpc.config.parallel.expert_weight.early_reduce_scatter_release,
+ enable_layer_fuse_isp_comm=gpc.config.parallel.expert_weight.get("layer_fuse_isp_comm", False),
+ )
+ for moe in submodule_filter(model, Experts):
+ for column_linear in submodule_filter(moe, (ColumnParallelLinear, GroupedWPLinear)):
+ column_linear.register_communicator(moe_isp_communicator)
+ for row_linear in submodule_filter(moe, RowParallelLinear):
+ row_linear.register_communicator(None)
+
+ isp_communicator_wrapper = ISPCommunicatorWrapper([isp_communicator, moe_isp_communicator])
+ else:
+ isp_communicator_wrapper = ISPCommunicatorWrapper([isp_communicator])
+
+ # register communictor for mtp/msp/fsp linear.
+
+ # tensor parallel
+ if gpc.config.parallel.tensor.mode == TensorParallelMode.mtp.name:
+ ColumnParallelLinear.register_cls_communicator(
+ TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.COLUMN)
+ )
+ RowParallelLinear.register_cls_communicator(
+ TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.ROW)
+ )
+
+ if gpc.config.model.get("num_experts", 1) > 1:
+ GroupedColumnLinear.register_cls_communicator(
+ TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.COLUMN)
+ )
+ GroupedRowLinear.register_cls_communicator(
+ TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.ROW)
+ )
+ GroupedWPLinear.register_cls_communicator(None)
+ # treat as sequence paralle if no_tp
+ if gpc.config.parallel.expert.no_tp:
+ _column_communicator = TensorParallelCommunicator(
+ process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.COLUMN
+ )
+ _row_communicator = TensorParallelCommunicator(
+ process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.ROW
+ )
+ for moe in submodule_filter(model, MoE):
+ # 1. the linear in MoE degrades as no tp communication pattern
+ for column_linear in submodule_filter(moe, ColumnParallelLinear):
+ column_linear.register_communicator(_column_communicator)
+ for row_linear in submodule_filter(moe, RowParallelLinear):
+ row_linear.register_communicator(_row_communicator)
+ # 2. register MoESequenceParallelCommunicator for MoE layer
+ MoESequenceParallelCommunicator(ParallelMode.TENSOR, reverse=True).register_module_hook(moe)
+
+ _head_communicator = HeadTensorParallelCommunicator(ParallelMode.TENSOR, _retain_out_sharded)
+ _embedding_communicator = EmbeddingTensorParallelCommunicator(ParallelMode.TENSOR)
+ # sequence parallel
+ if gpc.config.parallel.tensor.mode in (TensorParallelMode.msp.name, TensorParallelMode.fsp.name):
+ save_total_input_as_activation = gpc.config.parallel.tensor.mode == TensorParallelMode.msp.name
+
+ ColumnParallelLinear.register_cls_communicator(
+ SequenceParallelCommunicator(
+ process_group=gpc.get_group(ParallelMode.TENSOR),
+ role=LinearRole.COLUMN,
+ save_total_input_as_activation=save_total_input_as_activation,
+ )
+ )
+ RowParallelLinear.register_cls_communicator(
+ SequenceParallelCommunicator(
+ gpc.get_group(ParallelMode.TENSOR),
+ role=LinearRole.ROW,
+ save_total_input_as_activation=save_total_input_as_activation,
+ )
+ )
+ if gpc.config.model.get("num_experts", 1) > 1:
+ GroupedColumnLinear.register_cls_communicator(
+ SequenceParallelCommunicator(
+ process_group=gpc.get_group(ParallelMode.TENSOR),
+ role=LinearRole.COLUMN,
+ save_total_input_as_activation=save_total_input_as_activation,
+ )
+ )
+ GroupedRowLinear.register_cls_communicator(
+ SequenceParallelCommunicator(
+ gpc.get_group(ParallelMode.TENSOR),
+ role=LinearRole.ROW,
+ save_total_input_as_activation=save_total_input_as_activation,
+ )
+ )
+ GroupedWPLinear.register_cls_communicator(None)
+ if gpc.config.parallel.expert.no_tp:
+ _column_communicator = TensorParallelCommunicator(
+ process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.COLUMN
+ )
+ _row_communicator = TensorParallelCommunicator(
+ process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.ROW
+ )
+ for moe in submodule_filter(model, MoE):
+ # 1. the linear in MoE degrades as no tp communication pattern
+ for column_linear in submodule_filter(moe, ColumnParallelLinear):
+ column_linear.register_communicator(_column_communicator)
+ for row_linear in submodule_filter(moe, RowParallelLinear):
+ row_linear.register_communicator(_row_communicator)
+
+ _head_communicator = HeadSequenceParallelCommunicator(
+ ParallelMode.TENSOR, _retain_out_sharded, save_total_input_as_activation
+ )
+
+ _embedding_communicator = EmbeddingSequenceParallelCommunicator(ParallelMode.TENSOR)
+
+ # register communitorc for embedding layer.
+ if not is_using_fsdp():
+ for embedding in submodule_filter(model, Embedding1D):
+ _embedding_communicator.register_module_hook(embedding)
+
+ # register communictor for head layer.
+ ScaleColumnParallelLinear.register_cls_communicator(_head_communicator)
+ RewardModelLinear.register_cls_communicator(_head_communicator)
+
+ return isp_communicator_wrapper
diff --git a/internlm/initialize/launch.py b/internlm/initialize/initialize_launcher.py
similarity index 88%
rename from internlm/initialize/launch.py
rename to internlm/initialize/initialize_launcher.py
index 7e16ae1c3..b80b261c1 100644
--- a/internlm/initialize/launch.py
+++ b/internlm/initialize/initialize_launcher.py
@@ -2,7 +2,6 @@
# -*- encoding: utf-8 -*-
# Copyright (c) InternLM. All rights reserved.
-import argparse
import os
from pathlib import Path
from typing import Dict, Union
@@ -10,14 +9,15 @@
import torch
from internlm.accelerator import AcceleratorType, get_accelerator
-from internlm.core.context import Config
+from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
-from internlm.core.context.process_group_initializer import ParallelMode
+from internlm.initialize.constants import DEFAULT_DISTRIBUTED_PORT, DEFAULT_RANDOM_SEED
from internlm.utils.common import get_master_node
+from internlm.utils.config import Config
from internlm.utils.gputest import warmup_process_group
from internlm.utils.lazy import LazyObject
from internlm.utils.logger import get_logger
-from internlm.utils.parallel import is_using_hf
+from internlm.utils.parallel import is_using_fsdp, is_using_hf
from internlm.utils.timeout import llm_timeout
from internlm.utils.utils import DataType, ModelType, TensorParallelMode
@@ -35,43 +35,12 @@
internlm_accelerator = get_accelerator()
-def get_default_parser():
- """Reads user command line and uses an argument parser to parse the input arguments.
- Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
-
- Returns:
- Parser: Returns the parser with the default arguments, the user may add customized arguments into this parser.
- """
- parser = argparse.ArgumentParser()
- parser.add_argument("--config", type=str, help="path to the config file")
- parser.add_argument(
- "--launcher",
- type=str,
- default="slurm",
- choices=["slurm", "torch"],
- help="launcher for launching distributed environment",
- )
- parser.add_argument("--host", type=str, help="the master address for distributed training")
- parser.add_argument("--port", type=int, default=8888, help="the master port for distributed training")
- parser.add_argument("--world_size", type=int, help="world size for distributed training")
- parser.add_argument("--rank", type=int, help="rank for the default process group")
- parser.add_argument("--local_rank", type=int, help="local rank on the node")
- parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication")
- parser.add_argument("--seed", type=int, default=1024)
- parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.")
- parser.add_argument("--enable_ali_topology", default=False, action="store_true", help="enable ali switch topology.")
- parser.add_argument(
- "--disable_volc_topology", default=False, action="store_true", help="disable volc switch topology."
- )
- return parser
-
-
-def inject_hf_config_before_launch(hf: dict):
+def dispatch_hf_config_before_launch(hf: dict) -> None:
# get HuggingFace model config
cfg = LazyObject(hf.cfg, hf.cfg_cls)
cfg = cfg.build()
model_config = cfg(**hf.cfg_extra_kwargs)
- # inject HuggingFace model config into InternTrain as much as we know
+ # dispatch HuggingFace model config into InternEvo model config as much as we know
if hasattr(model_config, "vocab_size"):
gpc.config.model.vocab_size = gpc.config.VOCAB_SIZE = model_config.vocab_size
if hasattr(model_config, "num_hidden_layers"):
@@ -86,6 +55,14 @@ def inject_hf_config_before_launch(hf: dict):
gpc.config.model.mlp_ratio = gpc.config.MLP_RATIO = model_config.intermediate_size / model_config.hidden_size
if hasattr(model_config, "num_experts"):
gpc.config.model.num_experts = model_config.num_experts
+ elif hasattr(model_config, "n_routed_experts"):
+ gpc.config.model.num_experts = model_config.n_routed_experts
+ if hasattr(model_config, "first_k_dense_replace"):
+ gpc.config.model.first_k_dense_replace = model_config.first_k_dense_replace
+ if hasattr(model_config, "num_mtp_layers"):
+ gpc.config.model.num_mtp_layers = model_config.num_mtp_layers
+ elif hasattr(model_config, "num_nextn_predict_layers"):
+ gpc.config.model.num_mtp_layers = model_config.num_nextn_predict_layers
def args_sanity_check():
@@ -100,11 +77,6 @@ def args_sanity_check():
if "model_type" not in gpc.config:
gpc.config._add_item("model_type", ModelType.INTERNLM.name)
- # inject HuggingFace model config into IntrainTrain
- if is_using_hf():
- inject_hf_config_before_launch(gpc.config.hf)
- gpc.config.model_type = "hf"
-
if gpc.config.model_type == "InternLM3_M":
# TODO: need check for isp overlap
num_layers = gpc.config.model.num_self_decoder_layers + gpc.config.model.num_cross_decoder_layers
@@ -165,6 +137,9 @@ def args_sanity_check():
if gpc.config.parallel.pipeline["mode"] == "ZBV":
gpc.v_shape = True
+ if "fsdp" not in gpc.config.parallel:
+ gpc.config.parallel._add_item("fsdp", dict(enable=False))
+
# processing the data config in gpc
data = gpc.config.data
@@ -337,8 +312,9 @@ def args_sanity_check():
logger.info(f"clip_grad_norm: {clip_grad_norm}")
model = gpc.config.model
- if "enable_qkv_fusion" not in model:
- model._add_item("enable_qkv_fusion", True)
+ # TODO: should we set default value for enable_qkv_fusion?
+ # if "enable_qkv_fusion" not in model:
+ # model._add_item("enable_qkv_fusion", True)
if "dtype" not in model:
logger.warning("dtype is not set, use torch.float16 by defalut!")
@@ -610,9 +586,13 @@ def args_sanity_check():
assert (
not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param
), "not support overlap and moe at the same time"
- assert gpc.config.parallel.zero1.size in (
- -1,
- gpc.get_world_size(ParallelMode.DATA),
+ assert (
+ gpc.config.parallel.zero1.size
+ in (
+ -1,
+ gpc.get_world_size(ParallelMode.DATA),
+ )
+ or is_using_fsdp()
), "moe only support zero1, set zero1=dict(size=-1,...) can fix this"
if gpc.config.parallel.tensor.mode != "isp":
@@ -653,6 +633,49 @@ def args_sanity_check():
gpc.config.data.use_packed_dataset is False
), "only unpacked data is supported when using 2D sequence parallel."
+ # fsdp checks
+ if is_using_fsdp():
+ assert (
+ gpc.config.parallel.pipeline.size == 1
+ ), f"fsdp only compatible with pp size = 1, but get pipeline size = {gpc.config.parallel.pipeline.size}"
+ assert gpc.config.parallel.tensor.size == 1 or gpc.config.parallel.tensor.get("mode", "mtp") == "isp", (
+ f"fsdp only compatible with tp size > 1 in isp mode, but get tp size = "
+ f"{gpc.config.parallel.tensor.size} and tp mode = {gpc.config.parallel.tensor.mode}"
+ )
+ assert (
+ gpc.config.parallel.zero1.size == 1
+ ), f"fsdp only compatible with zero1 size = 1, but get zero1 size = {gpc.config.parallel.zero1.size}"
+ assert (
+ gpc.config.parallel.weight.size == 1
+ ), f"fsdp only compatible with weight size = 1, but get weight size = {gpc.config.parallel.weight.size}"
+ if "expert_zero1" in gpc.config.parallel:
+ assert gpc.config.parallel.expert_zero1.size == 1, (
+ f"fsdp only compatible with expert_zero1 size = 1, "
+ f"but get expert_zero1 size = {gpc.config.parallel.expert_zero1.size}"
+ )
+ if "expert_weight" in gpc.config.parallel:
+ assert gpc.config.parallel.expert_weight.size == 1, (
+ f"fsdp only compatible with expert_weight size = 1, "
+ f"but get expert_weight size = {gpc.config.parallel.expert_weight.size}"
+ )
+ assert "mode" in gpc.config.parallel.fsdp, "mode must be specified in fsdp when enabled"
+ fsdp_mode = gpc.config.parallel.fsdp.mode
+ assert "init_method" in gpc.config.parallel.fsdp, "init_method must be specified in fsdp when enabled"
+ fsdp_init_method = gpc.config.parallel.fsdp.init_method
+ if fsdp_mode == "v1":
+ fsdp_v1_min_version = "1.13.0"
+ assert (
+ torch.__version__ >= fsdp_v1_min_version
+ ), f"requires torch>={fsdp_v1_min_version} when using fsdp v1 but current version is {torch.__version__}"
+ elif fsdp_mode == "v2":
+ fsdp_v2_min_version = "2.6.0"
+ assert (
+ torch.__version__ >= fsdp_v2_min_version
+ ), f"requires torch>={fsdp_v2_min_version} when using fsdp v2 but current version is {torch.__version__}"
+ else:
+ raise ValueError(f"fsdp mode {fsdp_mode} not supported")
+ assert fsdp_init_method in ["cuda", "cpu", "meta"], f"fsdp init_method {fsdp_init_method} not supported"
+
# loss operator type
loss_cfg = gpc.config.loss
if loss_cfg.get("op_type", None) is None:
@@ -701,6 +724,10 @@ def launch(
# init default process group
gpc.init_global_dist(rank, world_size, backend, host, port)
+ # dispatch HuggingFace model config into InternEvo
+ if is_using_hf():
+ dispatch_hf_config_before_launch(gpc.config.hf)
+
# init process groups for different parallel modes from config
gpc.init_parallel_groups()
@@ -799,14 +826,14 @@ def launch_from_torch(
)
-@llm_timeout(func_name="initialize_distributed_env")
-def initialize_distributed_env(
+@llm_timeout(func_name="init_distributed")
+def initialize_launcher(
config: str,
launcher: str = "slurm",
- master_port: int = 8888,
- seed: int = 1024,
- args_check=True,
- backend: str = "nccl",
+ distributed_port: int = DEFAULT_DISTRIBUTED_PORT,
+ seed: int = DEFAULT_RANDOM_SEED,
+ args_check: bool = True,
+ dist_backend: str = "nccl",
):
"""
Initialize distributed environment for distributed training.
@@ -814,18 +841,18 @@ def initialize_distributed_env(
Args:
config (str): Config file path.
launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default.
- master_port (str): The master port for distributed training. 8888 by default.
+ distributed_port (str): Distributed backend port. 8888 by default.
seed (int, optional): Specified random seed for every process. 1024 by default.
"""
- backend = internlm_accelerator._communication_backend_name
+ dist_backend = internlm_accelerator.communication_backend_name()
if launcher == "torch":
- launch_from_torch(config=config, seed=seed, backend=backend)
+ launch_from_torch(config=config, seed=seed, backend=dist_backend)
elif launcher == "slurm":
launch_from_slurm(
config=config,
host=get_master_node(),
- port=master_port,
+ port=distributed_port,
seed=seed,
)
else:
@@ -835,14 +862,6 @@ def initialize_distributed_env(
args_sanity_check()
-def get_config_value(config, key, defalut):
- try:
- value = config[key]
- except KeyError:
- value = defalut
- return value
-
-
def try_bind_numa(global_rank, world_size, local_rank=None):
# Early return if numa module not available
if not get_numa:
diff --git a/internlm/initialize/initialize_model.py b/internlm/initialize/initialize_model.py
new file mode 100644
index 000000000..94541c529
--- /dev/null
+++ b/internlm/initialize/initialize_model.py
@@ -0,0 +1,360 @@
+import os
+from typing import Optional, Union
+
+import torch
+import torch.distributed as dist
+from torch import nn
+
+from internlm.core.context import (
+ IS_REPLICA_EXPERT_DATA_PARALLEL,
+ IS_REPLICA_ZERO_PARALLEL,
+ IS_TENSOR_EXPERT_DATA_PARALLEL,
+ IS_TENSOR_ZERO_PARALLEL,
+ IS_WEIGHT_EXPERT_DATA_PARALLEL,
+ IS_WEIGHT_ZERO_PARALLEL,
+ ParallelMode,
+)
+from internlm.core.context import global_context as gpc
+from internlm.core.context import set_mode
+from internlm.core.fsdp import wrap_FSDP_model
+from internlm.core.naive_amp import (
+ NaiveAMPModel,
+ set_fp32_attr_to_module,
+ unwrap_naive_amp,
+)
+from internlm.initialize.initialize_communicator import initialize_parallel_communicator
+from internlm.model.model_implementations.builder import create_model
+from internlm.model.model_implementations.registry import register_model_initializer
+from internlm.model.model_ops.modules.embedding import Embedding1D
+from internlm.model.model_ops.modules.linear import (
+ ParallelLinearWithCommExt,
+ ScaleColumnParallelLinear,
+)
+from internlm.model.model_ops.moe import Experts, MoE
+from internlm.model.model_ops.ops.norm import RMSNorm
+from internlm.utils.logger import get_logger
+from internlm.utils.parallel import (
+ is_replica_expert_data_parallel_parameter,
+ is_replica_zero_parallel_parameter,
+ is_tensor_expert_data_parallel_parameter,
+ is_tensor_zero_parallel_parameter,
+ is_using_fsdp,
+ is_using_hf,
+ is_using_isp,
+ is_weight_expert_data_parallel_parameter,
+ is_weight_zero_parallel_parameter,
+ sync_model_param,
+ sync_model_replica_param_group,
+)
+from internlm.utils.timeout import llm_timeout
+
+logger = get_logger(__file__)
+
+
+# For universal checkpoint
+# record offset and complete_size of param in each layer
+map_layer_attr = {}
+map_fqn_local_to_global = {}
+map_fqn_global_to_local = {}
+
+
+def set_param_unique_tracking_name(model):
+ for chunk_id, chunk in enumerate(unwrap_naive_amp(model)):
+ # Important: only works for llama-class models
+ childrens = chunk.named_children()
+ for children_name, children in childrens:
+ if isinstance(children, nn.ModuleList):
+ for idx, block in enumerate(children):
+ for name, child in block.named_modules():
+ if name == "":
+ continue
+
+ full_name = f"{chunk_id}.{idx}.{name}"
+ name_parts = f"{full_name}.weight".split(".", 2)
+ # global_id for pipeline parallel case
+ global_id = model.first_layer + idx
+ local_fqn = f"{children_name}." + ".".join(name_parts[1:])
+ global_fqn = f"{children_name}.{global_id}." + ".".join(name_parts[2:])
+
+ if isinstance(child, (ParallelLinearWithCommExt)):
+ setattr(
+ child.weight,
+ "tracking_name",
+ f"{full_name}.weight",
+ )
+ if child.bias is not None:
+ setattr(
+ child.bias,
+ "tracking_name",
+ f"{full_name}.bias",
+ )
+
+ setattr(
+ child.weight,
+ "fqn",
+ f"{local_fqn}",
+ )
+ if child.bias is not None:
+ setattr(
+ child.bias,
+ "fqn",
+ f"{local_fqn}",
+ )
+
+ assert hasattr(child, "offset"), f"{child}"
+ map_fqn_local_to_global[local_fqn] = global_fqn
+ map_fqn_global_to_local[global_fqn] = local_fqn
+
+ assert global_fqn not in map_layer_attr, f"{map_layer_attr} exists"
+ map_layer_attr[global_fqn] = {
+ "offset": getattr(child, "offset", [0] * len(child.weight.size())),
+ "complete_size": getattr(child, "complete_size", list(child.weight.size())),
+ }
+
+ elif isinstance(child, (RMSNorm)):
+ map_fqn_local_to_global[local_fqn] = global_fqn
+ map_fqn_global_to_local[global_fqn] = local_fqn
+ setattr(
+ child.weight,
+ "fqn",
+ f"{local_fqn}",
+ )
+ map_layer_attr[global_fqn] = {
+ "offset": getattr(child, "offset", [0] * len(child.weight.size())),
+ "complete_size": getattr(child, "complete_size", list(child.weight.size())),
+ }
+
+ else:
+ full_name = f"{chunk_id}.{children_name}"
+ local_fqn = f"{children_name}.weight"
+ assert getattr(children, "bias", None) is None
+ if isinstance(children, Embedding1D):
+ setattr(
+ children.weight,
+ "tracking_name",
+ f"{chunk_id}_embeddings.weight",
+ )
+ assert local_fqn not in map_layer_attr, f"{map_layer_attr} exists"
+ else:
+ setattr(
+ children.weight,
+ "tracking_name",
+ f"{full_name}.weight",
+ )
+ assert local_fqn not in map_layer_attr, f"{map_layer_attr} exists"
+
+ setattr(
+ children.weight,
+ "fqn",
+ f"{local_fqn}",
+ )
+ if getattr(children, "bias", None) is not None:
+ if children.bias is not None:
+ setattr(
+ children.bias,
+ "fqn",
+ f"{local_fqn}",
+ )
+
+ map_layer_attr[local_fqn] = {
+ "offset": getattr(children, "offset", [0] * len(children.weight.size())),
+ "complete_size": getattr(children, "complete_size", list(children.weight.size())),
+ }
+
+
+def generate_meta_data(optimizer):
+ if not gpc.config.ckpt.need_metadata:
+ return
+
+ if gpc.get_world_size(ParallelMode.PIPELINE) > 1:
+ assert optimizer.meta_for_zero is not None
+ dst = gpc.get_ranks_in_group(ParallelMode.PIPELINE)[0]
+ if gpc.get_global_rank() == dst:
+ output = [None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
+ else:
+ output = None
+
+ dist.gather_object(optimizer.meta_for_zero, output, dst=dst, group=gpc.get_group(ParallelMode.PIPELINE))
+ pp_gather_output = output
+
+ else:
+ pp_gather_output = [optimizer.meta_for_zero]
+
+ tp_parallel = ParallelMode.WEIGHT if is_using_isp() else ParallelMode.TENSOR
+ if gpc.get_world_size(tp_parallel) > 1:
+ dst = gpc.get_ranks_in_group(tp_parallel)[0]
+ if gpc.get_global_rank() == dst:
+ output = [None for _ in range(gpc.get_world_size(tp_parallel))]
+ else:
+ output = None
+
+ dist.gather_object(pp_gather_output, output, dst=dst, group=gpc.get_group(tp_parallel))
+ final_output = output
+ else:
+ final_output = [pp_gather_output]
+
+ if gpc.get_global_rank() == 0:
+ assert len(final_output) == gpc.get_world_size(tp_parallel)
+ assert len(final_output[0]) == gpc.get_world_size(ParallelMode.PIPELINE)
+ assert len(final_output[0][0]) == gpc.get_world_size(ParallelMode.ZERO1)
+ tp_mode = "wp_size" if is_using_isp() else "tp_size"
+ final_meta = {
+ "parallel_setting": {
+ tp_mode: gpc.get_world_size(tp_parallel),
+ "pp_size": gpc.get_world_size(ParallelMode.PIPELINE),
+ "zero1_size": gpc.get_world_size(ParallelMode.ZERO1),
+ },
+ "metaData": final_output,
+ }
+
+ if gpc.config.ckpt.generate_meta_data.enable:
+ save_path = os.path.join(gpc.config.ckpt.generate_meta_data.path, "metadata.pt")
+ torch.save(final_meta, save_path)
+ logger.info(f"Successfully generate metadata.pt in {gpc.config.ckpt.generate_meta_data.path}")
+
+ return final_meta
+ return None
+
+
+def set_fp32_attr_for_model(model: Union[nn.Module, nn.ModuleList]):
+ if not isinstance(model, nn.ModuleList):
+ model = [model]
+
+ for _chunk in model:
+ for _, module in _chunk.named_modules():
+ if isinstance(module, (RMSNorm, nn.LayerNorm)) and gpc.config.get("use_fp32_norm", False):
+ set_fp32_attr_to_module(module)
+
+
+def set_parallel_attr_for_param_groups(model: Union[nn.Module, nn.ModuleList]):
+ def _check_module(module):
+ # layer_norm
+ if isinstance(module, (RMSNorm, nn.LayerNorm)):
+ for param in module.parameters():
+ setattr(param, IS_REPLICA_ZERO_PARALLEL, True)
+
+ if isinstance(module, MoE):
+ for param in module.moe_layer.gate.parameters():
+ setattr(param, IS_REPLICA_ZERO_PARALLEL, True)
+ if hasattr(module, "coefficient"):
+ for param in module.coefficient.parameters():
+ setattr(param, IS_REPLICA_ZERO_PARALLEL, True)
+
+ # embedding and head
+ if isinstance(module, (Embedding1D, ScaleColumnParallelLinear)):
+ for param in module.parameters():
+ if gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp():
+ setattr(param, IS_WEIGHT_ZERO_PARALLEL, True)
+ elif gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp():
+ setattr(param, IS_TENSOR_ZERO_PARALLEL, True)
+
+ # for moe linear module
+ if isinstance(module, nn.Linear) and not isinstance(module, ParallelLinearWithCommExt):
+ for param in module.parameters():
+ setattr(param, IS_REPLICA_ZERO_PARALLEL, True)
+
+ if isinstance(module, Experts):
+ for param in module.parameters():
+ if (
+ gpc.is_initialized(ParallelMode.TENSOR)
+ and not is_using_isp()
+ and getattr(gpc.config.parallel.expert, "no_tp", False)
+ ):
+ setattr(param, IS_REPLICA_EXPERT_DATA_PARALLEL, True)
+ elif gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp():
+ setattr(param, IS_TENSOR_EXPERT_DATA_PARALLEL, True)
+ elif gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp():
+ setattr(param, IS_WEIGHT_EXPERT_DATA_PARALLEL, True)
+ # for non-moe linear module
+ elif isinstance(module, ParallelLinearWithCommExt):
+ for param in module.parameters():
+ if gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp():
+ setattr(param, IS_TENSOR_ZERO_PARALLEL, True)
+ elif gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp():
+ setattr(param, IS_WEIGHT_ZERO_PARALLEL, True)
+
+ for _chunk in unwrap_naive_amp(model):
+ if not is_using_fsdp():
+ # set param parallel attribute
+ for _, module in _chunk.named_modules():
+ _check_module(module)
+
+ for name, param in _chunk.named_parameters():
+ assert (
+ is_replica_zero_parallel_parameter(param)
+ or is_tensor_zero_parallel_parameter(param)
+ or is_weight_zero_parallel_parameter(param)
+ or is_tensor_expert_data_parallel_parameter(param)
+ or is_weight_expert_data_parallel_parameter(param)
+ or is_replica_expert_data_parallel_parameter(param)
+ ), f"parameter with name: {name} has no parallel attribution."
+
+
+@llm_timeout(func_name="initialize_model_and_parallel_communicator")
+def initialize_model_and_parallel_communicator(model: Optional[Union[nn.Module, nn.ModuleList]] = None):
+ """
+ initialize model with Automatic Mixed Precision.
+
+ Returns:
+ torch.nn.Module:
+ The neural network model to be trained or evaluated.
+ An isp communicator for managing comp/comm overlap.
+ """
+ if model is None:
+ register_model_initializer()
+ model = create_model()
+
+ # For non-HF or non-FSDP cases, set tracking name for parameters
+ if not is_using_hf() and not is_using_fsdp():
+ set_param_unique_tracking_name(model)
+
+ # should be set before NaiveAMPModel
+ set_fp32_attr_for_model(model)
+
+ if isinstance(model, nn.ModuleList):
+ model = nn.ModuleList(
+ [
+ NaiveAMPModel(
+ model=_m,
+ output_to_fp32=False, # manually controlled by interleaved pipleline scheduler
+ dtype=gpc.config.model.get("dtype", torch.half),
+ sync_buffer=False,
+ )
+ for _m in model
+ ]
+ )
+ else:
+ model = NaiveAMPModel(
+ model=model,
+ output_to_fp32=gpc.is_no_pp_or_last_stage(),
+ dtype=gpc.config.model.get("dtype", torch.half),
+ sync_buffer=False,
+ )
+
+ set_parallel_attr_for_param_groups(model)
+
+ # This sync is very important, cause the model weights kept in optimizer are copied
+ # from the origin parameters in the memory, so we should make sure the dp sync
+ # does not influence the model weights in optimizer be different with the origin parameters.
+ if not is_using_fsdp() or gpc.config.parallel.fsdp.get("init_method", "cuda") == "cuda":
+ sync_model_param(model)
+
+ # This function is needed to make sure parameters that are not splitted by tensor parallelism are
+ # the same across tensor parallelism.
+ sync_model_replica_param_group(model)
+
+ # Change random state mode to ParallelMode.DATA after model is built, guaranteeing the random
+ # state in the same dp group are all the same.
+ random_mode = ParallelMode.WEIGHT_DATA if is_using_isp() else ParallelMode.DATA
+ set_mode(random_mode)
+
+ # initialize isp communicator
+ isp_communicator = initialize_parallel_communicator(model)
+
+ model = wrap_FSDP_model(model)
+
+ if gpc.is_rank_for_log():
+ logger.info(f"show model: {model}")
+ logger.info(f"model params: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B")
+
+ return model, isp_communicator
diff --git a/internlm/initialize/initialize_optimizer.py b/internlm/initialize/initialize_optimizer.py
new file mode 100644
index 000000000..4303cf5ca
--- /dev/null
+++ b/internlm/initialize/initialize_optimizer.py
@@ -0,0 +1,208 @@
+from typing import Dict, Tuple, Union
+
+import torch
+from torch import nn
+
+from internlm.core.context import ParallelMode
+from internlm.core.context import global_context as gpc
+from internlm.core.naive_amp import unwrap_naive_amp
+from internlm.core.parallel.comm import ISPCommunicatorWrapper, ParamAsyncBcastHandler
+from internlm.model.model_ops.modules.utils import is_moe_param
+from internlm.solver.optimizer import (
+ FSDPadaptOptimizer,
+ HybridZeroOptimizer,
+ HybridZeroOptimizer_v2,
+)
+from internlm.solver.optimizer.compatible_adamw import new_compatible_adamw
+from internlm.solver.schedulers import Beta2Scheduler, FineTuneCosineAnnealingWarmupLR
+from internlm.utils.parallel import is_using_fsdp, is_using_hf
+from internlm.utils.timeout import llm_timeout
+
+
+def split_params_into_different_groups_for_optimizer(
+ param_groups: Tuple[Dict],
+) -> Tuple[Dict]:
+ """Split parameters into different groups for optimizer
+
+ Args:
+ param_groups (Tuple[Dict]): The list of parameter groups to split
+ Input Example:
+ >>> (
+ >>> {'name': 'default', 'params': [tensor], 'weight_decay' :xxx},
+ >>> )
+
+ Returns:
+ Tuple[Dict]: list of params groups for optimizer
+ Output Example:
+ >>> (
+ >>> {'name': 'default', 'params': [tensor], 'weight_decay' :xxx},
+ >>> {'name': 'embed_head', 'params': [tensor], 'weight_decay' :xxx},
+ >>> {'name': 'fp32', 'params': [tensor], 'weight_decay' :xxx},
+ >>> )
+ """
+
+ if isinstance(param_groups, tuple):
+ param_groups = list(param_groups) # Tuple cannot be modified
+ elif isinstance(param_groups, dict):
+ param_groups = [param_groups]
+ elif not isinstance(param_groups, list):
+ raise ValueError(f"Unknown param group type of {type(param_groups)}")
+
+ if is_using_fsdp():
+ optimizer_mode = ParallelMode.GLOBAL
+ optimizer_mode_expert = ParallelMode.EXPERT_DATA
+ expert_group_name = f"moe_ep_size_{gpc.get_world_size(ParallelMode.EXPERT)}"
+ expert_parallel_group_names = [expert_group_name]
+ else:
+ optimizer_mode = ParallelMode.ZERO1
+ optimizer_mode_expert = ParallelMode.EXPERT_DATA
+ expert_parallel_group_names = gpc.expert_parallel_group_names
+
+ new_groups = {}
+ # create new groups for fp32 parameter group
+ new_groups["fp32"] = {"name": "fp32", "params": [], "optimizer_mode": optimizer_mode}
+
+ if gpc.config.model.get("num_experts", 1) > 1:
+ for key in expert_parallel_group_names:
+ new_groups[key] = {"name": key, "moe": True, "params": [], "optimizer_mode": optimizer_mode_expert}
+
+ for pgroup in param_groups:
+ # copy attribute from origin group, we assume the input param_groups only
+ # have one group, so the attribute will not be copyed multiple times.
+ for ori_key in pgroup.keys():
+ if ori_key not in ("name", "params"):
+ for _, group in new_groups.items():
+ group[ori_key] = pgroup[ori_key]
+ # assign param
+ origin_params = []
+ for named_param in pgroup["params"]:
+ # moe param means MoE is enabled
+ name, param = named_param
+ # NOTICE: param attribute would get lost with PretrainedModel+FSDP
+ # DoHack: we split expert param via name as complementary method
+ if is_moe_param(param) or "experts" in name:
+ if is_using_fsdp():
+ if gpc.is_using_parallel_mode(ParallelMode.EXPERT) or not is_using_hf():
+ new_groups[expert_group_name]["params"].append(param)
+ else:
+ origin_params.append(param)
+ else:
+ new_groups[param.group_name]["params"].append(param)
+ elif param.dtype == torch.float32 and gpc.config.model.dtype != torch.float32:
+ new_groups["fp32"]["params"].append(param)
+ else:
+ origin_params.append(param)
+
+ # default param group, which is the first group in the param groups
+ pgroup["params"] = origin_params
+ pgroup["optimizer_mode"] = optimizer_mode
+
+ # param groups may contain empty groups, such as fp32
+ param_groups.extend(new_groups.values())
+
+ return list(param_groups)
+
+
+def create_param_groups(model, weight_decay):
+ parameters = {
+ "params": [(name, param) for name, param in model.named_parameters() if param.requires_grad],
+ "name": "default",
+ "weight_decay": weight_decay,
+ }
+ return split_params_into_different_groups_for_optimizer(parameters)
+
+
+def map_param_block(model):
+ for _chunk in unwrap_naive_amp(model):
+ for name, children in _chunk.named_children():
+ if isinstance(children, nn.ModuleList):
+ for idx, block in enumerate(children):
+ block_name = name + f"_{idx}"
+ for param in block.parameters():
+ setattr(param, "block_name", block_name)
+ else:
+ for param in children.parameters():
+ setattr(param, "block_name", name)
+
+
+@llm_timeout(func_name="initialize_optimizer")
+def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicator: ISPCommunicatorWrapper = None):
+ """
+ Initialize optimizer.
+
+ Args:
+ model (:class:`torch.nn.Module`): Your model instance to be trained or evaluated.
+
+ Returns:
+ A tuple of (optimizer, beta2_scheduler, lr_scheduler).
+ """
+
+ adam_cfg = gpc.config.adam
+ zero_cfg = gpc.config.hybrid_zero_optimizer
+ grad_scal_cfg = gpc.config.grad_scaler
+ use_apex_adam = getattr(gpc.config, "use_apex_adam", False)
+
+ if "use_split_tensor_optim" in zero_cfg and zero_cfg.use_split_tensor_optim:
+ map_param_block(model)
+
+ params = create_param_groups(model, adam_cfg.weight_decay)
+
+ naive_optimizer = new_compatible_adamw(
+ params=params,
+ lr=adam_cfg.lr,
+ betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2),
+ eps=adam_cfg.adam_eps,
+ use_apex_adam=use_apex_adam,
+ )
+
+ if (
+ zero_cfg.overlap_sync_grad
+ and gpc.is_using_parallel_mode(ParallelMode.PIPELINE)
+ and gpc.is_pipeline_first_stage() is False
+ ):
+ # When pipeline parallelism is enabled, we prefer to only enable optimizer
+ # gradient communication overlap in the first stage, to avoid amplifying
+ # the communication overhead stage by stage in cases where the optimizer
+ # communication overhead is greater than the compute overhead.
+ # For pipeline stages except the first, even if overlap is not enabled,
+ # their gradient synchronization overhead can be well hidden by
+ # the inherent bubbles of pipeline parallelism.
+ zero_cfg.overlap_sync_grad = False
+
+ if zero_cfg.overlap_sync_param:
+ param_bcast_sync_handler = ParamAsyncBcastHandler(ParallelMode.ZERO1, model, isp_communicator)
+ else:
+ param_bcast_sync_handler = None
+
+ if not is_using_fsdp():
+ if (
+ "use_split_tensor_optim" not in gpc.config.hybrid_zero_optimizer
+ or not gpc.config.hybrid_zero_optimizer.use_split_tensor_optim
+ ):
+ optimizer = HybridZeroOptimizer(
+ naive_optimizer,
+ grad_scal_cfg=grad_scal_cfg,
+ zero_cfg=zero_cfg,
+ param_bcast_sync_handler=param_bcast_sync_handler,
+ isp_communicator=isp_communicator,
+ )
+ else:
+ optimizer = HybridZeroOptimizer_v2(
+ naive_optimizer,
+ grad_scal_cfg=grad_scal_cfg,
+ zero_cfg=zero_cfg,
+ param_bcast_sync_handler=param_bcast_sync_handler,
+ isp_communicator=isp_communicator,
+ )
+ else:
+ optimizer = FSDPadaptOptimizer(
+ naive_optimizer,
+ grad_scal_cfg=grad_scal_cfg,
+ zero_cfg=zero_cfg,
+ )
+
+ beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
+
+ lr_scheduler = FineTuneCosineAnnealingWarmupLR(optimizer, **gpc.config.lr_scheduler)
+
+ return optimizer, beta2_scheduler, lr_scheduler
diff --git a/internlm/initialize/initialize_profiler.py b/internlm/initialize/initialize_profiler.py
new file mode 100644
index 000000000..eb9b41a19
--- /dev/null
+++ b/internlm/initialize/initialize_profiler.py
@@ -0,0 +1,61 @@
+import torch
+
+from internlm.accelerator import AcceleratorType
+from internlm.accelerator.abstract_accelerator import get_accelerator
+from internlm.core.context import ParallelMode
+from internlm.core.context import global_context as gpc
+from internlm.utils.common import DummyProfile
+from internlm.utils.logger import get_logger
+
+logger = get_logger(__file__)
+internlm_accelerator = get_accelerator()
+
+try:
+ import torch_npu
+except (ModuleNotFoundError, ImportError):
+ pass
+
+
+def initialize_llm_profile(profiling: bool = False, start_time: str = None):
+ """Initialize and return the profiler context manager instance."""
+
+ if profiling and gpc.get_local_rank(ParallelMode.DATA) == 0 and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
+ schedule_config = {"wait": 1, "warmup": 1, "active": 1, "repeat": 1, "skip_first": 3}
+ trace_path = (
+ f"RUN/{gpc.config.JOB_NAME}/{start_time}/traces/rank{gpc.get_global_rank()}_"
+ f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
+ f"wp{gpc.get_local_rank(ParallelMode.WEIGHT)}_"
+ f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}"
+ )
+ if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU:
+ experimental_config = torch_npu.profiler._ExperimentalConfig(
+ aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
+ profiler_level=torch_npu.profiler.ProfilerLevel.Level1,
+ l2_cache=False,
+ )
+ llm_profile = torch_npu.profiler.profile(
+ activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU],
+ schedule=torch_npu.profiler.schedule(**schedule_config),
+ on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(trace_path),
+ record_shapes=True,
+ profile_memory=True,
+ with_stack=False,
+ with_flops=False,
+ with_modules=False,
+ experimental_config=experimental_config,
+ )
+ logger.info(f"Do profiling for NPU on rank {gpc.get_global_rank()}!")
+ else:
+ llm_profile = torch.profiler.profile(
+ activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
+ schedule=torch.profiler.schedule(**schedule_config),
+ on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_path),
+ with_stack=True,
+ with_modules=True,
+ profile_memory=True,
+ )
+ logger.info(f"Do profiling for GPU on rank {gpc.get_global_rank()}!")
+ else:
+ llm_profile = DummyProfile()
+
+ return llm_profile
diff --git a/internlm/initialize/initialize_trainer.py b/internlm/initialize/initialize_trainer.py
index 48487c5fb..71e974899 100644
--- a/internlm/initialize/initialize_trainer.py
+++ b/internlm/initialize/initialize_trainer.py
@@ -26,8 +26,8 @@
from internlm.core.scheduler.pipeline_scheduler_1f1b import get_tensor_shape
from internlm.core.trainer import Trainer
from internlm.data.utils import packed_data_normalizer, unpack_data
-from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer
-from internlm.solver.schedulers.beta2_scheduler import Beta2Scheduler
+from internlm.solver.optimizer import BaseOptimizer
+from internlm.solver.schedulers import Beta2Scheduler
from internlm.utils.common import SchedulerHook, get_current_device
from internlm.utils.parallel import is_using_isp
@@ -36,6 +36,7 @@ def initialize_trainer(
model: nn.Module,
optimizer: Optimizer,
criterion: Optional[_Loss] = None,
+ mtp_criterions: Optional[List[_Loss]] = None,
lr_scheduler: Optional[_LRScheduler] = None,
beta2_scheduler: Optional[Beta2Scheduler] = None,
scheduler_hooks: Optional[List[SchedulerHook]] = None,
@@ -166,6 +167,7 @@ def _data_preparation_func(_data, _label):
lr_scheduler=lr_scheduler,
beta2_scheduler=beta2_scheduler,
criterion=criterion,
+ mtp_criterions=mtp_criterions,
gradient_handlers=gradient_handlers,
clip_grad_norm=clip_grad_norm,
)
diff --git a/internlm/initialize/legacy/launch.py b/internlm/initialize/legacy/launch.py
deleted file mode 100644
index 3a8ccedee..000000000
--- a/internlm/initialize/legacy/launch.py
+++ /dev/null
@@ -1,40 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from internlm.initialize.launch import get_config_value
-from internlm.utils.logger import get_logger
-
-logger = get_logger(__file__)
-
-
-def auto_resume_sanity_check(ckpt_config):
- load_given_ckpt = get_config_value(ckpt_config, "load_given_ckpt", None)
- if load_given_ckpt is None:
- return True # default value is True
- else:
- return not load_given_ckpt
-
-
-def ckpt_info_sanity_check(ckpt_config):
- load_ckpt_folder = get_config_value(ckpt_config, "load_ckpt_folder", None)
-
- load_model_only_folder = get_config_value(ckpt_config, "load_model_only_folder", None)
-
- if load_model_only_folder is not None:
- assert (
- load_ckpt_folder is None
- ), "Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \
-# and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
- return dict(path=load_model_only_folder, content=("model",), ckpt_type="internevo")
- else:
- load_optimizer = get_config_value(ckpt_config, "load_optimizer", True)
-
- if isinstance(load_ckpt_folder, str):
- if load_optimizer:
- return dict(path=load_ckpt_folder, content=("model", "sampler", "optimizer"), ckpt_type="internevo")
- else:
- return dict(path=load_ckpt_folder, content=("model", "sampler"), ckpt_type="internevo")
- elif load_ckpt_folder is None:
- return None
- else:
- assert f"Unsupport data type:'{type(load_ckpt_folder)}' for config.ckpt arg: 'load_ckpt_folder'"
diff --git a/internlm/initialize/legacy/__init__.py b/internlm/launcher/__init__.py
similarity index 100%
rename from internlm/initialize/legacy/__init__.py
rename to internlm/launcher/__init__.py
diff --git a/internlm/launcher/launch.py b/internlm/launcher/launch.py
new file mode 100644
index 000000000..16eec6c68
--- /dev/null
+++ b/internlm/launcher/launch.py
@@ -0,0 +1,45 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+from internlm.core.context import global_context as gpc
+from internlm.core.trainer_builder import TrainerBuilder
+from internlm.data import (
+ build_train_loader_with_data_type,
+ build_valid_loader_with_data_type,
+)
+from internlm.initialize import initialize_launcher
+from internlm.model.model_implementations.builder import create_model
+from internlm.model.model_implementations.registry import register_model_initializer
+from internlm.monitor import internevo_monitor
+from internlm.utils.common import parse_args
+
+
+@internevo_monitor(feishu_alert=True, clean_run=True)
+def main(args):
+ # initialize model
+ register_model_initializer()
+ model = create_model()
+
+ # initialize train dataloader
+ train_dl, dataset_types = build_train_loader_with_data_type()
+
+ # initialize validation dataloader
+ val_dls = build_valid_loader_with_data_type()
+
+ # build trainer
+ merged_args = {**vars(args), "dataset_types": dataset_types}
+ trainer = TrainerBuilder(model, train_dl, val_dls, **merged_args)
+
+ # training
+ trainer.fit()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+
+ # Initialize distributed environment
+ initialize_launcher(config=args.config, launcher=args.launcher, distributed_port=args.port, seed=args.seed)
+ assert hasattr(gpc, "config") and gpc.config is not None
+
+ # Run the main function with parsed arguments
+ main(args)
diff --git a/internlm/model/llava/clip_builder.py b/internlm/model/llava/clip_builder.py
deleted file mode 100644
index 78cc3fa0e..000000000
--- a/internlm/model/llava/clip_builder.py
+++ /dev/null
@@ -1,13 +0,0 @@
-import os
-
-from .clip_encoder import CLIPVisionTower
-
-
-def build_vision_tower(vision_tower_cfg, **kwargs):
- vision_tower = vision_tower_cfg.get("mm_vision_tower", None)
- is_absolute_path_exists = os.path.exists(vision_tower)
- if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
- model = CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
- return model
-
- raise ValueError(f"Unknown vision tower: {vision_tower}")
diff --git a/internlm/model/llava/clip_encoder.py b/internlm/model/llava/clip_encoder.py
deleted file mode 100644
index e1d982f72..000000000
--- a/internlm/model/llava/clip_encoder.py
+++ /dev/null
@@ -1,82 +0,0 @@
-import torch
-from torch import nn
-
-from transformers import CLIPVisionConfig, CLIPVisionModel
-
-
-class CLIPVisionTower(nn.Module): # pylint: disable=C0115
- def __init__(self, vision_tower, args, delay_load=False):
- super().__init__()
-
- self.is_loaded = False
-
- self.vision_tower_name = vision_tower
- self.select_layer = args.get("mm_vision_select_layer", -2)
- self.select_feature = args.get("mm_vision_select_feature", "patch")
-
- if not delay_load:
- self.load_model()
- self.image_size = self.config.image_size
- else:
- self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
-
- def load_model(self):
- self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
- self.vision_tower.requires_grad_(False)
-
- self.is_loaded = True
-
- def feature_select(self, image_forward_outs):
- image_features = image_forward_outs.hidden_states[self.select_layer]
- if self.select_feature == "patch":
- image_features = image_features[:, 1:]
- elif self.select_feature == "cls_patch":
- pass
- else:
- raise ValueError(f"Unexpected select feature: {self.select_feature}")
- return image_features
-
- @torch.no_grad()
- def forward(self, images):
- if isinstance(images, list):
- image_features = []
- for image in images:
- image_forward_out = self.vision_tower(
- image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True
- )
- image_feature = self.feature_select(image_forward_out).to(image.dtype)
- image_features.append(image_feature)
- else:
- image_forward_outs = self.vision_tower(
- images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
- )
- image_features = self.feature_select(image_forward_outs).to(images.dtype)
-
- return image_features
-
- @property
- def dummy_feature(self):
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
-
- @property
- def dtype(self):
- return self.vision_tower.dtype
-
- @property
- def device(self):
- return self.vision_tower.device
-
- @property
- def config(self):
- if self.is_loaded:
- return self.vision_tower.config
- else:
- return self.cfg_only
-
- @property
- def hidden_size(self):
- return self.config.hidden_size
-
- @property
- def num_patches(self):
- return (self.config.image_size // self.config.patch_size) ** 2
diff --git a/internlm/model/llava/projector_builder.py b/internlm/model/llava/projector_builder.py
deleted file mode 100644
index 2b1a701e3..000000000
--- a/internlm/model/llava/projector_builder.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import re
-
-from torch import nn
-
-
-class IdentityMap(nn.Module):
- def __init__(self):
- super().__init__()
-
- def forward(self, x):
- return x
-
- @property
- def config(self):
- return {"mm_projector_type": "identity"}
-
-
-class SimpleResBlock(nn.Module):
- def __init__(self, channels):
- super().__init__()
- self.pre_norm = nn.LayerNorm(channels)
-
- self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
-
- def forward(self, x):
- x = self.pre_norm(x)
- return x + self.proj(x)
-
-
-def build_vision_projector(config):
- projector_type = config.get("mm_projector_type", "linear")
-
- if projector_type == "linear":
- return nn.Linear(config.get("mm_hidden_size", 1024), config.get("hidden_size", 4096))
-
- mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
- if mlp_gelu_match:
- mlp_depth = int(mlp_gelu_match.group(1))
- modules = [nn.Linear(config.get("mm_hidden_size", 1024), config.get("hidden_size", 4096))]
- for _ in range(1, mlp_depth):
- modules.append(nn.GELU())
- modules.append(nn.Linear(config.get("hidden_size", 4096), config.get("hidden_size", 4096)))
- return nn.Sequential(*modules)
-
- if projector_type == "identity":
- return IdentityMap()
-
- raise ValueError(f"Unknown projector type: {projector_type}")
diff --git a/internlm/model/llava/__init__.py b/internlm/model/model_implementations/__init__.py
similarity index 100%
rename from internlm/model/llava/__init__.py
rename to internlm/model/model_implementations/__init__.py
diff --git a/internlm/model/builder.py b/internlm/model/model_implementations/builder.py
similarity index 71%
rename from internlm/model/builder.py
rename to internlm/model/model_implementations/builder.py
index e8d3f11b9..63bbe468f 100644
--- a/internlm/model/builder.py
+++ b/internlm/model/model_implementations/builder.py
@@ -6,20 +6,55 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.parallel.shard import pipeline_parallel_sharding_wrapper
-from internlm.model.base_model import BaseModel
-from internlm.model.modules.linear import (
+from internlm.model.model_implementations.registry import model_initializer
+from internlm.model.model_implementations.transformers.base_model import (
+ BaseTransformerModel,
+)
+from internlm.model.model_ops.modules.linear import (
ParallelLinearWithCommExt,
ScaleColumnParallelLinear,
)
-from internlm.model.registry import model_initializer
from internlm.utils.common import get_current_device
from internlm.utils.lazy import LazyObject
from internlm.utils.logger import get_logger
from internlm.utils.parallel import is_using_fsdp, is_using_hf, is_using_isp
+try:
+ import transformer_engine.pytorch as te
+
+ HAS_TE = True
+except (ModuleNotFoundError, ImportError):
+ HAS_TE = False
+
+
logger = get_logger(__file__)
+def simple_swap(model, device):
+ for submodule_name, submodule in model.named_modules():
+ if isinstance(submodule, torch.nn.Linear):
+ path_in_state_dict = submodule_name.split(".")
+ current_module = model
+
+ # traverse to leaf module
+ leaf_path = path_in_state_dict[:-1]
+ leaf_name = path_in_state_dict[-1]
+ for child_name in leaf_path:
+ current_module = getattr(current_module, child_name)
+
+ # perform a swap
+ old_leaf = getattr(current_module, leaf_name)
+ new_leaf = te.Linear(old_leaf.in_features, old_leaf.out_features, old_leaf.bias is not None, device=device)
+ with torch.no_grad():
+ new_leaf.weight.copy_(old_leaf.weight)
+ assert torch.equal(new_leaf.weight, old_leaf.weight)
+ if old_leaf.bias is not None:
+ new_leaf.bias.copy_(old_leaf.bias)
+ assert torch.equal(new_leaf.bias, old_leaf.bias)
+
+ setattr(current_module, leaf_name, new_leaf)
+
+
def create_model() -> Union[nn.Module, List[nn.Module]]:
if is_using_hf():
model = create_model_hf(hf=gpc.config.hf)
@@ -37,14 +72,13 @@ def create_model_builtin(model_type) -> Union[nn.Module, List[nn.Module]]:
# TODO: fix use_flash_attn parameter config
kwargs.pop("use_flash_attn", False)
- kwargs.pop("apply_post_layer_norm")
- kwargs.pop("embed_split_hidden", True)
kwargs["checkpoint"] = float(kwargs.get("checkpoint", False))
kwargs["device"] = get_current_device()
model_buidler = model_initializer.get_module(module_name=model_type)
+
if not gpc.is_using_parallel_mode(ParallelMode.PIPELINE):
kwargs["first"] = kwargs["last"] = True
kwargs["start_layer_idx"] = 0
@@ -55,8 +89,10 @@ def create_model_builtin(model_type) -> Union[nn.Module, List[nn.Module]]:
else:
model = pipeline_parallel_sharding_wrapper(num_layers, num_chunks, model_buidler, **kwargs)
- if not isinstance(model, BaseModel) and gpc.is_rank_for_log():
- logger.warning(f"To load/save huggingface ckpt, built-in model should inherited from {BaseModel.__name__}")
+ if not isinstance(model, BaseTransformerModel) and gpc.is_rank_for_log():
+ logger.warning(
+ f"To load/save huggingface ckpt, built-in model should inherited from {BaseTransformerModel.__name__}"
+ )
return model
@@ -126,4 +162,7 @@ def traverse(module):
else:
traverse(model)
+ if HAS_TE and gpc.config.get("fp8", None) is not None:
+ simple_swap(model=model, device=fsdp_init_method)
+
return model
diff --git a/internlm/model/registry.py b/internlm/model/model_implementations/registry.py
similarity index 73%
rename from internlm/model/registry.py
rename to internlm/model/model_implementations/registry.py
index 68013d268..6a21a79ff 100644
--- a/internlm/model/registry.py
+++ b/internlm/model/model_implementations/registry.py
@@ -4,16 +4,14 @@
from typing import Callable
-from internlm.model.modeling_baichuan2 import Baichuan2
-from internlm.model.modeling_gemma import Gemma
-from internlm.model.modeling_internlm import InternLM1
-from internlm.model.modeling_internlm2 import InternLM2
-from internlm.model.modeling_llama import Llama2
-from internlm.model.modeling_llava import Llava
-from internlm.model.modeling_mixtral import MixtralMoE
-from internlm.model.modeling_moe import Internlm1MoE
-from internlm.model.modeling_qwen2 import Qwen2
-from internlm.model.modeling_qwen2_moe import Qwen2Moe
+from internlm.model.model_implementations.transformers.modeling_internlm import (
+ InternLM1,
+)
+from internlm.model.model_implementations.transformers.modeling_internlm2 import (
+ InternLM2,
+)
+from internlm.model.model_implementations.transformers.modeling_llama import Llama2
+from internlm.model.model_implementations.transformers.modeling_moe import Internlm1MoE
from internlm.utils.common import SingletonMeta
from internlm.utils.utils import ModelType
@@ -89,12 +87,3 @@ def register_model_initializer() -> None:
model_initializer.register_module(ModelType.INTERNLM3.name, InternLM2)
model_initializer.register_module(ModelType.LLAMA2.name, Llama2)
model_initializer.register_module(ModelType.INTERNLM_MoE.name, Internlm1MoE)
- model_initializer.register_module(ModelType.LLAVA.name, Llava)
- model_initializer.register_module(ModelType.QWEN2.name, Qwen2)
- model_initializer.register_module(ModelType.BAICHUAN2.name, Baichuan2)
- model_initializer.register_module(ModelType.GEMMA.name, Gemma)
- model_initializer.register_module(ModelType.QWEN2MOE.name, Qwen2Moe)
- model_initializer.register_module(ModelType.MIXTRALMOE.name, MixtralMoE)
-
-
-register_model_initializer()
diff --git a/internlm/model/modules/__init__.py b/internlm/model/model_implementations/transformers/__init__.py
similarity index 100%
rename from internlm/model/modules/__init__.py
rename to internlm/model/model_implementations/transformers/__init__.py
diff --git a/internlm/model/base_model.py b/internlm/model/model_implementations/transformers/base_model.py
similarity index 74%
rename from internlm/model/base_model.py
rename to internlm/model/model_implementations/transformers/base_model.py
index cdbd04d6e..17bb0155e 100644
--- a/internlm/model/base_model.py
+++ b/internlm/model/model_implementations/transformers/base_model.py
@@ -2,12 +2,12 @@
from torch import nn
-from internlm.model.utils import load_src_states, merge_pp_src_states
+from internlm.model.model_ops.utils import load_src_states, merge_pp_src_states
-class BaseModel(nn.Module, metaclass=ABCMeta):
+class BaseTransformerModel(nn.Module, metaclass=ABCMeta):
"""
- Base class for all models.
+ Base class for InternEvo transformer models.
"""
@staticmethod
diff --git a/internlm/model/modeling_internlm.py b/internlm/model/model_implementations/transformers/modeling_internlm.py
similarity index 81%
rename from internlm/model/modeling_internlm.py
rename to internlm/model/model_implementations/transformers/modeling_internlm.py
index 367ba524a..a201abe3b 100644
--- a/internlm/model/modeling_internlm.py
+++ b/internlm/model/model_implementations/transformers/modeling_internlm.py
@@ -8,20 +8,29 @@
import torch
from torch import nn
from tqdm import tqdm
+from transformers.modeling_utils import (
+ SAFE_WEIGHTS_INDEX_NAME,
+ SAFE_WEIGHTS_NAME,
+ shard_checkpoint,
+)
from internlm.accelerator import get_accelerator
from internlm.core.context import ParallelMode
-from internlm.core.context.parallel_context import global_context as gpc
+from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import set_output_attr_to_module
-from internlm.core.parallel.shard import partition_uniform
-from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
-from internlm.model.base_model import BaseModel
-from internlm.model.modules.embedding import Embedding1D
-from internlm.model.modules.linear import new_linear
-from internlm.model.modules.mha import MHA
-from internlm.model.modules.mlp import new_feed_forward
-from internlm.model.modules.norm import new_layer_norm
-from internlm.model.utils import (
+from internlm.model.model_implementations.transformers.base_model import (
+ BaseTransformerModel,
+)
+from internlm.model.model_implementations.transformers.utils import (
+ normal_,
+ scaled_init_method_normal,
+)
+from internlm.model.model_ops.modules.embedding import Embedding1D
+from internlm.model.model_ops.modules.linear import new_linear
+from internlm.model.model_ops.modules.mha import MHA
+from internlm.model.model_ops.modules.mlp import new_feed_forward
+from internlm.model.model_ops.modules.norm import new_layer_norm
+from internlm.model.model_ops.utils import (
convert_attn_args_to_kwargs,
convert_attn_kwargs_to_args,
internlm1_mha_pre_load_convert,
@@ -30,11 +39,6 @@
from internlm.solver.activation_checkpoint import activation_checkpoint
from internlm.utils.logger import get_logger
from internlm.utils.storage_manager import get_fns, llm_load, llm_save
-from transformers.modeling_utils import (
- SAFE_WEIGHTS_INDEX_NAME,
- SAFE_WEIGHTS_NAME,
- shard_checkpoint,
-)
internlm_accelerator = get_accelerator()
logger = get_logger(__file__)
@@ -230,7 +234,7 @@ def _dropout_and_norm_ffn(_residual, _hidden_states):
return hidden_states + residual
-class InternLM1(BaseModel):
+class InternLM1(BaseTransformerModel):
"""
1D Packed Flash InternLm.
@@ -517,125 +521,6 @@ def load_hf_weights(folder: str, model: nn.Module) -> None:
internlm_accelerator.empty_cache()
- @staticmethod
- def load_internlm_with_dynamic_parallel_size(folder: str, model: nn.Module):
-
- assert folder is not None, "Please specify the folder of the pretrained model"
- if gpc.is_rank_for_log():
- logger.info(f"Loading pretrained model from {folder}")
-
- fns = get_fns(folder)
- model_fns = []
- for fn in fns:
- # filter with `_t` is for avoiding conflict with model_config.py
- if fn.startswith("model_t") and not fn.endswith("md5"):
- model_fns.append(fn)
-
- old_tp, old_pp = -1, -1
- for fn in model_fns:
- _, tp, pp = os.path.splitext(fn)[0].split("_")
- old_tp = max(old_tp, int(tp[2:]) + 1)
- old_pp = max(old_pp, int(pp[2:]) + 1)
-
- assert old_tp > 0 and old_pp > 0, f"ckpt with tp:{old_tp} and pp:{old_pp} is illegal"
-
- tp = gpc.get_world_size(ParallelMode.TENSOR)
- tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
- assert old_tp % tp == 0 or tp % old_tp == 0, (
- f"Expected TP size in loaded checkpoint to be fit with TP size in current config, but got {old_tp} in "
- f"checkpoint and {tp} in current config"
- )
-
- correspond_tps = []
-
- if old_tp <= tp:
- correspond_tps.append(tp_rank // (tp // old_tp))
- ratio = tp // old_tp
- rank = tp_rank % ratio
- else:
- for i in range(old_tp // tp):
- correspond_tps.append(tp_rank * (old_tp // tp) + i)
- rank = 0
- ratio = 1
-
- current_states = {}
-
- pp = gpc.get_world_size(ParallelMode.PIPELINE)
-
- assert gpc.config.model.num_chunks == 1, "May cause future collisions, ignore this if necessary"
-
- old_pp_partition = partition_uniform(gpc.config.model.num_layers, old_pp, 1)
-
- for idx, parts in enumerate(old_pp_partition):
- start, end = parts[0]
- if model.last_layer <= start or model.first_layer >= end:
- continue
-
- tmp_states = {}
-
- for correspond_tp in correspond_tps:
- model_name = f"model_tp{correspond_tp}_pp{idx}.pt"
- states = llm_load(os.path.join(folder, model_name), map_location="cpu")
- for i in range(start, end):
- if i >= model.last_layer:
- break
- if i < model.first_layer:
- continue
- for name in list(states.keys()):
- if f".{i-start}." in name:
- to_name = name.replace(f".{i-start}.", f".{i-model.first_layer}.")
- if "norm" in name:
- tmp_states[to_name] = [states.pop(name)]
- elif any(x in name for x in ("out_proj", "w2")):
- if "bias" not in name:
- tmp_states[to_name] = tmp_states.get(to_name, [])
- tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=-1)[rank])
- else:
- tmp_states[to_name] = [states.pop(name)]
- elif any(x in name for x in ("w1", "w3")):
- tmp_states[to_name] = tmp_states.get(to_name, [])
- tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=0)[rank])
- elif any(x in name for x in ("Wqkv",)):
- tmp_states[to_name] = tmp_states.get(to_name, [])
- _wqkv = states.pop(name).chunk(3, dim=0)
- _wq_splits = _wqkv[0].chunk(ratio, dim=0)
- _wk_splits = _wqkv[1].chunk(ratio, dim=0)
- _wv_splits = _wqkv[2].chunk(ratio, dim=0)
- new_wqkv = torch.concat([_wq_splits[rank], _wk_splits[rank], _wv_splits[rank]], dim=0)
- tmp_states[to_name].append(new_wqkv)
- else:
- raise KeyError(f"Unknown key {name}.")
-
- if "embedding.weight" in states and model.first_layer == 0:
- tmp_states["embedding.weight"] = tmp_states.get("embedding.weight", [])
- tmp_states["embedding.weight"].append(states["embedding.weight"].chunk(ratio, dim=1)[rank])
- if "head.weight" in states and model.last_layer == gpc.config.model.num_layers:
- tmp_states["norm.weight"] = [states["norm.weight"]]
- tmp_states["head.weight"] = tmp_states.get("head.weight", [])
- tmp_states["head.weight"].append(states["head.weight"].chunk(ratio, dim=0)[rank])
-
- states = {}
-
- for name in list(tmp_states.keys()):
- data = tmp_states.pop(name)
- if len(data) == 1:
- current_states[name] = data[0]
- else:
- current_states[name] = torch.concat(
- data, dim=1 if name == "embedding.weight" or any(x in name for x in ("out_proj", "w2")) else 0
- )
-
- missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)
-
- if gpc.get_local_rank(ParallelMode.DATA) == 0:
- pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
- logger.info(
- f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
- f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}"
- )
-
- internlm_accelerator.empty_cache()
-
@staticmethod
def convert_internevo2hf_weights(src: str, tgt: str) -> None:
model_config = gpc.config.model
diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/model_implementations/transformers/modeling_internlm2.py
similarity index 74%
rename from internlm/model/modeling_internlm2.py
rename to internlm/model/model_implementations/transformers/modeling_internlm2.py
index e15b9979a..e22e51a76 100644
--- a/internlm/model/modeling_internlm2.py
+++ b/internlm/model/model_implementations/transformers/modeling_internlm2.py
@@ -2,43 +2,42 @@
import math
import os
from contextlib import nullcontext
-from functools import reduce
from typing import Optional
import torch
from torch import nn
from tqdm import tqdm
+from transformers.modeling_utils import (
+ SAFE_WEIGHTS_INDEX_NAME,
+ SAFE_WEIGHTS_NAME,
+ shard_checkpoint,
+)
from internlm.accelerator import get_accelerator
from internlm.core.context import ParallelMode
-from internlm.core.context.parallel_context import global_context as gpc
+from internlm.core.context import global_context as gpc
from internlm.core.parallel.comm.cpu_offload import get_cpu_offload_context
-from internlm.core.parallel.shard import partition_uniform
-from internlm.initialize.initialize_tensor import (
+from internlm.model.model_implementations.transformers.base_model import (
+ BaseTransformerModel,
+)
+from internlm.model.model_implementations.transformers.utils import (
normal_,
scaled_init_method_normal,
scaled_init_method_uniform,
uniform_,
)
-from internlm.model.base_model import BaseModel
-from internlm.model.modules.embedding import Embedding1D
-from internlm.model.modules.linear import new_linear
-from internlm.model.modules.mha import GQA
-from internlm.model.modules.mlp import new_feed_forward
-from internlm.model.modules.norm import new_layer_norm
-from internlm.model.utils import (
+from internlm.model.model_ops.modules.embedding import Embedding1D
+from internlm.model.model_ops.modules.linear import new_linear
+from internlm.model.model_ops.modules.mha import GQA
+from internlm.model.model_ops.modules.mlp import new_feed_forward
+from internlm.model.model_ops.modules.norm import new_layer_norm
+from internlm.model.model_ops.utils import (
convert_attn_args_to_kwargs,
convert_attn_kwargs_to_args,
- get_parallel_size_from_file,
)
from internlm.solver.activation_checkpoint import activation_checkpoint
from internlm.utils.logger import get_logger
from internlm.utils.storage_manager import get_fns, llm_load, llm_save
-from transformers.modeling_utils import (
- SAFE_WEIGHTS_INDEX_NAME,
- SAFE_WEIGHTS_NAME,
- shard_checkpoint,
-)
internlm_accelerator = get_accelerator()
logger = get_logger(__file__)
@@ -295,7 +294,7 @@ def _dropout_and_norm_ffn(_residual, _hidden_states):
return hidden_states
-class InternLM2(BaseModel):
+class InternLM2(BaseTransformerModel):
"""
InternLM2 Model.
@@ -634,196 +633,6 @@ def load_hf_weights(folder: str, model: nn.Module) -> None:
internlm_accelerator.empty_cache()
- @staticmethod
- def load_internlm2_with_dynamic_parallel_size(folder, model):
- """Load InternLM2 with dynamic parallel size."""
- assert folder is not None, "Please specify the folder of the pretrained model"
- assert gpc.config.model_type in ["INTERNLM2"], "dynamic_parallel is only for INTERNLM2"
-
- fns = get_fns(folder)
- if gpc.is_rank_for_log():
- logger.info(f"Loading pretrained model from {folder}")
- model_fns, old_tp, old_pp = get_parallel_size_from_file(fns) # pylint: disable=W0612
-
- tp = gpc.get_world_size(ParallelMode.TENSOR)
- tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
- assert old_tp % tp == 0 or tp % old_tp == 0, (
- f"Expected TP size in loaded checkpoint to be fit with TP size in current config, but got {old_tp} in "
- f"checkpoint and {tp} in current config"
- )
-
- correspond_tps = []
-
- if old_tp <= tp:
- correspond_tps.append(tp_rank // (tp // old_tp))
- ratio = tp // old_tp
- rank = tp_rank % ratio
- else:
- for i in range(old_tp // tp):
- correspond_tps.append(tp_rank * (old_tp // tp) + i)
- rank = 0
- ratio = 1
-
- current_states = {}
-
- pp = gpc.get_world_size(ParallelMode.PIPELINE) # noqa: F841 # pylint: disable=W0612
-
- assert gpc.config.model.num_chunks == 1, "May cause future collisions, ignore this if necessary"
-
- old_pp_partition = partition_uniform(gpc.config.model.num_layers, old_pp, 1)
-
- for idx, parts in enumerate(old_pp_partition):
- start, end = parts[0]
- if model.last_layer <= start or model.first_layer >= end:
- continue
- tmp_states = {}
-
- for correspond_tp in correspond_tps:
- model_name = f"model_tp{correspond_tp}_pp{idx}.pt"
- states = llm_load(os.path.join(folder, model_name), map_location="cpu")
- states = {k.replace("model.", ""): v for k, v in states.items()}
- for i in range(start, end):
- if i >= model.last_layer:
- break
- if i < model.first_layer:
- continue
-
- for name in list(states.keys()):
- if f".{i-start}." in name:
- to_name = name.replace(f".{i-start}.", f".{i-model.first_layer}.")
-
- if gpc.config.model_type == "INTERNLM2":
- if "norm" in name:
- tmp_states[to_name] = [states.pop(name)]
- elif any(x in name for x in ("wo", "w2")):
- tmp_states[to_name] = tmp_states.get(to_name, [])
- tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=1)[rank])
- elif any(x in name for x in ("w1", "w3")):
- tmp_states[to_name] = tmp_states.get(to_name, [])
- tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=0)[rank])
- elif any(x in name for x in ("wqkv",)):
- tmp_states[to_name] = tmp_states.get(to_name, [])
- if tp > gpc.config.model.num_kv_attention_heads:
- assert old_tp <= gpc.config.model.num_kv_attention_heads, (
- f"`old_tp ({old_tp}) => tp ({tp})` is not supported. "
- "At least one of `tp` and `old_tp` should be less than or "
- "equal to `num_kv_attention_heads`"
- )
- # Suitable for cases where the num_kv_attention_head is small,
- # but you want to have a large TP Size
- q_per_kv = (
- gpc.config.model.num_attention_heads
- // gpc.config.model.num_kv_attention_heads
- )
- head_dim = gpc.config.model.hidden_size // gpc.config.model.num_attention_heads
- index = torch.concat(
- (
- torch.arange(q_per_kv).chunk(ratio, dim=0)[tp_rank % ratio],
- torch.tensor([q_per_kv, q_per_kv + 1]),
- )
- )
- index = index + (q_per_kv + 2) * (tp_rank // ratio)
- index = index % (
- (q_per_kv + 2) * (gpc.config.model.num_kv_attention_heads / old_tp)
- )
- index = index * head_dim
- index = index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat(
- index.shape[0]
- )
- tmp_states[to_name].append(
- torch.index_select(states.pop(name), 0, index.to(torch.int32))
- )
- else:
- tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=0)[rank])
- else:
- raise KeyError(f"Unknown key {name}.")
-
- else:
- assert False, "unsupported model type"
-
- if "tok_embeddings.weight" in states and model.first_layer == 0:
- tmp_states["tok_embeddings.weight"] = tmp_states.get("tok_embeddings.weight", [])
- tmp_states["tok_embeddings.weight"].append(
- states["tok_embeddings.weight"].chunk(ratio, dim=1)[rank]
- )
- if "output.weight" in states and model.last_layer == gpc.config.model.num_layers:
- tmp_states["norm.weight"] = [states["norm.weight"]]
- tmp_states["output.weight"] = tmp_states.get("output.weight", [])
- tmp_states["output.weight"].append(states["output.weight"].chunk(ratio, dim=0)[rank])
-
- states = {}
-
- for name in list(tmp_states.keys()):
- data = tmp_states.pop(name)
- if len(data) == 1:
- current_states[name] = data[0]
- else:
- current_states[name] = torch.concat(
- data, dim=1 if name == "tok_embeddings.weight" or any(x in name for x in ("wo", "w2")) else 0
- )
- # Merge copied kv heads
- if "wqkv" in name and old_tp > gpc.config.model.num_kv_attention_heads:
- assert (
- tp <= gpc.config.model.num_kv_attention_heads
- ), "new_tp should be less than or equal to num_kv_attention_heads"
- head_dim = gpc.config.model.hidden_size // gpc.config.model.num_attention_heads
- q_per_kv = gpc.config.model.num_attention_heads // gpc.config.model.num_kv_attention_heads
- copied_times = old_tp // gpc.config.model.num_kv_attention_heads
- cur_q_per_kv = q_per_kv // copied_times
-
- # pylint: disable=all
- def duplicate_kv_index(i):
- if i % (cur_q_per_kv + 2) >= cur_q_per_kv:
- return i
- else:
- return -100
-
- def unique_kv_index(i):
- if i // (cur_q_per_kv + 2) == copied_times - 1 or i % (cur_q_per_kv + 2) < cur_q_per_kv:
- return i
- else:
- return -100
-
- # pylint: enable=all
-
- # Verify
- duplicate_index = [duplicate_kv_index(i) for i in range((cur_q_per_kv + 2) * copied_times)]
- duplicate_index = [i for i in duplicate_index if i != -100]
- duplicate_index = _duplicate_index = torch.tensor(duplicate_index)
- for i in range(gpc.config.model.num_kv_attention_heads // tp - 1):
- duplicate_index = torch.concat(
- (duplicate_index, _duplicate_index + duplicate_index.max() + 1), dim=0
- )
- duplicate_kv = []
- for index in duplicate_index.reshape(-1, copied_times * 2).chunk(copied_times, dim=-1):
- index = index.reshape(-1) * head_dim
- index = index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat(index.shape[0])
- duplicate_kv.append(torch.index_select(current_states[name], 0, index))
- assert reduce(
- lambda x, y: x and y,
- [torch.allclose(duplicate_kv[0], x, atol=1e-5) for x in duplicate_kv[1:]],
- ), "Copied kv heads are not equal after training!"
-
- # Merge
- unique_index = [unique_kv_index(i) for i in range((cur_q_per_kv + 2) * copied_times)]
- unique_index = [i for i in unique_index if i != -100]
- unique_index = _unique_index = torch.tensor(unique_index)
- for i in range(gpc.config.model.num_kv_attention_heads // tp - 1):
- unique_index = torch.concat((unique_index, _unique_index + unique_index.max() + 1), dim=0)
- unique_index = unique_index * head_dim
- unique_index = unique_index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat(
- unique_index.shape[0]
- )
- current_states[name] = torch.index_select(current_states[name], 0, unique_index)
- missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)
-
- if gpc.get_local_rank(ParallelMode.DATA) == 0:
- pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
- logger.info(
- f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
- f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}"
- )
-
@staticmethod
def convert_internevo2hf_weights(src: str, tgt: str) -> None:
model_config = gpc.config.model
diff --git a/internlm/model/modeling_llama.py b/internlm/model/model_implementations/transformers/modeling_llama.py
similarity index 90%
rename from internlm/model/modeling_llama.py
rename to internlm/model/model_implementations/transformers/modeling_llama.py
index 56b88e83e..b2e1aef9e 100644
--- a/internlm/model/modeling_llama.py
+++ b/internlm/model/model_implementations/transformers/modeling_llama.py
@@ -5,35 +5,37 @@
import torch
from torch import nn
from tqdm import tqdm
+from transformers.modeling_utils import (
+ SAFE_WEIGHTS_INDEX_NAME,
+ SAFE_WEIGHTS_NAME,
+ shard_checkpoint,
+)
from internlm.accelerator import get_accelerator
from internlm.core.context import ParallelMode
-from internlm.core.context.parallel_context import global_context as gpc
+from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import set_output_attr_to_module
-from internlm.initialize.initialize_tensor import (
+from internlm.model.model_implementations.transformers.base_model import (
+ BaseTransformerModel,
+)
+from internlm.model.model_implementations.transformers.utils import (
normal_,
scaled_init_method_normal,
scaled_init_method_uniform,
uniform_,
)
-from internlm.model.base_model import BaseModel
-from internlm.model.modules.embedding import Embedding1D
-from internlm.model.modules.linear import new_linear
-from internlm.model.modules.mha import GQA
-from internlm.model.modules.mlp import new_feed_forward
-from internlm.model.modules.norm import new_layer_norm
-from internlm.model.utils import (
+from internlm.model.model_ops.modules.embedding import Embedding1D
+from internlm.model.model_ops.modules.linear import new_linear
+from internlm.model.model_ops.modules.mha import GQA
+from internlm.model.model_ops.modules.mlp import new_feed_forward
+from internlm.model.model_ops.modules.norm import new_layer_norm
+from internlm.model.model_ops.utils import (
convert_attn_args_to_kwargs,
convert_attn_kwargs_to_args,
)
from internlm.solver.activation_checkpoint import activation_checkpoint
from internlm.utils.logger import get_logger
from internlm.utils.storage_manager import get_fns, llm_load, llm_save
-from transformers.modeling_utils import (
- SAFE_WEIGHTS_INDEX_NAME,
- SAFE_WEIGHTS_NAME,
- shard_checkpoint,
-)
internlm_accelerator = get_accelerator()
logger = get_logger(__file__)
@@ -281,7 +283,7 @@ def _dropout_and_norm_ffn(_residual, _hidden_states):
return hidden_states
-class Llama2(BaseModel):
+class Llama2(BaseTransformerModel):
"""
Llama2 Model.
@@ -584,63 +586,6 @@ def load_hf_weights(folder: str, model: nn.Module):
internlm_accelerator.empty_cache()
- @staticmethod
- def load_llama_pretrained_weights(folder: str, model: nn.Module) -> None:
- """NOTE: when loading huggingface's llama pretrained weights, you should set `adapt_hf=True` in your config."""
- """NOTE: specified for meta-llama/Llama-2-7b"""
- assert folder is not None, "Please specify the folder of the pretrained model"
- if gpc.is_rank_for_log():
- logger.info(f"Loading pretrained model from {folder}")
-
- fns = get_fns(folder)
- model_fns = []
- for fn in fns:
- if fn.startswith("model_t") and not fn.endswith("md5"):
- model_fns.append(os.path.join(folder, fn))
-
- if len(model_fns) == 0:
- model_fns = [os.path.join(folder, fn) for fn in fns if fn.endswith(".pth") or fn.endswith(".pt")]
-
- if len(model_fns) == 0:
- raise FileNotFoundError(f"No checkpoint file found in {folder}")
-
- model_fns.sort()
-
- old_tp = len(model_fns)
- cur_tp = gpc.get_world_size(ParallelMode.TENSOR)
- # If the two tp are inconsistent, you need to consider the merge before splitting
- if old_tp != cur_tp:
- raise RuntimeError(
- f"Your current tp is `{cur_tp}`, but the tp in folder:`{folder}` is `{old_tp}`, use `` to convert first"
- )
-
- states = llm_load(model_fns[gpc.get_local_rank(ParallelMode.TENSOR)], map_location="cpu")
-
- current_states = {}
- for idx, i in enumerate(range(model.first_layer, model.last_layer)):
- for name in list(states.keys()):
- if f".{i}." in name:
- current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name)
-
- model_state_keys = set(list(model.state_dict().keys()))
-
- if "tok_embeddings.weight" in model_state_keys:
- current_states["tok_embeddings.weight"] = states["tok_embeddings.weight"]
- assert model.first_layer == 0, f"Expect model.NaiveAMPModel to be 0, but got {model.first_layer}"
- if "output.weight" in model_state_keys:
- current_states["norm.weight"] = states["norm.weight"]
- current_states["output.weight"] = states["output.weight"]
- missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)
-
- if gpc.get_local_rank(ParallelMode.DATA) == 0:
- pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
- logger.info(
- f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
- f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}"
- )
-
- internlm_accelerator.empty_cache()
-
@staticmethod
def convert_internevo2hf_weights(src: str, tgt: str) -> None:
model_config = gpc.config.model
diff --git a/internlm/model/modeling_moe.py b/internlm/model/model_implementations/transformers/modeling_moe.py
similarity index 96%
rename from internlm/model/modeling_moe.py
rename to internlm/model/model_implementations/transformers/modeling_moe.py
index ed32ca03c..fd3609af4 100644
--- a/internlm/model/modeling_moe.py
+++ b/internlm/model/model_implementations/transformers/modeling_moe.py
@@ -8,17 +8,22 @@
from torch import nn
from internlm.core.context import ParallelMode
-from internlm.core.context.parallel_context import global_context as gpc
+from internlm.core.context import global_context as gpc
from internlm.core.parallel.comm.cpu_offload import get_cpu_offload_context
-from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
-from internlm.model.base_model import BaseModel
-from internlm.model.modules.embedding import Embedding1D
-from internlm.model.modules.linear import new_linear
-from internlm.model.modules.mha import MHA
-from internlm.model.modules.mlp import new_feed_forward
-from internlm.model.modules.norm import new_layer_norm
-from internlm.model.moe.moe import MoE
-from internlm.model.utils import (
+from internlm.model.model_implementations.transformers.base_model import (
+ BaseTransformerModel,
+)
+from internlm.model.model_implementations.transformers.utils import (
+ normal_,
+ scaled_init_method_normal,
+)
+from internlm.model.model_ops.modules.embedding import Embedding1D
+from internlm.model.model_ops.modules.linear import new_linear
+from internlm.model.model_ops.modules.mha import MHA
+from internlm.model.model_ops.modules.mlp import new_feed_forward
+from internlm.model.model_ops.modules.norm import new_layer_norm
+from internlm.model.model_ops.moe.moe import MoE
+from internlm.model.model_ops.utils import (
convert_attn_args_to_kwargs,
convert_attn_kwargs_to_args,
internlm1_mha_pre_load_convert,
@@ -246,7 +251,7 @@ def _dropout_and_norm_ffn(_residual, _hidden_states):
return hidden_states + residual, moe_loss
-class Internlm1MoE(BaseModel):
+class Internlm1MoE(BaseTransformerModel):
"""
InternLM1 MoE.
diff --git a/internlm/initialize/initialize_tensor.py b/internlm/model/model_implementations/transformers/utils.py
similarity index 100%
rename from internlm/initialize/initialize_tensor.py
rename to internlm/model/model_implementations/transformers/utils.py
diff --git a/internlm/model/ops/__init__.py b/internlm/model/model_ops/__init__.py
similarity index 100%
rename from internlm/model/ops/__init__.py
rename to internlm/model/model_ops/__init__.py
diff --git a/internlm/model/losses/__init__.py b/internlm/model/model_ops/losses/__init__.py
similarity index 100%
rename from internlm/model/losses/__init__.py
rename to internlm/model/model_ops/losses/__init__.py
diff --git a/internlm/model/losses/ce_loss.py b/internlm/model/model_ops/losses/ce_loss.py
similarity index 97%
rename from internlm/model/losses/ce_loss.py
rename to internlm/model/model_ops/losses/ce_loss.py
index 5b2a380e8..e5645aba4 100644
--- a/internlm/model/losses/ce_loss.py
+++ b/internlm/model/model_ops/losses/ce_loss.py
@@ -2,7 +2,7 @@
from torch import nn
from internlm.accelerator import get_accelerator
-from internlm.model.ops.cross_entropy import new_cross_entropy
+from internlm.model.model_ops.ops.cross_entropy import new_cross_entropy
internlm_accelerator = get_accelerator()
diff --git a/internlm/model/metrics.py b/internlm/model/model_ops/metrics.py
similarity index 99%
rename from internlm/model/metrics.py
rename to internlm/model/model_ops/metrics.py
index a7f6c9668..e67079534 100644
--- a/internlm/model/metrics.py
+++ b/internlm/model/model_ops/metrics.py
@@ -4,7 +4,7 @@
from internlm.accelerator import AcceleratorType, get_accelerator
from internlm.core.context import global_context as gpc
-from internlm.model.ops.cross_entropy import new_cross_entropy
+from internlm.model.model_ops.ops.cross_entropy import new_cross_entropy
from internlm.utils.common import SchedulerHook, get_current_device
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
diff --git a/internlm/model/model_ops/modules/__init__.py b/internlm/model/model_ops/modules/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/internlm/model/modules/embedding.py b/internlm/model/model_ops/modules/embedding.py
similarity index 99%
rename from internlm/model/modules/embedding.py
rename to internlm/model/model_ops/modules/embedding.py
index 164686cc0..e3b81fac5 100644
--- a/internlm/model/modules/embedding.py
+++ b/internlm/model/model_ops/modules/embedding.py
@@ -10,7 +10,7 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
-from internlm.model.ops.rotary_emb import apply_rotary_emb
+from internlm.model.model_ops.ops.rotary_emb import apply_rotary_emb
from internlm.utils.parallel import is_using_isp
diff --git a/internlm/model/modules/linear.py b/internlm/model/model_ops/modules/linear.py
similarity index 99%
rename from internlm/model/modules/linear.py
rename to internlm/model/model_ops/modules/linear.py
index 0dc9d3072..ce8778578 100644
--- a/internlm/model/modules/linear.py
+++ b/internlm/model/model_ops/modules/linear.py
@@ -19,7 +19,7 @@
get_parallel_strategies_split_mode,
get_tensor_split_parallel_mode,
)
-from internlm.model.ops.linear import (
+from internlm.model.model_ops.ops.linear import (
gmm_backward_op,
gmm_forward_op,
linear_backward_op,
@@ -28,8 +28,7 @@
from internlm.utils.logger import get_logger
if TYPE_CHECKING:
- from internlm.core.parallel.comm.isp import WPCommunicator
- from internlm.core.parallel.comm.tensor import TPCommunicator
+ from internlm.core.parallel.comm import TPCommunicator, WPCommunicator
logger = get_logger(__file__)
internlm_accelerator = get_accelerator()
diff --git a/internlm/model/modules/mha.py b/internlm/model/model_ops/modules/mha.py
similarity index 99%
rename from internlm/model/modules/mha.py
rename to internlm/model/model_ops/modules/mha.py
index 5c8a60b3b..00528ad6c 100644
--- a/internlm/model/modules/mha.py
+++ b/internlm/model/model_ops/modules/mha.py
@@ -11,10 +11,10 @@
from torch.nn import functional as F
from internlm.core.context import global_context as gpc
-from internlm.model.modules.embedding import new_rotary_embedding
-from internlm.model.modules.linear import new_linear
-from internlm.model.modules.utils import update_kv_cache
-from internlm.model.ops.attention import CrossAttention, SelfAttention
+from internlm.model.model_ops.modules.embedding import new_rotary_embedding
+from internlm.model.model_ops.modules.linear import new_linear
+from internlm.model.model_ops.modules.utils import update_kv_cache
+from internlm.model.model_ops.ops.attention import CrossAttention, SelfAttention
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
diff --git a/internlm/model/modules/mlp.py b/internlm/model/model_ops/modules/mlp.py
similarity index 98%
rename from internlm/model/modules/mlp.py
rename to internlm/model/model_ops/modules/mlp.py
index cf97bbdc6..8051c38be 100644
--- a/internlm/model/modules/mlp.py
+++ b/internlm/model/model_ops/modules/mlp.py
@@ -6,8 +6,8 @@
import torch
from torch import nn
-from internlm.model.modules.linear import new_linear
-from internlm.model.modules.utils import Gelu, Silu
+from internlm.model.model_ops.modules.linear import new_linear
+from internlm.model.model_ops.modules.utils import Gelu, Silu
from internlm.utils.logger import get_logger
from internlm.utils.utils import ActivationType
diff --git a/internlm/model/modules/norm.py b/internlm/model/model_ops/modules/norm.py
similarity index 91%
rename from internlm/model/modules/norm.py
rename to internlm/model/model_ops/modules/norm.py
index 2a9700f8d..cab90e0f5 100644
--- a/internlm/model/modules/norm.py
+++ b/internlm/model/model_ops/modules/norm.py
@@ -8,7 +8,7 @@
import torch
from torch import nn
-from internlm.model.ops.norm import RMSNorm
+from internlm.model.model_ops.ops.norm import RMSNorm
Shape = Union[int, List[int], torch.Size]
diff --git a/internlm/model/modules/utils.py b/internlm/model/model_ops/modules/utils.py
similarity index 100%
rename from internlm/model/modules/utils.py
rename to internlm/model/model_ops/modules/utils.py
diff --git a/internlm/model/moe/__init__.py b/internlm/model/model_ops/moe/__init__.py
similarity index 100%
rename from internlm/model/moe/__init__.py
rename to internlm/model/model_ops/moe/__init__.py
diff --git a/internlm/model/moe/base_layer.py b/internlm/model/model_ops/moe/base_layer.py
similarity index 95%
rename from internlm/model/moe/base_layer.py
rename to internlm/model/model_ops/moe/base_layer.py
index 7811e056d..a99a7b3b6 100644
--- a/internlm/model/moe/base_layer.py
+++ b/internlm/model/model_ops/moe/base_layer.py
@@ -4,7 +4,7 @@
from torch.nn import Module, ModuleList
from internlm.core.context import global_context as gpc
-from internlm.model.moe.experts import Experts
+from internlm.model.model_ops.moe.experts import Experts
if TYPE_CHECKING:
Base = Module[Tensor]
diff --git a/internlm/model/moe/dropless_layer.py b/internlm/model/model_ops/moe/dropless_layer.py
similarity index 99%
rename from internlm/model/moe/dropless_layer.py
rename to internlm/model/model_ops/moe/dropless_layer.py
index 031c23065..c2868d7bc 100644
--- a/internlm/model/moe/dropless_layer.py
+++ b/internlm/model/model_ops/moe/dropless_layer.py
@@ -15,7 +15,7 @@
from internlm.accelerator import AcceleratorType, get_accelerator
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
-from internlm.model.modules.mlp import new_feed_forward
+from internlm.model.model_ops.modules.mlp import new_feed_forward
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
diff --git a/internlm/model/moe/experts.py b/internlm/model/model_ops/moe/experts.py
similarity index 100%
rename from internlm/model/moe/experts.py
rename to internlm/model/model_ops/moe/experts.py
diff --git a/internlm/model/moe/gshard_layer.py b/internlm/model/model_ops/moe/gshard_layer.py
similarity index 99%
rename from internlm/model/moe/gshard_layer.py
rename to internlm/model/model_ops/moe/gshard_layer.py
index a102b8c9e..c15810070 100644
--- a/internlm/model/moe/gshard_layer.py
+++ b/internlm/model/model_ops/moe/gshard_layer.py
@@ -15,7 +15,7 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
-from internlm.model.modules.mlp import new_feed_forward
+from internlm.model.model_ops.modules.mlp import new_feed_forward
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
diff --git a/internlm/model/moe/megablocks/__init__.py b/internlm/model/model_ops/moe/megablocks/__init__.py
similarity index 100%
rename from internlm/model/moe/megablocks/__init__.py
rename to internlm/model/model_ops/moe/megablocks/__init__.py
diff --git a/internlm/model/moe/megablocks/megablock_dmoe.py b/internlm/model/model_ops/moe/megablocks/megablock_dmoe.py
similarity index 96%
rename from internlm/model/moe/megablocks/megablock_dmoe.py
rename to internlm/model/model_ops/moe/megablocks/megablock_dmoe.py
index 46e1a81cd..ee80a07d8 100644
--- a/internlm/model/moe/megablocks/megablock_dmoe.py
+++ b/internlm/model/model_ops/moe/megablocks/megablock_dmoe.py
@@ -5,10 +5,10 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
-from internlm.model.moe.base_layer import BaseMoELayer
-from internlm.model.moe.megablocks.megablock_moe import MegaBlockMoE
-from internlm.model.moe.megablocks.mlp import MegaBlockGroupedFeedForward
-from internlm.model.moe.megablocks.utils import promote_scalar
+from internlm.model.model_ops.moe.base_layer import BaseMoELayer
+from internlm.model.model_ops.moe.megablocks.megablock_moe import MegaBlockMoE
+from internlm.model.model_ops.moe.megablocks.mlp import MegaBlockGroupedFeedForward
+from internlm.model.model_ops.moe.megablocks.utils import promote_scalar
try:
import stk
diff --git a/internlm/model/moe/megablocks/megablock_moe.py b/internlm/model/model_ops/moe/megablocks/megablock_moe.py
similarity index 98%
rename from internlm/model/moe/megablocks/megablock_moe.py
rename to internlm/model/model_ops/moe/megablocks/megablock_moe.py
index 257585da0..86a87fff6 100644
--- a/internlm/model/moe/megablocks/megablock_moe.py
+++ b/internlm/model/model_ops/moe/megablocks/megablock_moe.py
@@ -6,9 +6,9 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
-from internlm.model.moe.base_layer import BaseMoELayer
-from internlm.model.moe.megablocks.mlp import MegaBlockFeedForward
-from internlm.model.moe.utils import all_to_all
+from internlm.model.model_ops.moe.base_layer import BaseMoELayer
+from internlm.model.model_ops.moe.megablocks.mlp import MegaBlockFeedForward
+from internlm.model.model_ops.moe.utils import all_to_all
try:
from megablocks import ops
diff --git a/internlm/model/moe/megablocks/mlp.py b/internlm/model/model_ops/moe/megablocks/mlp.py
similarity index 95%
rename from internlm/model/moe/megablocks/mlp.py
rename to internlm/model/model_ops/moe/megablocks/mlp.py
index 374793d6c..91519a890 100644
--- a/internlm/model/moe/megablocks/mlp.py
+++ b/internlm/model/model_ops/moe/megablocks/mlp.py
@@ -3,8 +3,8 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
-from internlm.model.modules.utils import Silu
-from internlm.model.moe.megablocks.utils import (
+from internlm.model.model_ops.modules.utils import Silu
+from internlm.model.model_ops.moe.megablocks.utils import (
act_fn,
dsd_nn,
sdd_nt,
diff --git a/internlm/model/moe/megablocks/utils.py b/internlm/model/model_ops/moe/megablocks/utils.py
similarity index 99%
rename from internlm/model/moe/megablocks/utils.py
rename to internlm/model/model_ops/moe/megablocks/utils.py
index 857dd8b73..5c40dd619 100644
--- a/internlm/model/moe/megablocks/utils.py
+++ b/internlm/model/model_ops/moe/megablocks/utils.py
@@ -1,7 +1,7 @@
import torch
from internlm.accelerator import get_accelerator
-from internlm.model.modules.utils import Silu
+from internlm.model.model_ops.modules.utils import Silu
try:
import stk
diff --git a/internlm/model/moe/moe.py b/internlm/model/model_ops/moe/moe.py
similarity index 96%
rename from internlm/model/moe/moe.py
rename to internlm/model/model_ops/moe/moe.py
index 67fc40b56..ba96ecbca 100644
--- a/internlm/model/moe/moe.py
+++ b/internlm/model/model_ops/moe/moe.py
@@ -4,11 +4,11 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import set_fp32_attr_to_module
-from internlm.model.modules.mlp import new_feed_forward
-from internlm.model.moe.dropless_layer import DroplessMoELayer
-from internlm.model.moe.gshard_layer import GShardMoELayer
-from internlm.model.moe.megablocks.megablock_dmoe import MegaBlockdMoE
-from internlm.model.moe.megablocks.megablock_moe import MegaBlockMoE
+from internlm.model.model_ops.modules.mlp import new_feed_forward
+from internlm.model.model_ops.moe.dropless_layer import DroplessMoELayer
+from internlm.model.model_ops.moe.gshard_layer import GShardMoELayer
+from internlm.model.model_ops.moe.megablocks.megablock_dmoe import MegaBlockdMoE
+from internlm.model.model_ops.moe.megablocks.megablock_moe import MegaBlockMoE
from internlm.utils.logger import get_logger
# global llm logger
diff --git a/internlm/model/moe/utils.py b/internlm/model/model_ops/moe/utils.py
similarity index 100%
rename from internlm/model/moe/utils.py
rename to internlm/model/model_ops/moe/utils.py
diff --git a/internlm/model/model_ops/ops/__init__.py b/internlm/model/model_ops/ops/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/internlm/model/ops/_flash_attn.py b/internlm/model/model_ops/ops/_flash_attn.py
similarity index 100%
rename from internlm/model/ops/_flash_attn.py
rename to internlm/model/model_ops/ops/_flash_attn.py
diff --git a/internlm/model/ops/attention.py b/internlm/model/model_ops/ops/attention.py
similarity index 98%
rename from internlm/model/ops/attention.py
rename to internlm/model/model_ops/ops/attention.py
index 3aec51f55..5beccba9e 100644
--- a/internlm/model/ops/attention.py
+++ b/internlm/model/model_ops/ops/attention.py
@@ -17,11 +17,14 @@
from internlm.accelerator import AcceleratorType, get_accelerator
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
-from internlm.core.parallel.comm.isp import (
+from internlm.core.parallel.comm import (
auto_wrap_distributed_attention,
auto_wrap_func_distributed_attention,
)
-from internlm.model.ops.utils import pack_output_after_attn, unpack_qkv_before_attn
+from internlm.model.model_ops.ops.utils import (
+ pack_output_after_attn,
+ unpack_qkv_before_attn,
+)
from internlm.utils.common import get_current_device
from internlm.utils.utils import (
CuSeqlenType,
@@ -41,7 +44,7 @@
pass
else:
try:
- from internlm.model.ops.ring_flash_attn import (
+ from internlm.model.model_ops.ops.ring_flash_attn import (
zigzag_ring_flash_attn_kvpacked_func_with_sliding_window,
zigzag_ring_flash_attn_qkvpacked_func_with_sliding_window,
zigzag_ring_flash_attn_qkvsplited_func_with_sliding_window,
@@ -1185,12 +1188,9 @@ def isp_flash_attn_varlen_func(
causal=False,
softmax_scale=None,
attention_dropout=0.0,
- return_attn_probs=False,
):
- assert (
- device_backend == AcceleratorType.GPU and gpu_flash_attn_impl
- ), "isp_flash_attn_varlen_func currently only support GPU."
- return _flash_varlen_qkvsplited_func(
+ _, op = _select_attn_op(AttnOpType.VarLenQKVSplited)
+ return op(
q.flatten(0, 1),
k.flatten(0, 1),
v.flatten(0, 1),
@@ -1201,7 +1201,6 @@ def isp_flash_attn_varlen_func(
dropout_p=attention_dropout,
softmax_scale=softmax_scale,
causal=causal,
- return_attn_probs=return_attn_probs,
).unsqueeze(0)
@@ -1213,17 +1212,13 @@ def isp_flash_attn_func(
causal=False,
softmax_scale=None,
attention_dropout=0.0,
- return_attn_probs=False,
):
- assert (
- device_backend == AcceleratorType.GPU and gpu_flash_attn_impl
- ), "isp_flash_attn_func currently only support GPU."
- return _flash_fixedlen_qkvsplited_func(
+ _, op = _select_attn_op(AttnOpType.FixedLenQKVSplited)
+ return op(
q,
k,
v,
dropout_p=attention_dropout,
softmax_scale=softmax_scale,
causal=causal,
- return_attn_probs=return_attn_probs,
)
diff --git a/internlm/model/ops/cross_entropy.py b/internlm/model/model_ops/ops/cross_entropy.py
similarity index 91%
rename from internlm/model/ops/cross_entropy.py
rename to internlm/model/model_ops/ops/cross_entropy.py
index 99bf1e047..17b9f8c05 100644
--- a/internlm/model/ops/cross_entropy.py
+++ b/internlm/model/model_ops/ops/cross_entropy.py
@@ -14,10 +14,11 @@
from internlm.accelerator import AcceleratorType, get_accelerator
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
-from internlm.model.ops.cross_entropy_ops import (
+from internlm.model.model_ops.ops.cross_entropy_ops import (
CrossEntropyApexVocabParallel,
CrossEntropyLossApex,
CrossEntropyPython,
+ CrossEntropyLossFlash,
)
from internlm.utils.logger import get_logger
@@ -86,17 +87,8 @@ def new_cross_entropy(
assert gpc.get_group(ParallelMode.TENSOR) is not None, "The process group should not be None."
- try:
- from flash_attn.losses.cross_entropy import (
- CrossEntropyLoss as FlashCrossEntropyLoss,
- )
-
- flash_cross_entropy_impl = True
- except (ModuleNotFoundError, ImportError):
- flash_cross_entropy_impl = False
-
assert (
- gpc.config.model.get("use_flash_attn", False) and flash_cross_entropy_impl
+ gpc.config.model.get("use_flash_attn", False)
), "Only flash cross entropy support parallel_output"
assert (
@@ -108,7 +100,7 @@ def new_cross_entropy(
which may result loss divergency in long sequence."
)
- return FlashCrossEntropyLoss(
+ return CrossEntropyLossFlash(
ignore_index=ignore_index,
reduction=reduction,
label_smoothing=label_smoothing,
diff --git a/internlm/model/ops/cross_entropy_ops/__init__.py b/internlm/model/model_ops/ops/cross_entropy_ops/__init__.py
similarity index 83%
rename from internlm/model/ops/cross_entropy_ops/__init__.py
rename to internlm/model/model_ops/ops/cross_entropy_ops/__init__.py
index 1f4b6630d..ad8c208b0 100644
--- a/internlm/model/ops/cross_entropy_ops/__init__.py
+++ b/internlm/model/model_ops/ops/cross_entropy_ops/__init__.py
@@ -2,10 +2,12 @@
from .py_naive_loss import CrossEntropyPython
from .py_vocab_parallel_loss import CrossEntropyApexVocabParallel
from .sequence_parallel_loss import VocabSequenceParallelCrossEntropyLoss
+from .flash_loss import CrossEntropyLossFlash
__all__ = [
"CrossEntropyLossApex",
"CrossEntropyPython",
"CrossEntropyApexVocabParallel",
"VocabSequenceParallelCrossEntropyLoss",
+ "CrossEntropyLossFlash",
]
diff --git a/internlm/model/ops/cross_entropy_ops/apex_naive_loss.py b/internlm/model/model_ops/ops/cross_entropy_ops/apex_naive_loss.py
similarity index 100%
rename from internlm/model/ops/cross_entropy_ops/apex_naive_loss.py
rename to internlm/model/model_ops/ops/cross_entropy_ops/apex_naive_loss.py
diff --git a/internlm/model/model_ops/ops/cross_entropy_ops/flash_loss.py b/internlm/model/model_ops/ops/cross_entropy_ops/flash_loss.py
new file mode 100644
index 000000000..baab79e54
--- /dev/null
+++ b/internlm/model/model_ops/ops/cross_entropy_ops/flash_loss.py
@@ -0,0 +1,412 @@
+# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/losses/cross_entropy.py
+# Copyright (c) 2024, Tri Dao.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import triton
+import triton.language as tl
+
+from typing import Tuple, Optional, Union
+
+# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
+# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
+# version of PyTorch. The following 2 lines are for backward compatibility with
+# older PyTorch.
+if "all_gather_into_tensor" not in dir(torch.distributed):
+ torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
+
+
+@triton.heuristics(
+ {
+ "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
+ }
+)
+@triton.jit
+def cross_entropy_fwd_kernel(
+ loss_ptr, # data ptrs
+ lse_ptr,
+ z_loss_ptr,
+ logits_ptr,
+ labels_ptr,
+ smoothing,
+ logit_scale,
+ lse_square_scale,
+ ignore_index,
+ total_classes,
+ class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
+ n_cols, # shapes
+ logits_row_stride, # strides
+ BLOCK_SIZE: tl.constexpr,
+ HAS_SMOOTHING: tl.constexpr,
+ # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE
+ SPLIT: tl.constexpr,
+ PRECOMPUTED_LSE: tl.constexpr, # If LSE is already computed (also no smoothing and logit_scale == 1.0)
+):
+ row_idx = tl.program_id(0)
+ logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
+ sum_logits = 0.0 # For smoothing
+ if not PRECOMPUTED_LSE:
+ # Statistics for online softmax
+ m_i = -float("inf")
+ l_i = 0.0
+ for col_offset in range(0, n_cols, BLOCK_SIZE):
+ cols = col_offset + tl.arange(0, BLOCK_SIZE)
+ logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to(
+ tl.float32
+ ) * logit_scale
+ if HAS_SMOOTHING:
+ sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0))
+ m_i_new = tl.maximum(m_i, tl.max(logits))
+ l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new))
+ m_i = m_i_new
+ lse = tl.log(l_i) + m_i
+ tl.store(lse_ptr + row_idx, lse)
+ else:
+ lse = tl.load(lse_ptr + row_idx)
+ label_idx = tl.load(labels_ptr + row_idx)
+ if label_idx == ignore_index:
+ loss = 0.0
+ z_loss = 0.0
+ else:
+ label_idx -= class_start_idx
+ if label_idx >= 0 and label_idx < n_cols:
+ logits_label = tl.load(logits_ptr + label_idx) * logit_scale
+ if HAS_SMOOTHING:
+ loss = (
+ (lse if not SPLIT else 0.0)
+ - smoothing * sum_logits / total_classes
+ - (1 - smoothing) * logits_label
+ )
+ else:
+ loss = (lse if not SPLIT else 0.0) - logits_label
+ else:
+ # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss
+ if HAS_SMOOTHING:
+ loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)
+ else:
+ loss = 0.0
+ if not SPLIT:
+ z_loss = lse_square_scale * lse * lse
+ loss += z_loss
+ else:
+ z_loss = 0.0
+ tl.store(loss_ptr + row_idx, loss)
+ if not SPLIT:
+ tl.store(z_loss_ptr + row_idx, z_loss)
+
+
+@triton.heuristics(
+ {
+ "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0,
+ }
+)
+@triton.jit
+def cross_entropy_bwd_kernel(
+ dlogits_ptr, # data ptrs
+ dloss_ptr,
+ logits_ptr,
+ lse_ptr,
+ labels_ptr,
+ smoothing,
+ logit_scale,
+ lse_square_scale,
+ ignore_index,
+ total_classes,
+ class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes
+ n_cols, # shapes
+ logits_row_stride, # strides
+ dlogits_row_stride,
+ dloss_row_stride,
+ BLOCK_SIZE: tl.constexpr,
+ HAS_SMOOTHING: tl.constexpr,
+):
+ row_idx = tl.program_id(0)
+ col_block_idx = tl.program_id(1)
+ logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)
+ dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)
+ col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ label_idx = tl.load(labels_ptr + row_idx)
+ if label_idx != ignore_index:
+ dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)
+ else:
+ dloss = 0.0
+ logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
+ tl.float32
+ ) * logit_scale
+ lse = tl.load(lse_ptr + row_idx)
+ probs = tl.exp(logits - lse)
+ probs += 2.0 * lse_square_scale * lse * probs
+ label_idx -= class_start_idx
+ if HAS_SMOOTHING:
+ smooth_positive = 1.0 - smoothing
+ smooth_negative = smoothing / total_classes
+ probs = tl.where(col_offsets == label_idx, probs - smooth_positive, probs) - smooth_negative
+ else:
+ probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
+ tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
+
+
+class CrossEntropyLoss(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx,
+ logits,
+ labels,
+ precomputed_lse=None,
+ smoothing=0.0,
+ logit_scale=1.0,
+ lse_square_scale=0.0,
+ ignore_index=-100,
+ inplace_backward=False,
+ process_group=None,
+ ):
+ # For some reason Triton generates wrong code when labels has dtype long and its address
+ # is not aligned to 16 bytes. The ld.global.b64 seems to load the wrong label index.
+ if labels.dtype == torch.long and labels.data_ptr() % 16 != 0:
+ labels = F.pad(labels, (0, 1))[..., :-1]
+ assert labels.data_ptr() % 16 == 0
+ assert logit_scale > 0.0
+ n_rows, n_cols = logits.shape
+ assert labels.shape == (n_rows,)
+ world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
+ total_classes = world_size * n_cols
+ rank = 0 if process_group is None else torch.distributed.get_rank(process_group)
+ class_start_idx = rank * n_cols
+ use_precomputed_lse = precomputed_lse is not None and logit_scale == 1.0 and smoothing == 0.0
+
+ if logits.stride(-1) != 1:
+ logits = logits.contiguous()
+ MAX_BLOCK_SIZE = 16 * 1024
+ BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)
+ num_warps = (
+ 4
+ if BLOCK_SIZE < 2048
+ else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))
+ )
+ losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
+ if use_precomputed_lse:
+ assert precomputed_lse.shape == (n_rows,)
+ lse = precomputed_lse.contiguous()
+ else:
+ lse = torch.empty(n_rows, dtype=torch.float, device=logits.device)
+ z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device)
+ # Need this, otherwise Triton tries to launch from cuda:0 and we get
+ # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
+ with torch.cuda.device(logits.device.index):
+ cross_entropy_fwd_kernel[(n_rows,)](
+ losses, # data ptrs
+ lse,
+ z_losses,
+ logits,
+ labels,
+ smoothing,
+ logit_scale,
+ lse_square_scale,
+ ignore_index,
+ total_classes,
+ class_start_idx,
+ n_cols, # shapes
+ logits.stride(0), # strides
+ BLOCK_SIZE=BLOCK_SIZE, # constants
+ SPLIT=world_size > 1,
+ PRECOMPUTED_LSE=use_precomputed_lse,
+ num_warps=num_warps,
+ )
+
+ if world_size > 1:
+ # If there's no smoothing, if labels are in the vocab of this partition, losses contains
+ # - predicted logit, and 0 otherwise.
+ # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains
+ # -0.9 * predicted logit - 0.1 * sum logit / total_classes.
+ # For labels not in the vocab of this partition, losses contains
+ # -0.1 * sum logit / total_classes.
+ if world_size > 1:
+ lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)
+ torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)
+ handle_losses = torch.distributed.all_reduce(
+ losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True
+ )
+ lse = torch.logsumexp(lse_allgather, dim=0)
+ handle_losses.wait()
+ # After the allreduce, if there's no smoothing, the total losses are - predicted_logit,
+ # we just have to add the (global) lse.
+ # If there's smoothing=0.1, the total losses are
+ # -0.9 * predicted_logit - 0.1 * sum logit / total_classes.
+ # Again, we just have to add the (global) lse.
+ losses += lse
+ if lse_square_scale != 0.0:
+ z_losses = lse_square_scale * lse.square()
+ z_losses.masked_fill_(labels == ignore_index, 0.0)
+ losses += z_losses
+ else:
+ z_losses = torch.zeros_like(losses)
+ losses.masked_fill_(labels == ignore_index, 0.0)
+
+ ctx.save_for_backward(logits, lse, labels)
+ ctx.mark_non_differentiable(z_losses)
+ ctx.smoothing = smoothing
+ ctx.logit_scale = logit_scale
+ ctx.lse_square_scale = lse_square_scale
+ ctx.ignore_index = ignore_index
+ ctx.total_classes = total_classes
+ ctx.class_start_idx = class_start_idx
+ ctx.inplace_backward = inplace_backward
+ return losses, z_losses
+
+ @staticmethod
+ def backward(ctx, grad_losses, grad_z_losses):
+ del grad_z_losses # z_losses are only for logging.
+
+ logits, lse, labels = ctx.saved_tensors
+ dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)
+ n_rows, n_cols = logits.shape
+ BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)
+ num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)
+ grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa
+ # Need this, otherwise Triton tries to launch from cuda:0 and we get
+ # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
+ with torch.cuda.device(logits.device.index):
+ cross_entropy_bwd_kernel[grid](
+ dlogits, # data ptrs
+ grad_losses,
+ logits,
+ lse,
+ labels,
+ ctx.smoothing,
+ ctx.logit_scale,
+ ctx.lse_square_scale,
+ ctx.ignore_index,
+ ctx.total_classes,
+ ctx.class_start_idx,
+ n_cols, # shapes
+ logits.stride(0), # strides
+ dlogits.stride(0),
+ grad_losses.stride(0),
+ BLOCK_SIZE=BLOCK_SIZE, # constants
+ num_warps=num_warps,
+ )
+ return dlogits, None, None, None, None, None, None, None, None, None
+
+
+def cross_entropy_loss(
+ logits: torch.Tensor,
+ labels: torch.Tensor,
+ precomputed_lse: Optional[torch.Tensor] = None,
+ label_smoothing: float = 0.0,
+ logit_scale: float = 1.0,
+ lse_square_scale: float = 0.0,
+ ignore_index=-100,
+ inplace_backward: bool = False,
+ process_group=None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Arguments:
+ logits: (batch, vocab_size)
+ labels: (batch,)
+ label_smoothing: float
+ logit_scale: float. Multiply logits by this scale before calculating the loss.
+ lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
+ This is also referred to as "z-loss".
+ ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
+ inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
+ This saves memory.
+ process_group: if not None, we're doing Tensor Parallel: each process is responsible for
+ one part of the vocab. The loss will be aggregated across processes.
+ Returns:
+ losses: (batch,), float
+ z_losses: (batch,), float
+ """
+ return CrossEntropyLoss.apply(
+ logits,
+ labels,
+ precomputed_lse,
+ label_smoothing,
+ logit_scale,
+ lse_square_scale,
+ ignore_index,
+ inplace_backward,
+ process_group,
+ )
+
+
+
+class CrossEntropyLossFlash(nn.Module):
+ def __init__(
+ self,
+ ignore_index=-100,
+ reduction="mean",
+ label_smoothing=0.0,
+ logit_scale=1.0,
+ lse_square_scale=0.0,
+ inplace_backward=False,
+ process_group=None,
+ return_z_loss=False,
+ ):
+ """
+ Arguments:
+ ignore_index: int. If labels == ignore_index, the loss is set to 0.0.
+ label_smoothing: float
+ lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
+ This is also referred to as "z-loss".
+ inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits.
+ This saves memory.
+ process_group: if not None, we're doing Tensor Parallel: each process is responsible for
+ one part of the vocab. The loss will be aggregated across processes.
+ return_z_loss: bool. If True, we return the component of the loss contributed by
+ the lse_square_scale value. This value is only for logging and does not support
+ backprop.
+ """
+ super().__init__()
+ if reduction not in ["mean", "none", "sum"]:
+ raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'")
+ self.ignore_index = ignore_index
+ self.reduction = reduction
+ self.label_smoothing = label_smoothing
+ self.logit_scale = logit_scale
+ self.lse_square_scale = lse_square_scale
+ self.inplace_backward = inplace_backward
+ self.process_group = process_group
+ self.return_z_loss = return_z_loss
+
+ def forward(self, input, target, precomputed_lse=None):
+ """
+ Arguments:
+ input: (batch, vocab_size)
+ target: (batch,)
+ Returns:
+ losses: (batch,) if reduction is 'none', else (1,), dtype float
+ z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss)
+ """
+ assert input.is_cuda and target.is_cuda, "Only support CUDA tensors"
+ loss, z_loss = cross_entropy_loss(
+ input,
+ target,
+ precomputed_lse=precomputed_lse,
+ label_smoothing=self.label_smoothing,
+ logit_scale=self.logit_scale,
+ lse_square_scale=self.lse_square_scale,
+ ignore_index=self.ignore_index,
+ inplace_backward=self.inplace_backward,
+ process_group=self.process_group,
+ )
+ if self.reduction == "mean":
+ loss = loss.sum() / (target != self.ignore_index).sum()
+ elif self.reduction == "sum":
+ loss = loss.sum()
+ else:
+ loss = loss
+
+ if not self.return_z_loss:
+ return loss
+
+ if self.reduction == "mean":
+ z_loss = z_loss.sum() / (target != self.ignore_index).sum()
+ elif self.reduction == "sum":
+ z_loss = z_loss.sum()
+ else:
+ z_loss = z_loss
+
+ return loss, z_loss
diff --git a/internlm/model/ops/cross_entropy_ops/py_naive_loss.py b/internlm/model/model_ops/ops/cross_entropy_ops/py_naive_loss.py
similarity index 100%
rename from internlm/model/ops/cross_entropy_ops/py_naive_loss.py
rename to internlm/model/model_ops/ops/cross_entropy_ops/py_naive_loss.py
diff --git a/internlm/model/ops/cross_entropy_ops/py_vocab_parallel_loss.py b/internlm/model/model_ops/ops/cross_entropy_ops/py_vocab_parallel_loss.py
similarity index 100%
rename from internlm/model/ops/cross_entropy_ops/py_vocab_parallel_loss.py
rename to internlm/model/model_ops/ops/cross_entropy_ops/py_vocab_parallel_loss.py
diff --git a/internlm/model/ops/cross_entropy_ops/sequence_parallel_loss.py b/internlm/model/model_ops/ops/cross_entropy_ops/sequence_parallel_loss.py
similarity index 100%
rename from internlm/model/ops/cross_entropy_ops/sequence_parallel_loss.py
rename to internlm/model/model_ops/ops/cross_entropy_ops/sequence_parallel_loss.py
diff --git a/internlm/model/ops/fused_rmsnorm.py b/internlm/model/model_ops/ops/fused_rmsnorm.py
similarity index 100%
rename from internlm/model/ops/fused_rmsnorm.py
rename to internlm/model/model_ops/ops/fused_rmsnorm.py
diff --git a/internlm/model/ops/linear.py b/internlm/model/model_ops/ops/linear.py
similarity index 100%
rename from internlm/model/ops/linear.py
rename to internlm/model/model_ops/ops/linear.py
diff --git a/internlm/model/ops/norm.py b/internlm/model/model_ops/ops/norm.py
similarity index 100%
rename from internlm/model/ops/norm.py
rename to internlm/model/model_ops/ops/norm.py
diff --git a/internlm/model/ops/ring_flash_attn/__init__.py b/internlm/model/model_ops/ops/ring_flash_attn/__init__.py
similarity index 100%
rename from internlm/model/ops/ring_flash_attn/__init__.py
rename to internlm/model/model_ops/ops/ring_flash_attn/__init__.py
diff --git a/internlm/model/ops/ring_flash_attn/utils.py b/internlm/model/model_ops/ops/ring_flash_attn/utils.py
similarity index 100%
rename from internlm/model/ops/ring_flash_attn/utils.py
rename to internlm/model/model_ops/ops/ring_flash_attn/utils.py
diff --git a/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py b/internlm/model/model_ops/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py
similarity index 99%
rename from internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py
rename to internlm/model/model_ops/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py
index 9de9a1dd6..be536e5f2 100644
--- a/internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py
+++ b/internlm/model/model_ops/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py
@@ -4,7 +4,7 @@
import torch.distributed
from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward
-from internlm.core.context.parallel_context import global_context as gpc
+from internlm.core.context import global_context as gpc
from internlm.core.parallel.comm import get_offload_manager
from .utils import RingComm, update_out_and_lse
diff --git a/internlm/model/ops/rotary_emb.py b/internlm/model/model_ops/ops/rotary_emb.py
similarity index 100%
rename from internlm/model/ops/rotary_emb.py
rename to internlm/model/model_ops/ops/rotary_emb.py
diff --git a/internlm/model/ops/utils.py b/internlm/model/model_ops/ops/utils.py
similarity index 100%
rename from internlm/model/ops/utils.py
rename to internlm/model/model_ops/ops/utils.py
diff --git a/internlm/model/utils.py b/internlm/model/model_ops/utils.py
similarity index 98%
rename from internlm/model/utils.py
rename to internlm/model/model_ops/utils.py
index 7c974abeb..e3035a102 100644
--- a/internlm/model/utils.py
+++ b/internlm/model/model_ops/utils.py
@@ -4,8 +4,8 @@
from tqdm import tqdm
-from internlm.core.context.parallel_context import global_context as gpc
-from internlm.model.modules.mha import MHA
+from internlm.core.context import global_context as gpc
+from internlm.model.model_ops.modules.mha import MHA
from internlm.utils.logger import get_logger
from internlm.utils.storage_manager import get_fns, llm_load
from internlm.utils.utils import TensorParallelMode
diff --git a/internlm/model/modeling_baichuan2.py b/internlm/model/modeling_baichuan2.py
deleted file mode 100644
index 7dd632351..000000000
--- a/internlm/model/modeling_baichuan2.py
+++ /dev/null
@@ -1,637 +0,0 @@
-# Copyright (c) InternLM. All rights reserved.
-import math
-import os
-from typing import Optional
-
-import torch
-from einops import rearrange
-from torch import nn
-from tqdm import tqdm
-
-from internlm.accelerator import get_accelerator
-from internlm.core.context import ParallelMode
-from internlm.core.context.parallel_context import global_context as gpc
-from internlm.initialize.initialize_tensor import (
- normal_,
- scaled_init_method_normal,
- scaled_init_method_uniform,
- uniform_,
-)
-from internlm.model.base_model import BaseModel
-from internlm.model.modules.embedding import Embedding1D
-from internlm.model.modules.linear import new_linear
-from internlm.model.modules.mha import MHA
-from internlm.model.modules.mlp import new_feed_forward
-from internlm.model.modules.norm import new_layer_norm
-from internlm.model.utils import (
- convert_attn_args_to_kwargs,
- convert_attn_kwargs_to_args,
-)
-from internlm.solver.activation_checkpoint import activation_checkpoint
-from internlm.utils.logger import get_logger
-from internlm.utils.storage_manager import get_fns, llm_load, llm_save
-from transformers.modeling_utils import (
- SAFE_WEIGHTS_INDEX_NAME,
- SAFE_WEIGHTS_NAME,
- shard_checkpoint,
-)
-
-internlm_accelerator = get_accelerator()
-logger = get_logger(__file__)
-
-
-class Baichuan2Decoder(nn.Module):
- """
- 1D Packed Flash Llama Layer.
-
- Args:
- hidden_size (int): The hidden size of model. 768 by default.
- num_attention_heads (int): The number of attention heads. 12 by default.
- mlp_ratio (int): The ratio of MLP layers. 4 by default.
- attn_drop_rate (float): The dropout rate of attention module. 0 by default.
- drop_rate (float): The dropout rate of the input hidden state. 0.0 by default.
- dtype (torch.dtype): Type of data. torch.float by default.
- layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
- checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
- layer_idx (int): The index of current layer. 0 by default.
- residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
- device (Optional[Union[str, torch.device]]): The device will be used.
- norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
- attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.006 by default,
- attn_other_init_std (float): std used to init attn_other weight. 0.0015 by default,
- ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu
- otherwise init fc1 weight in ffn. 0.006 by default,
- ffn_other_init_std (float): std used to init ffn_other weight. 0.0015 by default,
- init_type (str): Initialization type. Use uniform or normal. "normal" by default,
- rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
- multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2.
- """
-
- def __init__(
- self,
- hidden_size: int = 768,
- num_attention_heads: int = 12,
- mlp_ratio: int = 4,
- attn_drop_rate: float = 0,
- drop_rate: float = 0.0,
- dtype: torch.dtype = torch.float,
- layer_norm_epsilon: float = 1e-6,
- checkpoint: bool = False,
- layer_idx: int = 0,
- use_dynamic_ntk_rope: bool = False,
- residual_in_fp32: bool = False,
- device: Optional[torch.device] = None,
- apply_post_layer_norm: bool = False,
- fused_dropout_add_ln: bool = True,
- no_bias: bool = False,
- norm_type: str = "rmsnorm",
- qk_interleaved: bool = False,
- dropout_selective_checkpoint: bool = True,
- use_scaled_init: bool = True,
- use_swiglu: bool = True,
- attn_wqkv_init_std: float = 0.006,
- attn_other_init_std: float = 0.0015,
- ffn_uplayer_init_std: float = 0.006,
- ffn_other_init_std: float = 0.0015,
- init_type: str = "normal",
- rope_base: int = 10000,
- mlp_layer_fusion: bool = False,
- multiple_of: int = 256,
- max_position_embeddings: int = 2048,
- ):
- super().__init__()
- self.checkpoint = checkpoint
- # dropout selective checkpoint can only be enabled when checkpoint is disabled.
- self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
- self.layer_idx = layer_idx
- self.prenorm = not apply_post_layer_norm
- assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here"
- self.fused_dropout_add_ln = fused_dropout_add_ln
- self.attn_wqkv_init_std = attn_wqkv_init_std
- self.attn_other_init_std = attn_other_init_std
- self.ffn_uplayer_init_std = ffn_uplayer_init_std
- self.ffn_other_init_std = ffn_other_init_std
-
- head_dim = hidden_size // num_attention_heads
-
- self.attention = MHA(
- embed_dim=hidden_size,
- num_heads=num_attention_heads,
- max_position_embeddings=max_position_embeddings,
- bias=not no_bias,
- dropout=attn_drop_rate,
- softmax_scale=1 / math.sqrt(head_dim),
- causal=True,
- layer_idx=layer_idx,
- use_dynamic_ntk_rope=use_dynamic_ntk_rope,
- rope_base=rope_base,
- rotary_emb_dim=head_dim,
- rotary_emb_scale_base=0,
- device=device,
- dtype=dtype,
- qk_interleaved=qk_interleaved,
- enable_qkv_fusion=True,
- out_bias=False,
- )
-
- self.dropout1 = nn.Dropout(drop_rate)
- self.dropout2 = nn.Dropout(drop_rate)
- self.attention_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon)
- self.ffn_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon)
-
- self.feed_forward = new_feed_forward(
- hidden_size,
- int(hidden_size * mlp_ratio),
- out_features=hidden_size,
- bias=False,
- device=device,
- dtype=dtype,
- mlp_layer_fusion=mlp_layer_fusion,
- multiple_of=multiple_of,
- # TODO: to support more activation functions
- activation_type="swiglu" if use_swiglu else "gelu",
- )
-
- self.use_swiglu = use_swiglu
- self.use_scaled_init = use_scaled_init
- self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
- self.return_residual = False
-
- if init_type == "normal":
- self.init_func = normal_
- self.scaled_init_func = scaled_init_method_normal
- else:
- self.init_func = uniform_
- self.scaled_init_func = scaled_init_method_uniform
-
- self.reset_parameters()
-
- def reset_parameters(self):
- with torch.no_grad():
- for name, param in self.attention.named_parameters():
- if param.ndim == 1:
- param.data.zero_()
- elif "wq" in name or "wk" in name or "wv" in name:
- self.init_func(std=self.attn_wqkv_init_std)(param.data)
- elif self.use_scaled_init: # wo
- self.scaled_init_func(sigma=self.attn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
- else:
- self.init_func(std=self.attn_other_init_std)(param.data)
-
- for name, param in self.feed_forward.named_parameters():
- if self.use_swiglu:
- if self.use_scaled_init and "w2" in name:
- self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
- else:
- # candidate: w1, w3, fused_w1_w3
- self.init_func(
- std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std
- )(param.data)
- else:
- if self.use_scaled_init and "fc1" not in name:
- self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
- else:
- self.init_func(std=self.ffn_uplayer_init_std if "fc1" in name else self.ffn_other_init_std)(
- param.data
- )
-
- def forward(self, hidden_states, residual=None, **kwargs):
- if self.checkpoint and self.training:
- args = convert_attn_kwargs_to_args(kwargs)
- return activation_checkpoint(self._forward, False, hidden_states, residual, *args)
- else:
- return self._forward(hidden_states, residual, **kwargs)
-
- def _forward(self, hidden_states, residual, *args, **kwargs):
- r"""Pass the input through the encoder layer.
-
- Args:
- hidden_states: the sequence to the encoder layer (required).
- residual: hidden_states = Attn/MLP(LN(residual))
- cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1
- indexes: the length of index is same as hidden states, which stand for the current position
- """
- if self.prenorm:
-
- def _dropout_and_norm_attn(_residual, _hidden_states):
- _dropped = self.dropout1(_hidden_states)
- _residual = (_dropped + _residual) if _residual is not None else _dropped
- _hidden_states = self.attention_norm(_residual.to(dtype=self.attention_norm.weight.dtype))
-
- return _residual, _hidden_states
-
- if self.dropout_selective_checkpoint:
- residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, residual, hidden_states)
- else:
- residual, hidden_states = _dropout_and_norm_attn(residual, hidden_states)
-
- if self.residual_in_fp32:
- residual = residual.to(torch.float32)
- mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs)
- hidden_states = self.attention(hidden_states, **mixer_kwargs)
-
- if not isinstance(self.feed_forward, nn.Identity):
- if not self.fused_dropout_add_ln:
-
- def _dropout_and_norm_ffn(_residual, _hidden_states):
- _dropped = self.dropout2(_hidden_states)
- _residual = (_dropped + _residual) if _residual is not None else _dropped
- _hidden_states = self.ffn_norm(_residual.to(self.ffn_norm.weight.dtype))
-
- return _residual, _hidden_states
-
- if self.dropout_selective_checkpoint:
- residual, hidden_states = activation_checkpoint(
- _dropout_and_norm_ffn, False, residual, hidden_states
- )
- else:
- residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states)
-
- if self.residual_in_fp32:
- residual = residual.to(torch.float32)
- hidden_states = self.feed_forward(hidden_states)
-
- return hidden_states + residual
- else:
- assert residual is None
-
- mixer_out = self.attention(hidden_states, **kwargs)
- if self.return_residual: # mixer out is actually a pair here
- mixer_out, hidden_states = mixer_out
- hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to(
- dtype=self.attention_norm.weight.dtype
- )
- if not isinstance(self.feed_forward, nn.Identity):
- mlp_out = self.feed_forward(hidden_states)
- if self.return_residual: # mlp out is actually a pair here
- mlp_out, hidden_states = mlp_out
- hidden_states = self.ffn_norm((self.dropout2(mlp_out)) + hidden_states).to(
- dtype=self.ffn_norm.weight.dtype
- )
- return hidden_states
-
-
-class Baichuan2(BaseModel):
- """
- 1D Packed Flash Llama.
-
- Args:
- num_layers (int): The number of layer. 12 by default.
- hidden_size (int): The size of hidden state. 768 by default.
- num_attention_heads (int): The number of attention head. 12 by default.
- vocab_size (int): The size of vocabulary. 50304 by default.
- mlp_ratio (int): The ratio of MLP layers. 4 by default.
- attn_drop_rate (float): The dropout rate of attention module. 0.0 by default.
- drop_rate (float): The dropout rate of input hidden state. 0.0 by default.
- dtype (torch.dtype): The type of data. torch.float by default.
- checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
- checkpoint_fraction (float): The proportion of layers that need to be checkpointed compared to the total number
- of layers. 1.0 by default.
- layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default.
- first (bool): Whether input embedding layer or not. False by default.
- last (bool): Whether output embedding layer or not. False by default.
- embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
- parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
- start_layer_idx (int): The index of start layer in the pipeline. 0 by default.
- device (Optional[Union[str, torch.device]]): The device will be used. None by default.
- residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
- norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
- qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved.
- embedding_init_std (float): std used to init embedding weight. 0.0052 by default,
- attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.006 by default,
- attn_other_init_std (float): std used to init attn_other weight. 0.0015 by default,
- ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu
- otherwise init fc1 weight in ffn. 0.006 by default,
- ffn_other_init_std (float): std used to init ffn_other weight. 0.0015 by default,
- out_head_init_std (float): std used to init output lmhead weight. 0.0052 by default,
- init_type (str): Initialization type. Use uniform or normal. "normal" by default,
- rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
- multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2.
- """
-
- def __init__(
- self,
- num_layers: int = 12,
- hidden_size: int = 768,
- num_attention_heads: int = 12,
- vocab_size: int = 50304,
- mlp_ratio: int = 4,
- attn_drop_rate: float = 0.0,
- drop_rate: float = 0.0,
- max_position_embeddings: int = 2048,
- dtype: torch.dtype = torch.float,
- checkpoint: float = 1.0,
- layer_norm_epsilon: float = 1e-5,
- first: bool = False,
- last: bool = False,
- embed_grad_scale: float = 0.1,
- parallel_output: bool = True,
- start_layer_idx: int = 0,
- use_dynamic_ntk_rope: bool = False,
- device: Optional[torch.device] = None,
- apply_post_layer_norm=False,
- no_bias=False,
- residual_in_fp32: bool = False,
- norm_type: str = "rmsnorm",
- qk_interleaved: bool = False,
- is_reward: bool = False,
- dropout_selective_checkpoint: bool = True,
- use_scaled_init: bool = True,
- use_swiglu: bool = True,
- embedding_init_std: float = 0.0052,
- attn_wqkv_init_std: float = 0.006,
- attn_other_init_std: float = 0.0015,
- ffn_uplayer_init_std: float = 0.006,
- ffn_other_init_std: float = 0.0015,
- out_head_init_std: float = 0.0052,
- init_type: str = "normal",
- norm_head: bool = False,
- rope_base: int = 10000,
- mlp_layer_fusion: bool = False,
- multiple_of: int = 256,
- ):
- super().__init__()
-
- checkpoint_layer_num = int(num_layers * checkpoint)
- self.embed_grad_scale = embed_grad_scale
- self.parallel_output = parallel_output
-
- if first:
- self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
-
- for _, param in self.tok_embeddings.named_parameters():
- if init_type == "normal":
- normal_(std=embedding_init_std)(param)
- else:
- uniform_(std=embedding_init_std)(param)
-
- self.layers = nn.ModuleList(
- [
- Baichuan2Decoder(
- hidden_size=hidden_size,
- num_attention_heads=num_attention_heads,
- mlp_ratio=mlp_ratio,
- attn_drop_rate=attn_drop_rate,
- drop_rate=drop_rate,
- max_position_embeddings=max_position_embeddings,
- dtype=dtype,
- layer_norm_epsilon=layer_norm_epsilon,
- checkpoint=lid < checkpoint_layer_num,
- layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation
- use_dynamic_ntk_rope=use_dynamic_ntk_rope,
- residual_in_fp32=residual_in_fp32,
- device=device,
- apply_post_layer_norm=apply_post_layer_norm,
- fused_dropout_add_ln=False,
- no_bias=no_bias,
- norm_type=norm_type,
- dropout_selective_checkpoint=dropout_selective_checkpoint,
- use_scaled_init=use_scaled_init,
- use_swiglu=use_swiglu,
- qk_interleaved=qk_interleaved,
- attn_wqkv_init_std=attn_wqkv_init_std,
- attn_other_init_std=attn_other_init_std,
- ffn_uplayer_init_std=ffn_uplayer_init_std,
- ffn_other_init_std=ffn_other_init_std,
- init_type=init_type,
- rope_base=rope_base,
- mlp_layer_fusion=mlp_layer_fusion,
- multiple_of=multiple_of,
- )
- for lid in range(num_layers)
- ]
- )
-
- if last:
- if not apply_post_layer_norm:
- self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon)
-
- self.output = new_linear(
- name="output",
- in_features=hidden_size,
- out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
- bias=False,
- device=device,
- dtype=dtype,
- is_reward=is_reward,
- weight_scale=embed_grad_scale,
- norm_head=norm_head,
- )
-
- for _, param in self.output.named_parameters():
- if init_type == "normal":
- normal_(std=out_head_init_std)(param)
- else:
- uniform_(std=out_head_init_std)(param)
-
- def forward(self, hidden_states=None, input_ids=None, **kwargs):
- # attention_mask: compute attention on the places where the value is 1
- if hasattr(self, "tok_embeddings") and input_ids is not None:
- hidden_states = self.tok_embeddings(input_ids)
- if self.embed_grad_scale != 1:
- hidden_states = (
- self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach()
- )
-
- for _, block in enumerate(self.layers):
- hidden_states = block(hidden_states, residual=None, **kwargs)
-
- if hasattr(self, "norm"):
- hidden_states = self.norm(hidden_states.to(self.norm.weight.dtype))
- if hasattr(self, "output"):
- hidden_states = self.output(hidden_states)
-
- return hidden_states
-
- @staticmethod
- def load_hf_weights(folder: str, model: nn.Module) -> None:
- assert folder is not None, "Please specify the folder of the pretrained model"
- if gpc.is_rank_for_log():
- logger.info(f"Loading pretrained model from {folder}")
-
- fns = get_fns(folder)
- model_fns = [
- os.path.join(folder, fn)
- for fn in fns
- if (fn.endswith(".bin") and fn.startswith("pytorch_model"))
- or (fn.endswith(".safetensors") and fn.startswith("model"))
- ]
- model_fns.sort()
-
- state_dict = {}
- for model_fn in model_fns:
- state_dict.update(llm_load(model_fn, map_location="cpu"))
-
- tp_size = gpc.get_world_size(ParallelMode.TENSOR)
- tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
- wp_size = gpc.get_world_size(ParallelMode.WEIGHT)
- wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT)
- tp_mode = gpc.config.parallel.tensor["mode"]
- split_size = wp_size if tp_mode == "isp" else tp_size
- local_rank = wp_rank if tp_mode == "isp" else tp_rank
- row_dim = 0 if tp_mode == "isp" else 1
- if gpc.config.model.get("embed_split_hidden", True):
- embed_concat_dim = 1
- else:
- embed_concat_dim = 0
-
- new_state_dict = {}
-
- # embedding
- if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)):
- new_state_dict["tok_embeddings.weight"] = torch.chunk(
- state_dict.pop("model.embed_tokens.weight"),
- split_size,
- dim=embed_concat_dim,
- )[local_rank]
-
- for idx, i in enumerate(range(model.first_layer, model.last_layer)):
- layer_ids = i
-
- # attn
- state_dict[f"layers.{i}.attention.wqkv.weight"] = torch.chunk(
- state_dict.pop(f"model.layers.{layer_ids}.self_attn.W_pack.weight"),
- split_size,
- dim=0,
- )[local_rank]
- state_dict[f"layers.{i}.attention.out_proj.weight"] = torch.chunk(
- state_dict.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"),
- split_size,
- dim=row_dim,
- )[local_rank]
-
- # ffn
- state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk(
- state_dict.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"),
- split_size,
- dim=0,
- )[local_rank]
- state_dict[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk(
- state_dict.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"),
- split_size,
- dim=0,
- )[local_rank]
- state_dict[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk(
- state_dict.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"),
- split_size,
- dim=row_dim,
- )[local_rank]
-
- # attn norm
- state_dict[f"layers.{i}.attention_norm.weight"] = state_dict.pop(
- f"model.layers.{layer_ids}.input_layernorm.weight"
- )
- # ffn norm
- state_dict[f"layers.{i}.ffn_norm.weight"] = state_dict.pop(
- f"model.layers.{layer_ids}.post_attention_layernorm.weight"
- )
-
- # replace value within decoder layer
- for name in list(state_dict.keys()):
- if name.startswith(f"layers.{i}"):
- new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name)
-
- # output
- if gpc.is_last_rank(ParallelMode.PIPELINE):
- new_state_dict["output.weight"] = torch.chunk(
- state_dict.pop("lm_head.weight"),
- split_size,
- dim=0,
- )[local_rank]
- new_state_dict["norm.weight"] = state_dict.pop("model.norm.weight")
-
- missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
-
- if gpc.get_local_rank(ParallelMode.DATA) == 0:
- pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
- logger.info(
- f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
- f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}"
- )
-
- internlm_accelerator.empty_cache()
-
- @staticmethod
- def convert_internevo2hf_weights(src: str, tgt: str) -> None:
- def permute(qkv, num_heads, num_kv_heads, head_dim, qk_interleaved=False):
- if not qk_interleaved:
- return qkv
- q_per_kv = num_heads // num_kv_heads
- qkv = rearrange(qkv.T, "o (g n i) -> o g n i", n=q_per_kv + 2, i=head_dim)
- q, k, v = qkv[..., :q_per_kv, :], qkv[..., -2:-1, :], qkv[..., -1:, :]
- q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1)
- k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1)
- qkv = torch.cat((q, k, v), dim=2)
- qkv = rearrange(qkv, "o g n i -> o (g n i)").T
- return qkv
-
- model_config = gpc.config.model
- tp_mode = gpc.config.parallel.tensor["mode"]
- row_dim = 0 if tp_mode == "isp" else 1
- if model_config["embed_split_hidden"]:
- embed_concat_dim = 1
- else:
- embed_concat_dim = 0
-
- # load states
- states, num_shards = Baichuan2.load_sharded_states(src)
-
- # convert state_dict
- state_dict = {}
- embedding_key_list = ["tok_embeddings.weight", "embed_tokens.weight", None]
- for layer_i in tqdm(range(model_config["num_layers"])):
- # attn norm, ffn norm
- state_dict.update(
- {
- f"model.layers.{layer_i}.input_layernorm.weight": states[0][
- f"layers.{layer_i}.attention_norm.weight"
- ].clone(),
- f"model.layers.{layer_i}.post_attention_layernorm.weight": states[0][
- f"layers.{layer_i}.ffn_norm.weight"
- ].clone(),
- }
- )
- # attn
- state_dict[f"model.layers.{layer_i}.self_attn.W_pack.weight"] = permute(
- torch.cat([states[i][f"layers.{layer_i}.attention.wqkv.weight"] for i in range(num_shards)], dim=0),
- num_heads=model_config["num_attention_heads"],
- # num_kv_attention_heads equals to num_attention_heads in MHA
- num_kv_heads=model_config["num_attention_heads"],
- head_dim=model_config["hidden_size"] // model_config["num_attention_heads"],
- qk_interleaved=model_config.get("qk_interleaved", False),
- )
- state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
- [states[i][f"layers.{layer_i}.attention.out_proj.weight"] for i in range(num_shards)], dim=row_dim
- )
- # ffn
- state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
- [states[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
- )
- state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
- [states[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=row_dim
- )
- state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
- [states[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
- )
- # embedding, output
- for embedding_key in embedding_key_list:
- if embedding_key in states[0]:
- break
- if embedding_key is None:
- raise KeyError("Cannot find embedding key!")
- state_dict.update(
- {
- "model.norm.weight": states[0]["norm.weight"],
- "model.embed_tokens.weight": torch.cat(
- [states[i][embedding_key] for i in range(num_shards)], dim=embed_concat_dim
- ),
- "lm_head.weight": torch.cat([states[i]["output.weight"] for i in range(num_shards)], dim=0),
- },
- )
-
- # save state_dict to hf format
- shards, index = shard_checkpoint(state_dict, weights_name=SAFE_WEIGHTS_NAME)
- for shard_file, shard in shards.items():
- llm_save(save_path=os.path.join(tgt, shard_file), saved_obj=shard, metadata={"format": "pt"})
- if index is not None:
- llm_save(save_path=os.path.join(tgt, SAFE_WEIGHTS_INDEX_NAME), saved_obj=index)
diff --git a/internlm/model/modeling_gemma.py b/internlm/model/modeling_gemma.py
deleted file mode 100644
index 74d71796e..000000000
--- a/internlm/model/modeling_gemma.py
+++ /dev/null
@@ -1,750 +0,0 @@
-# Copyright (c) InternLM. All rights reserved.
-import math
-import os
-from typing import Optional
-
-import torch
-from torch import nn
-from tqdm import tqdm
-
-from internlm.accelerator import get_accelerator
-from internlm.core.context import ParallelMode
-from internlm.core.context.parallel_context import global_context as gpc
-from internlm.initialize.initialize_tensor import (
- normal_,
- scaled_init_method_normal,
- scaled_init_method_uniform,
- uniform_,
-)
-from internlm.model.base_model import BaseModel
-from internlm.model.modules.embedding import Embedding1D
-from internlm.model.modules.linear import new_linear
-from internlm.model.modules.mha import GQA
-from internlm.model.modules.mlp import new_feed_forward
-from internlm.model.modules.norm import new_layer_norm
-from internlm.model.utils import (
- convert_attn_args_to_kwargs,
- convert_attn_kwargs_to_args,
-)
-from internlm.solver.activation_checkpoint import activation_checkpoint
-from internlm.utils.logger import get_logger
-from internlm.utils.storage_manager import get_fns, llm_load, llm_save
-from transformers.modeling_utils import (
- SAFE_WEIGHTS_INDEX_NAME,
- SAFE_WEIGHTS_NAME,
- shard_checkpoint,
-)
-
-try:
- from flash_attn.modules.mlp import ParallelFusedMLP
-except ImportError:
- pass
-
-internlm_accelerator = get_accelerator()
-logger = get_logger(__file__)
-
-
-class GemmaDecoder(nn.Module):
- """
- 1D Packed Flash Llama Layer.
-
- Args:
- hidden_size (int): The hidden size of model. 768 by default.
- num_attention_heads (int): The number of attention heads. 12 by default.
- head_dim (int): The dimention of attention head dimention. hidden_size divided by num_heads by default.
- mlp_ratio (int): The ratio of MLP layers. 4 by default.
- attn_drop_rate (float): The dropout rate of attention module. 0 by default.
- drop_rate (float): The dropout rate of the input hidden state. 0.0 by default.
- dtype (torch.dtype): Type of data. torch.float by default.
- layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
- checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
- layer_idx (int): The index of current layer. 0 by default.
- residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
- device (Optional[Union[str, torch.device]]): The device will be used.
- add_unit_offset(bool): Add one to RMSNorm weight multiply by normed input. False by default.
- use_glu (bool): Whether to use glu. True by default.
- use_swiglu (bool): Whether to use swiglu. True by default.
- attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default,
- attn_other_init_std (float): std used to init attn_other weight. 0.02 by default,
- ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu
- otherwise init fc1 weight in ffn. 0.02 by default,
- ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default,
- init_type (str): Initialization type. Use uniform or normal. "normal" by default,
- rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
- multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2.
- tp_mode (str): The string value of tensor parallel mode, should be in ["mtp", "msp", "fsp", "isp"],
- "mtp" by default.
- """
-
- def __init__(
- self,
- hidden_size: int = 768,
- num_attention_heads: int = 12,
- num_kv_attention_heads: int = 12,
- head_dim: int = None,
- mlp_ratio: int = 4,
- attn_drop_rate: float = 0,
- drop_rate: float = 0.0,
- max_position_embeddings: int = 2048,
- dtype: torch.dtype = torch.float,
- layer_norm_epsilon: float = 1e-6,
- checkpoint: bool = False,
- layer_idx: int = 0,
- use_dynamic_ntk_rope: bool = False,
- residual_in_fp32: bool = False,
- device: Optional[torch.device] = None,
- apply_post_layer_norm: bool = False,
- fused_dropout_add_ln: bool = True,
- no_bias: bool = False,
- norm_type: str = "rmsnorm",
- qk_interleaved: bool = False,
- add_unit_offset: bool = False,
- dropout_selective_checkpoint: bool = True,
- use_scaled_init: bool = True,
- use_glu: bool = True,
- use_swiglu: bool = True,
- attn_wqkv_init_std: float = 0.02,
- attn_other_init_std: float = 0.02,
- ffn_uplayer_init_std: float = 0.02,
- ffn_other_init_std: float = 0.02,
- init_type: str = "normal",
- rope_base: int = 10000,
- mlp_layer_fusion: bool = False,
- multiple_of: int = 256,
- tp_mode: str = "mtp",
- ):
- super().__init__()
- self.checkpoint = checkpoint
- # dropout selective checkpoint can only be enabled when checkpoint is disabled.
- self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
- self.layer_idx = layer_idx
- self.prenorm = not apply_post_layer_norm
- assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here"
- self.fused_dropout_add_ln = fused_dropout_add_ln
- self.attn_wqkv_init_std = attn_wqkv_init_std
- self.attn_other_init_std = attn_other_init_std
- self.ffn_uplayer_init_std = ffn_uplayer_init_std
- self.ffn_other_init_std = ffn_other_init_std
-
- if not head_dim:
- head_dim = hidden_size // num_attention_heads
-
- self.attention = GQA(
- embed_dim=hidden_size,
- num_heads=num_attention_heads,
- num_kv_heads=num_kv_attention_heads,
- head_dim=head_dim,
- dropout=attn_drop_rate,
- max_position_embeddings=max_position_embeddings,
- softmax_scale=1 / math.sqrt(head_dim),
- causal=True,
- layer_idx=layer_idx,
- use_dynamic_ntk_rope=use_dynamic_ntk_rope,
- rotary_emb_dim=head_dim,
- rotary_emb_scale_base=0,
- device=device,
- dtype=dtype,
- qk_interleaved=qk_interleaved,
- bias=not no_bias,
- rope_base=rope_base,
- enable_qkv_fusion=False,
- )
-
- self.dropout1 = nn.Dropout(drop_rate)
- self.dropout2 = nn.Dropout(drop_rate)
- self.attention_norm = new_layer_norm(
- norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=add_unit_offset
- )
- self.ffn_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=add_unit_offset)
-
- sequence_parallel = gpc.config.parallel.get("sequence_parallel", False)
- parallel_mode = ParallelMode.WEIGHT if tp_mode == "isp" else ParallelMode.TENSOR
-
- if use_glu:
- self.feed_forward = new_feed_forward(
- hidden_size,
- int(hidden_size * mlp_ratio),
- out_features=hidden_size,
- bias=False,
- device=device,
- dtype=dtype,
- mlp_layer_fusion=mlp_layer_fusion,
- multiple_of=multiple_of,
- activation_type="swiglu" if use_swiglu else "gelu",
- )
- else:
- self.feed_forward = ParallelFusedMLP(
- hidden_size,
- int(hidden_size * mlp_ratio),
- out_features=hidden_size,
- activation="gelu_approx",
- process_group=gpc.get_group(parallel_mode),
- bias1=False,
- bias2=False,
- sequence_parallel=sequence_parallel,
- checkpoint_lvl=0,
- heuristic="auto",
- device=device,
- dtype=dtype,
- )
-
- self.use_glu = use_glu
- self.use_swiglu = use_swiglu
- self.use_scaled_init = use_scaled_init
- self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
- self.return_residual = False
-
- if init_type == "normal":
- self.init_func = normal_
- self.scaled_init_func = scaled_init_method_normal
- else:
- self.init_func = uniform_
- self.scaled_init_func = scaled_init_method_uniform
-
- self.reset_parameters()
-
- def reset_parameters(self):
- with torch.no_grad():
- for name, param in self.attention.named_parameters():
- if param.ndim == 1:
- param.data.zero_()
- elif "wq" in name or "wk" in name or "wv" in name:
- self.init_func(std=self.attn_wqkv_init_std)(param.data)
- elif self.use_scaled_init: # wo
- self.scaled_init_func(sigma=self.attn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
- else:
- self.init_func(std=self.attn_other_init_std)(param.data)
-
- for name, param in self.feed_forward.named_parameters():
- if self.use_glu:
- if self.use_scaled_init and "w2" in name:
- self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
- else:
- self.init_func(
- std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std
- )(param.data)
- else:
- if self.use_scaled_init and "fc1" not in name:
- self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
- else:
- self.init_func(std=self.ffn_uplayer_init_std if "fc1" in name else self.ffn_other_init_std)(
- param.data
- )
-
- def forward(self, hidden_states, residual=None, **kwargs):
- if self.checkpoint and self.training:
- args = convert_attn_kwargs_to_args(kwargs)
- return activation_checkpoint(self._forward, False, hidden_states, residual, *args)
- else:
- return self._forward(hidden_states, residual, **kwargs)
-
- def _forward(self, hidden_states, residual, *args, **kwargs):
- r"""Pass the input through the encoder layer.
-
- Args:
- hidden_states: the sequence to the encoder layer (required).
- residual: hidden_states = Attn/MLP(LN(residual))
- cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1
- indexes: the length of index is same as hidden states, which stand for the current position
- """
- if self.prenorm:
-
- def _dropout_and_norm_attn(_residual, _hidden_states):
- _dropped = self.dropout1(_hidden_states)
- _residual = (_dropped + _residual) if _residual is not None else _dropped
- _hidden_states = self.attention_norm(_residual.to(dtype=self.attention_norm.weight.dtype))
-
- return _residual, _hidden_states
-
- if self.dropout_selective_checkpoint:
- residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, residual, hidden_states)
- else:
- residual, hidden_states = _dropout_and_norm_attn(residual, hidden_states)
-
- if self.residual_in_fp32:
- residual = residual.to(torch.float32)
-
- mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs)
- hidden_states = self.attention(hidden_states, **mixer_kwargs)
-
- if not isinstance(self.feed_forward, nn.Identity):
- if not self.fused_dropout_add_ln:
-
- def _dropout_and_norm_ffn(_residual, _hidden_states):
- _dropped = self.dropout2(_hidden_states)
- _residual = (_dropped + _residual) if _residual is not None else _dropped
- _hidden_states = self.ffn_norm(_residual.to(self.ffn_norm.weight.dtype))
-
- return _residual, _hidden_states
-
- if self.dropout_selective_checkpoint:
- residual, hidden_states = activation_checkpoint(
- _dropout_and_norm_ffn, False, residual, hidden_states
- )
- else:
- residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states)
-
- if self.residual_in_fp32:
- residual = residual.to(torch.float32)
- hidden_states = self.feed_forward(hidden_states)
-
- return hidden_states + residual
- else:
- assert residual is None
-
- mixer_out = self.attention(hidden_states, **kwargs)
- if self.return_residual: # mixer out is actually a pair here
- mixer_out, hidden_states = mixer_out
- hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to(
- dtype=self.attention_norm.weight.dtype
- )
- if not isinstance(self.feed_forward, nn.Identity):
- mlp_out = self.feed_forward(hidden_states)
- if self.return_residual: # mlp out is actually a pair here
- mlp_out, hidden_states = mlp_out
- hidden_states = self.ffn_norm((self.dropout2(mlp_out)) + hidden_states).to(
- dtype=self.ffn_norm.weight.dtype
- )
- return hidden_states
-
-
-class Gemma(BaseModel):
- """
- 1D Packed Flash Llama.
-
- Args:
- num_layers (int): The number of layer. 12 by default.
- hidden_size (int): The size of hidden state. 768 by default.
- num_attention_heads (int): The number of attention head. 12 by default.
- head_dim (int): The dimention of attention head dimention. hidden_size divided by num_heads by default.
- vocab_size (int): The size of vocabulary. 50304 by default.
- mlp_ratio (int): The ratio of MLP layers. 4 by default.
- attn_drop_rate (float): The dropout rate of attention module. 0.0 by default.
- drop_rate (float): The dropout rate of input hidden state. 0.0 by default.
- dtype (torch.dtype): The type of data. torch.float by default.
- checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
- checkpoint_fraction (float): The proportion of layers that need to be checkpointed compared to the total number
- of layers. 1.0 by default.
- layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default.
- first (bool): Whether input embedding layer or not. False by default.
- last (bool): Whether output embedding layer or not. False by default.
- embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
- parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
- start_layer_idx (int): The index of start layer in the pipeline. 0 by default.
- device (Optional[Union[str, torch.device]]): The device will be used. None by default.
- residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
- add_unit_offset(bool): Add one to RMSNorm weight multiply by normed input. False by default.
- use_glu (bool): Whether to use glu. True by default.
- use_swiglu (bool): Whether to use swiglu. True by default.
- embedding_init_std (float): std used to init embedding weight. 0.02 by default,
- attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default,
- attn_other_init_std (float): std used to init attn_other weight. 0.02 by default,
- ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu
- otherwise init fc1 weight in ffn. 0.02 by default,
- ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default,
- out_head_init_std (float): std used to init output lmhead weight. 0.02 by default,
- init_type (str): Initialization type. Use uniform or normal. "normal" by default,
- extra_pred_tokens (int): The number of extra output head for multi-token-prediction. 0 by default.
- rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
- multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2.
- """
-
- def __init__(
- self,
- num_layers: int = 12,
- hidden_size: int = 768,
- num_attention_heads: int = 12,
- num_kv_attention_heads: int = 12,
- head_dim: int = None,
- vocab_size: int = 50304,
- mlp_ratio: int = 4,
- attn_drop_rate: float = 0.0,
- drop_rate: float = 0.0,
- max_position_embeddings: int = 2048,
- dtype: torch.dtype = torch.float,
- checkpoint: float = 1.0,
- layer_norm_epsilon: float = 1e-5,
- first: bool = False,
- last: bool = False,
- embed_grad_scale: float = 0.1,
- parallel_output: bool = True,
- start_layer_idx: int = 0,
- use_dynamic_ntk_rope: bool = False,
- device: Optional[torch.device] = None,
- apply_post_layer_norm=False,
- no_bias=False,
- residual_in_fp32: bool = False,
- norm_type: str = "rmsnorm",
- qk_interleaved: bool = False,
- add_unit_offset: bool = False,
- is_reward: bool = False,
- dropout_selective_checkpoint: bool = True,
- use_scaled_init: bool = True,
- use_glu: bool = True,
- use_swiglu: bool = False,
- embedding_init_std: float = 0.02,
- attn_wqkv_init_std: float = 0.02,
- attn_other_init_std: float = 0.02,
- ffn_uplayer_init_std: float = 0.02,
- ffn_other_init_std: float = 0.02,
- out_head_init_std: float = 0.02,
- init_type: str = "normal",
- extra_pred_tokens: int = 0,
- rope_base: int = 10000,
- norm_head: bool = False,
- mlp_layer_fusion: bool = False,
- multiple_of: int = 256,
- ):
- super().__init__()
-
- checkpoint_layer_num = int(num_layers * checkpoint)
- self.hidden_size = hidden_size
- self.embed_grad_scale = embed_grad_scale
- self.parallel_output = parallel_output
- self.tp_mode = "mtp"
- if isinstance(gpc.config.parallel["tensor"], dict):
- self.tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp")
-
- if first:
- self.embed_tokens = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
- for _, param in self.embed_tokens.named_parameters():
- if init_type == "normal":
- normal_(std=embedding_init_std)(param)
- else:
- uniform_(std=embedding_init_std)(param)
-
- self.layers = nn.ModuleList(
- [
- GemmaDecoder(
- hidden_size=hidden_size,
- num_attention_heads=num_attention_heads,
- num_kv_attention_heads=num_kv_attention_heads,
- head_dim=head_dim,
- mlp_ratio=mlp_ratio,
- attn_drop_rate=attn_drop_rate,
- drop_rate=drop_rate,
- max_position_embeddings=max_position_embeddings,
- dtype=dtype,
- layer_norm_epsilon=layer_norm_epsilon,
- checkpoint=lid < checkpoint_layer_num,
- layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation
- use_dynamic_ntk_rope=use_dynamic_ntk_rope,
- residual_in_fp32=residual_in_fp32,
- device=device,
- apply_post_layer_norm=apply_post_layer_norm,
- fused_dropout_add_ln=False,
- no_bias=no_bias,
- norm_type=norm_type,
- add_unit_offset=add_unit_offset,
- dropout_selective_checkpoint=dropout_selective_checkpoint,
- use_scaled_init=use_scaled_init,
- use_glu=use_glu,
- use_swiglu=use_swiglu,
- qk_interleaved=qk_interleaved,
- attn_wqkv_init_std=attn_wqkv_init_std,
- attn_other_init_std=attn_other_init_std,
- ffn_uplayer_init_std=ffn_uplayer_init_std,
- ffn_other_init_std=ffn_other_init_std,
- init_type=init_type,
- rope_base=rope_base,
- mlp_layer_fusion=mlp_layer_fusion,
- multiple_of=multiple_of,
- tp_mode=self.tp_mode,
- )
- for lid in range(num_layers)
- ]
- )
-
- if last:
- if not apply_post_layer_norm:
- self.norm = new_layer_norm(
- norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=add_unit_offset
- )
-
- self.output = new_linear(
- name="output",
- in_features=hidden_size,
- out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
- bias=False,
- device=device,
- is_reward=is_reward,
- dtype=dtype,
- weight_scale=embed_grad_scale,
- norm_head=norm_head,
- )
- for _, param in self.output.named_parameters():
- if init_type == "normal":
- normal_(std=out_head_init_std)(param)
- else:
- uniform_(std=out_head_init_std)(param)
-
- if extra_pred_tokens > 0:
- self.extra_pred_tokens = extra_pred_tokens
- assert not is_reward, "extra_pred_tokens > 0 means using multi token prediction, not implement for RLHF"
- self.extra_outputs = nn.ModuleList(
- [
- new_linear(
- name="output",
- in_features=hidden_size,
- out_features=vocab_size,
- bias=False,
- device=device,
- is_reward=is_reward,
- dtype=dtype,
- weight_scale=embed_grad_scale,
- norm_head=norm_head,
- )
- for _ in range(self.extra_pred_tokens)
- ]
- )
- for _, param in self.extra_outputs.named_parameters():
- if init_type == "normal":
- normal_(std=out_head_init_std)(param)
- else:
- uniform_(std=out_head_init_std)(param)
-
- def forward(self, hidden_states=None, input_ids=None, **kwargs):
- # attention_mask: compute attention on the places where the value is 1
- if hasattr(self, "embed_tokens"):
- hidden_states = self.embed_tokens(input_ids)
- if self.embed_grad_scale != 1:
- hidden_states = (
- self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach()
- )
- hidden_states = hidden_states * (self.hidden_size**0.5)
-
- for _, block in enumerate(self.layers):
- hidden_states = block(hidden_states, residual=None, **kwargs)
-
- if hasattr(self, "norm"):
- hidden_states = self.norm(hidden_states.to(self.norm.weight.dtype))
- if hasattr(self, "extra_pred_tokens") and self.extra_pred_tokens > 0:
- extra_hidden_states_list = [self.extra_outputs[i](hidden_states) for i in range(self.extra_pred_tokens)]
- else:
- extra_hidden_states_list = None
- if hasattr(self, "output"):
- hidden_states = self.output(hidden_states)
-
- if extra_hidden_states_list is not None:
- return (hidden_states, extra_hidden_states_list)
-
- return hidden_states
-
- @staticmethod
- def load_hf_weights(folder: str, model: nn.Module) -> None:
- assert folder is not None, "Please specify the folder of the pretrained model"
- if gpc.is_rank_for_log():
- logger.info(f"Loading pretrained model from {folder}")
-
- fns = get_fns(folder)
- model_fns = [
- os.path.join(folder, fn)
- for fn in fns
- if (fn.endswith(".bin") and fn.startswith("pytorch_model"))
- or (fn.endswith(".safetensors") and fn.startswith("model"))
- ]
- model_fns.sort()
-
- state_dict = {}
- for model_fn in model_fns:
- state_dict.update(llm_load(model_fn, map_location="cpu"))
-
- tp_size = gpc.get_world_size(ParallelMode.TENSOR)
- tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
- wp_size = gpc.get_world_size(ParallelMode.WEIGHT)
- wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT)
- tp_mode = gpc.config.parallel.tensor["mode"]
- split_size = wp_size if tp_mode == "isp" else tp_size
- local_rank = wp_rank if tp_mode == "isp" else tp_rank
- row_dim = 0 if tp_mode == "isp" else 1
- if gpc.config.model.get("embed_split_hidden", True):
- embed_concat_dim = 1
- else:
- embed_concat_dim = 0
-
- new_state_dict = {}
-
- # embedding
- if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)):
- new_state_dict["embed_tokens.weight"] = torch.chunk(
- state_dict.get("model.embed_tokens.weight"),
- split_size,
- dim=embed_concat_dim,
- )[local_rank]
-
- for idx, i in enumerate(range(model.first_layer, model.last_layer)):
- layer_ids = i
-
- # attn
- state_dict[f"layers.{i}.attention.wq.weight"] = torch.chunk(
- state_dict.pop(f"model.layers.{layer_ids}.self_attn.q_proj.weight"),
- split_size,
- dim=0,
- )[local_rank]
- state_dict[f"layers.{i}.attention.wk.weight"] = torch.chunk(
- state_dict.pop(f"model.layers.{layer_ids}.self_attn.k_proj.weight"),
- split_size,
- dim=0,
- )[local_rank]
- state_dict[f"layers.{i}.attention.wv.weight"] = torch.chunk(
- state_dict.pop(f"model.layers.{layer_ids}.self_attn.v_proj.weight"),
- split_size,
- dim=0,
- )[local_rank]
- state_dict[f"layers.{i}.attention.wo.weight"] = torch.chunk(
- state_dict.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"),
- split_size,
- dim=row_dim,
- )[local_rank]
-
- # ffn
- state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk(
- state_dict.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"),
- split_size,
- dim=0,
- )[local_rank]
- state_dict[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk(
- state_dict.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"),
- split_size,
- dim=0,
- )[local_rank]
- state_dict[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk(
- state_dict.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"),
- split_size,
- dim=row_dim,
- )[local_rank]
-
- # attn norm
- state_dict[f"layers.{i}.attention_norm.weight"] = state_dict.pop(
- f"model.layers.{layer_ids}.input_layernorm.weight"
- )
- # ffn norm
- state_dict[f"layers.{i}.ffn_norm.weight"] = state_dict.pop(
- f"model.layers.{layer_ids}.post_attention_layernorm.weight"
- )
-
- # replace value within decoder layer
- for name in list(state_dict.keys()):
- if name.startswith(f"layers.{i}"):
- new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name)
-
- # output
- if gpc.is_last_rank(ParallelMode.PIPELINE):
- if "lm_head.weight" in state_dict:
- new_state_dict["output.weight"] = torch.chunk(
- state_dict.pop("lm_head.weight"), # we do not tie lm head with embedding
- split_size,
- dim=0,
- )[local_rank]
- state_dict.pop("model.embed_tokens.weight")
- else:
- new_state_dict["output.weight"] = torch.chunk(
- # gemma model ties lm head with embedding in transformers implementation
- state_dict.pop("model.embed_tokens.weight"),
- split_size,
- dim=0,
- )[local_rank]
- new_state_dict["norm.weight"] = state_dict.pop("model.norm.weight")
-
- missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
-
- if gpc.get_local_rank(ParallelMode.DATA) == 0:
- pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
- logger.info(
- f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
- f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}"
- )
-
- internlm_accelerator.empty_cache()
-
- @staticmethod
- def convert_internevo2hf_weights(src: str, tgt: str) -> None:
- model_config = gpc.config.model
- tp_mode = gpc.config.parallel.tensor["mode"]
- row_dim = 0 if tp_mode == "isp" else 1
-
- # load states
- states, num_shards = Gemma.load_sharded_states(src)
-
- # convert state_dict
- state_dict = {}
- embedding_key_list = ["tok_embeddings.weight", "embed_tokens.weight", None]
- for layer_i in tqdm(range(model_config["num_layers"])):
- # attn norm, mlp norm
- state_dict.update(
- {
- f"model.layers.{layer_i}.input_layernorm.weight": states[0][
- f"layers.{layer_i}.attention_norm.weight"
- ].clone(),
- f"model.layers.{layer_i}.post_attention_layernorm.weight": states[0][
- f"layers.{layer_i}.ffn_norm.weight"
- ].clone(),
- }
- )
- # attn wqkv weight and bias
- state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = torch.cat(
- [states[i][f"layers.{layer_i}.attention.wq.weight"] for i in range(num_shards)],
- dim=0,
- )
- state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = torch.cat(
- [states[i][f"layers.{layer_i}.attention.wk.weight"] for i in range(num_shards)],
- dim=0,
- )
- state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
- [states[i][f"layers.{layer_i}.attention.wv.weight"] for i in range(num_shards)],
- dim=0,
- )
- # attn wo weight
- state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
- [states[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=row_dim
- )
-
- # mlp
- state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
- [states[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
- )
- state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
- [states[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=row_dim
- )
- state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
- [states[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
- )
-
- # embedding, head
- for embedding_key in embedding_key_list:
- if embedding_key in states[0]:
- break
- if embedding_key is None:
- raise KeyError("Cannot find embedding key!")
- if model_config["embed_split_hidden"]:
- embed_concat_dim = 1
- tok_emb_list = [states[i][embedding_key] for i in range(num_shards)]
- else:
- embed_concat_dim = 0
- _, size_1 = states[0][embedding_key].shape
- embdim_pertp = size_1 // num_shards
- tok_emb_list = [
- torch.concat(
- [
- states[tp][embedding_key][:, embdim_pertp * local_rank : embdim_pertp * (local_rank + 1)]
- for tp in range(num_shards)
- ],
- dim=0,
- )
- for local_rank in range(num_shards)
- ]
- state_dict.update(
- {
- "model.norm.weight": states[0]["norm.weight"],
- "model.embed_tokens.weight": torch.cat(tok_emb_list, dim=embed_concat_dim),
- "lm_head.weight": torch.cat([states[i]["output.weight"] for i in range(num_shards)], dim=0),
- },
- )
-
- # save state_dict to hf format
- shards, index = shard_checkpoint(state_dict, weights_name=SAFE_WEIGHTS_NAME)
- for shard_file, shard in shards.items():
- llm_save(save_path=os.path.join(tgt, shard_file), saved_obj=shard, metadata={"format": "pt"})
- if index is not None:
- # Save the index as well
- llm_save(save_path=os.path.join(tgt, SAFE_WEIGHTS_INDEX_NAME), saved_obj=index)
diff --git a/internlm/model/modeling_llava.py b/internlm/model/modeling_llava.py
deleted file mode 100644
index 4c2bb1745..000000000
--- a/internlm/model/modeling_llava.py
+++ /dev/null
@@ -1,244 +0,0 @@
-from typing import Optional
-
-import torch
-from torch import nn
-
-from internlm.core.context import ParallelMode
-from internlm.core.context.parallel_context import global_context as gpc
-from internlm.core.naive_amp import set_output_attr_to_module
-from internlm.initialize.initialize_tensor import normal_, uniform_
-from internlm.model.base_model import BaseModel
-from internlm.model.llava.clip_builder import build_vision_tower
-from internlm.model.llava.projector_builder import build_vision_projector
-from internlm.model.modeling_llama import Llama2Decoder
-from internlm.model.modules.embedding import Embedding1D
-from internlm.model.modules.linear import new_linear
-from internlm.model.modules.norm import new_layer_norm
-from internlm.utils.logger import get_logger
-
-logger = get_logger(__file__)
-
-
-class Llava(BaseModel):
- """
- 1D Packed Flash Llava.
-
- Args:
- num_layers (int): The number of layer. 48 by default.
- hidden_size (int): The size of hidden state. 2048 by default.
- num_attention_heads (int): The number of attention head. 32 by default.
- num_kv_attention_heads (int): The number of key/value attention heads. Defaults to 32.
- vocab_size (int): The size of vocabulary. 50304 by default.
- mlp_ratio (int): The ratio of MLP layers. 4 by default.
- attn_drop_rate (float): The dropout rate of attention module. 0.0 by default.
- drop_rate (float): The dropout rate of input hidden state. 0.0 by default.
- dtype (torch.dtype): The type of data. torch.float by default.
- checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
- layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default.
- first (bool): Whether input embedding layer or not. False by default.
- last (bool): Whether output embedding layer or not. False by default.
- embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
- parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
- start_layer_idx (int): The index of start layer in the pipeline. 0 by default.
- device (Optional[Union[str, torch.device]]): The device will be used. None by default.
- residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
- norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
- qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved.
- embedding_init_std (float): std used to init embedding weight. 0.02 by default,
- attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default,
- attn_other_init_std (float): std used to init attn_other weight. 0.02 by default,
- ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu
- otherwise init fc1 weight in ffn. 0.02 by default,
- ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default,
- out_head_init_std (float): std used to init output lmhead weight. 0.02 by default,
- init_type (str): Initialization type. Use uniform or normal. "normal" by default,
- rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
- image_token_id (int): image token id. 200000 by default.
- vit_cfg (dict): The config of vision tower. None by default.
- vision_proj_cfg (dict): The config of vision projector. None by default.
- """
-
- def __init__(
- self,
- num_layers: int = 48,
- hidden_size: int = 2048,
- num_attention_heads: int = 32,
- num_kv_attention_heads: int = 32,
- vocab_size: int = 50304,
- mlp_ratio: int = 4,
- attn_drop_rate: float = 0.0,
- drop_rate: float = 0.0,
- dtype: torch.dtype = torch.float,
- checkpoint: bool = False,
- layer_norm_epsilon: float = 1e-5,
- first: bool = False,
- last: bool = False,
- embed_grad_scale: float = 0.1,
- parallel_output: bool = True,
- start_layer_idx: int = 0,
- device: Optional[torch.device] = None,
- apply_post_layer_norm=False,
- no_bias=False,
- residual_in_fp32: bool = False,
- norm_type: str = "rmsnorm",
- qk_interleaved: bool = False,
- is_reward: bool = False,
- dropout_selective_checkpoint: bool = True,
- use_scaled_init: bool = True,
- use_swiglu: bool = True,
- embedding_init_std: float = 0.02,
- attn_wqkv_init_std: float = 0.02,
- attn_other_init_std: float = 0.02,
- ffn_uplayer_init_std: float = 0.02,
- ffn_other_init_std: float = 0.02,
- out_head_init_std: float = 0.02,
- init_type: str = "normal",
- rope_base: int = 10000,
- mlp_layer_fusion: bool = False,
- multiple_of: int = 256,
- image_token_id: int = 200000,
- vit_cfg=None,
- vision_proj_cfg=None,
- ):
- super().__init__()
-
- checkpoint_layer_num = num_layers * checkpoint
-
- self.dtype = dtype
- self.image_token_id = image_token_id
- self.embed_grad_scale = embed_grad_scale
- self.parallel_output = parallel_output
-
- if first:
- self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
-
- for _, param in self.tok_embeddings.named_parameters():
- if init_type == "normal":
- normal_(std=embedding_init_std)(param)
- else:
- uniform_(std=embedding_init_std)(param)
-
- self.layers = nn.ModuleList(
- [
- Llama2Decoder(
- hidden_size=hidden_size,
- num_attention_heads=num_attention_heads,
- num_kv_attention_heads=num_kv_attention_heads,
- mlp_ratio=mlp_ratio,
- attn_drop_rate=attn_drop_rate,
- drop_rate=drop_rate,
- dtype=dtype,
- layer_norm_epsilon=layer_norm_epsilon,
- checkpoint=lid < checkpoint_layer_num,
- layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation
- residual_in_fp32=residual_in_fp32,
- device=device,
- apply_post_layer_norm=apply_post_layer_norm,
- fused_dropout_add_ln=False,
- no_bias=no_bias,
- norm_type=norm_type,
- dropout_selective_checkpoint=dropout_selective_checkpoint,
- use_scaled_init=use_scaled_init,
- use_swiglu=use_swiglu,
- qk_interleaved=qk_interleaved,
- attn_wqkv_init_std=attn_wqkv_init_std,
- attn_other_init_std=attn_other_init_std,
- ffn_uplayer_init_std=ffn_uplayer_init_std,
- ffn_other_init_std=ffn_other_init_std,
- init_type=init_type,
- rope_base=rope_base,
- mlp_layer_fusion=mlp_layer_fusion,
- multiple_of=multiple_of,
- )
- for lid in range(num_layers)
- ]
- )
-
- if last:
- if not apply_post_layer_norm:
- self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon)
-
- self.output = new_linear(
- name="output",
- in_features=hidden_size,
- out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
- bias=False,
- device=device,
- dtype=dtype,
- is_reward=is_reward,
- weight_scale=embed_grad_scale,
- )
- set_output_attr_to_module(self.output)
- for _, param in self.output.named_parameters():
- if init_type == "normal":
- normal_(std=out_head_init_std)(param)
- else:
- uniform_(std=out_head_init_std)(param)
-
- if first:
- assert vit_cfg is not None
- self.vit = build_vision_tower(vit_cfg)
- self.vit.requires_grad_(False)
-
- assert vision_proj_cfg is not None
- self.vision_proj = build_vision_projector(vision_proj_cfg)
- # self.vision_proj.requires_grad_(False)
-
- def forward(self, hidden_states=None, images=None, input_ids=None, **kwargs):
- xs = []
- pure_text = False
- images = [] if images is None else images
-
- if hasattr(self, "vit") and hasattr(self, "vision_proj") and hasattr(self, "tok_embeddings"):
- # vit
- if len(images) == 1 and len(images[0]) == 0: # make sure grad in Qformer for update
- images = [torch.rand(1, 3, self.vit.image_size, self.vit.image_size).cuda().to(self.dtype)]
- pure_text = True
-
- for image in images:
- assert len(image) > 0
- if len(image) == 0:
- x = []
- else:
- assert not isinstance(image, list), image
- x = image.to(torch.cuda.current_device()).to(self.dtype)
- x = self.vit(x)
- x = self.vision_proj(x)
- xs.append(x)
-
- # tok embeddings
- org_ids = input_ids.clone()
- input_ids[input_ids == self.image_token_id] = 0
- hidden_states = self.tok_embeddings(input_ids).clone()
-
- if pure_text and len(xs) > 0:
- hidden_states = hidden_states + 0 * xs[0].sum()
- else:
- for i in range(len(xs)):
- hidden_states[i, org_ids[i] == self.image_token_id] = (xs[i].reshape((-1, xs[i].shape[-1]))).to(
- hidden_states.dtype
- )
-
- if self.embed_grad_scale != 1:
- hidden_states = (
- self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach()
- )
-
- for _, block in enumerate(self.layers):
- hidden_states = block(hidden_states, residual=None, **kwargs)
-
- if hasattr(self, "norm"):
- hidden_states = self.norm(hidden_states.float())
-
- if hasattr(self, "output"):
- hidden_states = self.output(hidden_states)
-
- return hidden_states
-
- @staticmethod
- def load_hf_weights(folder: str, model: nn.Module) -> None:
- raise NotImplementedError
-
- @staticmethod
- def convert_internevo2hf_weights(src: str, tgt: str) -> None:
- raise NotImplementedError
diff --git a/internlm/model/modeling_mixtral.py b/internlm/model/modeling_mixtral.py
deleted file mode 100644
index 8e8767ced..000000000
--- a/internlm/model/modeling_mixtral.py
+++ /dev/null
@@ -1,429 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import math
-from typing import Optional
-
-import torch
-from torch import nn
-
-from internlm.core.context import ParallelMode
-from internlm.core.context.parallel_context import global_context as gpc
-from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
-from internlm.model.base_model import BaseModel
-from internlm.model.modules.embedding import Embedding1D
-from internlm.model.modules.linear import new_linear
-from internlm.model.modules.mha import SWA
-from internlm.model.modules.mlp import new_feed_forward
-from internlm.model.modules.norm import new_layer_norm
-from internlm.model.moe.moe import MoE
-from internlm.model.utils import (
- convert_attn_args_to_kwargs,
- convert_attn_kwargs_to_args,
-)
-from internlm.solver.activation_checkpoint import activation_checkpoint
-from internlm.utils.logger import get_logger
-
-logger = get_logger(__file__)
-
-
-class MixtralMoEDecoder(nn.Module):
- """
- InternLM1 MoE Decoder Layer.
-
- Args:
- hidden_size (int): The hidden size of model. 768 by default.
- num_attention_heads (int): The number of attention heads. 12 by default.
- mlp_ratio (int): The ratio of MLP layers. 4 by default.
- attn_drop_rate (float): The dropout rate of attention module. 0 by default.
- drop_rate (float): The dropout rate of the input hidden state. 0.0 by default.
- max_position_embeddings (int): The maximum position embeddings. 2048 by default.
- dtype (torch.dtype): Type of data. torch.float by default.
- layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
- checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
- layer_idx (int): The index of current layer. 0 by default.
- use_dynamic_ntk_rope (bool): Whether to use dynamic ntk rope. False by default.
- residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
- device (Optional[Union[str, torch.device]]): The device will be used.
- norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
- qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved.
- dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout layers only.
- use_scaled_init (bool): Whether to use scaled initialization for weights.
- use_swiglu (bool): Whether to use SwiGLU activation in the mlp module.
- rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
- mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization.
- multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization.
- """
-
- def __init__(
- self,
- hidden_size: int = 768,
- num_attention_heads: int = 12,
- num_kv_attention_heads: int = 12,
- mlp_ratio: int = 4,
- attn_drop_rate: float = 0,
- drop_rate: float = 0.0,
- max_position_embeddings: int = 2048,
- dtype: torch.dtype = torch.float,
- layer_norm_epsilon: float = 1e-6,
- checkpoint: bool = False,
- layer_idx: int = 0,
- use_dynamic_ntk_rope: bool = False,
- residual_in_fp32: bool = False,
- device: Optional[torch.device] = None,
- qkv_bias=True,
- o_bias=False,
- norm_type: str = "rmsnorm",
- rope_base: int = 10000,
- rope_scaling_factor: float = 1.0,
- use_sliding_window: bool = False,
- sliding_window: int = None,
- qk_interleaved: bool = False,
- dropout_selective_checkpoint: bool = True,
- use_scaled_init: bool = True,
- use_swiglu: bool = True,
- mlp_layer_fusion: bool = False,
- multiple_of: int = 256,
- num_experts: int = 1,
- top_k: int = 1,
- num_shared_experts: int = 0,
- moe_layer_kwargs: dict = None,
- ):
- super().__init__()
- self.checkpoint = checkpoint
- # dropout selective checkpoint can only be enabled when checkpoint is disabled.
- self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
- self.layer_idx = layer_idx
-
- head_dim = hidden_size // num_attention_heads
- softmax_scale = 1 / math.sqrt(head_dim)
-
- self.mixer = SWA(
- embed_dim=hidden_size,
- num_heads=num_attention_heads,
- num_kv_heads=num_kv_attention_heads,
- dropout=attn_drop_rate,
- max_position_embeddings=max_position_embeddings,
- softmax_scale=softmax_scale,
- causal=True,
- layer_idx=layer_idx,
- use_dynamic_ntk_rope=use_dynamic_ntk_rope,
- rotary_emb_dim=head_dim,
- rotary_emb_scale_base=0,
- device=device,
- dtype=dtype,
- qk_interleaved=qk_interleaved,
- qkv_bias=qkv_bias,
- o_bias=o_bias,
- rope_base=rope_base,
- rope_scaling_factor=rope_scaling_factor,
- use_sliding_window=use_sliding_window,
- sliding_window=sliding_window,
- )
-
- self.dropout1 = nn.Dropout(drop_rate)
- self.dropout2 = nn.Dropout(drop_rate)
- self.norm1 = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon)
- self.norm2 = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon)
-
- self.num_experts = num_experts
- if num_experts <= 1: # dense, not MoE
- self.mlp = new_feed_forward(
- hidden_size,
- int(hidden_size * mlp_ratio),
- out_features=hidden_size,
- bias=False,
- device=device,
- dtype=dtype,
- mlp_layer_fusion=mlp_layer_fusion,
- multiple_of=multiple_of,
- # TODO: to support more activation functions
- activation_type="swiglu" if use_swiglu else "gelu",
- )
- else:
- # replace mlp by MoE module. The expert in MoE is a FeedForward module.
- # mlp_cls = get_mlp_cls(self.tp_mode)
- self.mlp = MoE(
- hidden_size,
- int(hidden_size * mlp_ratio),
- out_features=hidden_size,
- num_experts=num_experts,
- top_k=top_k,
- num_shared_experts=num_shared_experts,
- moe_layer_kwargs=moe_layer_kwargs,
- device=device,
- dtype=dtype,
- mlp_layer_fusion=mlp_layer_fusion,
- multiple_of=multiple_of,
- # TODO: to support more activation functions
- activation_type="swiglu" if use_swiglu else "gelu",
- )
-
- self.use_swiglu = use_swiglu
- self.use_scaled_init = use_scaled_init
- self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
- self.return_residual = False
- self.reset_parameters() # TODO: check this should be changed when moe is added
-
- def reset_parameters(self):
- with torch.no_grad():
- for name, param in self.mixer.named_parameters():
- if param.ndim == 1:
- param.data.zero_()
- elif "wqkv" in name:
- normal_(std=0.006)(param.data)
- elif self.use_scaled_init:
- scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
- else:
- normal_(std=0.0015)(param.data)
-
- for name, param in self.mlp.named_parameters():
- if param.ndim == 1 and "bias" in name:
- param.data.zero_()
- elif self.use_swiglu:
- if self.use_scaled_init and "w2" in name:
- scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
- else:
- # candidate: w1, w3, fused_w1_w3
- normal_(std=0.006 if "w1" in name or "w3" in name else 0.0015)(param.data)
- else:
- if self.use_scaled_init and "fc1" not in name:
- scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
- else:
- normal_(std=0.006 if "fc1" in name else 0.0015)(param.data)
-
- def forward(self, hidden_states, **kwargs):
- if self.checkpoint and self.training:
- # TODO: check whether this will be affected by moe
- # NOTICE: activation_checkpiont do not support kwargs when use_reentrant = True.
- args = convert_attn_kwargs_to_args(kwargs)
- return activation_checkpoint(self._forward, False, hidden_states, *args)
- else:
- return self._forward(hidden_states, **kwargs)
-
- def _forward(self, hidden_states, *args, **kwargs):
- r"""Pass the input through the encoder layer.
-
- Args:
- hidden_states: the sequence to the encoder layer (required).
- residual: hidden_states = Attn/MLP(LN(residual))
- cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1
- indexes: the length of index is same as hidden states, which stand for the current position
- """
-
- def _dropout_and_norm_attn(_hidden_states):
- _dropped = self.dropout1(_hidden_states)
- _residual = _dropped
- _hidden_states = self.norm1(_residual.to(self.norm1.weight.dtype))
- return _residual, _hidden_states
-
- if self.dropout_selective_checkpoint:
- residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, hidden_states)
- else:
- residual, hidden_states = _dropout_and_norm_attn(hidden_states)
-
- if self.residual_in_fp32:
- residual = residual.to(torch.float32)
-
- mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs)
- hidden_states = self.mixer(hidden_states, **mixer_kwargs)
-
- def _dropout_and_norm_ffn(_residual, _hidden_states):
- _dropped = self.dropout2(_hidden_states)
- _residual = (_dropped + _residual) if _residual is not None else _dropped
- _hidden_states = self.norm2(_residual.to(self.norm2.weight.dtype))
- return _residual, _hidden_states
-
- if self.dropout_selective_checkpoint:
- residual, hidden_states = activation_checkpoint(_dropout_and_norm_ffn, False, residual, hidden_states)
- else:
- residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states)
-
- if self.residual_in_fp32:
- residual = residual.to(torch.float32)
-
- # MLP.
- if self.num_experts <= 1: # dense mlp output
- hidden_states = self.mlp(hidden_states)
- moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype)
- else: # MoE output
- hidden_states, moe_loss, _ = self.mlp(hidden_states)
-
- return hidden_states + residual, moe_loss
-
-
-class MixtralMoE(BaseModel):
- """
- InternLM1 MoE.
-
- Args:
- num_layers (int): The number of layer. 12 by default.
- hidden_size (int): The size of hidden state. 768 by default.
- num_attention_heads (int): The number of attention head. 12 by default.
- vocab_size (int): The size of vocabulary. 50304 by default.
- mlp_ratio (int): The ratio of MLP layers. 4 by default.
- attn_drop_rate (float): The dropout rate of attention module. 0.0 by default.
- drop_rate (float): The dropout rate of input hidden state. 0.0 by default.
- max_position_embeddings (int): The maximum position embeddings. 2048 by default.
- dtype (torch.dtype): The type of data. torch.float by default.
- checkpoint (float): The proportion of layers that need to be checkpointed compared to the total number
- of layers. 0.0 by default.
- layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
- first (bool): Whether input embedding layer or not. False by default.
- last (bool): Whether output embedding layer or not. False by default.
- embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
- parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
- start_layer_idx (int): The index of start layer in the pipeline. 0 by default.
- use_dynamic_ntk_rope (bool): Whether to use dynamic ntk rope. False by default.
- device (Optional[Union[str, torch.device]]): The device will be used. None by default.
- residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
- norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
- qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved.
- dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout and norm layers.
- use_scaled_init (bool): Whether to use scaled initialization for weights.
- use_swiglu (bool): Whether to use SwiGLU activation in the mlp module.
- num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
- moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
- (https://arxiv.org/abs/2201.05596) layer.
- moe_type (str): determine which moe impl will be used, default is GShardMoE
- mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization.
- multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization.
- """
-
- def __init__(
- self,
- num_layers: int = 48,
- hidden_size: int = 2048,
- num_attention_heads: int = 32,
- num_kv_attention_heads: int = 12,
- vocab_size: int = 50304,
- mlp_ratio: float = 4.0,
- attn_drop_rate: float = 0.0,
- drop_rate: float = 0.0,
- max_position_embeddings: int = 2048,
- dtype: torch.dtype = torch.float,
- checkpoint: float = 0.0,
- layer_norm_epsilon: float = 1e-5,
- first: bool = False,
- last: bool = False,
- embed_grad_scale: float = 0.1,
- parallel_output: bool = True,
- start_layer_idx: int = 0,
- use_dynamic_ntk_rope: bool = False,
- device: Optional[torch.device] = None,
- qkv_bias=True,
- o_bias=False,
- residual_in_fp32: bool = False,
- norm_type: str = "rmsnorm",
- qk_interleaved: bool = False,
- is_reward: bool = False,
- dropout_selective_checkpoint: bool = True,
- use_scaled_init: bool = True,
- use_swiglu: bool = True,
- rope_base: int = 10000,
- rope_scaling_factor: float = 1.0,
- use_sliding_window: bool = False,
- max_window_layers: int = 0,
- sliding_window: int = None,
- mlp_layer_fusion: bool = False,
- multiple_of: int = 256,
- moe_type: str = None, # pylint: disable=W0613
- num_experts: bool = 1,
- top_k: int = 1,
- num_shared_experts: int = 0,
- moe_layer_kwargs: dict = None,
- ):
- super().__init__()
-
- checkpoint_layer_num = int(num_layers * checkpoint)
-
- if first:
- self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
-
- for _, param in self.embedding.named_parameters():
- normal_(std=0.0052)(param)
- self.embed_grad_scale = embed_grad_scale
- self.blocks = nn.ModuleList(
- [
- MixtralMoEDecoder(
- hidden_size=hidden_size,
- num_attention_heads=num_attention_heads,
- num_kv_attention_heads=num_kv_attention_heads,
- mlp_ratio=mlp_ratio,
- attn_drop_rate=attn_drop_rate,
- drop_rate=drop_rate,
- max_position_embeddings=max_position_embeddings,
- dtype=dtype,
- layer_norm_epsilon=layer_norm_epsilon,
- checkpoint=lid < checkpoint_layer_num,
- layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation
- use_dynamic_ntk_rope=use_dynamic_ntk_rope,
- residual_in_fp32=residual_in_fp32,
- device=device,
- qkv_bias=qkv_bias,
- o_bias=o_bias,
- norm_type=norm_type,
- dropout_selective_checkpoint=dropout_selective_checkpoint,
- use_scaled_init=use_scaled_init,
- use_swiglu=use_swiglu,
- qk_interleaved=qk_interleaved,
- rope_base=rope_base,
- rope_scaling_factor=rope_scaling_factor,
- use_sliding_window=use_sliding_window and lid >= max_window_layers,
- sliding_window=sliding_window,
- mlp_layer_fusion=mlp_layer_fusion,
- multiple_of=multiple_of,
- num_experts=num_experts,
- top_k=top_k,
- num_shared_experts=num_shared_experts,
- moe_layer_kwargs=moe_layer_kwargs,
- )
- for lid in range(num_layers)
- ]
- )
- if last:
- self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon)
- self.head = new_linear(
- name="head",
- in_features=hidden_size,
- out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
- bias=False,
- device=device,
- dtype=dtype,
- is_reward=is_reward,
- weight_scale=embed_grad_scale,
- )
- for _, param in self.head.named_parameters():
- normal_(std=0.0052)(param)
-
- self.parallel_output = parallel_output
-
- def forward(self, hidden_states=None, input_ids=None, **kwargs):
- # attention_mask: compute attention on the places where the value is 1
- # old condition may fail when use shared embedding
- if gpc.is_pipeline_first_stage() and input_ids is not None:
- hidden_states = self.embedding(input_ids)
- if self.embed_grad_scale != 1:
- hidden_states = (
- self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach()
- )
-
- moe_losses = []
- for _, block in enumerate(self.blocks):
- hidden_states, mos_loss = block(hidden_states, **kwargs)
- moe_losses.append(mos_loss)
-
- if hasattr(self, "norm"):
- hidden_states = self.norm(hidden_states.float())
- if hasattr(self, "head"):
- hidden_states = self.head(hidden_states)
-
- return hidden_states, moe_losses
-
- @staticmethod
- def load_hf_weights(folder: str, model: nn.Module) -> None:
- raise NotImplementedError
-
- @staticmethod
- def convert_internevo2hf_weights(src: str, tgt: str) -> None:
- raise NotImplementedError
diff --git a/internlm/model/modeling_qwen2.py b/internlm/model/modeling_qwen2.py
deleted file mode 100644
index 5a4bde534..000000000
--- a/internlm/model/modeling_qwen2.py
+++ /dev/null
@@ -1,750 +0,0 @@
-# Copyright (c) InternLM. All rights reserved.
-import math
-import os
-from typing import Optional
-
-import torch
-from torch import nn
-from tqdm import tqdm
-
-from internlm.accelerator import get_accelerator
-from internlm.core.context import ParallelMode
-from internlm.core.context.parallel_context import global_context as gpc
-from internlm.initialize.initialize_tensor import (
- normal_,
- scaled_init_method_normal,
- scaled_init_method_uniform,
- uniform_,
-)
-from internlm.model.base_model import BaseModel
-from internlm.model.modules.embedding import Embedding1D
-from internlm.model.modules.linear import new_linear
-from internlm.model.modules.mha import SWA
-from internlm.model.modules.mlp import new_feed_forward
-from internlm.model.modules.norm import new_layer_norm
-from internlm.model.utils import (
- convert_attn_args_to_kwargs,
- convert_attn_kwargs_to_args,
-)
-from internlm.solver.activation_checkpoint import activation_checkpoint
-from internlm.utils.logger import get_logger
-from internlm.utils.storage_manager import get_fns, llm_load, llm_save
-from transformers.modeling_utils import (
- SAFE_WEIGHTS_INDEX_NAME,
- SAFE_WEIGHTS_NAME,
- shard_checkpoint,
-)
-
-internlm_accelerator = get_accelerator()
-logger = get_logger(__file__)
-
-
-class Qwen2Decoder(nn.Module):
- """
- 1D Packed Flash Qwen Layer.
-
- Args:
- hidden_size (int): The hidden size of model. 768 by default.
- num_attention_heads (int): The number of attention heads. 12 by default.
- mlp_ratio (int): The ratio of MLP layers. 4 by default.
- attn_drop_rate (float): The dropout rate of attention module. 0 by default.
- drop_rate (float): The dropout rate of the input hidden state. 0.0 by default.
- dtype (torch.dtype): Type of data. torch.float by default.
- layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
- checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
- layer_idx (int): The index of current layer. 0 by default.
- residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
- device (Optional[Union[str, torch.device]]): The device will be used.
- norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
- attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default,
- attn_other_init_std (float): std used to init attn_other weight. 0.02 by default,
- ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu
- otherwise init fc1 weight in ffn. 0.02 by default,
- ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default,
- init_type (str): Initialization type. Use uniform or normal. "normal" by default,
- rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
- multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2.
- """
-
- def __init__(
- self,
- hidden_size: int = 768,
- num_attention_heads: int = 12,
- num_kv_attention_heads: int = 12,
- mlp_ratio: int = 4,
- attn_drop_rate: float = 0,
- drop_rate: float = 0.0,
- max_position_embeddings: int = 2048,
- dtype: torch.dtype = torch.float,
- layer_norm_epsilon: float = 1e-6,
- checkpoint: bool = False,
- layer_idx: int = 0,
- use_dynamic_ntk_rope: bool = False,
- residual_in_fp32: bool = False,
- device: Optional[torch.device] = None,
- apply_post_layer_norm: bool = False,
- fused_dropout_add_ln: bool = True,
- qkv_bias=True,
- o_bias=False,
- mlp_bias=False,
- norm_type: str = "rmsnorm",
- qk_interleaved: bool = False,
- dropout_selective_checkpoint: bool = True,
- use_scaled_init: bool = True,
- use_swiglu: bool = True,
- attn_wqkv_init_std: float = 0.02,
- attn_other_init_std: float = 0.02,
- ffn_uplayer_init_std: float = 0.02,
- ffn_other_init_std: float = 0.02,
- init_type: str = "normal",
- rope_type: str = "normal",
- rope_base: int = 10000,
- rope_scaling_factor: float = 1.0,
- use_sliding_window: bool = False,
- sliding_window: int = None,
- mlp_layer_fusion: bool = False,
- multiple_of: int = 256,
- scale_attn_weights: bool = False, # Qwen1
- use_logn_attn: bool = False, # Qwen1
- ):
- super().__init__()
- self.checkpoint = checkpoint
- # dropout selective checkpoint can only be enabled when checkpoint is disabled.
- self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
- self.layer_idx = layer_idx
- self.prenorm = not apply_post_layer_norm
- assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here"
- self.fused_dropout_add_ln = fused_dropout_add_ln
- self.attn_wqkv_init_std = attn_wqkv_init_std
- self.attn_other_init_std = attn_other_init_std
- self.ffn_uplayer_init_std = ffn_uplayer_init_std
- self.ffn_other_init_std = ffn_other_init_std
-
- head_dim = hidden_size // num_attention_heads
-
- if scale_attn_weights:
- softmax_scale = None
- else:
- softmax_scale = 1 / math.sqrt(head_dim)
- self.attention = SWA(
- embed_dim=hidden_size,
- num_heads=num_attention_heads,
- num_kv_heads=num_kv_attention_heads,
- dropout=attn_drop_rate,
- max_position_embeddings=max_position_embeddings,
- softmax_scale=softmax_scale,
- causal=True,
- layer_idx=layer_idx,
- use_dynamic_ntk_rope=use_dynamic_ntk_rope,
- rotary_emb_dim=head_dim,
- rotary_emb_scale_base=0,
- device=device,
- dtype=dtype,
- qk_interleaved=qk_interleaved,
- qkv_bias=qkv_bias,
- o_bias=o_bias,
- rope_type=rope_type,
- rope_base=rope_base,
- rope_scaling_factor=rope_scaling_factor,
- use_sliding_window=use_sliding_window,
- sliding_window=sliding_window,
- use_logn_attn=use_logn_attn,
- )
-
- self.dropout1 = nn.Dropout(drop_rate)
- self.dropout2 = nn.Dropout(drop_rate)
- self.attention_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon)
- self.ffn_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon)
-
- self.feed_forward = new_feed_forward(
- hidden_size,
- int(hidden_size * mlp_ratio),
- out_features=hidden_size,
- bias=mlp_bias,
- device=device,
- dtype=dtype,
- mlp_layer_fusion=mlp_layer_fusion,
- multiple_of=multiple_of,
- activation_type="swiglu" if use_swiglu else "gelu",
- )
-
- self.use_swiglu = use_swiglu
- self.use_scaled_init = use_scaled_init
- self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
- self.return_residual = False
-
- if init_type == "normal":
- self.init_func = normal_
- self.scaled_init_func = scaled_init_method_normal
- else:
- self.init_func = uniform_
- self.scaled_init_func = scaled_init_method_uniform
-
- self.reset_parameters()
-
- def reset_parameters(self):
- with torch.no_grad():
- for name, param in self.attention.named_parameters():
- if param.ndim == 1:
- param.data.zero_()
- elif "wq" in name or "wk" in name or "wv" in name:
- self.init_func(std=self.attn_wqkv_init_std)(param.data)
- elif self.use_scaled_init: # wo
- self.scaled_init_func(sigma=self.attn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
- else:
- self.init_func(std=self.attn_other_init_std)(param.data)
-
- for name, param in self.feed_forward.named_parameters():
- if self.use_swiglu:
- if self.use_scaled_init and "w2" in name:
- self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
- else:
- # candidate: w1, w3, fused_w1_w3
- self.init_func(
- std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std
- )(param.data)
- else:
- if self.use_scaled_init and "fc1" not in name:
- self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
- else:
- self.init_func(std=self.ffn_uplayer_init_std if "fc1" in name else self.ffn_other_init_std)(
- param.data
- )
-
- def forward(self, hidden_states, residual=None, **kwargs):
- if self.checkpoint and self.training:
- args = convert_attn_kwargs_to_args(kwargs)
- return activation_checkpoint(self._forward, False, hidden_states, residual, *args)
- else:
- return self._forward(hidden_states, residual, **kwargs)
-
- def _forward(self, hidden_states, residual, *args, **kwargs):
- r"""Pass the input through the encoder layer.
-
- Args:
- hidden_states: the sequence to the encoder layer (required).
- residual: hidden_states = Attn/MLP(LN(residual))
- cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1
- indexes: the length of index is same as hidden states, which stand for the current position
- """
- if self.prenorm:
-
- def _dropout_and_norm_attn(_residual, _hidden_states):
- _dropped = self.dropout1(_hidden_states)
- _residual = (_dropped + _residual) if _residual is not None else _dropped
- _hidden_states = self.attention_norm(_residual.to(dtype=self.attention_norm.weight.dtype))
-
- return _residual, _hidden_states
-
- if self.dropout_selective_checkpoint:
- residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, residual, hidden_states)
- else:
- residual, hidden_states = _dropout_and_norm_attn(residual, hidden_states)
-
- if self.residual_in_fp32:
- residual = residual.to(torch.float32)
-
- mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs)
- hidden_states = self.attention(hidden_states, **mixer_kwargs)
-
- if not isinstance(self.feed_forward, nn.Identity):
- if not self.fused_dropout_add_ln:
-
- def _dropout_and_norm_ffn(_residual, _hidden_states):
- _dropped = self.dropout2(_hidden_states)
- _residual = (_dropped + _residual) if _residual is not None else _dropped
- _hidden_states = self.ffn_norm(_residual.to(self.ffn_norm.weight.dtype))
-
- return _residual, _hidden_states
-
- if self.dropout_selective_checkpoint:
- residual, hidden_states = activation_checkpoint(
- _dropout_and_norm_ffn, False, residual, hidden_states
- )
- else:
- residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states)
-
- if self.residual_in_fp32:
- residual = residual.to(torch.float32)
- hidden_states = self.feed_forward(hidden_states)
-
- return hidden_states + residual
- else:
- assert residual is None
-
- mixer_out = self.attention(hidden_states, **kwargs)
- if self.return_residual: # mixer out is actually a pair here
- mixer_out, hidden_states = mixer_out
- hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to(
- dtype=self.attention_norm.weight.dtype
- )
- if not isinstance(self.feed_forward, nn.Identity):
- mlp_out = self.feed_forward(hidden_states)
- if self.return_residual: # mlp out is actually a pair here
- mlp_out, hidden_states = mlp_out
- hidden_states = self.ffn_norm((self.dropout2(mlp_out)) + hidden_states).to(
- dtype=self.ffn_norm.weight.dtype
- )
- return hidden_states
-
-
-class Qwen2(BaseModel):
- """
- 1D Packed Flash Qwen.
-
- Args:
- num_layers (int): The number of layer. 12 by default.
- hidden_size (int): The size of hidden state. 768 by default.
- num_attention_heads (int): The number of attention head. 12 by default.
- vocab_size (int): The size of vocabulary. 50304 by default.
- mlp_ratio (int): The ratio of MLP layers. 4 by default.
- attn_drop_rate (float): The dropout rate of attention module. 0.0 by default.
- drop_rate (float): The dropout rate of input hidden state. 0.0 by default.
- dtype (torch.dtype): The type of data. torch.float by default.
- checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
- layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default.
- first (bool): Whether input embedding layer or not. False by default.
- last (bool): Whether output embedding layer or not. False by default.
- embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
- parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
- start_layer_idx (int): The index of start layer in the pipeline. 0 by default.
- device (Optional[Union[str, torch.device]]): The device will be used. None by default.
- residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
- norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
- embedding_init_std (float): std used to init embedding weight. 0.02 by default,
- attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default,
- attn_other_init_std (float): std used to init attn_other weight. 0.02 by default,
- ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu
- otherwise init fc1 weight in ffn. 0.02 by default,
- ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default,
- out_head_init_std (float): std used to init output lmhead weight. 0.02 by default,
- init_type (str): Initialization type. Use uniform or normal. "normal" by default,
- extra_pred_tokens (int): The number of extra output head for multi-token-prediction. 0 by default.
- rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
- multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2.
- """
-
- def __init__(
- self,
- num_layers: int = 12,
- hidden_size: int = 768,
- num_attention_heads: int = 12,
- num_kv_attention_heads: int = 12,
- vocab_size: int = 50304,
- mlp_ratio: int = 4,
- attn_drop_rate: float = 0.0,
- drop_rate: float = 0.0,
- max_position_embeddings: int = 2048,
- dtype: torch.dtype = torch.float,
- checkpoint: float = 1.0,
- layer_norm_epsilon: float = 1e-5,
- first: bool = False,
- last: bool = False,
- embed_grad_scale: float = 0.1,
- parallel_output: bool = True,
- start_layer_idx: int = 0,
- use_dynamic_ntk_rope: bool = False,
- device: Optional[torch.device] = None,
- apply_post_layer_norm=False,
- qkv_bias=True,
- o_bias=False,
- mlp_bias=False,
- residual_in_fp32: bool = False,
- norm_type: str = "rmsnorm",
- qk_interleaved: bool = False,
- is_reward: bool = False,
- dropout_selective_checkpoint: bool = True,
- use_scaled_init: bool = True,
- use_swiglu: bool = True,
- embedding_init_std: float = 0.02,
- attn_wqkv_init_std: float = 0.02,
- attn_other_init_std: float = 0.02,
- ffn_uplayer_init_std: float = 0.02,
- ffn_other_init_std: float = 0.02,
- out_head_init_std: float = 0.02,
- init_type: str = "normal",
- extra_pred_tokens: int = 0,
- rope_type: str = "normal",
- rope_base: int = 10000,
- rope_scaling_factor: float = 1.0,
- use_sliding_window: bool = False,
- max_window_layers: int = 0,
- sliding_window: int = None,
- mlp_layer_fusion: bool = False,
- multiple_of: int = 256,
- scale_attn_weights: bool = False, # Qwen1
- use_logn_attn: bool = False, # Qwen1
- ):
- super().__init__()
-
- self.embed_grad_scale = embed_grad_scale
-
- checkpoint_layer_num = int(num_layers * checkpoint)
-
- if first:
- self.embed_tokens = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
- for _, param in self.embed_tokens.named_parameters():
- if init_type == "normal":
- normal_(std=embedding_init_std)(param)
- else:
- uniform_(std=embedding_init_std)(param)
-
- self.layers = nn.ModuleList(
- [
- Qwen2Decoder(
- hidden_size=hidden_size,
- num_attention_heads=num_attention_heads,
- num_kv_attention_heads=num_kv_attention_heads,
- mlp_ratio=mlp_ratio,
- attn_drop_rate=attn_drop_rate,
- drop_rate=drop_rate,
- dtype=dtype,
- layer_norm_epsilon=layer_norm_epsilon,
- checkpoint=lid < checkpoint_layer_num,
- layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation
- use_dynamic_ntk_rope=use_dynamic_ntk_rope,
- residual_in_fp32=residual_in_fp32,
- device=device,
- apply_post_layer_norm=apply_post_layer_norm,
- fused_dropout_add_ln=False,
- qkv_bias=qkv_bias,
- o_bias=o_bias,
- mlp_bias=mlp_bias,
- norm_type=norm_type,
- dropout_selective_checkpoint=dropout_selective_checkpoint,
- use_scaled_init=use_scaled_init,
- use_swiglu=use_swiglu,
- qk_interleaved=qk_interleaved,
- attn_wqkv_init_std=attn_wqkv_init_std,
- attn_other_init_std=attn_other_init_std,
- ffn_uplayer_init_std=ffn_uplayer_init_std,
- ffn_other_init_std=ffn_other_init_std,
- init_type=init_type,
- rope_type=rope_type,
- rope_base=rope_base,
- rope_scaling_factor=rope_scaling_factor,
- use_sliding_window=use_sliding_window and lid >= max_window_layers,
- sliding_window=sliding_window,
- mlp_layer_fusion=mlp_layer_fusion,
- multiple_of=multiple_of,
- max_position_embeddings=max_position_embeddings,
- scale_attn_weights=scale_attn_weights,
- use_logn_attn=use_logn_attn,
- )
- for lid in range(num_layers)
- ]
- )
-
- if last:
- if not apply_post_layer_norm:
- self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon)
-
- self.output = new_linear(
- name="output",
- in_features=hidden_size,
- out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
- bias=False,
- device=device,
- dtype=dtype,
- is_reward=is_reward,
- weight_scale=embed_grad_scale,
- )
-
- for _, param in self.output.named_parameters():
- if init_type == "normal":
- normal_(std=out_head_init_std)(param)
- else:
- uniform_(std=out_head_init_std)(param)
-
- if extra_pred_tokens > 0:
- self.extra_pred_tokens = extra_pred_tokens
- assert not is_reward, "extra_pred_tokens > 0 means using multi token prediction, not implement for RLHF"
- self.extra_outputs = nn.ModuleList(
- [
- new_linear(
- name="output",
- in_features=hidden_size,
- out_features=vocab_size,
- bias=False,
- device=device,
- dtype=dtype,
- is_reward=is_reward,
- weight_scale=embed_grad_scale,
- )
- for _ in range(self.extra_pred_tokens)
- ]
- )
- for _, param in self.extra_outputs.named_parameters():
- if init_type == "normal":
- normal_(std=out_head_init_std)(param)
- else:
- uniform_(std=out_head_init_std)(param)
-
- self.parallel_output = parallel_output
-
- def forward(self, hidden_states=None, input_ids=None, **kwargs):
- # attention_mask: compute attention on the places where the value is 1
- if hasattr(self, "embed_tokens"):
- hidden_states = self.embed_tokens(input_ids)
- if self.embed_grad_scale != 1:
- hidden_states = (
- self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach()
- )
-
- for _, block in enumerate(self.layers):
- hidden_states = block(
- hidden_states,
- residual=None,
- **kwargs,
- )
-
- if hasattr(self, "norm"):
- hidden_states = self.norm(hidden_states.to(self.norm.weight.dtype))
- if hasattr(self, "extra_pred_tokens") and self.extra_pred_tokens > 0:
- extra_hidden_states_list = [self.extra_outputs[i](hidden_states) for i in range(self.extra_pred_tokens)]
- else:
- extra_hidden_states_list = None
- if hasattr(self, "output"):
- hidden_states = self.output(hidden_states)
-
- if extra_hidden_states_list is not None:
- return (hidden_states, extra_hidden_states_list)
-
- return hidden_states
-
- @staticmethod
- def load_hf_weights(folder: str, model: nn.Module) -> None:
- assert folder is not None, "Please specify the folder of the pretrained model"
- if gpc.is_rank_for_log():
- logger.info(f"Loading pretrained model from {folder}")
-
- fns = get_fns(folder)
- model_fns = [
- os.path.join(folder, fn)
- for fn in fns
- if (fn.endswith(".bin") and fn.startswith("pytorch_model"))
- or (fn.endswith(".safetensors") and fn.startswith("model"))
- ]
- model_fns.sort()
-
- state_dict = {}
- for model_fn in model_fns:
- state_dict.update(llm_load(model_fn, map_location="cpu"))
-
- tp_size = gpc.get_world_size(ParallelMode.TENSOR)
- tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
- wp_size = gpc.get_world_size(ParallelMode.WEIGHT)
- wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT)
- tp_mode = gpc.config.parallel.tensor["mode"]
- split_size = wp_size if tp_mode == "isp" else tp_size
- local_rank = wp_rank if tp_mode == "isp" else tp_rank
- row_dim = 0 if tp_mode == "isp" else 1
- if gpc.config.model.get("embed_split_hidden", True):
- embed_concat_dim = 1
- else:
- embed_concat_dim = 0
-
- new_state_dict = {}
-
- # embedding
- if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)):
- new_state_dict["embed_tokens.weight"] = torch.chunk(
- state_dict.pop("model.embed_tokens.weight"),
- split_size,
- dim=embed_concat_dim,
- )[local_rank]
-
- for idx, i in enumerate(range(model.first_layer, model.last_layer)):
- layer_ids = i
-
- # attn
- state_dict[f"layers.{i}.attention.wq.weight"] = torch.chunk(
- state_dict.pop(f"model.layers.{layer_ids}.self_attn.q_proj.weight"),
- split_size,
- dim=0,
- )[local_rank]
- state_dict[f"layers.{i}.attention.wq.bias"] = torch.chunk(
- state_dict.pop(f"model.layers.{layer_ids}.self_attn.q_proj.bias"),
- split_size,
- dim=0,
- )[local_rank]
- state_dict[f"layers.{i}.attention.wk.weight"] = torch.chunk(
- state_dict.pop(f"model.layers.{layer_ids}.self_attn.k_proj.weight"),
- split_size,
- dim=0,
- )[local_rank]
- state_dict[f"layers.{i}.attention.wk.bias"] = torch.chunk(
- state_dict.pop(f"model.layers.{layer_ids}.self_attn.k_proj.bias"),
- split_size,
- dim=0,
- )[local_rank]
- state_dict[f"layers.{i}.attention.wv.weight"] = torch.chunk(
- state_dict.pop(f"model.layers.{layer_ids}.self_attn.v_proj.weight"),
- split_size,
- dim=0,
- )[local_rank]
- state_dict[f"layers.{i}.attention.wv.bias"] = torch.chunk(
- state_dict.pop(f"model.layers.{layer_ids}.self_attn.v_proj.bias"),
- split_size,
- dim=0,
- )[local_rank]
- state_dict[f"layers.{i}.attention.wo.weight"] = torch.chunk(
- state_dict.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"),
- split_size,
- dim=row_dim,
- )[local_rank]
-
- # ffn
- state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk(
- state_dict.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"),
- split_size,
- dim=0,
- )[local_rank]
- state_dict[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk(
- state_dict.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"),
- split_size,
- dim=0,
- )[local_rank]
- state_dict[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk(
- state_dict.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"),
- split_size,
- dim=row_dim,
- )[local_rank]
-
- # attn norm
- state_dict[f"layers.{i}.attention_norm.weight"] = state_dict.pop(
- f"model.layers.{layer_ids}.input_layernorm.weight"
- )
- # ffn norm
- state_dict[f"layers.{i}.ffn_norm.weight"] = state_dict.pop(
- f"model.layers.{layer_ids}.post_attention_layernorm.weight"
- )
-
- # replace value within decoder layer
- for name in list(state_dict.keys()):
- if name.startswith(f"layers.{i}"):
- new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name)
-
- # output
- if gpc.is_last_rank(ParallelMode.PIPELINE):
- new_state_dict["output.weight"] = torch.chunk(
- state_dict.pop("lm_head.weight"),
- split_size,
- dim=0,
- )[local_rank]
- new_state_dict["norm.weight"] = state_dict.pop("model.norm.weight")
-
- missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
-
- if gpc.get_local_rank(ParallelMode.DATA) == 0:
- pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
- logger.info(
- f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
- f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}"
- )
-
- internlm_accelerator.empty_cache()
-
- @staticmethod
- def convert_internevo2hf_weights(src: str, tgt: str) -> None:
- model_config = gpc.config.model
- tp_mode = gpc.config.parallel.tensor["mode"]
- row_dim = 0 if tp_mode == "isp" else 1
-
- # load states
- states, num_shards = Qwen2.load_sharded_states(src)
-
- # convert state_dict
- state_dict = {}
- embedding_key_list = ["tok_embeddings.weight", "embed_tokens.weight", None]
- for layer_i in tqdm(range(model_config["num_layers"])):
- # attn norm, mlp norm
- state_dict.update(
- {
- f"model.layers.{layer_i}.input_layernorm.weight": states[0][
- f"layers.{layer_i}.attention_norm.weight"
- ].clone(),
- f"model.layers.{layer_i}.post_attention_layernorm.weight": states[0][
- f"layers.{layer_i}.ffn_norm.weight"
- ].clone(),
- }
- )
- # attn wqkv weight and bias
- state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = torch.cat(
- [states[i][f"layers.{layer_i}.attention.wq.weight"] for i in range(num_shards)],
- dim=0,
- )
- state_dict[f"model.layers.{layer_i}.self_attn.q_proj.bias"] = torch.cat(
- [states[i][f"layers.{layer_i}.attention.wq.bias"] for i in range(num_shards)],
- dim=0,
- )
- state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = torch.cat(
- [states[i][f"layers.{layer_i}.attention.wk.weight"] for i in range(num_shards)],
- dim=0,
- )
- state_dict[f"model.layers.{layer_i}.self_attn.k_proj.bias"] = torch.cat(
- [states[i][f"layers.{layer_i}.attention.wk.bias"] for i in range(num_shards)],
- dim=0,
- )
- state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
- [states[i][f"layers.{layer_i}.attention.wv.weight"] for i in range(num_shards)],
- dim=0,
- )
- state_dict[f"model.layers.{layer_i}.self_attn.v_proj.bias"] = torch.cat(
- [states[i][f"layers.{layer_i}.attention.wv.bias"] for i in range(num_shards)],
- dim=0,
- )
- # attn wo weight
- state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
- [states[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=row_dim
- )
-
- # mlp
- state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
- [states[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
- )
- state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
- [states[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=row_dim
- )
- state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
- [states[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
- )
-
- # embedding, head
- for embedding_key in embedding_key_list:
- if embedding_key in states[0]:
- break
- if embedding_key is None:
- raise KeyError("Cannot find embedding key!")
- if model_config["embed_split_hidden"]:
- embed_concat_dim = 1
- tok_emb_list = [states[i][embedding_key] for i in range(num_shards)]
- else:
- embed_concat_dim = 0
- _, size_1 = states[0][embedding_key].shape
- embdim_pertp = size_1 // num_shards
- tok_emb_list = [
- torch.concat(
- [
- states[tp][embedding_key][:, embdim_pertp * local_rank : embdim_pertp * (local_rank + 1)]
- for tp in range(num_shards)
- ],
- dim=0,
- )
- for local_rank in range(num_shards)
- ]
- state_dict.update(
- {
- "model.norm.weight": states[0]["norm.weight"],
- "model.embed_tokens.weight": torch.cat(tok_emb_list, dim=embed_concat_dim),
- "lm_head.weight": torch.cat([states[i]["output.weight"] for i in range(num_shards)], dim=0),
- },
- )
-
- # save state_dict to hf format
- shards, index = shard_checkpoint(state_dict, weights_name=SAFE_WEIGHTS_NAME)
- for shard_file, shard in shards.items():
- llm_save(save_path=os.path.join(tgt, shard_file), saved_obj=shard, metadata={"format": "pt"})
- if index is not None:
- # Save the index as well
- llm_save(save_path=os.path.join(tgt, SAFE_WEIGHTS_INDEX_NAME), saved_obj=index)
diff --git a/internlm/model/modeling_qwen2_moe.py b/internlm/model/modeling_qwen2_moe.py
deleted file mode 100644
index cfa98098a..000000000
--- a/internlm/model/modeling_qwen2_moe.py
+++ /dev/null
@@ -1,559 +0,0 @@
-# Copyright (c) InternLM. All rights reserved.
-import math
-from typing import Optional
-
-import torch
-from torch import nn
-
-from internlm.accelerator import get_accelerator
-from internlm.core.context import ParallelMode
-from internlm.core.context.parallel_context import global_context as gpc
-from internlm.initialize.initialize_tensor import (
- normal_,
- scaled_init_method_normal,
- scaled_init_method_uniform,
- uniform_,
-)
-from internlm.model.base_model import BaseModel
-from internlm.model.modules.embedding import Embedding1D
-from internlm.model.modules.linear import new_linear
-from internlm.model.modules.mha import SWA
-from internlm.model.modules.mlp import new_feed_forward
-from internlm.model.modules.norm import new_layer_norm
-from internlm.model.moe.moe import Qwen2MoE
-from internlm.model.utils import (
- convert_attn_args_to_kwargs,
- convert_attn_kwargs_to_args,
-)
-from internlm.solver.activation_checkpoint import activation_checkpoint
-from internlm.utils.logger import get_logger
-
-internlm_accelerator = get_accelerator()
-logger = get_logger(__file__)
-
-
-class Qwen2MoeDecoder(nn.Module):
- """
- 1D Packed Flash Qwen Layer.
-
- Args:
- hidden_size (int): The hidden size of model. 768 by default.
- num_attention_heads (int): The number of attention heads. 12 by default.
- mlp_ratio (int): The ratio of MLP layers. 4 by default.
- attn_drop_rate (float): The dropout rate of attention module. 0 by default.
- drop_rate (float): The dropout rate of the input hidden state. 0.0 by default.
- dtype (torch.dtype): Type of data. torch.float by default.
- layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
- checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
- layer_idx (int): The index of current layer. 0 by default.
- residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
- device (Optional[Union[str, torch.device]]): The device will be used.
- norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
- attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default,
- attn_other_init_std (float): std used to init attn_other weight. 0.02 by default,
- ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu
- otherwise init fc1 weight in ffn. 0.02 by default,
- ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default,
- init_type (str): Initialization type. Use uniform or normal. "normal" by default,
- rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
- multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2.
- """
-
- def __init__(
- self,
- hidden_size: int = 768,
- num_attention_heads: int = 12,
- num_kv_attention_heads: int = 12,
- mlp_ratio: int = 4,
- attn_drop_rate: float = 0,
- drop_rate: float = 0.0,
- max_position_embeddings: int = 2048,
- dtype: torch.dtype = torch.float,
- layer_norm_epsilon: float = 1e-6,
- checkpoint: bool = False,
- layer_idx: int = 0,
- use_dynamic_ntk_rope: bool = False,
- residual_in_fp32: bool = False,
- device: Optional[torch.device] = None,
- apply_post_layer_norm: bool = False,
- fused_dropout_add_ln: bool = True,
- qkv_bias=True,
- o_bias=False,
- mlp_bias=False,
- norm_type: str = "rmsnorm",
- qk_interleaved: bool = False,
- dropout_selective_checkpoint: bool = True,
- use_scaled_init: bool = True,
- use_swiglu: bool = True,
- attn_wqkv_init_std: float = 0.02,
- attn_other_init_std: float = 0.02,
- ffn_uplayer_init_std: float = 0.02,
- ffn_other_init_std: float = 0.02,
- init_type: str = "normal",
- rope_type: str = "normal",
- rope_base: int = 10000,
- rope_scaling_factor: float = 1.0,
- use_sliding_window: bool = False,
- sliding_window: int = None,
- mlp_layer_fusion: bool = False,
- multiple_of: int = 256,
- scale_attn_weights: bool = False, # Qwen1
- use_logn_attn: bool = False, # Qwen1
- num_experts: int = 1,
- top_k: int = 1,
- num_shared_experts: int = 0,
- moe_layer_kwargs: dict = None,
- ):
- super().__init__()
- self.checkpoint = checkpoint
- # dropout selective checkpoint can only be enabled when checkpoint is disabled.
- self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
- self.layer_idx = layer_idx
- self.prenorm = not apply_post_layer_norm
- assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here"
- self.fused_dropout_add_ln = fused_dropout_add_ln
- self.attn_wqkv_init_std = attn_wqkv_init_std
- self.attn_other_init_std = attn_other_init_std
- self.ffn_uplayer_init_std = ffn_uplayer_init_std
- self.ffn_other_init_std = ffn_other_init_std
-
- head_dim = hidden_size // num_attention_heads
-
- if scale_attn_weights:
- softmax_scale = None
- else:
- softmax_scale = 1 / math.sqrt(head_dim)
- self.attention = SWA(
- embed_dim=hidden_size,
- num_heads=num_attention_heads,
- num_kv_heads=num_kv_attention_heads,
- dropout=attn_drop_rate,
- max_position_embeddings=max_position_embeddings,
- softmax_scale=softmax_scale,
- causal=True,
- layer_idx=layer_idx,
- use_dynamic_ntk_rope=use_dynamic_ntk_rope,
- rotary_emb_dim=head_dim,
- rotary_emb_scale_base=0,
- device=device,
- dtype=dtype,
- qk_interleaved=qk_interleaved,
- qkv_bias=qkv_bias,
- o_bias=o_bias,
- rope_type=rope_type,
- rope_base=rope_base,
- rope_scaling_factor=rope_scaling_factor,
- use_sliding_window=use_sliding_window,
- sliding_window=sliding_window,
- use_logn_attn=use_logn_attn,
- )
-
- self.dropout1 = nn.Dropout(drop_rate)
- self.dropout2 = nn.Dropout(drop_rate)
- self.attention_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon)
- self.ffn_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon)
-
- self.num_experts = num_experts
- if num_experts <= 1: # dense, not MoE
- self.feed_forward = new_feed_forward(
- hidden_size,
- int(hidden_size * mlp_ratio),
- out_features=hidden_size,
- bias=mlp_bias,
- device=device,
- dtype=dtype,
- mlp_layer_fusion=mlp_layer_fusion,
- multiple_of=multiple_of,
- activation_type="swiglu" if use_swiglu else "gelu",
- )
- else:
- # replace mlp by MoE module. The expert in MoE is a FeedForward module.
- # mlp_cls = get_mlp_cls(self.tp_mode)
- self.feed_forward = Qwen2MoE(
- hidden_size,
- int(hidden_size * mlp_ratio),
- out_features=hidden_size,
- num_experts=num_experts,
- top_k=top_k,
- num_shared_experts=num_shared_experts,
- moe_layer_kwargs=moe_layer_kwargs,
- device=device,
- dtype=dtype,
- mlp_layer_fusion=mlp_layer_fusion,
- multiple_of=multiple_of,
- activation_type="swiglu" if use_swiglu else "gelu",
- )
-
- self.use_swiglu = use_swiglu
- self.use_scaled_init = use_scaled_init
- self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
- self.return_residual = False
-
- if init_type == "normal":
- self.init_func = normal_
- self.scaled_init_func = scaled_init_method_normal
- else:
- self.init_func = uniform_
- self.scaled_init_func = scaled_init_method_uniform
-
- self.reset_parameters()
-
- def reset_parameters(self):
- with torch.no_grad():
- for name, param in self.attention.named_parameters():
- if param.ndim == 1:
- param.data.zero_()
- elif "wq" in name or "wk" in name or "wv" in name:
- self.init_func(std=self.attn_wqkv_init_std)(param.data)
- elif self.use_scaled_init: # wo
- self.scaled_init_func(sigma=self.attn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
- else:
- self.init_func(std=self.attn_other_init_std)(param.data)
-
- for name, param in self.feed_forward.named_parameters():
- if self.use_swiglu:
- if self.use_scaled_init and "w2" in name:
- self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
- else:
- # candidate: w1, w3, fused_w1_w3
- self.init_func(
- std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std
- )(param.data)
- else:
- if self.use_scaled_init and "fc1" not in name:
- self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
- else:
- self.init_func(std=self.ffn_uplayer_init_std if "fc1" in name else self.ffn_other_init_std)(
- param.data
- )
-
- def forward(self, hidden_states, residual=None, **kwargs):
- if self.checkpoint and self.training:
- args = convert_attn_kwargs_to_args(kwargs)
- return activation_checkpoint(self._forward, False, hidden_states, residual, *args)
- else:
- return self._forward(hidden_states, residual, **kwargs)
-
- def _forward(self, hidden_states, residual, *args, **kwargs):
- r"""Pass the input through the encoder layer.
-
- Args:
- hidden_states: the sequence to the encoder layer (required).
- residual: hidden_states = Attn/MLP(LN(residual))
- cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1
- indexes: the length of index is same as hidden states, which stand for the current position
- """
- if self.prenorm:
-
- def _dropout_and_norm_attn(_residual, _hidden_states):
- _dropped = self.dropout1(_hidden_states)
- _residual = (_dropped + _residual) if _residual is not None else _dropped
- _hidden_states = self.attention_norm(_residual.to(dtype=self.attention_norm.weight.dtype))
-
- return _residual, _hidden_states
-
- if self.dropout_selective_checkpoint:
- residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, residual, hidden_states)
- else:
- residual, hidden_states = _dropout_and_norm_attn(residual, hidden_states)
-
- if self.residual_in_fp32:
- residual = residual.to(torch.float32)
-
- mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs)
- hidden_states = self.attention(hidden_states, **mixer_kwargs)
-
- if not isinstance(self.feed_forward, nn.Identity):
- if not self.fused_dropout_add_ln:
-
- def _dropout_and_norm_ffn(_residual, _hidden_states):
- _dropped = self.dropout2(_hidden_states)
- _residual = (_dropped + _residual) if _residual is not None else _dropped
- _hidden_states = self.ffn_norm(_residual.to(self.ffn_norm.weight.dtype))
-
- return _residual, _hidden_states
-
- if self.dropout_selective_checkpoint:
- residual, hidden_states = activation_checkpoint(
- _dropout_and_norm_ffn, False, residual, hidden_states
- )
- else:
- residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states)
-
- if self.residual_in_fp32:
- residual = residual.to(torch.float32)
-
- if self.num_experts <= 1: # dense mlp output
- hidden_states = self.feed_forward(hidden_states)
- moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype)
- else: # MoE output
- hidden_states, moe_loss, _ = self.feed_forward(hidden_states)
-
- return hidden_states + residual, moe_loss
- else:
- assert residual is None
-
- mixer_out = self.attention(hidden_states, **kwargs)
- if self.return_residual: # mixer out is actually a pair here
- mixer_out, hidden_states = mixer_out
- hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to(
- dtype=self.attention_norm.weight.dtype
- )
- if not isinstance(self.feed_forward, nn.Identity):
- if self.num_experts <= 1: # dense mlp output
- mlp_out = self.feed_forward(hidden_states)
- moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype)
- else: # MoE output
- mlp_out, moe_loss, _ = self.feed_forward(hidden_states)
-
- if self.return_residual: # mlp out is actually a pair here
- mlp_out, hidden_states = mlp_out
- hidden_states = self.ffn_norm((self.dropout2(mlp_out)) + hidden_states).to(
- dtype=self.ffn_norm.weight.dtype
- )
- return hidden_states, moe_loss
-
-
-class Qwen2Moe(BaseModel):
- """
- 1D Packed Flash Qwen.
-
- Args:
- num_layers (int): The number of layer. 12 by default.
- hidden_size (int): The size of hidden state. 768 by default.
- num_attention_heads (int): The number of attention head. 12 by default.
- vocab_size (int): The size of vocabulary. 50304 by default.
- mlp_ratio (int): The ratio of MLP layers. 4 by default.
- attn_drop_rate (float): The dropout rate of attention module. 0.0 by default.
- drop_rate (float): The dropout rate of input hidden state. 0.0 by default.
- dtype (torch.dtype): The type of data. torch.float by default.
- checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
- layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default.
- first (bool): Whether input embedding layer or not. False by default.
- last (bool): Whether output embedding layer or not. False by default.
- embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
- parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
- start_layer_idx (int): The index of start layer in the pipeline. 0 by default.
- device (Optional[Union[str, torch.device]]): The device will be used. None by default.
- residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
- norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
- embedding_init_std (float): std used to init embedding weight. 0.02 by default,
- attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default,
- attn_other_init_std (float): std used to init attn_other weight. 0.02 by default,
- ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu
- otherwise init fc1 weight in ffn. 0.02 by default,
- ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default,
- out_head_init_std (float): std used to init output lmhead weight. 0.02 by default,
- init_type (str): Initialization type. Use uniform or normal. "normal" by default,
- extra_pred_tokens (int): The number of extra output head for multi-token-prediction. 0 by default.
- rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
- multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2.
- """
-
- def __init__(
- self,
- num_layers: int = 12,
- hidden_size: int = 768,
- num_attention_heads: int = 12,
- num_kv_attention_heads: int = 12,
- vocab_size: int = 50304,
- mlp_ratio: int = 4,
- attn_drop_rate: float = 0.0,
- drop_rate: float = 0.0,
- max_position_embeddings: int = 2048,
- dtype: torch.dtype = torch.float,
- checkpoint: float = 1.0,
- layer_norm_epsilon: float = 1e-5,
- first: bool = False,
- last: bool = False,
- embed_grad_scale: float = 0.1,
- parallel_output: bool = True,
- start_layer_idx: int = 0,
- use_dynamic_ntk_rope: bool = False,
- device: Optional[torch.device] = None,
- apply_post_layer_norm=False,
- qkv_bias=True,
- o_bias=False,
- mlp_bias=False,
- residual_in_fp32: bool = False,
- norm_type: str = "rmsnorm",
- qk_interleaved: bool = False,
- is_reward: bool = False,
- dropout_selective_checkpoint: bool = True,
- use_scaled_init: bool = True,
- use_swiglu: bool = True,
- embedding_init_std: float = 0.02,
- attn_wqkv_init_std: float = 0.02,
- attn_other_init_std: float = 0.02,
- ffn_uplayer_init_std: float = 0.02,
- ffn_other_init_std: float = 0.02,
- out_head_init_std: float = 0.02,
- init_type: str = "normal",
- extra_pred_tokens: int = 0,
- rope_type: str = "normal",
- rope_base: int = 10000,
- rope_scaling_factor: float = 1.0,
- use_sliding_window: bool = False,
- max_window_layers: int = 0,
- sliding_window: int = None,
- mlp_layer_fusion: bool = False,
- multiple_of: int = 256,
- scale_attn_weights: bool = False, # Qwen1
- use_logn_attn: bool = False, # Qwen1
- moe_type: str = None, # pylint: disable=W0613
- num_experts: bool = 1,
- top_k: int = 1,
- num_shared_experts: int = 0,
- moe_layer_kwargs: dict = None,
- ):
- super().__init__()
-
- self.embed_grad_scale = embed_grad_scale
-
- checkpoint_layer_num = int(num_layers * checkpoint)
-
- if first:
- self.embed_tokens = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
- for _, param in self.embed_tokens.named_parameters():
- if init_type == "normal":
- normal_(std=embedding_init_std)(param)
- else:
- uniform_(std=embedding_init_std)(param)
-
- self.layers = nn.ModuleList(
- [
- Qwen2MoeDecoder(
- hidden_size=hidden_size,
- num_attention_heads=num_attention_heads,
- num_kv_attention_heads=num_kv_attention_heads,
- mlp_ratio=mlp_ratio,
- attn_drop_rate=attn_drop_rate,
- drop_rate=drop_rate,
- dtype=dtype,
- layer_norm_epsilon=layer_norm_epsilon,
- checkpoint=lid < checkpoint_layer_num,
- layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation
- use_dynamic_ntk_rope=use_dynamic_ntk_rope,
- residual_in_fp32=residual_in_fp32,
- device=device,
- apply_post_layer_norm=apply_post_layer_norm,
- fused_dropout_add_ln=False,
- qkv_bias=qkv_bias,
- o_bias=o_bias,
- mlp_bias=mlp_bias,
- norm_type=norm_type,
- dropout_selective_checkpoint=dropout_selective_checkpoint,
- use_scaled_init=use_scaled_init,
- use_swiglu=use_swiglu,
- qk_interleaved=qk_interleaved,
- attn_wqkv_init_std=attn_wqkv_init_std,
- attn_other_init_std=attn_other_init_std,
- ffn_uplayer_init_std=ffn_uplayer_init_std,
- ffn_other_init_std=ffn_other_init_std,
- init_type=init_type,
- rope_type=rope_type,
- rope_base=rope_base,
- rope_scaling_factor=rope_scaling_factor,
- use_sliding_window=use_sliding_window and lid >= max_window_layers,
- sliding_window=sliding_window,
- mlp_layer_fusion=mlp_layer_fusion,
- multiple_of=multiple_of,
- max_position_embeddings=max_position_embeddings,
- scale_attn_weights=scale_attn_weights,
- use_logn_attn=use_logn_attn,
- num_experts=num_experts,
- top_k=top_k,
- num_shared_experts=num_shared_experts,
- moe_layer_kwargs=moe_layer_kwargs,
- )
- for lid in range(num_layers)
- ]
- )
-
- if last:
- if not apply_post_layer_norm:
- self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon)
-
- self.output = new_linear(
- name="output",
- in_features=hidden_size,
- out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
- bias=False,
- device=device,
- dtype=dtype,
- is_reward=is_reward,
- weight_scale=embed_grad_scale,
- )
-
- for _, param in self.output.named_parameters():
- if init_type == "normal":
- normal_(std=out_head_init_std)(param)
- else:
- uniform_(std=out_head_init_std)(param)
-
- if extra_pred_tokens > 0:
- self.extra_pred_tokens = extra_pred_tokens
- assert not is_reward, "extra_pred_tokens > 0 means using multi token prediction, not implement for RLHF"
- self.extra_outputs = nn.ModuleList(
- [
- new_linear(
- name="output",
- in_features=hidden_size,
- out_features=vocab_size,
- bias=False,
- device=device,
- dtype=dtype,
- is_reward=is_reward,
- weight_scale=embed_grad_scale,
- )
- for _ in range(self.extra_pred_tokens)
- ]
- )
- for _, param in self.extra_outputs.named_parameters():
- if init_type == "normal":
- normal_(std=out_head_init_std)(param)
- else:
- uniform_(std=out_head_init_std)(param)
-
- self.parallel_output = parallel_output
-
- def forward(self, hidden_states=None, input_ids=None, **kwargs):
- # attention_mask: compute attention on the places where the value is 1
- # old condition may fail when use shared embedding
- if gpc.is_pipeline_first_stage() and input_ids is not None:
- hidden_states = self.embed_tokens(input_ids)
- if self.embed_grad_scale != 1:
- hidden_states = (
- self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach()
- )
-
- moe_losses = []
- for _, block in enumerate(self.layers):
- hidden_states, moe_loss = block(
- hidden_states,
- residual=None,
- **kwargs,
- )
- moe_losses.append(moe_loss)
-
- if hasattr(self, "norm"):
- hidden_states = self.norm(hidden_states.to(self.norm.weight.dtype))
- if hasattr(self, "extra_pred_tokens") and self.extra_pred_tokens > 0:
- extra_hidden_states_list = [self.extra_outputs[i](hidden_states) for i in range(self.extra_pred_tokens)]
- else:
- extra_hidden_states_list = None
- if hasattr(self, "output"):
- hidden_states = self.output(hidden_states)
-
- if extra_hidden_states_list is not None:
- return (hidden_states, extra_hidden_states_list), moe_losses
-
- return hidden_states, moe_losses
-
- @staticmethod
- def load_hf_weights(folder: str, model: nn.Module) -> None:
- raise NotImplementedError
-
- @staticmethod
- def convert_internevo2hf_weights(src: str, tgt: str) -> None:
- raise NotImplementedError
diff --git a/internlm/monitor/__init__.py b/internlm/monitor/__init__.py
index 2bcfa2ccf..6f5e511f3 100644
--- a/internlm/monitor/__init__.py
+++ b/internlm/monitor/__init__.py
@@ -1,9 +1,15 @@
-from .monitor import initialize_monitor_manager, internevo_monitor, send_alert_message
-from .utils import set_env_var
+from .alert import send_feishu_msg_with_webhook
+from .monitor import (
+ initialize_monitor_manager,
+ internevo_monitor,
+ monitor_manager,
+ send_alert_message,
+)
__all__ = [
"send_alert_message",
"initialize_monitor_manager",
- "set_env_var",
"internevo_monitor",
+ "monitor_manager",
+ "send_feishu_msg_with_webhook",
]
diff --git a/internlm/monitor/monitor.py b/internlm/monitor/monitor.py
index fc33de62a..252a0380b 100644
--- a/internlm/monitor/monitor.py
+++ b/internlm/monitor/monitor.py
@@ -12,10 +12,10 @@
from internlm.accelerator.abstract_accelerator import get_accelerator
from internlm.core.context import global_context as gpc
-from internlm.monitor.alert import send_feishu_msg_with_webhook
-from internlm.utils.common import SingletonMeta
+from internlm.monitor import send_feishu_msg_with_webhook
+from internlm.utils.common import SingletonMeta, set_env_var
-from .utils import get_job_key, set_env_var
+from .utils import get_job_key
logger = logging.getLogger(__file__)
internlm_accelerator = get_accelerator()
diff --git a/internlm/monitor/utils.py b/internlm/monitor/utils.py
index 34360b521..0bdd3db2e 100644
--- a/internlm/monitor/utils.py
+++ b/internlm/monitor/utils.py
@@ -6,10 +6,6 @@ def now_time():
return datetime.now().strftime("%b%d_%H-%M-%S")
-def set_env_var(key, value):
- os.environ[str(key)] = str(value)
-
-
def get_job_id():
job_id = "none"
if os.getenv("SLURM_JOB_ID") is not None:
diff --git a/internlm/solver/activation_checkpoint.py b/internlm/solver/activation_checkpoint.py
index 2b5c9e4ed..93d7a1ba1 100644
--- a/internlm/solver/activation_checkpoint.py
+++ b/internlm/solver/activation_checkpoint.py
@@ -10,16 +10,10 @@
from torch.utils.checkpoint import check_backward_validity, detach_variable
from internlm.accelerator import get_accelerator
-from internlm.core.context.parallel_context import global_context as gpc
-from internlm.core.context.random import (
- get_current_mode,
- get_states,
- set_mode,
- set_seed_states,
- sync_states,
-)
-
-from ..utils.common import get_current_device
+from internlm.core.context import get_current_mode, get_states
+from internlm.core.context import global_context as gpc
+from internlm.core.context import set_mode, set_seed_states, sync_states
+from internlm.utils.common import get_current_device
internlm_accelerator = get_accelerator()
diff --git a/internlm/solver/optimizer/__init__.py b/internlm/solver/optimizer/__init__.py
index 55070fc33..7f848c9bd 100644
--- a/internlm/solver/optimizer/__init__.py
+++ b/internlm/solver/optimizer/__init__.py
@@ -1,8 +1,9 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
+from .base_optimizer import BaseOptimizer
from .fsdp_optimizer import FSDPadaptOptimizer
from .hybrid_zero_optim import HybridZeroOptimizer
from .hybrid_zero_optim_v2 import HybridZeroOptimizer_v2
-__all__ = ["FSDPadaptOptimizer", "HybridZeroOptimizer", "HybridZeroOptimizer_v2"]
+__all__ = ["FSDPadaptOptimizer", "HybridZeroOptimizer", "BaseOptimizer", "HybridZeroOptimizer_v2"]
diff --git a/internlm/solver/optimizer/fsdp_optimizer.py b/internlm/solver/optimizer/fsdp_optimizer.py
index 94cc411c6..d4f3cc811 100644
--- a/internlm/solver/optimizer/fsdp_optimizer.py
+++ b/internlm/solver/optimizer/fsdp_optimizer.py
@@ -8,15 +8,16 @@
from torch.optim import Optimizer
from internlm.accelerator import get_accelerator
-from internlm.core.context import Config, ParallelMode
+from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
-from internlm.solver.optimizer.base_optimizer import BaseOptimizer
+from internlm.solver.optimizer import BaseOptimizer
from internlm.solver.optimizer.utils import (
DynamicGradScaler,
get_norm,
release_param_grad,
)
-from internlm.utils.common import get_tensor_norm, move_norm_to_cuda
+from internlm.utils.common import get_current_device, get_tensor_norm, move_norm_to_cuda
+from internlm.utils.config import Config
from internlm.utils.logger import get_logger
try:
@@ -36,6 +37,7 @@
def compute_norm(
gradients: Iterable[torch.Tensor],
parameters: Iterable[torch.Tensor],
+ zero_mode,
) -> float:
"""Get L2 norm
Arguments:
@@ -60,7 +62,15 @@ def compute_norm(
if DTENSOR_SUPPORTED and isinstance(total_norm, DTensor):
total_norm = total_norm.full_tensor()
- dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.GLOBAL))
+ if gpc.is_using_parallel_mode(zero_mode):
+ dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(zero_mode))
+
+ # Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce
+ # model and zero have been reduced!!!
+ if zero_mode == ParallelMode.EXPERT_DATA:
+ scaled_norm = torch.tensor(total_norm * 1.0 / float(gpc.get_world_size(ParallelMode.EXPERT)), device=get_current_device(), dtype=torch.float)
+ dist.all_reduce(scaled_norm, group=gpc.get_group(ParallelMode.EXPERT))
+ total_norm = scaled_norm.item()
if torch.is_tensor(total_norm):
total_norm = total_norm.item()
@@ -111,10 +121,14 @@ def __init__(
# fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group
self._fp16_param_groups = dict()
self._fp32_param_tensor_groups = dict()
+ self._broadcast_parallel_mode = []
# init fp16 and fp32 params
for group_idx, param_group in enumerate(self.optim.param_groups):
group_params = param_group["params"]
+
+ zero_mode = param_group["optimizer_mode"]
+ self._broadcast_parallel_mode.append(zero_mode)
# fp16 FlatParam storage
self._fp16_param_groups[group_idx] = group_params
@@ -141,7 +155,7 @@ def _compute_norm_with_fsdp_flatten(self, group_id):
norm_group = 0
if len(params) <= 0 or len(gradients) <= 0:
return norm_group
- norm_group = compute_norm(gradients=gradients, parameters=params)
+ norm_group = compute_norm(gradients=gradients, parameters=params, zero_mode=self._broadcast_parallel_mode[group_id])
return norm_group
diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py
index 6620dda2e..d45c63e08 100644
--- a/internlm/solver/optimizer/hybrid_zero_optim.py
+++ b/internlm/solver/optimizer/hybrid_zero_optim.py
@@ -8,22 +8,22 @@
import torch
import torch.distributed as dist
+from torch._utils import _flatten_dense_tensors
from torch.optim import Optimizer
from internlm.accelerator import AcceleratorType, get_accelerator
-from internlm.core.context import Config, ParallelMode
-from internlm.core.context import global_context as gpc
-from internlm.core.context.parallel_context import (
+from internlm.core.context import (
IS_REPLICA_EXPERT_DATA_PARALLEL,
IS_REPLICA_ZERO_PARALLEL,
IS_TENSOR_EXPERT_DATA_PARALLEL,
IS_TENSOR_ZERO_PARALLEL,
IS_WEIGHT_EXPERT_DATA_PARALLEL,
IS_WEIGHT_ZERO_PARALLEL,
+ ParallelMode,
)
-from internlm.core.parallel.comm.isp import ISPCommunicatorWrapper
-from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler
-from internlm.model.modules.utils import is_moe_param
+from internlm.core.context import global_context as gpc
+from internlm.core.parallel.comm import ISPCommunicatorWrapper, ParamAsyncBcastHandler
+from internlm.model.model_ops.modules.utils import is_moe_param
from internlm.monitor import send_alert_message
from internlm.solver.optimizer.store import (
BucketStore,
@@ -33,7 +33,6 @@
)
from internlm.solver.optimizer.utils import (
DynamicGradScaler,
- flatten,
get_grad_accumulate_object,
has_inf_or_nan,
reduce_tensor,
@@ -42,6 +41,7 @@
sync_param,
)
from internlm.utils.common import get_current_device
+from internlm.utils.config import Config
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import is_using_isp, should_reduce_replica_param
@@ -210,7 +210,7 @@ def __init__(
if rank not in self.param_group_no_params_ranks[group_id]:
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
with torch.no_grad():
- flat_tensor = flatten(tensor_list)
+ flat_tensor = _flatten_dense_tensors(tensor_list)
flat_tensor = flat_tensor.data.to(get_current_device())
self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, flat_tensor)
sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list)
@@ -288,7 +288,7 @@ def _partition_param_list(self, group_id, param_group):
if group_id not in self.meta_for_zero[rank_to_go]:
self.meta_for_zero[rank_to_go][group_id] = {}
- from internlm.train.pipeline import map_fqn_local_to_global
+ from internlm.initialize.initialize_model import map_fqn_local_to_global
global_fqn = map_fqn_local_to_global[param.fqn] if param.fqn in map_fqn_local_to_global else param.fqn
self.meta_for_zero[rank_to_go][group_id][global_fqn] = {
@@ -839,7 +839,7 @@ def _step(self, closure=None, norms=None):
# create flat gradient for the flat fp32 params
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
with torch.no_grad():
- flat_fp16_avg_grads = flatten(gradients)
+ flat_fp16_avg_grads = _flatten_dense_tensors(gradients)
self._grad_store.reset_average_gradients_by_group(group_id)
gradients = None # release cuda memory
diff --git a/internlm/solver/optimizer/hybrid_zero_optim_v2.py b/internlm/solver/optimizer/hybrid_zero_optim_v2.py
index 11158aba5..c167b53eb 100644
--- a/internlm/solver/optimizer/hybrid_zero_optim_v2.py
+++ b/internlm/solver/optimizer/hybrid_zero_optim_v2.py
@@ -5,17 +5,18 @@
import torch
import torch.distributed as dist
+from torch._utils import _flatten_dense_tensors
from torch.optim import Optimizer
-from internlm.core.context import Config, ParallelMode
-from internlm.core.context import global_context as gpc
-from internlm.core.context.parallel_context import (
+from internlm.core.context import (
IS_REPLICA_ZERO_PARALLEL,
IS_TENSOR_EXPERT_DATA_PARALLEL,
IS_TENSOR_ZERO_PARALLEL,
IS_WEIGHT_ZERO_PARALLEL,
+ ParallelMode,
)
-from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler
+from internlm.core.context import global_context as gpc
+from internlm.core.parallel.comm import ParamAsyncBcastHandler
from internlm.monitor import send_alert_message
from internlm.solver.optimizer.store import (
BucketStore_v2,
@@ -24,12 +25,12 @@
)
from internlm.solver.optimizer.utils import (
DynamicGradScaler,
- flatten,
reduce_tensor,
release_param_grad,
sync_param,
)
from internlm.utils.common import get_current_device
+from internlm.utils.config import Config
from internlm.utils.logger import get_logger
from internlm.utils.parallel import is_using_isp, is_using_sequence_parallel
@@ -669,7 +670,9 @@ def step(self, closure=None):
# Update working parameters
for working_param, all_splited_param in zip(working_params_list[gather_idx], all_splited_param_list):
- working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].view_as(working_param))
+ working_param.data.copy_(
+ _flatten_dense_tensors(all_splited_param)[: working_param.numel()].view_as(working_param)
+ )
for group_id in range(self.num_param_groups):
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py
index 5aeae887f..b3532dfb0 100644
--- a/internlm/solver/optimizer/utils.py
+++ b/internlm/solver/optimizer/utils.py
@@ -8,7 +8,7 @@
import torch
import torch.distributed as dist
from torch import Tensor
-from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
+from torch._utils import _unflatten_dense_tensors
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
@@ -39,14 +39,6 @@
inf = math.inf
-def flatten(input_):
- return _flatten_dense_tensors(input_)
-
-
-def unflatten(flat, tensors):
- return _unflatten_dense_tensors(flat, tensors)
-
-
def get_grad_accumulate_object(tensor):
"""
Return the AccumulateGrad of the input tensor
@@ -176,7 +168,7 @@ def sync_param(flat_tensor, tensor_list):
:type flat_tensor: torch.Tensor
:type tensor_list: List[torch.Tensor]
"""
- updated_params = unflatten(flat_tensor, tensor_list)
+ updated_params = _unflatten_dense_tensors(flat_tensor, tensor_list)
# update the tensor data
for p, q in zip(tensor_list, updated_params):
diff --git a/internlm/train/__init__.py b/internlm/train/__init__.py
deleted file mode 100644
index f3c680da4..000000000
--- a/internlm/train/__init__.py
+++ /dev/null
@@ -1,23 +0,0 @@
-from .pipeline import (
- get_scheduler_hooks,
- initialize_llm_profile,
- initialize_model_and_parallel_communicator,
- initialize_optimizer,
- initialize_parallel_communicator,
- load_new_batch,
- record_current_batch_training_metrics,
- set_fp32_attr_for_model,
- set_parallel_attr_for_param_groups,
-)
-
-__all__ = [
- "initialize_llm_profile",
- "initialize_model_and_parallel_communicator",
- "initialize_parallel_communicator",
- "initialize_optimizer",
- "load_new_batch",
- "record_current_batch_training_metrics",
- "get_scheduler_hooks",
- "set_parallel_attr_for_param_groups",
- "set_fp32_attr_for_model",
-]
diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py
deleted file mode 100644
index 0586cafc7..000000000
--- a/internlm/train/pipeline.py
+++ /dev/null
@@ -1,1342 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-import collections
-import functools
-import itertools
-import math
-import os
-import time
-from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union
-
-import torch
-import torch.distributed as dist
-from torch import nn
-from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
-from torch.distributed.fsdp.fully_sharded_data_parallel import (
- BackwardPrefetch,
- ShardingStrategy,
-)
-from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
-from torch.utils.data import DataLoader
-
-from internlm.accelerator import AcceleratorType, get_accelerator
-from internlm.checkpoint.utils import init_fsdp_v1
-from internlm.core.context import (
- IS_REPLICA_EXPERT_DATA_PARALLEL,
- IS_REPLICA_ZERO_PARALLEL,
- IS_TENSOR_EXPERT_DATA_PARALLEL,
- IS_TENSOR_ZERO_PARALLEL,
- IS_WEIGHT_EXPERT_DATA_PARALLEL,
- IS_WEIGHT_ZERO_PARALLEL,
- ParallelMode,
-)
-from internlm.core.context import global_context as gpc
-from internlm.core.context.random import set_mode
-from internlm.core.naive_amp import (
- NaiveAMPModel,
- set_fp32_attr_to_module,
- unwrap_naive_amp,
-)
-from internlm.core.parallel.comm.isp import (
- EmbeddingWeightParallelCommunicator,
- HeadWeightParallelCommunicator,
- ISPCommModelConfig,
- ISPCommunicator,
- ISPCommunicatorSchedulerHook,
- ISPCommunicatorWrapper,
-)
-from internlm.core.parallel.comm.tensor import (
- EmbeddingSequenceParallelCommunicator,
- EmbeddingTensorParallelCommunicator,
- HeadSequenceParallelCommunicator,
- HeadTensorParallelCommunicator,
- LinearRole,
- MoESequenceParallelCommunicator,
- SequenceParallelCommunicator,
- TensorParallelCommunicator,
-)
-from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler
-from internlm.core.trainer import TrainState
-from internlm.data.utils import unpack_type_ids
-from internlm.model.builder import create_model
-from internlm.model.metrics import SchedulerMetricHook
-from internlm.model.modules.embedding import Embedding1D
-from internlm.model.modules.linear import (
- ColumnParallelLinear,
- GroupedColumnLinear,
- GroupedRowLinear,
- GroupedWPLinear,
- ParallelLinearWithCommExt,
- RewardModelLinear,
- RowParallelLinear,
- ScaleColumnParallelLinear,
- new_linear,
-)
-from internlm.model.modules.norm import new_layer_norm
-from internlm.model.moe import Experts, MoE
-from internlm.model.moe.moe import Qwen2MoE
-from internlm.model.ops.norm import RMSNorm
-from internlm.model.registry import register_model_initializer
-from internlm.monitor import set_env_var
-from internlm.monitor.monitor import monitor_manager as mm
-from internlm.solver.optimizer import (
- FSDPadaptOptimizer,
- HybridZeroOptimizer,
- HybridZeroOptimizer_v2,
-)
-from internlm.solver.optimizer.compatible_adamw import new_compatible_adamw
-from internlm.solver.schedulers.beta2_scheduler import Beta2Scheduler
-from internlm.solver.schedulers.lr_scheduler import FineTuneCosineAnnealingWarmupLR
-from internlm.train.utils import create_param_groups, map_param_block, timeout_input
-from internlm.utils.common import DummyProfile, SchedulerHook, get_current_device
-from internlm.utils.lazy import LazyObject
-from internlm.utils.logger import get_logger
-from internlm.utils.megatron_timers import megatron_timer as timer
-from internlm.utils.parallel import (
- is_replica_expert_data_parallel_parameter,
- is_replica_zero_parallel_parameter,
- is_tensor_expert_data_parallel_parameter,
- is_tensor_zero_parallel_parameter,
- is_using_fsdp,
- is_using_hf,
- is_using_isp,
- is_weight_expert_data_parallel_parameter,
- is_weight_zero_parallel_parameter,
- sync_model_param,
- sync_model_replica_param_group,
-)
-from internlm.utils.timeout import llm_timeout
-from internlm.utils.utils import TensorParallelMode
-
-try:
- import torch_npu
-except (ImportError, ModuleNotFoundError):
- pass
-
-try:
- from torch.distributed._composable.fsdp import fully_shard
-
- FSDP2_SUPPORTED = True
-except (ImportError, ModuleNotFoundError):
- FSDP2_SUPPORTED = False
-
-
-try:
- from torch.distributed.checkpoint.state_dict import (
- StateDictOptions,
- set_model_state_dict,
- )
-
- DCP_SUPPORTED = True
-except (ImportError, ModuleNotFoundError):
- DCP_SUPPORTED = False
-
-
-IS_INJECTED = "is_injected"
-
-LINEAR2NEWLINEAR_NAME_MAPPING = dict(
- q_proj="wq",
- k_proj="wk",
- v_proj="wv",
- o_proj="wo",
- gate_proj="w1",
- down_proj="w2",
- up_proj="w3",
- lm_head="head",
- W_pack="wqkv",
-)
-
-logger = get_logger(__file__)
-internlm_accelerator = get_accelerator()
-
-# For universal checkpoint
-# record offset and complete_size of param in each layer
-map_layer_attr = {}
-map_fqn_local_to_global = {}
-map_fqn_global_to_local = {}
-
-
-def set_param_unique_tracking_name(model):
- for chunk_id, chunk in enumerate(unwrap_naive_amp(model)):
- # Important: only works for llama-class models
- childrens = chunk.named_children()
- for children_name, children in childrens:
- if isinstance(children, nn.ModuleList):
- for idx, block in enumerate(children):
- for name, child in block.named_modules():
- if name == "":
- continue
-
- full_name = f"{chunk_id}.{idx}.{name}"
- name_parts = f"{full_name}.weight".split(".", 2)
- # global_id for pipeline parallel case
- global_id = model.first_layer + idx
- local_fqn = f"{children_name}." + ".".join(name_parts[1:])
- global_fqn = f"{children_name}.{global_id}." + ".".join(name_parts[2:])
-
- if isinstance(child, (ParallelLinearWithCommExt)):
- setattr(
- child.weight,
- "tracking_name",
- f"{full_name}.weight",
- )
- if child.bias is not None:
- setattr(
- child.bias,
- "tracking_name",
- f"{full_name}.bias",
- )
-
- setattr(
- child.weight,
- "fqn",
- f"{local_fqn}",
- )
- if child.bias is not None:
- setattr(
- child.bias,
- "fqn",
- f"{local_fqn}",
- )
-
- assert hasattr(child, "offset"), f"{child}"
- map_fqn_local_to_global[local_fqn] = global_fqn
- map_fqn_global_to_local[global_fqn] = local_fqn
-
- assert global_fqn not in map_layer_attr, f"{map_layer_attr} exists"
- map_layer_attr[global_fqn] = {
- "offset": getattr(child, "offset", [0] * len(child.weight.size())),
- "complete_size": getattr(child, "complete_size", list(child.weight.size())),
- }
-
- elif isinstance(child, (RMSNorm)):
- map_fqn_local_to_global[local_fqn] = global_fqn
- map_fqn_global_to_local[global_fqn] = local_fqn
- setattr(
- child.weight,
- "fqn",
- f"{local_fqn}",
- )
- map_layer_attr[global_fqn] = {
- "offset": getattr(child, "offset", [0] * len(child.weight.size())),
- "complete_size": getattr(child, "complete_size", list(child.weight.size())),
- }
-
- else:
- full_name = f"{chunk_id}.{children_name}"
- local_fqn = f"{children_name}.weight"
- assert getattr(children, "bias", None) is None
- if isinstance(children, Embedding1D):
- setattr(
- children.weight,
- "tracking_name",
- f"{chunk_id}_embeddings.weight",
- )
- assert local_fqn not in map_layer_attr, f"{map_layer_attr} exists"
- else:
- setattr(
- children.weight,
- "tracking_name",
- f"{full_name}.weight",
- )
- assert local_fqn not in map_layer_attr, f"{map_layer_attr} exists"
-
- setattr(
- children.weight,
- "fqn",
- f"{local_fqn}",
- )
- if getattr(children, "bias", None) is not None:
- if children.bias is not None:
- setattr(
- children.bias,
- "fqn",
- f"{local_fqn}",
- )
-
- map_layer_attr[local_fqn] = {
- "offset": getattr(children, "offset", [0] * len(children.weight.size())),
- "complete_size": getattr(children, "complete_size", list(children.weight.size())),
- }
-
-
-def generate_meta_data(optimizer):
- if not gpc.config.ckpt.need_metadata:
- return
-
- if gpc.get_world_size(ParallelMode.PIPELINE) > 1:
- assert optimizer.meta_for_zero is not None
- dst = gpc.get_ranks_in_group(ParallelMode.PIPELINE)[0]
- if gpc.get_global_rank() == dst:
- output = [None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
- else:
- output = None
-
- dist.gather_object(optimizer.meta_for_zero, output, dst=dst, group=gpc.get_group(ParallelMode.PIPELINE))
- pp_gather_output = output
-
- else:
- pp_gather_output = [optimizer.meta_for_zero]
-
- tp_parallel = ParallelMode.WEIGHT if is_using_isp() else ParallelMode.TENSOR
- if gpc.get_world_size(tp_parallel) > 1:
- dst = gpc.get_ranks_in_group(tp_parallel)[0]
- if gpc.get_global_rank() == dst:
- output = [None for _ in range(gpc.get_world_size(tp_parallel))]
- else:
- output = None
-
- dist.gather_object(pp_gather_output, output, dst=dst, group=gpc.get_group(tp_parallel))
- final_output = output
- else:
- final_output = [pp_gather_output]
-
- if gpc.get_global_rank() == 0:
- assert len(final_output) == gpc.get_world_size(tp_parallel)
- assert len(final_output[0]) == gpc.get_world_size(ParallelMode.PIPELINE)
- assert len(final_output[0][0]) == gpc.get_world_size(ParallelMode.ZERO1)
- tp_mode = "wp_size" if is_using_isp() else "tp_size"
- final_meta = {
- "parallel_setting": {
- tp_mode: gpc.get_world_size(tp_parallel),
- "pp_size": gpc.get_world_size(ParallelMode.PIPELINE),
- "zero1_size": gpc.get_world_size(ParallelMode.ZERO1),
- },
- "metaData": final_output,
- }
-
- if gpc.config.ckpt.generate_meta_data.enable:
- save_path = os.path.join(gpc.config.ckpt.generate_meta_data.path, "metadata.pt")
- torch.save(final_meta, save_path)
- logger.info(f"Successfully generate metadata.pt in {gpc.config.ckpt.generate_meta_data.path}")
-
- return final_meta
- return None
-
-
-def set_fp32_attr_for_model(model: Union[nn.Module, nn.ModuleList]):
- if not isinstance(model, nn.ModuleList):
- model = [model]
-
- for _chunk in model:
- for _, module in _chunk.named_modules():
- if isinstance(module, (RMSNorm, nn.LayerNorm)) and gpc.config.get("use_fp32_norm", False):
- set_fp32_attr_to_module(module)
-
-
-def set_parallel_attr_for_param_groups(model: Union[nn.Module, nn.ModuleList]):
- def _check_module_pure_dp(name, module): # pylint: disable=W0613
- for param in module.parameters():
- setattr(param, IS_REPLICA_ZERO_PARALLEL, True)
-
- def _check_module(name, module):
- # layer_norm
- if isinstance(module, (RMSNorm, nn.LayerNorm)):
- for param in module.parameters():
- setattr(param, IS_REPLICA_ZERO_PARALLEL, True)
-
- if isinstance(module, (MoE, Qwen2MoE)):
- for param in module.moe_layer.gate.parameters():
- setattr(param, IS_REPLICA_ZERO_PARALLEL, True)
- if hasattr(module, "coefficient"):
- for param in module.coefficient.parameters():
- setattr(param, IS_REPLICA_ZERO_PARALLEL, True)
-
- # embedding and head
- if isinstance(module, (Embedding1D, ScaleColumnParallelLinear)):
- for param in module.parameters():
- if gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp():
- setattr(param, IS_WEIGHT_ZERO_PARALLEL, True)
- elif gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp():
- setattr(param, IS_TENSOR_ZERO_PARALLEL, True)
-
- # for moe linear module
- if isinstance(module, nn.Linear) and not isinstance(module, ParallelLinearWithCommExt):
- for param in module.parameters():
- setattr(param, IS_REPLICA_ZERO_PARALLEL, True)
-
- if isinstance(module, Experts):
- for param in module.parameters():
- if (
- gpc.is_initialized(ParallelMode.TENSOR)
- and not is_using_isp()
- and getattr(gpc.config.parallel.expert, "no_tp", False)
- ):
- setattr(param, IS_REPLICA_EXPERT_DATA_PARALLEL, True)
- elif gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp():
- setattr(param, IS_TENSOR_EXPERT_DATA_PARALLEL, True)
- elif gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp():
- setattr(param, IS_WEIGHT_EXPERT_DATA_PARALLEL, True)
- # for non-moe linear module
- elif isinstance(module, ParallelLinearWithCommExt):
- for param in module.parameters():
- if gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp():
- setattr(param, IS_TENSOR_ZERO_PARALLEL, True)
- elif gpc.is_initialized(ParallelMode.WEIGHT) and is_using_isp():
- setattr(param, IS_WEIGHT_ZERO_PARALLEL, True)
-
- # for vit and vit project
- if "vision_tower" in name.lower() or "vision_proj" in name.lower():
- for param in module.parameters():
- setattr(param, IS_REPLICA_ZERO_PARALLEL, True)
-
- for _chunk in unwrap_naive_amp(model):
- if not is_using_fsdp():
- # special case for pure dp mode
- if (
- isinstance(gpc.config.parallel["tensor"], dict)
- and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name)
- == TensorParallelMode.mtp.name
- and gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL)
- ):
- _check_module_func = _check_module_pure_dp
- else:
- _check_module_func = _check_module
- # set param parallel attribute
- for name, module in _chunk.named_modules():
- _check_module_func(name, module)
-
- for name, param in _chunk.named_parameters():
- assert (
- is_replica_zero_parallel_parameter(param)
- or is_tensor_zero_parallel_parameter(param)
- or is_weight_zero_parallel_parameter(param)
- or is_tensor_expert_data_parallel_parameter(param)
- or is_weight_expert_data_parallel_parameter(param)
- or is_replica_expert_data_parallel_parameter(param)
- ), f"parameter with name: {name} has no parallel attribution."
-
-
-@llm_timeout(func_name="initialize_model_and_parallel_communicator")
-def initialize_model_and_parallel_communicator(
- pre_process_func: Optional[Callable] = None, post_process_func: Optional[Callable] = None
-):
- """
- Initialize model with Automatic Mixed Precision.
- Returns:
- torch.nn.Module:
- The neural network model to be trained or evaluated.
- An isp communicator for managing comp/comm overlap.
- """
- if pre_process_func:
- pre_process_output = pre_process_func()
-
- register_model_initializer()
-
- model = create_model()
-
- if post_process_func:
- post_process_func(pre_process_output)
-
- return inject_model(model)
-
-
-def inject_model(model):
- """
- Inject model with Automatic Mixed Precision.
-
- Args:
- torch.nn.Module:
- The bare neural network model to be trained or evaluated.
-
- Returns:
- torch.nn.Module:
- The injected neural network model to be trained or evaluated.
- An isp communicator for managing comp/comm overlap.
- """
- if hasattr(model, IS_INJECTED) and getattr(model, IS_INJECTED):
- return model
-
- # For non-HF cases, set tracking name for parameters
- if not is_using_hf():
- set_param_unique_tracking_name(model)
-
- # For non-fsdp cases, set model inject helper
- if not is_using_fsdp():
- inject_model_helper(model, inject_info=gpc.config.model.get("inject_info", None))
-
- # should be set before NaiveAMPModel
- set_fp32_attr_for_model(model)
-
- if isinstance(model, nn.ModuleList):
- model = nn.ModuleList(
- [
- NaiveAMPModel(
- model=_m,
- output_to_fp32=False, # manually controlled by interleaved pipleline scheduler
- dtype=gpc.config.model.get("dtype", torch.half),
- sync_buffer=False,
- )
- for _m in model
- ]
- )
- else:
- model = NaiveAMPModel(
- model=model,
- output_to_fp32=gpc.is_no_pp_or_last_stage(),
- dtype=gpc.config.model.get("dtype", torch.half),
- sync_buffer=False,
- )
-
- set_parallel_attr_for_param_groups(model)
-
- # This sync is very important, cause the model weights kept in optimizer are copied
- # from the origin parameters in the memory, so we should make sure the dp sync
- # does not influence the model weights in optimizer be different with the origin parameters.
- if not is_using_fsdp() or gpc.config.parallel.fsdp.get("init_method", "cuda") == "cuda":
- sync_model_param(model)
-
- # This function is needed to make sure parameters that are not splitted by tensor parallelism are
- # the same across tensor parallelism.
- sync_model_replica_param_group(model)
-
- # Change random state mode to ParallelMode.DATA after model is built, guaranteeing the random
- # state in the same dp group are all the same.
- random_mode = ParallelMode.WEIGHT_DATA if is_using_isp() else ParallelMode.DATA
- set_mode(random_mode)
-
- # initialize isp communicator
- isp_communicator = initialize_parallel_communicator(model)
-
- model = wrap_FSDP_model(model)
-
- # set is_injected flag
- setattr(model, "IS_INJECTED", True)
-
- return model, isp_communicator
-
-
-_T = TypeVar("_T")
-
-
-def _submodule_filter(model: Union[nn.Module, nn.ModuleList], target_cls: Union[_T, Tuple[_T]]) -> Iterable[_T]:
- for _chunk in unwrap_naive_amp(model):
- for _module in _chunk.modules():
- if not isinstance(_module, target_cls):
- continue
-
- yield _module
-
-
-def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]):
- """
- Initialize communicator for isp tensor parallel mode.
-
- Args:
- model (:class:`torch.nn.Module`): Your model instance to be trained or evaluated.
-
- Returns:
- An isp communicator for managing comp/comm overlap.
- """
- isp_communicator_wrapper = None
- _retain_out_sharded = gpc.config.model.get("parallel_output", True)
-
- if is_using_isp():
- isp_communicator = ISPCommunicator(
- model,
- ISPCommModelConfig(
- gpc.config.model.dtype,
- get_current_device(),
- gpc.config.model.checkpoint,
- ),
- gpc.config.parallel.weight.overlap and not is_using_fsdp(),
- gpc.get_group(ParallelMode.WEIGHT),
- is_moe=False,
- selective_ckpt_offload=gpc.config.get("selective_checkpoint_offload", False),
- early_reduce_scatter_release=gpc.config.parallel.weight.early_reduce_scatter_release,
- enable_layer_fuse_isp_comm=gpc.config.parallel.weight.get("layer_fuse_isp_comm", False),
- )
- # register communicator for isp column parallel linear.
- ColumnParallelLinear.register_cls_communicator(isp_communicator)
- # row parallel linear will not be used.
- RowParallelLinear.register_cls_communicator(None)
- _head_communicator = HeadWeightParallelCommunicator(
- weight_process_group=gpc.get_group(ParallelMode.WEIGHT),
- seq_process_group=gpc.get_group(ParallelMode.TENSOR),
- retain_out_sharded=_retain_out_sharded,
- )
- _embedding_communicator = EmbeddingWeightParallelCommunicator(ParallelMode.WEIGHT)
-
- if gpc.config.model.get("num_experts", 1) > 1:
- # register communicator for moe isp column parallel linear.
- # NOTE: this wil overwrite registed communicator
- moe_isp_communicator = ISPCommunicator(
- model,
- ISPCommModelConfig(
- gpc.config.model.dtype,
- get_current_device(),
- gpc.config.model.checkpoint,
- ),
- gpc.config.parallel.expert_weight.overlap,
- gpc.get_group(ParallelMode.EXPERT_WEIGHT),
- is_moe=True,
- early_reduce_scatter_release=gpc.config.parallel.expert_weight.early_reduce_scatter_release,
- enable_layer_fuse_isp_comm=gpc.config.parallel.expert_weight.get("layer_fuse_isp_comm", False),
- )
- for moe in _submodule_filter(model, Experts):
- for column_linear in _submodule_filter(moe, (ColumnParallelLinear, GroupedWPLinear)):
- column_linear.register_communicator(moe_isp_communicator)
- for row_linear in _submodule_filter(moe, RowParallelLinear):
- row_linear.register_communicator(None)
-
- isp_communicator_wrapper = ISPCommunicatorWrapper([isp_communicator, moe_isp_communicator])
- else:
- isp_communicator_wrapper = ISPCommunicatorWrapper([isp_communicator])
-
- # register communictor for mtp/msp/fsp linear.
-
- # tensor parallel
- if gpc.config.parallel.tensor.mode == TensorParallelMode.mtp.name:
- ColumnParallelLinear.register_cls_communicator(
- TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.COLUMN)
- )
- RowParallelLinear.register_cls_communicator(
- TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.ROW)
- )
-
- if gpc.config.model.get("num_experts", 1) > 1:
- GroupedColumnLinear.register_cls_communicator(
- TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.COLUMN)
- )
- GroupedRowLinear.register_cls_communicator(
- TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.ROW)
- )
- GroupedWPLinear.register_cls_communicator(None)
- # treat as sequence paralle if no_tp
- if gpc.config.parallel.expert.no_tp:
- _column_communicator = TensorParallelCommunicator(
- process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.COLUMN
- )
- _row_communicator = TensorParallelCommunicator(
- process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.ROW
- )
- for moe in _submodule_filter(model, MoE):
- # 1. the linear in MoE degrades as no tp communication pattern
- for column_linear in _submodule_filter(moe, ColumnParallelLinear):
- column_linear.register_communicator(_column_communicator)
- for row_linear in _submodule_filter(moe, RowParallelLinear):
- row_linear.register_communicator(_row_communicator)
- # 2. register MoESequenceParallelCommunicator for MoE layer
- MoESequenceParallelCommunicator(ParallelMode.TENSOR, reverse=True).register_module_hook(moe)
-
- _head_communicator = HeadTensorParallelCommunicator(ParallelMode.TENSOR, _retain_out_sharded)
- _embedding_communicator = EmbeddingTensorParallelCommunicator(ParallelMode.TENSOR)
- # sequence parallel
- if gpc.config.parallel.tensor.mode in (TensorParallelMode.msp.name, TensorParallelMode.fsp.name):
- save_total_input_as_activation = gpc.config.parallel.tensor.mode == TensorParallelMode.msp.name
-
- ColumnParallelLinear.register_cls_communicator(
- SequenceParallelCommunicator(
- process_group=gpc.get_group(ParallelMode.TENSOR),
- role=LinearRole.COLUMN,
- save_total_input_as_activation=save_total_input_as_activation,
- )
- )
- RowParallelLinear.register_cls_communicator(
- SequenceParallelCommunicator(
- gpc.get_group(ParallelMode.TENSOR),
- role=LinearRole.ROW,
- save_total_input_as_activation=save_total_input_as_activation,
- )
- )
- if gpc.config.model.get("num_experts", 1) > 1:
- GroupedColumnLinear.register_cls_communicator(
- SequenceParallelCommunicator(
- process_group=gpc.get_group(ParallelMode.TENSOR),
- role=LinearRole.COLUMN,
- save_total_input_as_activation=save_total_input_as_activation,
- )
- )
- GroupedRowLinear.register_cls_communicator(
- SequenceParallelCommunicator(
- gpc.get_group(ParallelMode.TENSOR),
- role=LinearRole.ROW,
- save_total_input_as_activation=save_total_input_as_activation,
- )
- )
- GroupedWPLinear.register_cls_communicator(None)
- if gpc.config.parallel.expert.no_tp:
- _column_communicator = TensorParallelCommunicator(
- process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.COLUMN
- )
- _row_communicator = TensorParallelCommunicator(
- process_group=gpc.get_group(ParallelMode.EXPERT_TENSOR), role=LinearRole.ROW
- )
- for moe in _submodule_filter(model, MoE):
- # 1. the linear in MoE degrades as no tp communication pattern
- for column_linear in _submodule_filter(moe, ColumnParallelLinear):
- column_linear.register_communicator(_column_communicator)
- for row_linear in _submodule_filter(moe, RowParallelLinear):
- row_linear.register_communicator(_row_communicator)
-
- _head_communicator = HeadSequenceParallelCommunicator(
- ParallelMode.TENSOR, _retain_out_sharded, save_total_input_as_activation
- )
-
- _embedding_communicator = EmbeddingSequenceParallelCommunicator(ParallelMode.TENSOR)
-
- # register communitorc for embedding layer.
- if not is_using_fsdp():
- for embedding in _submodule_filter(model, Embedding1D):
- _embedding_communicator.register_module_hook(embedding)
-
- # register communictor for head layer.
- ScaleColumnParallelLinear.register_cls_communicator(_head_communicator)
- RewardModelLinear.register_cls_communicator(_head_communicator)
-
- return isp_communicator_wrapper
-
-
-@llm_timeout(func_name="initialize_optimizer")
-def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicator: ISPCommunicatorWrapper = None):
- """
- Initialize optimizer.
-
- Args:
- model (:class:`torch.nn.Module`): Your model instance to be trained or evaluated.
-
- Returns:
- A tuple of (optimizer, beta2_scheduler, lr_scheduler).
- """
-
- adam_cfg = gpc.config.adam
- zero_cfg = gpc.config.hybrid_zero_optimizer
- grad_scal_cfg = gpc.config.grad_scaler
- use_apex_adam = getattr(gpc.config, "use_apex_adam", False)
-
- if "use_split_tensor_optim" in zero_cfg and zero_cfg.use_split_tensor_optim:
- map_param_block(model)
-
- params = create_param_groups(model, adam_cfg.weight_decay)
-
- naive_optimizer = new_compatible_adamw(
- params=params,
- lr=adam_cfg.lr,
- betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2),
- eps=adam_cfg.adam_eps,
- use_apex_adam=use_apex_adam,
- )
-
- if (
- zero_cfg.overlap_sync_grad
- and gpc.is_using_parallel_mode(ParallelMode.PIPELINE)
- and gpc.is_pipeline_first_stage() is False
- ):
- # When pipeline parallelism is enabled, we prefer to only enable optimizer
- # gradient communication overlap in the first stage, to avoid amplifying
- # the communication overhead stage by stage in cases where the optimizer
- # communication overhead is greater than the compute overhead.
- # For pipeline stages except the first, even if overlap is not enabled,
- # their gradient synchronization overhead can be well hidden by
- # the inherent bubbles of pipeline parallelism.
- zero_cfg.overlap_sync_grad = False
-
- if zero_cfg.overlap_sync_param:
- param_bcast_sync_handler = ParamAsyncBcastHandler(ParallelMode.ZERO1, model, isp_communicator)
- else:
- param_bcast_sync_handler = None
-
- if not is_using_fsdp():
- if (
- "use_split_tensor_optim" not in gpc.config.hybrid_zero_optimizer
- or not gpc.config.hybrid_zero_optimizer.use_split_tensor_optim
- ):
- optimizer = HybridZeroOptimizer(
- naive_optimizer,
- grad_scal_cfg=grad_scal_cfg,
- zero_cfg=zero_cfg,
- param_bcast_sync_handler=param_bcast_sync_handler,
- isp_communicator=isp_communicator,
- )
- else:
- optimizer = HybridZeroOptimizer_v2(
- naive_optimizer,
- grad_scal_cfg=grad_scal_cfg,
- zero_cfg=zero_cfg,
- param_bcast_sync_handler=param_bcast_sync_handler,
- isp_communicator=isp_communicator,
- )
- else:
- optimizer = FSDPadaptOptimizer(
- naive_optimizer,
- grad_scal_cfg=grad_scal_cfg,
- zero_cfg=zero_cfg,
- )
-
- beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
-
- lr_scheduler = FineTuneCosineAnnealingWarmupLR(optimizer, **gpc.config.lr_scheduler)
-
- return optimizer, beta2_scheduler, lr_scheduler
-
-
-def get_scheduler_hooks(metric, zero_optim, isp_communicator_wrapper) -> List[SchedulerHook]:
- scheduler_hooks: List[SchedulerHook] = []
-
- if metric is not None:
- scheduler_hooks.append(
- SchedulerMetricHook(
- metric=metric,
- skip=(
- gpc.is_using_parallel_mode(ParallelMode.PIPELINE)
- and hasattr(gpc.config.model, "num_chunks")
- and gpc.config.model.num_chunks > 1
- and gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
- ),
- ),
- )
-
- if isp_communicator_wrapper is not None:
- for isp_communicator in isp_communicator_wrapper.isp_communicators:
- if isp_communicator is not None and isp_communicator.overlap:
- scheduler_hooks.append(ISPCommunicatorSchedulerHook(isp_communicator, zero_optim))
-
- return scheduler_hooks
-
-
-@llm_timeout(func_name="load_new_batch")
-def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState):
- """
- Load and return the new batch data based on training data loader.
-
- Args:
- train_dl (torch.utils.data.DataLoader): Dataloader for training.
- train_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
- train_state (TrainState): Current training state.
-
- Returns: A batch data and the updated train_iter.
- """
-
- timer("batch-gen").start()
- try:
- batch = next(train_iter) # structure is ({'input_ids': Tensor, 'cu_seqlens': Tensor}, Tensor)
- if hasattr(train_state, "batch_sampler_iter"):
- next(train_state.batch_sampler_iter)
- except StopIteration:
- train_iter = iter(train_dl)
- batch = next(train_iter)
- train_state.num_consumed_samples_in_epoch = 0
- if hasattr(train_state, "batch_sampler"):
- train_state.batch_sampler.batch_count = 0
- train_state.batch_sampler.num_consumed_samples_in_epoch = 0
- train_state.batch_sampler_iter = iter(train_state.batch_sampler)
- next(train_state.batch_sampler_iter)
- timer("batch-gen").stop()
-
- if batch[0].get("type_ids", None) is not None:
- # if use_packed_dataset is False, we need to unpack type_ids
- if not gpc.config.data.use_packed_dataset:
- batch[0]["type_ids"] = unpack_type_ids(batch[0]["type_ids"], batch[0]["cu_seqlens"])
-
- return batch, train_iter
-
-
-def initialize_llm_profile(profiling: bool = False, start_time: str = None):
- """Initialize and return the profiler context manager instance."""
-
- if profiling and gpc.get_local_rank(ParallelMode.DATA) == 0 and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
- schedule_config = {"wait": 1, "warmup": 1, "active": 1, "repeat": 1, "skip_first": 3}
- trace_path = (
- f"RUN/{gpc.config.JOB_NAME}/{start_time}/traces/rank{gpc.get_global_rank()}_"
- f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
- f"wp{gpc.get_local_rank(ParallelMode.WEIGHT)}_"
- f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}"
- )
- if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU:
- experimental_config = torch_npu.profiler._ExperimentalConfig(
- aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
- profiler_level=torch_npu.profiler.ProfilerLevel.Level1,
- l2_cache=False,
- )
- llm_profile = torch_npu.profiler.profile(
- activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU],
- schedule=torch_npu.profiler.schedule(**schedule_config),
- on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(trace_path),
- record_shapes=True,
- profile_memory=True,
- with_stack=False,
- with_flops=False,
- with_modules=False,
- experimental_config=experimental_config,
- )
- logger.info(f"Do profiling for NPU on rank {gpc.get_global_rank()}!")
- else:
- llm_profile = torch.profiler.profile(
- activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
- schedule=torch.profiler.schedule(**schedule_config),
- on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_path),
- with_stack=True,
- with_modules=True,
- profile_memory=True,
- )
- logger.info(f"Do profiling for GPU on rank {gpc.get_global_rank()}!")
- else:
- llm_profile = DummyProfile()
-
- return llm_profile
-
-
-@llm_timeout(func_name="record_current_batch_training_metrics")
-def record_current_batch_training_metrics(
- get_tflops_func,
- logger,
- writer,
- success_update,
- batch_count,
- batch,
- train_state,
- optimizer,
- beta2_scheduler,
- engine,
- start_time,
- very_begining_time,
- loss,
- moe_loss,
- grad_norm,
- metric,
-):
- """
- Print some training metrics of current batch.
- """
-
- set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time()))
-
- timer.store_last_timers()
- if success_update in (0, True):
- train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
- if gpc.is_no_pp_or_last_stage():
- acc_perplex = metric.get_metric()
-
- if success_update and gpc.is_rank_for_log():
- lr = optimizer.param_groups[0]["lr"]
- if hasattr(engine.optimizer, "grad_scaler"):
- scaler = engine.optimizer.grad_scaler._scale.item()
- elif hasattr(engine.optimizer.optim, "grad_scaler"):
- scaler = engine.optimizer.optim.grad_scaler._scale.item()
-
- num_tokens_in_batch = batch[1].nelement()
- real_num_tokens = math.ceil(acc_perplex.pop("real_token_num") / gpc.get_world_size(ParallelMode.GLOBAL))
- num_samples_in_batch = sum([len(b) - 1 for b in batch[0]["cu_seqlens"]])
- max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]])
- max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]])
- min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]])
- time_cost = time.time() - start_time
- tk_per_gpu = round(
- num_tokens_in_batch * gpc.get_world_size(ParallelMode.DATA) / gpc.get_world_size(ParallelMode.GLOBAL),
- 4,
- )
- tgs_statistic = train_state.tgs_statistic
- tgs_statistic["sum_step"] += 1
- tgs_statistic["sum_tg"] += tk_per_gpu
- tgs_statistic["total_time"] = time.time() - very_begining_time
- tgs_statistic["sum_last_tg_10"] += tk_per_gpu
- tgs_statistic["sum_last_time_10"] += time_cost
- tgs_statistic["sum_last_tg_50"] += tk_per_gpu
- tgs_statistic["sum_last_time_50"] += time_cost
- tgs_statistic["SMA_tg_50"] += tk_per_gpu
- tgs_statistic["SMA_time_50"] += time_cost
- tgs_statistic["SMA_tg_50_list"].append(tk_per_gpu)
- tgs_statistic["SMA_time_50_list"].append(time_cost)
- if tgs_statistic["sum_step"] > 50:
- tgs_statistic["SMA_tg_50"] -= tgs_statistic["SMA_tg_50_list"][0]
- tgs_statistic["SMA_time_50"] -= tgs_statistic["SMA_time_50_list"][0]
- tgs_statistic["SMA_tg_50_list"].popleft()
- tgs_statistic["SMA_time_50_list"].popleft()
-
- last_tgs_1 = round(tk_per_gpu / time_cost, 2)
- tgs_statistic["sum_tgs"] += last_tgs_1
-
- if tgs_statistic["sum_step"] % 10 == 0:
- tgs_statistic["last_tgs_10"] = round(tgs_statistic["sum_last_tg_10"] / tgs_statistic["sum_last_time_10"], 2)
- tgs_statistic["sum_last_tg_10"] = 0
- tgs_statistic["sum_last_time_10"] = 0
-
- if tgs_statistic["sum_step"] % 50 == 0:
- tgs_statistic["last_tgs_50"] = round(tgs_statistic["sum_last_tg_50"] / tgs_statistic["sum_last_time_50"], 2)
- tgs_statistic["sum_last_tg_50"] = 0
- tgs_statistic["sum_last_time_50"] = 0
-
- last_tgs_10 = tgs_statistic["last_tgs_10"]
- last_tgs_50 = tgs_statistic["last_tgs_50"]
-
- tgs_all = round(tgs_statistic["sum_tg"] / tgs_statistic["total_time"], 2)
- tgs_avg = round(tgs_statistic["sum_tgs"] / tgs_statistic["sum_step"], 2)
- tgs_SMA = round(tgs_statistic["SMA_tg_50"] / tgs_statistic["SMA_time_50"], 2)
-
- tflops = get_tflops_func(time_cost)
-
- tgs_origin = round(
- num_tokens_in_batch
- * gpc.get_world_size(ParallelMode.DATA)
- / gpc.get_world_size(ParallelMode.GLOBAL)
- / time_cost,
- 2,
- )
-
- real_tgs = round(
- real_num_tokens / time_cost,
- 2,
- )
-
- infos = {
- "tflops": tflops,
- "step": batch_count,
- "loss": loss.item() - moe_loss.item() if moe_loss is not None else loss.item(),
- "real_tgs": real_tgs,
- "tgs (tokens/gpu/second)": tgs_origin,
- "tgs/last_tgs_1": last_tgs_1,
- "tgs/tgs_all": tgs_all,
- "tgs/tgs_avg": tgs_avg,
- "tgs/tgs_SMA": tgs_SMA,
- "tgs/last_tgs_10": last_tgs_10,
- "tgs/last_tgs_50": last_tgs_50,
- "lr": lr,
- "loss_scale": scaler,
- "grad_norm": grad_norm,
- }
- if moe_loss is not None:
- infos["moe_loss"] = moe_loss.item()
-
- infos["micro_num"] = len(batch[1])
- infos["num_consumed_tokens"] = train_state.num_consumed_tokens
- infos["inf_nan_skip_batches"] = train_state.inf_nan_skip_batches
- infos["num_samples_in_batch"] = num_samples_in_batch # the number of batches which have the most samples
- infos["largest_length"] = max_length_in_batch # the longest input
- infos["largest_batch"] = max_samples_in_batch # the batch with the most samples
- infos["smallest_batch"] = min_samples_in_batch
- infos["adam_beta2"] = beta2_scheduler.get_beta2()
-
- fwd_bwd_time = round(timer("fwd-bwd").elapsed(), 2)
- infos["fwd_bwd_time"] = fwd_bwd_time
- bwd_time = round(timer("bwd").elapsed(), 2)
- infos["bwd_time"] = bwd_time
-
- for key, value in acc_perplex.items():
- infos[key] = value
-
- line = ""
- for key, value in infos.items():
- line += f"{key}={value} "
- if isinstance(value, dict):
- writer.add_scalars(key=key, value=value, step=train_state.step_count)
- else:
- writer.add_scalar(key=key, value=value, step=train_state.step_count)
-
- logger.info(line)
-
- # if loss spike occurs, send alert info to feishu
- mm.monitor_loss_spike(
- alert_address=gpc.config.monitor.alert.feishu_alert_address,
- step_count=batch_count,
- cur_step_loss=loss.item(),
- )
-
-
-def inject_embed(model: nn.Module, inject=False, interactive=False) -> None:
- def traverse(module):
- for name, child in module.named_children():
- if isinstance(child, nn.Embedding) and not isinstance(child, Embedding1D):
- msg = (
- f"To get parallel training enabled, module {name} of type {nn.Embedding.__name__} "
- + f"is required to be replaced with {Embedding1D.__name__}."
- )
- if inject:
- help_msg = f"Do you want to replace {name}? (y/n)"
- opt = timeout_input(
- f"{msg}\n{help_msg}",
- default="y",
- timeout=60,
- interactive=interactive,
- )
- if opt in ["y", "yes"]:
- child_new = Embedding1D(
- num_embeddings=child.num_embeddings,
- embedding_dim=child.embedding_dim,
- padding_idx=child.padding_idx,
- ).to(device=child.weight.device, dtype=child.weight.dtype)
- setattr(module, name, child_new)
- else:
- if gpc.is_rank_for_log():
- logger.warning(f"Skip replacing {name}")
- else:
- if gpc.is_rank_for_log():
- logger.warning(msg)
- else:
- traverse(child)
-
- traverse(model)
-
-
-def inject_linear(model: nn.Module, inject=False, interactive=False) -> None:
- def traverse(module):
- for name, child in module.named_children():
- if isinstance(child, nn.Linear) and not isinstance(child, ParallelLinearWithCommExt):
- msg = (
- f"To get parallel training enabled, module {name} of type {nn.Linear.__name__} "
- + f"is required to be replaced with {new_linear.__name__}."
- )
- if inject:
- help_msg = f"Do you want to replace {name}? (y/n)"
- opt = timeout_input(
- f"{msg}\n{help_msg}",
- default="y",
- timeout=60,
- interactive=interactive,
- )
- if opt in ["y", "yes"]:
- child_new = new_linear(
- name=LINEAR2NEWLINEAR_NAME_MAPPING.get(name, name),
- in_features=child.in_features,
- out_features=child.out_features,
- bias=child.bias is not None,
- ).to(device=child.weight.device, dtype=child.weight.dtype)
- setattr(module, name, child_new)
- else:
- if gpc.is_rank_for_log():
- logger.warning(f"Skip replacing {name}")
- else:
- if gpc.is_rank_for_log():
- logger.warning(msg)
- else:
- traverse(child)
-
- traverse(model)
-
-
-def inject_norm(model: nn.Module, inject=False, interactive=False) -> None:
- def traverse(module):
- for name, child in module.named_children():
- cls_name = type(child).__name__
- if "RMSNorm" in cls_name:
- msg = (
- f"To re-use unified RMSNorm implementation, {cls_name} "
- + f"is suggested to be replaced with {new_layer_norm.__name__}."
- )
- if inject:
- help_msg = f"Do you want to replace {name}? (y/n)"
- opt = timeout_input(
- f"{msg}\n{help_msg}",
- default="y",
- timeout=60,
- interactive=interactive,
- )
- if opt in ["y", "yes"]:
- child_new = new_layer_norm(
- norm_type="rmsnorm",
- normalized_shape=child.weight.shape,
- eps=child.variance_epsilon,
- ).to(device=child.weight.device, dtype=child.weight.dtype)
- setattr(module, name, child_new)
- else:
- if gpc.is_rank_for_log():
- logger.warning(f"Skip replacing {name}")
- else:
- if gpc.is_rank_for_log():
- logger.warning(msg)
- else:
- traverse(child)
-
- traverse(model)
-
-
-def inject_config(model: nn.Module) -> None:
- # Compatibility for Vision-Language Model
- if hasattr(model.config, "text_config"):
- llm_cfg = model.config.text_config
- else:
- llm_cfg = model.config
- gpc.config.model.vocab_size = gpc.config.VOCAB_SIZE = llm_cfg.vocab_size
- gpc.config.model.hidden_size = gpc.config.HIDDEN_SIZE = llm_cfg.hidden_size
- gpc.config.model.num_layers = gpc.config.NUM_LAYER = llm_cfg.num_hidden_layers
- # Compatibility for Mamba
- if hasattr(llm_cfg, "num_attention_heads"):
- gpc.config.model.num_attention_heads = gpc.config.NUM_ATTENTION_HEAD = llm_cfg.num_attention_heads
- gpc.config.model.mlp_ratio = gpc.config.MLP_RATIO = llm_cfg.intermediate_size / llm_cfg.hidden_size
- # For models that use GQA
- if hasattr(llm_cfg, "num_key_value_heads"):
- gpc.config.model.num_kv_attention_heads = gpc.config.NUM_KV_ATTENTION_HEAD = llm_cfg.num_key_value_heads
-
-
-def _get_modules_to_materialize(
- root_module: nn.Module,
- ignored_modules: Set[nn.Module],
-) -> List[nn.Module]:
- # Run BFS to collect the modules to materialize via `reset_parameters()`,
- # stopping at any module with FSDP already applied or at ignored modules.
- modules_to_materialize: List[nn.Module] = []
- queue = collections.deque([root_module])
- visited_modules: Set[nn.Module] = {root_module}
- while queue:
- module = queue.popleft()
- modules_to_materialize.append(module)
- for child_module in module.children():
- if child_module not in visited_modules and child_module not in ignored_modules:
- visited_modules.add(child_module)
- queue.append(child_module)
- return modules_to_materialize
-
-
-def _materialize_meta_module(
- root_module: nn.Module,
- ignored_modules: Set[nn.Module],
- device_id: Optional[torch.device],
-) -> None:
- # Run default meta device initialization
- modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules)
- module = None
- try:
- # Assume that each module's `reset_parameters()` only initializes its
- # own parameters and not those of its children
- with torch.no_grad():
- for module in modules_to_materialize:
- # As a contract to the user, only call `reset_parameters()` if
- # the module has directly managed parameters/buffers
- module_state_iter = itertools.chain(module.parameters(recurse=False), module.buffers(recurse=False))
- has_module_states = len(list(module_state_iter)) > 0
- if has_module_states:
- module.to_empty(device=device_id, recurse=False)
- module.reset_parameters() # type: ignore[operator]
- except BaseException as e:
- logger.warning(
- "Unable to call `reset_parameters()` for module on meta "
- f"device with error {str(e)}. Please ensure that your module of"
- f"type {type(module)} implements a `reset_parameters()` method." # type: ignore[possibly-undefined]
- )
- raise e
-
-
-def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
- if is_using_fsdp():
- assert isinstance(model, nn.Module), "Currently FSDP does not support pipeline parallel."
- wrap_cls = tuple(
- LazyObject(warp_cls["mod"], warp_cls["mod_cls"]).build() for warp_cls in gpc.config.get("fsdp_wrap_cls", [])
- )
- fsdp_mode = gpc.config.parallel.fsdp.get("mode", "v1")
- fsdp_init_method = gpc.config.parallel.fsdp.get("init_method", "cuda")
-
- if fsdp_mode == "v1":
- model = FSDP(
- module=model,
- process_group=gpc.get_group(ParallelMode.GLOBAL),
- sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO2: SHARD_GRAD_OP, ZeRO3: FULL_SHARD
- auto_wrap_policy=functools.partial(transformer_auto_wrap_policy, transformer_layer_cls=set(wrap_cls)),
- sync_module_states=fsdp_init_method != "cuda", # sync model paramters
- forward_prefetch=True,
- backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
- limit_all_gathers=True,
- use_orig_params=True,
- device_id=None if fsdp_init_method == "cuda" else get_current_device(), # needed for sync_module_states
- )
- # For FSDP v1, to get ckpt resuming work normally, we do dummy forward.
- # This hack is needed due to FSDP v1 lazy initialization in model construction.
- # FYI: https://github.com/pytorch/pytorch/issues/113496
- model = init_fsdp_v1(model, get_current_device())
- elif FSDP2_SUPPORTED and fsdp_mode == "v2":
- fsdp_kwargs = {
- "reshard_after_forward": True, # ZeRO2: False, ZeRO3: True
- }
- for module in model.modules():
- if isinstance(module, wrap_cls):
- fully_shard(module, **fsdp_kwargs)
- fully_shard(model, **fsdp_kwargs)
- if fsdp_init_method == "meta":
- _materialize_meta_module(model, set(), get_current_device())
- elif fsdp_init_method == "cpu":
- model.to(get_current_device())
- else:
- raise ValueError(f"Unsupported FSDP mode: {fsdp_mode}")
-
- if is_using_hf() and not gpc.config.ckpt.get("auto_resume", False):
- load_ckpt_info = gpc.config.ckpt.load_ckpt_info
- load_ckpt_path = load_ckpt_info.get("path", None)
- load_ckpt_content = load_ckpt_info.get("content", [])
- if load_ckpt_path:
- assert load_ckpt_content == (
- "model",
- ), "If auto_resume=False and checkpoint path is given, only model can be loaded"
- if DCP_SUPPORTED:
- hf = gpc.config.hf
- mod = LazyObject(hf.mod, hf.mod_cls)
- mod = mod.build()
- state_dict = mod.from_pretrained(
- pretrained_model_name_or_path=load_ckpt_path, use_safetensors=True
- ).state_dict()
- state_dict = {f"model.{key}": state_dict[key].clone().detach() for key in state_dict}
- set_model_state_dict(
- model=model, model_state_dict=state_dict, options=StateDictOptions(full_state_dict=True)
- )
- del state_dict
- internlm_accelerator.empty_cache()
- else:
- raise RuntimeError("DCP is not supported in this version of PyTorch.")
-
- return model
-
-
-def inject_model_helper(model: Union[nn.Module, nn.ModuleList], inject_info: Optional[Dict] = None) -> None:
- """
- Inject model helper functions.
-
- Args:
- model (Union[nn.Module, nn.ModuleList]):
- For built-in models, it is nn.Module for no pp and nn.ModuleList for pp.
- For injected models, it is nn.Module.
- inject_info (Optional[Dict]): configurations for injected_models.
- """
- # parse inject_info
- if inject_info is not None:
- inject = inject_info.get("inject", False)
- interactive = inject_info.get("interactive", False)
- modules = inject_info.get("modules", [])
- reset_params = inject_info.get("reset_params", False)
- extra_linear2newlinear = inject_info.get("extra_linear2newlinear", {})
- else:
- inject = False
- interactive = False
- modules = []
- reset_params = False
- extra_linear2newlinear = {}
-
- LINEAR2NEWLINEAR_NAME_MAPPING.update(extra_linear2newlinear)
-
- inject_funcs = {
- "embed": inject_embed,
- "linear": inject_linear,
- "norm": inject_norm,
- }
-
- # inject config
- if inject:
- inject_config(model)
-
- if not isinstance(model, nn.ModuleList):
- model = [model]
- for _chunk in model:
- # Special case for pure dp mode: skip
- if (
- isinstance(gpc.config.parallel["tensor"], dict)
- and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.mtp.name
- and gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL)
- ):
- continue
- # In-place replacement or check for modules: "embed", "linear", "norm"
- # (1) If inject=True, in-place replacement
- # (2) If inject=False, check
- for mod in modules:
- inject_funcs[mod](_chunk, inject, interactive)
- # reset parameters if needed, model should have reset_parameters() method
- if reset_params:
- _chunk.reset_parameters()
- for _chunk in model:
- # If model is initialized on cpu, model should be moved to cuda device after injection
- if not next(_chunk.parameters()).is_cuda:
- _chunk.to(get_current_device())
-
- # print injected model
- if inject and gpc.is_rank_for_log():
- logger.info(
- f"inject is enabled, please check the model carefully, "
- f"if there are any problems, please report issue to us. "
- f"The injected model is \n {model}"
- )
diff --git a/internlm/train/utils.py b/internlm/train/utils.py
deleted file mode 100644
index d1bf4fe90..000000000
--- a/internlm/train/utils.py
+++ /dev/null
@@ -1,116 +0,0 @@
-from typing import Dict, Tuple
-
-import torch
-from torch import nn
-
-from internlm.core.context.parallel_context import ParallelMode
-from internlm.core.context.parallel_context import global_context as gpc
-from internlm.core.naive_amp import unwrap_naive_amp
-from internlm.model.modules.utils import is_moe_param
-from internlm.utils.logger import get_logger
-
-logger = get_logger(__file__)
-
-
-def split_params_into_different_groups_for_optimizer(
- param_groups: Tuple[Dict],
-) -> Tuple[Dict]:
- """Split parameters into different groups for optimizer
-
- Args:
- param_groups (Tuple[Dict]): The list of parameter groups to split
- Input Example:
- >>> (
- >>> {'name': 'default', 'params': [tensor], 'weight_decay' :xxx},
- >>> )
-
- Returns:
- Tuple[Dict]: list of params groups for optimizer
- Output Example:
- >>> (
- >>> {'name': 'default', 'params': [tensor], 'weight_decay' :xxx},
- >>> {'name': 'embed_head', 'params': [tensor], 'weight_decay' :xxx},
- >>> {'name': 'fp32', 'params': [tensor], 'weight_decay' :xxx},
- >>> )
- """
-
- if isinstance(param_groups, tuple):
- param_groups = list(param_groups) # Tuple cannot be modified
- elif isinstance(param_groups, dict):
- param_groups = [param_groups]
- elif not isinstance(param_groups, list):
- raise ValueError(f"Unknown param group type of {type(param_groups)}")
-
- new_groups = {}
- # create new groups for fp32 parameter group
- new_groups["fp32"] = {"name": "fp32", "params": [], "optimizer_mode": ParallelMode.ZERO1}
-
- if gpc.config.model.get("num_experts", 1) > 1:
- for key in gpc.expert_parallel_group_names:
- new_groups[key] = {"name": key, "moe": True, "params": [], "optimizer_mode": ParallelMode.EXPERT_DATA}
-
- for pgroup in param_groups:
- # copy attribute from origin group, we assume the input param_groups only
- # have one group, so the attribute will not be copyed multiple times.
- for ori_key in pgroup.keys():
- if ori_key not in ("name", "params"):
- for _, group in new_groups.items():
- group[ori_key] = pgroup[ori_key]
- # assign param
- origin_params = []
- for param in pgroup["params"]:
- # moe param means MoE is enabled
- if is_moe_param(param):
- new_groups[param.group_name]["params"].append(param)
- elif param.dtype == torch.float32 and gpc.config.model.dtype != torch.float32:
- new_groups["fp32"]["params"].append(param)
- else:
- origin_params.append(param)
-
- # default param group, which is the first group in the param groups
- pgroup["params"] = origin_params
- pgroup["optimizer_mode"] = ParallelMode.ZERO1
-
- # param groups may contain empty groups, such as fp32
- param_groups.extend(new_groups.values())
-
- return tuple(param_groups)
-
-
-def create_param_groups(model, weight_decay):
- parameters = {
- "params": [param for param in model.parameters() if param.requires_grad],
- "name": "default",
- "weight_decay": weight_decay,
- }
- return split_params_into_different_groups_for_optimizer(parameters)
-
-
-def map_param_block(model):
- for _chunk in unwrap_naive_amp(model):
- for name, children in _chunk.named_children():
- if isinstance(children, nn.ModuleList):
- for idx, block in enumerate(children):
- block_name = name + f"_{idx}"
- for param in block.parameters():
- setattr(param, "block_name", block_name)
- else:
- for param in children.parameters():
- setattr(param, "block_name", name)
-
-
-def timeout_input(printout, default, timeout=None, interactive=True):
- if not interactive:
- return default
- import select
- import sys
-
- if gpc.is_rank_for_log():
- logger.info(printout)
-
- i, _, _ = select.select([sys.stdin], [], [], timeout)
- if i:
- msg = sys.stdin.readline().strip()
- return default if len(msg) == 0 else msg
- else:
- return default
diff --git a/internlm/utils/common.py b/internlm/utils/common.py
index 56ebcfbe6..c444e456a 100644
--- a/internlm/utils/common.py
+++ b/internlm/utils/common.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
+import argparse
import bisect
import inspect
import os
@@ -15,7 +16,6 @@
import numpy as np
import torch
-import internlm
from internlm.accelerator import AcceleratorType, get_accelerator
from internlm.utils.logger import get_logger
@@ -24,8 +24,39 @@
internlm_accelerator = get_accelerator()
+def get_default_parser():
+ """Reads user command line and uses an argument parser to parse the input arguments.
+ Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
+
+ Returns:
+ Parser: Returns the parser with the default arguments, the user may add customized arguments into this parser.
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, help="path to the config file")
+ parser.add_argument(
+ "--launcher",
+ type=str,
+ default="slurm",
+ choices=["slurm", "torch"],
+ help="launcher for launching distributed environment",
+ )
+ parser.add_argument("--host", type=str, help="the master address for distributed training")
+ parser.add_argument("--port", type=int, default=8888, help="the master port for distributed training")
+ parser.add_argument("--world_size", type=int, help="world size for distributed training")
+ parser.add_argument("--rank", type=int, help="rank for the default process group")
+ parser.add_argument("--local_rank", type=int, help="local rank on the node")
+ parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication")
+ parser.add_argument("--seed", type=int, default=1024)
+ parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.")
+ parser.add_argument("--enable_ali_topology", default=False, action="store_true", help="enable ali switch topology.")
+ parser.add_argument(
+ "--disable_volc_topology", default=False, action="store_true", help="disable volc switch topology."
+ )
+ return parser
+
+
def parse_args():
- parser = internlm.get_default_parser()
+ parser = get_default_parser()
args = parser.parse_args()
return args
@@ -318,3 +349,7 @@ def __setitem__(self, key, value):
mapping[key] = value
return
self.maps[0][key] = value
+
+
+def set_env_var(key, value):
+ os.environ[str(key)] = str(value)
diff --git a/internlm/utils/config.py b/internlm/utils/config.py
new file mode 100644
index 000000000..7d54d0ca8
--- /dev/null
+++ b/internlm/utils/config.py
@@ -0,0 +1,103 @@
+import inspect
+import sys
+from importlib.machinery import SourceFileLoader
+from pathlib import Path
+
+
+class Config(dict):
+ """This is a wrapper class for dict objects so that values of which can be
+ accessed as attributes.
+
+ Args:
+ config (dict): The dict object to be wrapped.
+ """
+
+ def __init__(self, config: dict = None): # pylint: disable=W0231
+ if config is not None:
+ for k, v in config.items():
+ self._add_item(k, v)
+
+ def __missing__(self, key):
+ raise KeyError(key)
+
+ def __getattr__(self, key):
+ try:
+ value = super().__getitem__(key)
+ return value
+ except KeyError:
+ raise AttributeError(key)
+
+ def __setattr__(self, key, value):
+ super().__setitem__(key, value)
+
+ def _add_item(self, key, value):
+ if isinstance(value, dict):
+ self.__setattr__(key, Config(value))
+ else:
+ self.__setattr__(key, value)
+
+ def update(self, config):
+ assert isinstance(config, (Config, dict)), "can only update dictionary or Config objects."
+ for k, v in config.items():
+ self._add_item(k, v)
+ return self
+
+ @staticmethod
+ def from_file(filename: str):
+ """Reads a python file and constructs a corresponding :class:`Config` object.
+
+ Args:
+ filename (str): Name of the file to construct the return object.
+
+ Returns:
+ :class:`Config`: A :class:`Config` object constructed with information in the file.
+
+ Raises:
+ AssertionError: Raises an AssertionError if the file does not exist, or the file is not .py file
+ """
+
+ # check config path
+ if isinstance(filename, str):
+ filepath = Path(filename).absolute()
+ elif isinstance(filename, Path):
+ filepath = filename.absolute()
+
+ assert filepath.exists(), f"{filename} is not found, please check your configuration path"
+
+ # check extension
+ extension = filepath.suffix
+ assert extension == ".py", "only .py files are supported"
+
+ # import the config as module
+ remove_path = False
+ if filepath.parent not in sys.path:
+ sys.path.insert(0, (filepath))
+ remove_path = True
+
+ module_name = filepath.stem
+ source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath))
+ module = source_file.load_module() # pylint: disable=W4902,E1120,W1505
+
+ # load into config
+ config = Config()
+
+ for k, v in module.__dict__.items():
+ if k.startswith("__") or inspect.ismodule(v) or inspect.isclass(v):
+ continue
+ else:
+ config._add_item(k, v)
+
+ # remove module
+ del sys.modules[module_name]
+ if remove_path:
+ sys.path.pop(0)
+
+ return config
+
+
+def get_config_value(config, key, defalut):
+ try:
+ value = config[key]
+ except KeyError:
+ value = defalut
+ return value
diff --git a/internlm/utils/lazy.py b/internlm/utils/lazy.py
index e67c63aa2..e8dc8d860 100644
--- a/internlm/utils/lazy.py
+++ b/internlm/utils/lazy.py
@@ -1,3 +1,4 @@
+# adapted from https://github.com/open-mmlab/mmengine/blob/main/mmengine/config/lazy.py
# Copyright (c) OpenMMLab. All rights reserved.
import abc
import importlib
@@ -43,7 +44,7 @@ class LazyObject:
During parsing process, the syntax like:
Examples:
- >>> import torch.nn as nn
+ >>> from torch import nn
>>> from mmdet.models import RetinaNet
>>> import mmcls.models
>>> import mmcls.datasets
@@ -52,7 +53,7 @@ class LazyObject:
Will be parsed as:
Examples:
- >>> # import torch.nn as nn
+ >>> # from torch import nn
>>> nn = lazyObject('torch.nn')
>>> # from mmdet.models import RetinaNet
>>> RetinaNet = lazyObject('mmdet.models', 'RetinaNet')
diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py
index 843c31fcb..b02af1e66 100644
--- a/internlm/utils/parallel.py
+++ b/internlm/utils/parallel.py
@@ -13,7 +13,7 @@
ParallelMode,
)
from internlm.core.context import global_context as gpc
-from internlm.model.modules.utils import is_gate_param
+from internlm.model.model_ops.modules.utils import is_gate_param
from internlm.utils.utils import TensorParallelMode
diff --git a/internlm/utils/timeout.py b/internlm/utils/timeout.py
index 5b09f9d5a..d3720ca21 100644
--- a/internlm/utils/timeout.py
+++ b/internlm/utils/timeout.py
@@ -39,7 +39,7 @@ def __exit__(self, error_type, value, traceback):
timeout_threshold_dict = {
- "initialize_distributed_env": 240,
+ "init_distributed": 240,
"nopp_forward_backward_step": 360,
"initialize_model_and_parallel_communicator": 60,
"initialize_optimizer": 60,
diff --git a/internlm/utils/utils.py b/internlm/utils/utils.py
index c9abe3a5a..c45f561fa 100644
--- a/internlm/utils/utils.py
+++ b/internlm/utils/utils.py
@@ -50,13 +50,7 @@ class ModelType(Enum):
INTERNLM2 = 2
LLAMA2 = 3
INTERNLM_MoE = 4
- LLAVA = 5
- QWEN2 = 6
- BAICHUAN2 = 7
- GEMMA = 8
- QWEN2MOE = 9
- MIXTRALMOE = 10
- INTERNLM3 = 11
+ INTERNLM3 = 5
class DataType(Enum):
diff --git a/requirements/runtime.txt b/requirements/runtime.txt
index a545f766c..419fa22cb 100644
--- a/requirements/runtime.txt
+++ b/requirements/runtime.txt
@@ -1,20 +1,22 @@
-transformers
+transformers<4.47.0
sentencepiece
datasets
numpy
+scipy
+decorator
tqdm
einops
-psutil
-packaging
-pre-commit
-ninja
-gputil
-pytest
boto3
botocore
-torch-scatter
pyecharts
py-libnuma
pynvml
+psutil
+gputil
tensorboard
--f https://data.pyg.org/whl/torch-2.1.0+cu118.html
+ninja
+packaging
+pre-commit
+pylint
+pytest
+image
diff --git a/requirements/torch.txt b/requirements/torch.txt
deleted file mode 100644
index c9a04b3d8..000000000
--- a/requirements/torch.txt
+++ /dev/null
@@ -1,4 +0,0 @@
---extra-index-url https://download.pytorch.org/whl/cu118
-torch==2.1.0+cu118
-torchvision==0.16.0+cu118
-torchaudio==2.1.0+cu118
diff --git a/setup.py b/setup.py
index f37599543..673b86577 100644
--- a/setup.py
+++ b/setup.py
@@ -1,56 +1,56 @@
import os
-import re
import sys
-import subprocess
-from setuptools import setup, find_packages
-from setuptools.command.install import install
+from typing import List
+
+from setuptools import find_packages, setup
pwd = os.path.dirname(__file__)
+
def readme():
- with open(os.path.join(pwd, 'README.md')) as f:
+ with open(os.path.join(pwd, "README.md")) as f:
content = f.read()
return content
+
def get_version():
- with open(os.path.join(pwd, 'version.txt'), 'r') as f:
+ with open(os.path.join(pwd, "version.txt"), encoding="utf-8") as f:
content = f.read()
return content
-def has_nvcc():
- try:
- subprocess.run(['nvcc', '--version'], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
- return True
- except (subprocess.CalledProcessError, FileNotFoundError):
- return False
-
-def fetch_requirements(path):
- with open(path, 'r') as fd:
- return [r.strip() for r in fd.readlines() if 'torch-scatter' not in r and not r.startswith('-f ')]
-
-if has_nvcc():
- install_requires = [
- fetch_requirements('requirements/runtime.txt'),
- 'rotary_emb',
- 'xentropy',
- ]
-else:
- install_requires = [
- fetch_requirements('requirements/runtime.txt'),
- ]
+
+def get_requires() -> List[str]:
+ with open(os.path.join("requirements", "runtime.txt"), encoding="utf-8") as f:
+ file_content = f.read()
+ lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
+ return lines
+
+
+extra_require = {
+ "torch": ["torch>=2.1.0"],
+ "torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "numpy==1.26.4"],
+}
+
+
+if sys.platform.startswith("linux"):
+ extra_require["torch"].append("flash-attn>=2.6.3")
+
setup(
- name='InternEvo',
+ name="InternEvo",
version=get_version(),
- description='an open-sourced lightweight training framework aims to support model pre-training without the need for extensive dependencies',
+ description="Lightweight training framework for LLM",
+ author="InternEvo team",
+ license="Apache 2.0 License",
long_description=readme(),
- long_description_content_type='text/markdown',
- packages=find_packages(),
- install_requires=install_requires,
+ long_description_content_type="text/markdown",
+ packages=find_packages(exclude=["tests"]),
+ install_requires=get_requires(),
+ extras_require=extra_require,
classifiers=[
- 'Programming Language :: Python :: 3.10',
- 'Intended Audience :: Developers',
- 'Intended Audience :: Education',
- 'Intended Audience :: Science/Research',
+ "Programming Language :: Python :: 3.10",
+ "Intended Audience :: Developers",
+ "Intended Audience :: Education",
+ "Intended Audience :: Science/Research",
],
)
diff --git a/tests/common_fixture.py b/tests/common_fixture.py
index e5a8b9aa1..0d7fd95dc 100644
--- a/tests/common_fixture.py
+++ b/tests/common_fixture.py
@@ -5,12 +5,11 @@
import numpy as np
import torch
-import internlm
from internlm.accelerator import get_accelerator
from internlm.core.context import global_context as gpc
-from internlm.core.context.parallel_context import Config
from internlm.data.utils import unpack_type_ids
-from internlm.initialize.launch import args_sanity_check
+from internlm.initialize import initialize_launcher
+from internlm.utils.config import Config
internlm_accelerator = get_accelerator()
@@ -41,14 +40,12 @@
model=dict(
checkpoint=False,
num_attention_heads=32,
- embed_split_hidden=True,
vocab_size=103168,
embed_grad_scale=1,
parallel_output=True,
hidden_size=4096,
num_layers=32,
mlp_ratio=8 / 3,
- apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
@@ -119,9 +116,7 @@ def build_environment(rank, world_size, free_port, config):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(free_port)
internlm_accelerator.empty_cache()
- # launcher="torch"
- internlm.launch_from_torch(config=config, seed=1024)
- args_sanity_check()
+ initialize_launcher(config=config, launcher="torch", distributed_port=8888, seed=1024, args_check=True, dist_backend="nccl")
def seed_all(seed, cuda_deterministic=False):
diff --git a/tests/test_core/test_pipeline.py b/tests/test_core/test_pipeline.py
index 79105cda4..95ef7687d 100644
--- a/tests/test_core/test_pipeline.py
+++ b/tests/test_core/test_pipeline.py
@@ -5,9 +5,9 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
-from internlm.core.context.parallel_context import Config
from internlm.solver.optimizer.compatible_adamw import new_compatible_adamw
from internlm.utils.common import get_current_device
+from internlm.utils.config import Config
from tests.test_core.utils import (
MlpModel,
MyLoss,
diff --git a/tests/test_core/utils.py b/tests/test_core/utils.py
index 5ccaccaf3..1436e2988 100644
--- a/tests/test_core/utils.py
+++ b/tests/test_core/utils.py
@@ -5,7 +5,6 @@
from torch import nn
from torch.testing import assert_close
-import internlm
from internlm.accelerator import get_accelerator
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
@@ -17,8 +16,9 @@
NonPipelineScheduler,
PipelineScheduler,
)
-from internlm.model.metrics import SchedulerMetricHook
-from internlm.train import initialize_optimizer
+from internlm.initialize import initialize_launcher
+from internlm.initialize.initialize_optimizer import initialize_optimizer
+from internlm.model.model_ops.metrics import SchedulerMetricHook
from internlm.utils.common import get_current_device
internlm_accelerator = get_accelerator()
@@ -155,8 +155,7 @@ def build_environment(rank, world_size, config):
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "33333"
internlm_accelerator.empty_cache()
- # launcher="torch"
- internlm.launch_from_torch(config=config, seed=1024)
+ initialize_launcher(config=config, launcher="torch", distributed_port=8888, seed=1024, args_check=False, dist_backend="nccl")
def loose_close(a, b, dtype: torch.dtype = torch.float32):
diff --git a/tests/test_data/test_batch_sampler.py b/tests/test_data/test_batch_sampler.py
index 7600b7637..cf4400c0f 100644
--- a/tests/test_data/test_batch_sampler.py
+++ b/tests/test_data/test_batch_sampler.py
@@ -6,21 +6,18 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
-
-# from internlm.core.context import ParallelMode
-from internlm.core.context.parallel_context import Config
from internlm.core.trainer import TrainState
from internlm.data import (
build_train_loader_with_data_type,
build_valid_loader_with_data_type,
)
-from internlm.eval.evaluation import (
- switch_evaluation_mode,
- switch_evaluation_pipeline_scheduler,
-)
-from internlm.train import load_new_batch
+from internlm.eval import switch_evaluation_mode, switch_evaluation_pipeline_scheduler
+from internlm.core.trainer import load_new_batch
+
+# from internlm.core.context import ParallelMode
+from internlm.utils.config import Config
-# from internlm.core.context.parallel_context import global_context as gpc
+# from internlm.core.context import global_context as gpc
from tests.test_core.utils import build_environment, init_model_and_optim
micro_bszs = [1, 2]
diff --git a/tests/test_infer/test_generate.py b/tests/test_infer/test_generate.py
index 14741b494..ad8f36a3a 100644
--- a/tests/test_infer/test_generate.py
+++ b/tests/test_infer/test_generate.py
@@ -5,8 +5,10 @@
from sentencepiece import SentencePieceProcessor
from internlm.apis.inference import SequenceGenerator, batch_tokenize
-from internlm.initialize import initialize_distributed_env # noqa: E402
-from internlm.train import initialize_model_and_parallel_communicator
+from internlm.initialize import initialize_launcher # noqa: E402
+from internlm.initialize.initialize_model import (
+ initialize_model_and_parallel_communicator,
+)
def set_seed(seed: int = 1024):
@@ -23,7 +25,6 @@ def load_and_generate(path, model_type="INTERNLM2", tokenizer_path=""):
model_cfg = os.path.join(path, "model_config.pt")
model_wt = os.path.join(path, "model_tp0_pp0.pt")
model_config = torch.load(model_cfg)
- model_config["apply_post_layer_norm"] = False
if model_config.get("adapt_hf") is not None:
model_config.pop("adapt_hf")
evo_cfg = dict(
@@ -36,7 +37,7 @@ def load_and_generate(path, model_type="INTERNLM2", tokenizer_path=""):
sequence_parallel=0,
),
)
- initialize_distributed_env(evo_cfg, master_port=23574, args_check=False)
+ initialize_launcher(evo_cfg, distributed_port=23574, args_check=False)
tokenizer = SentencePieceProcessor(tokenizer_path) # pylint: disable=E1121
diff --git a/tests/test_infer/test_trainer_generate.py b/tests/test_infer/test_trainer_generate.py
index c3149dda3..4b6e0967d 100644
--- a/tests/test_infer/test_trainer_generate.py
+++ b/tests/test_infer/test_trainer_generate.py
@@ -3,23 +3,23 @@
import pytest
from sentencepiece import SentencePieceProcessor
-import internlm # noqa: E402
from internlm.apis.inference import SequenceGenerator, batch_tokenize
from internlm.checkpoint import CheckpointManager # noqa: E402
from internlm.core.context import global_context as gpc # noqa: E402
-from internlm.core.trainer import TrainState, Trainer # noqa: E402
+from internlm.core.trainer import Trainer, TrainState # noqa: E402
from internlm.data import build_train_loader_with_data_type # noqa: E402
-from internlm.initialize import initialize_distributed_env # noqa: E402
-from internlm.model.losses import InternLoss # noqa: E402
-from internlm.train import ( # noqa: E402
- get_scheduler_hooks,
+from internlm.initialize import initialize_launcher # noqa: E402
+from internlm.initialize.initialize_model import ( # noqa: E402
initialize_model_and_parallel_communicator,
- initialize_optimizer,
)
+from internlm.initialize.initialize_optimizer import initialize_optimizer
+from internlm.initialize import initialize_trainer
+from internlm.model.model_ops.losses import InternLoss # noqa: E402
+from internlm.core.trainer import get_scheduler_hooks # noqa: E402
def setup_generator(config, tokenizer):
- initialize_distributed_env(config=config)
+ initialize_launcher(config=config)
model, isp_communicator = initialize_model_and_parallel_communicator()
@@ -45,7 +45,7 @@ def setup_generator(config, tokenizer):
ckpt_manager.try_resume_training(train_state)
# initialize trainer
- engine, scheduler = internlm.initialize_trainer(
+ engine, scheduler = initialize_trainer(
model=model,
optimizer=optimizer,
criterion=criterion,
diff --git a/tests/test_model/test_embedding.py b/tests/test_model/test_embedding.py
index d8b58f552..b7252e376 100644
--- a/tests/test_model/test_embedding.py
+++ b/tests/test_model/test_embedding.py
@@ -3,7 +3,7 @@
import pytest
import torch
-from internlm.model.modules.embedding import Embedding1D
+from internlm.model.model_ops.modules.embedding import Embedding1D
from internlm.utils.common import get_current_device
from tests.common_fixture import find_free_port
from tests.test_model.test_model_internlm import build_environment, seed_all
diff --git a/tests/test_model/test_feed_forward.py b/tests/test_model/test_feed_forward.py
index 311f30d7e..c55e55716 100644
--- a/tests/test_model/test_feed_forward.py
+++ b/tests/test_model/test_feed_forward.py
@@ -1,7 +1,7 @@
import pytest
import torch
-from internlm.model.modules.mlp import new_feed_forward, split_fused_mlp_weight
+from internlm.model.model_ops.modules.mlp import new_feed_forward, split_fused_mlp_weight
from internlm.utils.common import get_current_device
SEQ_LEN = 64
diff --git a/tests/test_model/test_fused_precision/test_fused_precision.py b/tests/test_model/test_fused_precision/test_fused_precision.py
index d0b79aaef..98d3511c3 100644
--- a/tests/test_model/test_fused_precision/test_fused_precision.py
+++ b/tests/test_model/test_fused_precision/test_fused_precision.py
@@ -6,9 +6,11 @@
from torch import nn
from internlm.core.naive_amp import NaiveAMPModel, set_fp32_attr_to_module
-from internlm.model.modeling_internlm import InternLM1Decoder
-from internlm.train.pipeline import initialize_parallel_communicator
-from internlm.train.utils import create_param_groups
+from internlm.initialize.initialize_communicator import initialize_parallel_communicator
+from internlm.model.model_implementations.transformers.modeling_internlm import (
+ InternLM1Decoder,
+)
+from internlm.initialize.initialize_optimizer import create_param_groups
from internlm.utils.common import get_current_device
from tests.common_fixture import find_free_port
from tests.test_model.test_model_internlm import build_environment, seed_all
diff --git a/tests/test_model/test_model_internlm.py b/tests/test_model/test_model_internlm.py
index e2655d291..084702d59 100644
--- a/tests/test_model/test_model_internlm.py
+++ b/tests/test_model/test_model_internlm.py
@@ -6,25 +6,27 @@
import torch
from torch import nn
-import internlm
from internlm.accelerator import get_accelerator
from internlm.core.context import ParallelMode
-from internlm.core.context.parallel_context import Config
-from internlm.core.context.parallel_context import global_context as gpc
-from internlm.core.parallel.comm.tensor import (
+from internlm.core.context import global_context as gpc
+from internlm.core.parallel.comm import (
HeadTensorParallelCommunicator,
LinearRole,
TensorParallelCommunicator,
)
from internlm.core.parallel.comm.utils import gather_forward_split_backward
-from internlm.model.modeling_internlm import InternLM1Decoder
-from internlm.model.modules.linear import (
+from internlm.initialize import initialize_launcher
+from internlm.model.model_implementations.transformers.modeling_internlm import (
+ InternLM1Decoder,
+)
+from internlm.model.model_ops.modules.linear import (
ColumnParallelLinear,
RowParallelLinear,
ScaleColumnParallelLinear,
new_linear,
)
from internlm.utils.common import get_current_device
+from internlm.utils.config import Config
from tests.common_fixture import find_free_port
internlm_accelerator = get_accelerator()
@@ -52,14 +54,12 @@
model=dict(
checkpoint=False,
num_attention_heads=2,
- embed_split_hidden=True,
vocab_size=103168,
embed_grad_scale=1,
parallel_output=True,
hidden_size=1024,
num_layers=2,
mlp_ratio=1,
- apply_post_layer_norm=False,
dtype=torch.bfloat16,
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
@@ -83,8 +83,7 @@ def build_environment(rank, world_size, free_port):
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = free_port
internlm_accelerator.empty_cache()
- # launcher="torch"
- internlm.launch_from_torch(config=config, seed=1024)
+ initialize_launcher(config=config, launcher="torch", distributed_port=8888, seed=1024, args_check=False, dist_backend="nccl")
def seed_all(seed, cuda_deterministic=False):
diff --git a/tests/test_model/test_norm.py b/tests/test_model/test_norm.py
index 83861b365..2e32d8b1b 100644
--- a/tests/test_model/test_norm.py
+++ b/tests/test_model/test_norm.py
@@ -3,7 +3,7 @@
import pytest
import torch
-from internlm.model.modules.norm import new_layer_norm
+from internlm.model.model_ops.modules.norm import new_layer_norm
from internlm.utils.common import get_current_device
from tests.common_fixture import find_free_port
from tests.test_model.test_model_internlm import build_environment, seed_all
diff --git a/tests/test_model/test_npu_ops/test_flash_attention.py b/tests/test_model/test_npu_ops/test_flash_attention.py
index a2a8b91b8..96c11cde3 100644
--- a/tests/test_model/test_npu_ops/test_flash_attention.py
+++ b/tests/test_model/test_npu_ops/test_flash_attention.py
@@ -12,11 +12,14 @@
from torch import nn
from internlm.accelerator import AcceleratorType, get_accelerator
-from internlm.core.context import Config
from internlm.core.context import global_context as gpc
-from internlm.model.ops.attention import SelfAttention
-from internlm.model.ops.utils import pack_output_after_attn, unpack_qkv_before_attn
+from internlm.model.model_ops.ops.attention import SelfAttention
+from internlm.model.model_ops.ops.utils import (
+ pack_output_after_attn,
+ unpack_qkv_before_attn,
+)
from internlm.utils.common import get_current_device, set_random_seed
+from internlm.utils.config import Config
HEAD_NUM = 32
HIDDEN_SZIE = 4096
@@ -139,7 +142,7 @@ def npu_transform(B, S, N_KV, dtype):
def deeplink_fwd_transform(B, S, N_KV, dtype):
from deeplink_ext.internevo_ops import FlashSelfAttention
- from internlm.model.modules.multi_head_attention import CrossAttention
+ from internlm.model.model_ops.modules.multi_head_attention import CrossAttention
set_random_seed(1024)
softmax_scale = 1 / math.sqrt(HEAD_DIM)
diff --git a/tests/test_model/test_npu_ops/test_npu_rmsnorm.py b/tests/test_model/test_npu_ops/test_npu_rmsnorm.py
index adeb37c00..74116655f 100644
--- a/tests/test_model/test_npu_ops/test_npu_rmsnorm.py
+++ b/tests/test_model/test_npu_ops/test_npu_rmsnorm.py
@@ -2,8 +2,8 @@
import torch
from internlm.accelerator import AcceleratorType, get_accelerator
-from internlm.model.ops.norm import _RMSNorm as RMSNormTorch
-from internlm.model.ops.norm import _RMSNormNPU as RMSNormNPU
+from internlm.model.model_ops.ops.norm import _RMSNorm as RMSNormTorch
+from internlm.model.model_ops.ops.norm import _RMSNormNPU as RMSNormNPU
from internlm.utils.common import get_current_device
internlm_accelerator = get_accelerator()
diff --git a/tests/test_model/test_npu_ops/test_rotary_embed.py b/tests/test_model/test_npu_ops/test_rotary_embed.py
index 8fca38ce2..71f5d4312 100644
--- a/tests/test_model/test_npu_ops/test_rotary_embed.py
+++ b/tests/test_model/test_npu_ops/test_rotary_embed.py
@@ -3,7 +3,7 @@
from torch import nn
from internlm.accelerator import get_accelerator
-from internlm.model.ops.rotary_emb import (
+from internlm.model.model_ops.ops.rotary_emb import (
ApplyRotaryEmb,
rotary_emb_in_rotate_half_style,
)
diff --git a/tests/test_solver/test_optimizer.py b/tests/test_solver/test_optimizer.py
index ca470ffc9..617035c76 100644
--- a/tests/test_solver/test_optimizer.py
+++ b/tests/test_solver/test_optimizer.py
@@ -9,12 +9,13 @@
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
-import internlm
from internlm.accelerator import get_accelerator
-from internlm.core.context.parallel_context import Config, ParallelMode
-from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler
+from internlm.core.context import ParallelMode
+from internlm.core.parallel.comm import ParamAsyncBcastHandler
+from internlm.initialize import initialize_launcher
from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.utils.common import get_current_device
+from internlm.utils.config import Config
internlm_accelerator = get_accelerator()
@@ -95,8 +96,7 @@ def build_environment(rank, world_size):
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12345"
internlm_accelerator.empty_cache()
- # launcher="torch"
- internlm.launch_from_torch(config=config, seed=1024)
+ initialize_launcher(config=config, launcher="torch", distributed_port=8888, seed=1024, args_check=False, dist_backend="nccl")
def loose_close(a, b, dtype: torch.dtype = torch.float32):
diff --git a/tests/test_training/7B_check_acc.py b/tests/test_training/7B_check_acc.py
index 70b612c1d..eb8d32705 100644
--- a/tests/test_training/7B_check_acc.py
+++ b/tests/test_training/7B_check_acc.py
@@ -128,7 +128,6 @@
checkpoint=False,
num_chunks=1,
num_attention_heads=NUM_ATTENTION_HEAD,
- embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
@@ -136,7 +135,6 @@
num_layers=NUM_LAYER,
no_bias=True,
mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
diff --git a/tests/test_training/7B_check_init.py b/tests/test_training/7B_check_init.py
index 27794dd02..9097a47c6 100644
--- a/tests/test_training/7B_check_init.py
+++ b/tests/test_training/7B_check_init.py
@@ -133,7 +133,6 @@
checkpoint=False,
num_chunks=1,
num_attention_heads=NUM_ATTENTION_HEAD,
- embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
@@ -141,7 +140,6 @@
num_layers=NUM_LAYER,
no_bias=True,
mlp_ratio=MLP_RATIO,
- apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
diff --git a/tests/test_training/test_forward_output_no_fa.py b/tests/test_training/test_forward_output_no_fa.py
index ab81dbeed..f76cdcd52 100644
--- a/tests/test_training/test_forward_output_no_fa.py
+++ b/tests/test_training/test_forward_output_no_fa.py
@@ -7,21 +7,21 @@
import pytest
import torch
-import internlm
from internlm.accelerator import get_accelerator
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
-from internlm.core.context.parallel_context import Config
from internlm.core.trainer import Trainer
from internlm.data import build_train_loader_with_data_type
-from internlm.initialize.launch import args_sanity_check
-from internlm.model.losses import InternLoss
-from internlm.model.metrics import AccPerplex, SchedulerMetricHook
-from internlm.train import (
+from internlm.initialize import initialize_launcher
+from internlm.initialize.initialize_model import (
initialize_model_and_parallel_communicator,
- initialize_optimizer,
)
+from internlm.initialize.initialize_optimizer import initialize_optimizer
+from internlm.initialize import initialize_trainer
+from internlm.model.model_ops.losses import InternLoss
+from internlm.model.model_ops.metrics import AccPerplex, SchedulerMetricHook
from internlm.utils.common import get_current_device
+from internlm.utils.config import Config
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
@@ -54,14 +54,12 @@
model=dict(
checkpoint=True,
num_attention_heads=32,
- embed_split_hidden=True,
vocab_size=92544,
embed_grad_scale=1,
parallel_output=False,
hidden_size=4096,
num_layers=32,
mlp_ratio=8 / 3,
- apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
@@ -133,8 +131,7 @@ def build_environment(rank, world_size, free_port, config):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(free_port)
internlm_accelerator.empty_cache()
- internlm.launch_from_torch(config=config, seed=1024)
- args_sanity_check()
+ initialize_launcher(config=config, launcher="torch", distributed_port=8888, seed=1024, args_check=True, dist_backend="nccl")
def seed_all(seed, cuda_deterministic=False):
@@ -198,7 +195,7 @@ def train_check_output(args):
),
]
- engine, scheduler = internlm.initialize_trainer(
+ engine, scheduler = initialize_trainer(
model=model,
optimizer=optimizer,
criterion=criterion,
diff --git a/tests/test_training/test_load_ckpt_loss.py b/tests/test_training/test_load_ckpt_loss.py
index f9516c279..7a554622e 100644
--- a/tests/test_training/test_load_ckpt_loss.py
+++ b/tests/test_training/test_load_ckpt_loss.py
@@ -1,6 +1,7 @@
import multiprocessing as mp
from internlm.accelerator import get_accelerator
+from internlm.initialize.initialize_optimizer import initialize_optimizer
backup_ForkingPickler = mp.reduction.ForkingPickler
backup_dump = mp.reduction.dump
@@ -14,7 +15,6 @@
import torch # noqa: E402 #pylint: disable=wrong-import-position
import torch.distributed as dist # noqa: E402 #pylint: disable=wrong-import-position
-import internlm # noqa: E402 #pylint: disable=wrong-import-position
from internlm.checkpoint import ( # noqa: E402 #pylint: disable=wrong-import-position
CheckpointManager,
)
@@ -24,35 +24,39 @@
from internlm.core.context import ( # noqa: E402 #pylint: disable=wrong-import-position
global_context as gpc,
)
-from internlm.core.context.parallel_context import ( # noqa: E402 #pylint: disable=wrong-import-position
- Config,
-)
from internlm.core.trainer import ( # noqa: E402 #pylint: disable=wrong-import-position
- TrainState,
Trainer,
+ TrainState,
)
from internlm.data import ( # noqa: E402 #pylint: disable=wrong-import-position
build_train_loader_with_data_type,
)
-from internlm.initialize.launch import ( # noqa: E402 #pylint: disable=wrong-import-position
- args_sanity_check,
+from internlm.initialize import ( # noqa: E402 #pylint: disable=wrong-import-position
+ initialize_launcher
+)
+from internlm.initialize.initialize_model import ( # noqa: E402 #pylint: disable=wrong-import-position
+ initialize_model_and_parallel_communicator,
)
-from internlm.model.losses import ( # noqa: E402 #pylint: disable=wrong-import-position
+from internlm.initialize import ( # noqa: E402 #pylint: disable=wrong-import-position
+ initialize_trainer,
+)
+from internlm.model.model_ops.losses import ( # noqa: E402 #pylint: disable=wrong-import-position
InternLoss,
)
-from internlm.model.metrics import ( # noqa: E402 #pylint: disable=wrong-import-position
+from internlm.model.model_ops.metrics import ( # noqa: E402 #pylint: disable=wrong-import-position
AccPerplex,
SchedulerMetricHook,
)
-from internlm.train import ( # noqa: E402 #pylint: disable=wrong-import-position
- initialize_model_and_parallel_communicator,
- initialize_optimizer,
+from internlm.core.trainer import ( # noqa: E402 #pylint: disable=wrong-import-position
load_new_batch,
)
from internlm.utils.common import ( # noqa: E402 #pylint: disable=wrong-import-position
get_current_device,
launch_time,
)
+from internlm.utils.config import ( # noqa: E402 #pylint: disable=wrong-import-position
+ Config,
+)
from internlm.utils.logger import ( # noqa: E402 #pylint: disable=wrong-import-position
get_logger,
)
@@ -93,14 +97,12 @@
model=dict(
checkpoint=False,
num_attention_heads=16,
- embed_split_hidden=True,
vocab_size=103168,
embed_grad_scale=1,
parallel_output=True,
hidden_size=4096,
num_layers=16,
mlp_ratio=8 / 3,
- apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
@@ -173,9 +175,7 @@ def build_environment(rank, world_size, free_port, config):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(free_port)
internlm_accelerator.empty_cache()
- # launcher="torch"
- internlm.launch_from_torch(config=config, seed=1024)
- args_sanity_check()
+ initialize_launcher(config=config, launcher="torch", distributed_port=8888, seed=1024, args_check=True, dist_backend="nccl")
def seed_all(seed, cuda_deterministic=False):
@@ -265,7 +265,7 @@ def train_model(args):
),
]
- engine, scheduler = internlm.initialize_trainer(
+ engine, scheduler = initialize_trainer(
model=model,
optimizer=optimizer,
criterion=criterion,
diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py
index fb6cd820b..904a757b3 100644
--- a/tests/test_training/test_loss.py
+++ b/tests/test_training/test_loss.py
@@ -1,28 +1,36 @@
import math
import os
+from functools import reduce
import pytest
import torch
import torch.distributed as dist
-import internlm
from internlm.accelerator import AcceleratorType, get_accelerator
from internlm.checkpoint import CheckpointManager
-from internlm.core.context import Config, ParallelMode
+from internlm.checkpoint.load_funcs import LOAD_FUNC_DICT
+from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
-from internlm.core.trainer import Trainer, TrainState
-from internlm.data import build_train_loader_with_data_type
-from internlm.initialize import initialize_distributed_env
-from internlm.model.losses import InternLoss
-from internlm.train import (
+from internlm.core.parallel.shard import partition_uniform
+from internlm.core.trainer import (
+ Trainer,
+ TrainState,
get_scheduler_hooks,
- initialize_model_and_parallel_communicator,
- initialize_optimizer,
load_new_batch,
)
+from internlm.data import build_train_loader_with_data_type
+from internlm.initialize import initialize_launcher, initialize_trainer
+from internlm.initialize.initialize_model import (
+ initialize_model_and_parallel_communicator,
+)
+from internlm.initialize.initialize_optimizer import initialize_optimizer
+from internlm.model.model_ops.losses import InternLoss
+from internlm.model.model_ops.utils import get_parallel_size_from_file
from internlm.utils.common import BatchSkipper, launch_time
+from internlm.utils.config import Config
from internlm.utils.gputest import empty_cache_and_diag
from internlm.utils.megatron_timers import megatron_timer as timer
+from internlm.utils.storage_manager import get_fns, llm_load
CONFIG_FILE_PATH = os.getenv("CONFIG_FILE_PATH", "./configs/7B_internlm2.py")
INTERNLM2_CKPT_PATH = os.path.join(os.environ["share_path"], "quailty_assurance/test_loss_pri/model_ckpt")
@@ -43,11 +51,199 @@
4.799427032470703,
]
-
cur_loss_list = []
internlm_accelerator = get_accelerator()
+def load_internlm2_with_dynamic_parallel_size(folder, model):
+ """Load InternLM2 with dynamic parallel size."""
+ assert folder is not None, "Please specify the folder of the pretrained model"
+ assert gpc.config.model_type in ["INTERNLM2"], "dynamic_parallel is only for INTERNLM2"
+
+ fns = get_fns(folder)
+ model_fns, old_tp, old_pp = get_parallel_size_from_file(fns) # pylint: disable=W0612
+
+ tp = gpc.get_world_size(ParallelMode.TENSOR)
+ tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
+ assert old_tp % tp == 0 or tp % old_tp == 0, (
+ f"Expected TP size in loaded checkpoint to be fit with TP size in current config, but got {old_tp} in "
+ f"checkpoint and {tp} in current config"
+ )
+
+ correspond_tps = []
+
+ if old_tp <= tp:
+ correspond_tps.append(tp_rank // (tp // old_tp))
+ ratio = tp // old_tp
+ rank = tp_rank % ratio
+ else:
+ for i in range(old_tp // tp):
+ correspond_tps.append(tp_rank * (old_tp // tp) + i)
+ rank = 0
+ ratio = 1
+
+ current_states = {}
+
+ pp = gpc.get_world_size(ParallelMode.PIPELINE) # noqa: F841 # pylint: disable=W0612
+
+ assert gpc.config.model.num_chunks == 1, "May cause future collisions, ignore this if necessary"
+
+ old_pp_partition = partition_uniform(gpc.config.model.num_layers, old_pp, 1)
+
+ for idx, parts in enumerate(old_pp_partition):
+ start, end = parts[0]
+ if model.last_layer <= start or model.first_layer >= end:
+ continue
+ tmp_states = {}
+
+ for correspond_tp in correspond_tps:
+ model_name = f"model_tp{correspond_tp}_pp{idx}.pt"
+ states = llm_load(os.path.join(folder, model_name), map_location="cpu")
+ states = {k.replace("model.", ""): v for k, v in states.items()}
+ for i in range(start, end):
+ if i >= model.last_layer:
+ break
+ if i < model.first_layer:
+ continue
+
+ for name in list(states.keys()):
+ if f".{i-start}." in name:
+ to_name = name.replace(f".{i-start}.", f".{i-model.first_layer}.")
+
+ if gpc.config.model_type == "INTERNLM2":
+ if "norm" in name:
+ tmp_states[to_name] = [states.pop(name)]
+ elif any(x in name for x in ("wo", "w2")):
+ tmp_states[to_name] = tmp_states.get(to_name, [])
+ tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=1)[rank])
+ elif any(x in name for x in ("w1", "w3")):
+ tmp_states[to_name] = tmp_states.get(to_name, [])
+ tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=0)[rank])
+ elif any(x in name for x in ("wqkv",)):
+ tmp_states[to_name] = tmp_states.get(to_name, [])
+ if tp > gpc.config.model.num_kv_attention_heads:
+ assert old_tp <= gpc.config.model.num_kv_attention_heads, (
+ f"`old_tp ({old_tp}) => tp ({tp})` is not supported. "
+ "At least one of `tp` and `old_tp` should be less than or "
+ "equal to `num_kv_attention_heads`"
+ )
+ # Suitable for cases where the num_kv_attention_head is small,
+ # but you want to have a large TP Size
+ q_per_kv = (
+ gpc.config.model.num_attention_heads // gpc.config.model.num_kv_attention_heads
+ )
+ head_dim = gpc.config.model.hidden_size // gpc.config.model.num_attention_heads
+ index = torch.concat(
+ (
+ torch.arange(q_per_kv).chunk(ratio, dim=0)[tp_rank % ratio],
+ torch.tensor([q_per_kv, q_per_kv + 1]),
+ )
+ )
+ index = index + (q_per_kv + 2) * (tp_rank // ratio)
+ index = index % (
+ (q_per_kv + 2) * (gpc.config.model.num_kv_attention_heads / old_tp)
+ )
+ index = index * head_dim
+ index = index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat(
+ index.shape[0]
+ )
+ tmp_states[to_name].append(
+ torch.index_select(states.pop(name), 0, index.to(torch.int32))
+ )
+ else:
+ tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=0)[rank])
+ else:
+ raise KeyError(f"Unknown key {name}.")
+
+ else:
+ assert False, "unsupported model type"
+
+ if "tok_embeddings.weight" in states and model.first_layer == 0:
+ tmp_states["tok_embeddings.weight"] = tmp_states.get("tok_embeddings.weight", [])
+ tmp_states["tok_embeddings.weight"].append(states["tok_embeddings.weight"].chunk(ratio, dim=1)[rank])
+ if "output.weight" in states and model.last_layer == gpc.config.model.num_layers:
+ tmp_states["norm.weight"] = [states["norm.weight"]]
+ tmp_states["output.weight"] = tmp_states.get("output.weight", [])
+ tmp_states["output.weight"].append(states["output.weight"].chunk(ratio, dim=0)[rank])
+
+ states = {}
+
+ for name in list(tmp_states.keys()):
+ data = tmp_states.pop(name)
+ if len(data) == 1:
+ current_states[name] = data[0]
+ else:
+ current_states[name] = torch.concat(
+ data, dim=1 if name == "tok_embeddings.weight" or any(x in name for x in ("wo", "w2")) else 0
+ )
+ # Merge copied kv heads
+ if "wqkv" in name and old_tp > gpc.config.model.num_kv_attention_heads:
+ assert (
+ tp <= gpc.config.model.num_kv_attention_heads
+ ), "new_tp should be less than or equal to num_kv_attention_heads"
+ head_dim = gpc.config.model.hidden_size // gpc.config.model.num_attention_heads
+ q_per_kv = gpc.config.model.num_attention_heads // gpc.config.model.num_kv_attention_heads
+ copied_times = old_tp // gpc.config.model.num_kv_attention_heads
+ cur_q_per_kv = q_per_kv // copied_times
+
+ # pylint: disable=all
+ def duplicate_kv_index(i):
+ if i % (cur_q_per_kv + 2) >= cur_q_per_kv:
+ return i
+ else:
+ return -100
+
+ def unique_kv_index(i):
+ if i // (cur_q_per_kv + 2) == copied_times - 1 or i % (cur_q_per_kv + 2) < cur_q_per_kv:
+ return i
+ else:
+ return -100
+
+ # pylint: enable=all
+
+ # Verify
+ duplicate_index = [duplicate_kv_index(i) for i in range((cur_q_per_kv + 2) * copied_times)]
+ duplicate_index = [i for i in duplicate_index if i != -100]
+ duplicate_index = _duplicate_index = torch.tensor(duplicate_index)
+ for i in range(gpc.config.model.num_kv_attention_heads // tp - 1):
+ duplicate_index = torch.concat(
+ (duplicate_index, _duplicate_index + duplicate_index.max() + 1), dim=0
+ )
+ duplicate_kv = []
+ for index in duplicate_index.reshape(-1, copied_times * 2).chunk(copied_times, dim=-1):
+ index = index.reshape(-1) * head_dim
+ index = index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat(index.shape[0])
+ duplicate_kv.append(torch.index_select(current_states[name], 0, index))
+ assert reduce(
+ lambda x, y: x and y,
+ [torch.allclose(duplicate_kv[0], x, atol=1e-5) for x in duplicate_kv[1:]],
+ ), "Copied kv heads are not equal after training!"
+
+ # Merge
+ unique_index = [unique_kv_index(i) for i in range((cur_q_per_kv + 2) * copied_times)]
+ unique_index = [i for i in unique_index if i != -100]
+ unique_index = _unique_index = torch.tensor(unique_index)
+ for i in range(gpc.config.model.num_kv_attention_heads // tp - 1):
+ unique_index = torch.concat((unique_index, _unique_index + unique_index.max() + 1), dim=0)
+ unique_index = unique_index * head_dim
+ unique_index = unique_index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat(
+ unique_index.shape[0]
+ )
+ current_states[name] = torch.index_select(current_states[name], 0, unique_index)
+ missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)
+
+ if gpc.get_local_rank(ParallelMode.DATA) == 0:
+ pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
+ print(
+ f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
+ f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}",
+ flush=True,
+ )
+
+
+LOAD_FUNC_DICT["internlm2_test"] = load_internlm2_with_dynamic_parallel_size
+
+
def train(
dp_size: int = 1,
tp_size: int = 1,
@@ -129,7 +325,7 @@ def train(
config.model.parallel_output = False
config.model.checkpoint = True
- initialize_distributed_env(config=config, launcher=launcher)
+ initialize_launcher(config=config, launcher=launcher)
assert hasattr(gpc, "config") and gpc.config is not None
gpc.config.ckpt.need_metadata = False
@@ -201,7 +397,7 @@ def train(
metric = None
# initialize trainer
- engine, scheduler = internlm.initialize_trainer(
+ engine, scheduler = initialize_trainer(
model=model,
optimizer=optimizer,
criterion=criterion,
diff --git a/tests/test_training/test_no_fa_train_temp.py b/tests/test_training/test_no_fa_train_temp.py
index 0b0493bb2..e0715d67d 100644
--- a/tests/test_training/test_no_fa_train_temp.py
+++ b/tests/test_training/test_no_fa_train_temp.py
@@ -2,19 +2,19 @@
import pytest
-import internlm
from internlm.accelerator import get_accelerator
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.trainer import Trainer
from internlm.data import build_train_loader_with_data_type
-from internlm.model.losses import InternLoss
-from internlm.model.metrics import AccPerplex
-from internlm.train import (
- get_scheduler_hooks,
+from internlm.initialize.initialize_model import (
initialize_model_and_parallel_communicator,
- initialize_optimizer,
)
+from internlm.initialize.initialize_optimizer import initialize_optimizer
+from internlm.initialize import initialize_trainer
+from internlm.model.model_ops.losses import InternLoss
+from internlm.model.model_ops.metrics import AccPerplex
+from internlm.core.trainer import get_scheduler_hooks
from internlm.utils.logger import get_logger
from tests.common_fixture import (
build_environment,
@@ -50,7 +50,7 @@ def train_check(args):
# set seed
seed_all(1024)
- # initialize model and isp communicator
+ # initialize model and isp communicator
model, isp_communicator = initialize_model_and_parallel_communicator()
# initialize loss function
@@ -67,7 +67,7 @@ def train_check(args):
dataset_types=dataset_types,
)
- engine, scheduler = internlm.initialize_trainer(
+ engine, scheduler = initialize_trainer(
model=model,
optimizer=optimizer,
criterion=criterion,
diff --git a/tests/test_training/test_norm_weight.py b/tests/test_training/test_norm_weight.py
index 1306da69b..c6c1be04e 100644
--- a/tests/test_training/test_norm_weight.py
+++ b/tests/test_training/test_norm_weight.py
@@ -5,19 +5,19 @@
import pytest
import torch
-import internlm
from internlm.accelerator import get_accelerator
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.trainer import Trainer
from internlm.data import build_train_loader_with_data_type
-from internlm.model.losses import InternLoss
-from internlm.model.metrics import AccPerplex
-from internlm.train import (
- get_scheduler_hooks,
+from internlm.initialize.initialize_model import (
initialize_model_and_parallel_communicator,
- initialize_optimizer,
)
+from internlm.initialize.initialize_optimizer import initialize_optimizer
+from internlm.initialize import initialize_trainer
+from internlm.model.model_ops.losses import InternLoss
+from internlm.model.model_ops.metrics import AccPerplex
+from internlm.core.trainer import get_scheduler_hooks
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
from tests.common_fixture import (
@@ -87,7 +87,7 @@ def train_check_norm_weight(args):
dataset_types=dataset_types,
)
- engine, scheduler = internlm.initialize_trainer(
+ engine, scheduler = initialize_trainer(
model=model,
optimizer=optimizer,
criterion=criterion,
diff --git a/tests/test_training/test_swap_nb_loss_and_gradnorm.py b/tests/test_training/test_swap_nb_loss_and_gradnorm.py
index 84b79d9f0..5780cf609 100644
--- a/tests/test_training/test_swap_nb_loss_and_gradnorm.py
+++ b/tests/test_training/test_swap_nb_loss_and_gradnorm.py
@@ -9,25 +9,25 @@
import torch.distributed as dist
from tqdm import tqdm
-import internlm
from internlm.accelerator import get_accelerator
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
-from internlm.core.context.parallel_context import Config
from internlm.core.trainer import Trainer
from internlm.data import (
build_train_loader_with_data_type,
build_valid_loader_with_data_type,
)
-from internlm.eval.evaluation import switch_evaluation_mode
-from internlm.initialize.launch import args_sanity_check
-from internlm.model.losses import InternLoss
-from internlm.model.metrics import AccPerplex, SchedulerMetricHook
-from internlm.train import (
+from internlm.eval import switch_evaluation_mode
+from internlm.initialize import initialize_launcher
+from internlm.initialize.initialize_model import (
initialize_model_and_parallel_communicator,
- initialize_optimizer,
)
+from internlm.initialize.initialize_optimizer import initialize_optimizer
+from internlm.initialize import initialize_trainer
+from internlm.model.model_ops.losses import InternLoss
+from internlm.model.model_ops.metrics import AccPerplex, SchedulerMetricHook
from internlm.utils.common import get_current_device
+from internlm.utils.config import Config
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
@@ -63,14 +63,12 @@
model=dict(
checkpoint=False,
num_attention_heads=16,
- embed_split_hidden=True,
vocab_size=103168,
embed_grad_scale=1,
parallel_output=True,
hidden_size=4096,
num_layers=16,
mlp_ratio=8 / 3,
- apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
@@ -136,9 +134,7 @@ def build_environment(rank, world_size, config):
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "33333"
internlm_accelerator.empty_cache()
- # launcher="torch"
- internlm.launch_from_torch(config=config, seed=1024)
- args_sanity_check()
+ initialize_launcher(config=config, launcher="torch", distributed_port=8888, seed=1024, args_check=True, dist_backend="nccl")
def seed_all(seed, cuda_deterministic=False):
@@ -302,7 +298,7 @@ def exam_loss(args):
),
]
- engine, scheduler = internlm.initialize_trainer(
+ engine, scheduler = initialize_trainer(
model=model,
optimizer=optimizer,
criterion=criterion,
diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py
index 5edd08f39..a14039d69 100644
--- a/tests/test_training/train_CI.py
+++ b/tests/test_training/train_CI.py
@@ -12,11 +12,13 @@
import torch
import torch.distributed as dist
+from internlm.initialize.initialize_optimizer import initialize_optimizer
+from internlm.initialize.initialize_profiler import initialize_llm_profile
+
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(script_dir, "../../"))
sys.path.append(project_root)
-import internlm # noqa: E402
from internlm.checkpoint import CheckpointManager # noqa: E402
from internlm.core.context import ParallelMode # noqa: E402
from internlm.core.context import global_context as gpc # noqa: E402
@@ -25,21 +27,21 @@
build_train_loader_with_data_type,
build_valid_loader_with_data_type,
)
-from internlm.eval.evaluation import evaluate_on_val_dls # noqa: E402
-from internlm.initialize import initialize_distributed_env # noqa: E402
-from internlm.model.losses import InternLoss # noqa: E402
-from internlm.model.metrics import AccPerplex, SchedulerMetricHook # noqa: E402
-from internlm.monitor import ( # noqa: E402
- initialize_monitor_manager,
- send_alert_message,
-)
-from internlm.monitor.monitor import monitor_manager as mm # noqa: E402
-from internlm.train import ( # noqa: E402
- initialize_llm_profile,
+from internlm.eval import evaluate_on_val_dls # noqa: E402
+from internlm.initialize import initialize_launcher # noqa: E402
+from internlm.initialize.initialize_model import ( # noqa: E402
initialize_model_and_parallel_communicator,
- initialize_optimizer,
- record_current_batch_training_metrics,
)
+from internlm.initialize import initialize_trainer # noqa: E402
+from internlm.model.model_ops.losses import InternLoss # noqa: E402
+from internlm.model.model_ops.metrics import ( # noqa: E402
+ AccPerplex,
+ SchedulerMetricHook,
+)
+from internlm.monitor import initialize_monitor_manager # noqa: E402
+from internlm.monitor import monitor_manager as mm # noqa: E402
+from internlm.monitor import send_alert_message # noqa: E402
+from internlm.core.trainer import record_current_batch_training_metrics # noqa: E402
from internlm.utils.common import ( # noqa: E402
BatchSkipper,
get_current_device,
@@ -131,7 +133,7 @@ def main(args):
current_time = objs[0]
# initialize model
- model , _ = initialize_model_and_parallel_communicator()
+ model, _ = initialize_model_and_parallel_communicator()
with open(args.config, "r") as f:
config_lines = f.readlines()
@@ -196,7 +198,7 @@ def main(args):
),
]
- engine, scheduler = internlm.initialize_trainer(
+ engine, scheduler = initialize_trainer(
model=model,
optimizer=optimizer,
criterion=criterion,
@@ -377,7 +379,7 @@ def main(args):
hostname = socket.gethostname()
# initialize distributed environment
- initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
+ initialize_launcher(config=args.config, launcher=args.launcher, distributed_port=args.port, seed=args.seed)
assert hasattr(gpc, "config") and gpc.config is not None
# initialize monitor manager context
diff --git a/tests/test_utils/common_fixture.py b/tests/test_utils/common_fixture.py
index 6e0f40892..a1d882403 100644
--- a/tests/test_utils/common_fixture.py
+++ b/tests/test_utils/common_fixture.py
@@ -6,13 +6,13 @@
import torch
from internlm.core.context import global_context as gpc
-from internlm.core.context.parallel_context import Config
from internlm.core.naive_amp import NaiveAMPModel
-from internlm.model.builder import create_model
-from internlm.model.registry import register_model_initializer
-from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
-from internlm.train.utils import create_param_groups
+from internlm.model.model_implementations.builder import create_model
+from internlm.model.model_implementations.registry import register_model_initializer
+from internlm.solver.optimizer import HybridZeroOptimizer
+from internlm.initialize.initialize_optimizer import create_param_groups
from internlm.utils.common import SingletonMeta
+from internlm.utils.config import Config
OSS_NAME = os.environ.get("OSS_BUCKET_NAME", None)
OSS_IP = os.environ.get("OSS_IP", None)
@@ -67,14 +67,12 @@
model=dict(
checkpoint=False,
num_attention_heads=2,
- embed_split_hidden=True,
vocab_size=103168,
embed_grad_scale=1,
parallel_output=True,
hidden_size=1024,
num_layers=2,
mlp_ratio=1,
- apply_post_layer_norm=False,
dtype=torch.bfloat16,
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
@@ -154,21 +152,21 @@ def reset_singletons():
def reset_seed():
- from internlm.core.context.random import _SEED_MANAGER
+ from internlm.core.context import _SEED_MANAGER
_SEED_MANAGER.reset()
@pytest.fixture(scope="module")
def init_dist_and_model(rank=0, world_size=1):
- from internlm.initialize import initialize_distributed_env
+ from internlm.initialize import initialize_launcher
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12377"
- initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False)
+ initialize_launcher(config=init_config, launcher="torch", distributed_port=12377, args_check=False)
# setup
print("set up", flush=True)
diff --git a/tests/test_utils/test_model_checkpoint.py b/tests/test_utils/test_model_checkpoint.py
index 5fe8b3c49..83b645dbe 100644
--- a/tests/test_utils/test_model_checkpoint.py
+++ b/tests/test_utils/test_model_checkpoint.py
@@ -10,10 +10,10 @@
import torch.distributed as dist
from internlm.checkpoint import CheckpointManager
-from internlm.core.context.parallel_context import Config
from internlm.core.trainer import TrainState
-from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
+from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.utils.common import SingletonMeta
+from internlm.utils.config import Config
from internlm.utils.storage_manager import wait_async_upload_finish
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
ASYNC_TMP_FOLDER,
@@ -28,38 +28,6 @@
# (TOTAL_STEP, CKPT_EVERY, SNPASHOT_EVERY)
step_info_list = [(8, 4, 2), (3, 4, 2), (1, 6, 3)]
ckpt_config_list = [
- # Old interface format
- dict(
- enable_save_ckpt=True,
- save_ckpt_folder=BOTO_SAVE_PATH,
- load_optimizer=True,
- checkpoint_every=0,
- async_upload=True,
- async_upload_tmp_folder=ASYNC_TMP_FOLDER,
- snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]) if BOTO_SAVE_PATH is not None else None,
- oss_snapshot_freq=0,
- stop_file_path=None,
- load_model_only_folder=None,
- load_given_ckpt=False,
- load_ckpt_folder=None,
- is_old_api=True,
- ),
- # Old interface format
- dict(
- enable_save_ckpt=True,
- save_ckpt_folder=LOCAL_SAVE_PATH,
- load_optimizer=True,
- checkpoint_every=0,
- async_upload=False,
- async_upload_tmp_folder=ASYNC_TMP_FOLDER,
- snapshot_ckpt_folder="/".join([LOCAL_SAVE_PATH, "snapshot"]),
- oss_snapshot_freq=0,
- stop_file_path=None,
- load_model_only_folder=None,
- load_given_ckpt=False,
- load_ckpt_folder=None,
- is_old_api=True,
- ),
# New interface format
dict(
enable_save_ckpt=True,
@@ -201,8 +169,8 @@ def return_latest_save_path(save_ckpt_folder, total_step, snapshot_freq, ckpt_fr
@pytest.mark.parametrize("step_info", step_info_list)
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
def test_ckpt_mm(step_info, ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-import
- from internlm.core.context import global_context as gpc
from internlm.checkpoint.checkpoint_manager import CheckpointLoadMask
+ from internlm.core.context import global_context as gpc
ckpt_config = Config(ckpt_config)
total_step, checkpoint_every, oss_snapshot_freq = step_info
@@ -297,9 +265,9 @@ def test_ckpt_mm(step_info, ckpt_config, init_dist_and_model): # noqa # pylint:
def query_quit_file(rank, world_size=2):
- from internlm.core.context import global_context as gpc
- from internlm.initialize import initialize_distributed_env
from internlm.checkpoint.checkpoint_manager import CheckpointSaveType
+ from internlm.core.context import global_context as gpc
+ from internlm.initialize import initialize_launcher
ckpt_config = Config(
dict(
@@ -325,7 +293,7 @@ def query_quit_file(rank, world_size=2):
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12376"
- initialize_distributed_env(config=init_config, launcher="torch", master_port=12376, args_check=False)
+ initialize_launcher(config=init_config, launcher="torch", distributed_port=12376, args_check=False)
train_state = TrainState(init_config, None)
ckpt_mm = CheckpointManager(ckpt_config, model=None, optimizer=None)
if rank == 0:
diff --git a/tests/test_utils/test_storage_manager.py b/tests/test_utils/test_storage_manager.py
index 9454a8369..57021e5ce 100644
--- a/tests/test_utils/test_storage_manager.py
+++ b/tests/test_utils/test_storage_manager.py
@@ -3,8 +3,7 @@
import pytest
import torch
-from internlm.core.context.parallel_context import Config
-from internlm.initialize.launch import get_config_value
+from internlm.utils.config import Config, get_config_value
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
ALI_SAVE_PATH,
BOTO_SAVE_PATH,
diff --git a/tests/test_utils/test_timeout.py b/tests/test_utils/test_timeout.py
index 49a49d27e..4f9cd47a8 100644
--- a/tests/test_utils/test_timeout.py
+++ b/tests/test_utils/test_timeout.py
@@ -65,14 +65,14 @@ def local_timeout(rank, _):
def gpc_timeout(rank, world_size):
- from internlm.initialize import initialize_distributed_env
+ from internlm.initialize import initialize_launcher
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12377"
- initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False)
+ initialize_launcher(config=init_config, launcher="torch", distributed_port=12377, args_check=False)
try:
nccl_timeout_func(rank)
diff --git a/tools/README.md b/tools/README.md
index a24040cae..9861ce39f 100644
--- a/tools/README.md
+++ b/tools/README.md
@@ -160,7 +160,6 @@ LLaMA 7B推理的例子:
num_chunks=1,
checkpoint=0.2,
dtype="torch.bfloat16",
- embed_split_hidden=True,
num_layers=32,
hidden_size=4096,
vocab_size=32000,
@@ -171,7 +170,6 @@ LLaMA 7B推理的例子:
mlp_ratio=2.675,
use_flash_attn=True,
norm_type="rmsnorm",
- apply_post_layer_norm=False,
no_bias=True,
layer_norm_epsilon=1e-5,
),
diff --git a/tools/load_internlm2_model.py b/tools/load_internlm2_model.py
index 4b639003e..121c0842e 100644
--- a/tools/load_internlm2_model.py
+++ b/tools/load_internlm2_model.py
@@ -10,8 +10,10 @@
from internlm.apis.inference import SequenceGenerator
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
-from internlm.initialize.launch import initialize_distributed_env
-from internlm.train import initialize_model_and_parallel_communicator
+from internlm.initialize import initialize_launcher
+from internlm.initialize.initialize_model import (
+ initialize_model_and_parallel_communicator,
+)
from internlm.utils.storage_manager import get_fns, init_storage_manager, llm_load
from tools.interface import GenerationConfig
@@ -180,7 +182,7 @@ def initialize_internlm_model(
if gpc.is_rank_for_log():
logger.info(f"model_config: {model_config}.")
- initialize_distributed_env(
+ initialize_launcher(
config=dict(
model_type=model_type,
model=model_config,
@@ -193,7 +195,7 @@ def initialize_internlm_model(
),
launcher="torch" if use_torchrun_starter() else "slurm",
seed=seed,
- master_port=23574,
+ distributed_port=23574,
args_check=False,
)
# Directly get the origin model without NativeAMP wrapper.
@@ -284,7 +286,6 @@ def get_default_parser():
num_chunks=1,
checkpoint=0.2,
dtype="torch.bfloat16",
- embed_split_hidden=True,
num_layers=32,
hidden_size=4096,
vocab_size=92544,
@@ -296,7 +297,6 @@ def get_default_parser():
use_flash_attn=True,
norm_type="rmsnorm",
qk_interleaved=True,
- apply_post_layer_norm=False,
no_bias=True,
layer_norm_epsilon=1e-5,
rope_base=1000000,
diff --git a/tools/moe_group_ckpt_converter.py b/tools/moe_group_ckpt_converter.py
index d3fefb7c7..e07d6a273 100644
--- a/tools/moe_group_ckpt_converter.py
+++ b/tools/moe_group_ckpt_converter.py
@@ -8,7 +8,6 @@
from tqdm import tqdm
sys.path.append(".")
-import internlm # noqa: E402,F401 # pylint: disable=W0611,C0413
moe_str_prefix = None
weight_key_suffix = ".weight"
diff --git a/tools/tokenizer.py b/tools/tokenizer.py
index e67874f9d..8eba0c64a 100644
--- a/tools/tokenizer.py
+++ b/tools/tokenizer.py
@@ -7,7 +7,7 @@
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, "tokenizer_internlm.model")
-sys.path.append(os.path.join(current_dir, "../transformers"))
+sys.path.append(os.path.join(current_dir, "../huggingface_models"))
from internlm_model import InternLMTokenizer # noqa: E402 # pylint: disable=C0413
tokenizer = InternLMTokenizer(vocab_file=model_path, add_bos_token=True, add_eos_token=True)
diff --git a/train.py b/train.py
deleted file mode 100755
index 6e5e1399f..000000000
--- a/train.py
+++ /dev/null
@@ -1,43 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from internlm.core.context import global_context as gpc
-from internlm.core.trainer_builder import TrainerBuilder
-from internlm.data import (
- build_train_loader_with_data_type,
- build_valid_loader_with_data_type,
-)
-from internlm.initialize import initialize_distributed_env
-from internlm.model.builder import create_model
-from internlm.monitor import internevo_monitor
-from internlm.utils.common import parse_args
-
-
-@internevo_monitor(feishu_alert=True, clean_run=True)
-def main(args):
- # initialize model
- model = create_model()
-
- # initialize train dataloader
- train_dl, dataset_types = build_train_loader_with_data_type()
-
- # initialize validation dataloader
- val_dls = build_valid_loader_with_data_type()
-
- # build trainer
- merged_args = {**vars(args), "dataset_types": dataset_types}
- trainer = TrainerBuilder(model, train_dl, val_dls, **merged_args)
-
- # training
- trainer.fit()
-
-
-if __name__ == "__main__":
- args = parse_args()
-
- # Initialize distributed environment
- initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
- assert hasattr(gpc, "config") and gpc.config is not None
-
- # Run the main function with parsed arguments
- main(args)
diff --git a/train.py b/train.py
new file mode 120000
index 000000000..744178299
--- /dev/null
+++ b/train.py
@@ -0,0 +1 @@
+internlm/launcher/launch.py
\ No newline at end of file
diff --git a/version.txt b/version.txt
index be14282b7..c52db9804 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-0.5.3
+0.5.3
\ No newline at end of file
diff --git a/web_demo_internlm.py b/web_demo_internlm.py
index abe0568e7..b89e2ae24 100644
--- a/web_demo_internlm.py
+++ b/web_demo_internlm.py
@@ -21,14 +21,12 @@
"internlm-chat-7b": dict(
checkpoint=False,
num_attention_heads=32,
- embed_split_hidden=True,
vocab_size=103168,
embed_grad_scale=1,
parallel_output=False,
hidden_size=4096,
num_layers=32,
mlp_ratio=8 / 3,
- apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
@@ -39,14 +37,12 @@
"internlm-chat-7b-v1.1": dict(
checkpoint=False,
num_attention_heads=32,
- embed_split_hidden=True,
vocab_size=103168,
embed_grad_scale=1,
parallel_output=False,
hidden_size=4096,
num_layers=32,
mlp_ratio=8 / 3,
- apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
|