Skip to content

Commit bf8b40e

Browse files
committed
fix and update readme
1 parent b536cd3 commit bf8b40e

11 files changed

+140
-306
lines changed

gen_profiler_data.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
21
from internlm.simulator.profiler.perf_comm import gen_perf
32

43
if __name__ == "__main__":
5-
gen_perf()
4+
gen_perf()

internlm/core/context/parallel_context.py

Lines changed: 15 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,13 @@
1818
import torch.distributed as dist
1919

2020
from internlm.accelerator import get_accelerator
21-
from internlm.core.context.process_group_initializer_simplified import Initializer, ParallelMeta
22-
from internlm.utils.common import SingletonMeta
21+
from internlm.utils.common import SingletonMeta, get_args
2322
from internlm.utils.logger import get_logger
2423
from internlm.utils.timeout import LLM_NCCL_TIMEOUT
2524

2625
from . import process_group_initializer as pgroup_initializer
27-
from .process_group_initializer_simplified import ParallelMode
26+
from .process_group_initializer import ParallelMode
2827
from .random import add_seed, get_seeds, set_mode
29-
from internlm.utils.common import get_args
3028

3129
IS_REPLICA_ZERO_PARALLEL = "is_replica_zero_parallel"
3230
# for isp, with optimizer split in dp group
@@ -422,20 +420,6 @@ def init_global_dist(
422420
use_cpu (bool): whether to set up cpu process group.
423421
"""
424422

425-
# find cluster info
426-
if "clusters" not in self.config:
427-
nv_info = {
428-
"rank_range": [0, 8],
429-
"peak_tflops": 320,
430-
"capacity": 80 * 1024**3,
431-
"intra_bw": 150,
432-
"inter_bw": 100,
433-
}
434-
self.set_cluster_info("nv_cluster", nv_info)
435-
else:
436-
for cluster in self.config.clusters:
437-
self.clusters.append(ClusterInfo(**cluster))
438-
439423
# initialize the default process group
440424
if not fake_mode:
441425
init_method = f"tcp://[{host}]:{port}"
@@ -576,8 +560,7 @@ def init_parallel_groups(self, fake_mode: bool = False):
576560
self._set_parallel_size_from_config(parallel_config, "tensor", "tensor_parallel_size")
577561
self._set_parallel_size_from_config(parallel_config, "pipeline", "pipeline_parallel_size")
578562
self._set_parallel_size_from_config(parallel_config, "zero1", "zero1_parallel_size")
579-
580-
563+
581564
if get_args().use_simplified_gp_init:
582565
self._init_use_simplified_pg(rank, world_size, parallel_config)
583566
else:
@@ -592,10 +575,7 @@ def _init_pg(self, rank, world_size, parallel_config):
592575
1, self.world_size // self.pipeline_parallel_size // self.weight_parallel_size
593576
)
594577

595-
if (
596-
isinstance(parallel_config["tensor"], dict)
597-
and parallel_config["tensor"]["mode"] == "isp"
598-
):
578+
if isinstance(parallel_config["tensor"], dict) and parallel_config["tensor"]["mode"] == "isp":
599579
if self.zero1_parallel_size == -1:
600580
self.zero1_parallel_size = self.weight_data_parallel_size
601581
self.zero1_parallel_size = max(1, self.zero1_parallel_size)
@@ -622,8 +602,7 @@ def _init_pg(self, rank, world_size, parallel_config):
622602
if "sequence_parallel" not in parallel_config:
623603
parallel_config._add_item("sequence_parallel", True)
624604
if isinstance(parallel_config["tensor"], int) or (
625-
isinstance(parallel_config["tensor"], dict)
626-
and parallel_config["tensor"]["mode"] == "mtp"
605+
isinstance(parallel_config["tensor"], dict) and parallel_config["tensor"]["mode"] == "mtp"
627606
):
628607
parallel_config["sequence_parallel"] = False
629608

@@ -665,10 +644,7 @@ def _init_pg(self, rank, world_size, parallel_config):
665644
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
666645
initializers.append(pgroup_initializer.Initializer_Data(*initializer_args))
667646
initializers.append(pgroup_initializer.Initializer_ISP_Data(*initializer_args))
668-
if (
669-
isinstance(parallel_config["tensor"], dict)
670-
and parallel_config["tensor"]["mode"] == TensorParallelMode.isp.name
671-
):
647+
if isinstance(parallel_config["tensor"], dict) and parallel_config["tensor"]["mode"] == "isp":
672648
initializers.append(pgroup_initializer.Initializer_Zero1_ISP(*initializer_args))
673649
else:
674650
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
@@ -686,7 +662,7 @@ def _init_pg(self, rank, world_size, parallel_config):
686662
self._register_dist(*args)
687663
else:
688664
self._register_dist(*parallel_setting)
689-
665+
690666
def _init_use_simplified_pg(self, rank, world_size, parallel_config):
691667
try:
692668
self.tensor_mode = parallel_config["tensor"]["mode"]
@@ -723,6 +699,11 @@ def _init_use_simplified_pg(self, rank, world_size, parallel_config):
723699

724700
self.check_sanity()
725701

702+
from internlm.core.context.process_group_initializer_simplified import (
703+
Initializer,
704+
ParallelMeta,
705+
)
706+
726707
parallel_info = {
727708
"tp": ParallelMeta(self.tensor_parallel_size, ParallelMode.TENSOR),
728709
"wp": ParallelMeta(self.weight_parallel_size, ParallelMode.WEIGHT),
@@ -861,14 +842,14 @@ def check_pg_is_intra(self, parallel_mode: ParallelMode):
861842
return (max_rank - min_rank) <= 7
862843

863844
def same_group_in_one_node(self, parallel_mode: ParallelMode):
864-
"""获得一个节点内有多少个相同类型的PG, 在跨节点通信时会存在带宽竞争
865-
这里返回的相同PG的数量会乘上每个rank的通信数据量大小
845+
"""Get the number of the same type of PG within a node. There will be bandwidth competition during cross-node communication.
846+
The number of the same PG returned here will be multiplied by the communication data size of each rank.
866847
867848
Args:
868849
parallel_mode (ParallelMode):
869850
870851
Returns:
871-
int: 一个节点内相同类型的PG的数量
852+
int: The number of the same type of PG within a node.
872853
"""
873854
pg_group_ranks = self.get_ranks_in_group(parallel_mode)
874855
pg_group_ranks = sorted(pg_group_ranks)
@@ -881,68 +862,5 @@ def same_group_in_one_node(self, parallel_mode: ParallelMode):
881862
else:
882863
return stride
883864

884-
# def set_cluster_info(self, name: str, info: dict):
885-
# self.clusters[name] = ClusterInfo(**info)
886-
887-
def get_cluster_info(self, name: str):
888-
return self.clusters[name]
889-
890-
def get_cluster_name_from_ip(self):
891-
"""
892-
node_ip_list = [
893-
'metax-c500-1',
894-
'metax-c500-2',
895-
'nvidia-node-1',
896-
'nvidia-node-2',
897-
]
898-
"""
899-
hostname = socket.gethostname()
900-
cluster_name = hostname.split("-")[0]
901-
return cluster_name
902-
903-
def sort_rank_based_on_ip_and_capacity(self):
904-
Capacity = []
905-
906-
def sort_rank(x, y):
907-
x_name = self.get_cluster_name_from_ip(x)
908-
y_name = self.get_cluster_name_from_ip(y)
909-
if x_name == y_name:
910-
return x_name > y_name
911-
else:
912-
x_c = self.clusters[x_name]["capacity"]
913-
y_c = self.clusters[y_name]["capacity"]
914-
return x_c > y_c
915-
916-
for cluster_name, cluster_info in self.clusters.items():
917-
peak_tflops.append(cluster_info["peak_tflops"])
918-
# Alpha.append(cluster_info.rank_range[-1] - cluster_info.rank_range[-1] + 1)
919-
Capacity.append(cluster_info["capacity"])
920-
921-
def switch_topology_aware_rank_scheduling():
922-
"""
923-
Switch topology-aware rank scheduling can optimize the performance of small-scale
924-
collective communications. Currently only supported in Alibaba Cloud.
925-
"""
926-
927-
local_rank = int(os.environ["LOCAL_RANK"])
928-
cluster_name = get_cluster_name_from_ip()
929-
930-
try:
931-
if cluster_name == "Ali":
932-
pass
933-
else:
934-
rank = int(os.environ["MLP_WORKER_RACK_RANK_INDEX"]) * 8 + local_rank
935-
except Exception as e:
936-
logger.error(
937-
f"The switch topology awareness error is reported, the reason is: {e}",
938-
"but don’t worry, this error will not affect normal training.",
939-
"If you train on Alibaba or Volcano Cloud, please contact wangguoteng or lijiaxing",
940-
)
941-
else:
942-
# If there is no any error, hack torch rank.
943-
os.environ["RANK"] = str(rank)
944-
if local_rank == 0:
945-
logger.info("Successfully bound node switch affinity!")
946-
947865

948866
global_context = ParallelContext()

internlm/core/context/process_group_initializer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ class ParallelMode(Enum):
6262

6363
# grouped query attention
6464
GQA = "gqa"
65+
66+
INTRA_DP_SZIE = "intra_dp"
67+
68+
INTER_DP_SZIE = "inter_dp"
6569

6670

6771
class ProcessGroupInitializer(ABC):

internlm/core/context/process_group_initializer_simplified.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22
# -*- encoding: utf-8 -*-
33

44
from copy import deepcopy
5-
from enum import Enum
65

76
import torch
87
import torch.distributed as dist
98

109
from internlm.utils.timeout import LLM_NCCL_TIMEOUT
11-
from internlm.core.context.process_group_initializer import ParallelMode
1210

1311
class ParallelMeta:
1412
def __init__(self, parallel_size, mode) -> None:

internlm/initialize/launch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from internlm.accelerator import AcceleratorType, get_accelerator
1313
from internlm.core.context import Config
1414
from internlm.core.context import global_context as gpc
15-
from internlm.core.context.process_group_initializer_simplified import ParallelMode
15+
from internlm.core.context.process_group_initializer import ParallelMode
1616
from internlm.utils.common import get_master_node
1717
from internlm.utils.gputest import warmup_process_group
1818
from internlm.utils.logger import get_logger
@@ -86,7 +86,8 @@ def add_simulator_arguments(parser):
8686
group.add_argument(
8787
"--pre_profiling_data_path", type=str, help="The path to pre-profiled performance data on the target cluster."
8888
)
89-
group.add_argument("--use_simplified_gp_init", action="store_true", default=False)
89+
group.add_argument("--use_simplified_gp_init", action="store_true", default=True)
90+
9091
return parser
9192

9293

internlm/model/ops/linear.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@
1414

1515
from internlm.accelerator import AcceleratorType, get_accelerator
1616
from internlm.core.context import global_context as gpc
17-
from internlm.simulator.ops.linear import (
18-
_fake_linear_bwdward_op,
19-
_fake_linear_forward_op,
20-
)
2117

2218
try:
2319
from fused_dense_lib import linear_bias_wgrad as _flash_linear_backward_op

internlm/simulator/README.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# InternLM Simulator
2+
3+
4+
## 1. Introduction
5+
The solver mainly consists of two components:
6+
1. `profiling`: Collects the time consumption of each stage during the model training process in advance and saves it as data files and image files.
7+
2. `simulation`: Simulates the model training process based on the collected data files and outputs the time consumption of each stage during the training process.
8+
9+
## 2. Usage
10+
11+
### 2.1 Generate profiling data
12+
13+
There are two types of profiling data:
14+
1. '`linear`' profiling data, include: [`LINEAR`]
15+
2. '`Communication`' profiling data, include: [`ALL2ALL`, `ALLREDUCE`, `REDUCESCATTER`, `ALLGATHER`, `BROADCAST`]
16+
17+
18+
Note:
19+
1. It is recommended to use more than 64 GPUs for data collection to ensure more accurate communication data.
20+
2. `Flash Attention` information is not collected in advance but is collected on the fly during the simulation and stored in the cache. This is because there are many variables that affect the performance of flash attention, and collecting in advance cannot cover all variables.
21+
22+
```python
23+
# generate profiling data
24+
torchrun --nproc-per-node=8 gen_profiler_data.py
25+
26+
# the profiling data will be saved in the following path
27+
./prof_data
28+
├── data.pt
29+
└── pics
30+
├── cal
31+
│ └── linear.jpg
32+
└── comm
33+
├── all2all_intra_2_inter_1.jpg
34+
├── all2all_intra_4_inter_1.jpg
35+
├── all_gather_intra_2_inter_1.jpg
36+
├── all_gather_intra_4_inter_1.jpg
37+
├── all_reduce_intra_2_inter_1.jpg
38+
├── all_reduce_intra_4_inter_1.jpg
39+
├── broadcast_intra_2_inter_1.jpg
40+
├── broadcast_intra_4_inter_1.jpg
41+
├── reduce_scatter_intra_2_inter_1.jpg
42+
└── reduce_scatter_intra_4_inter_1.jpg
43+
44+
```
45+
46+
### 2.2 Run simulation
47+
Running the solver does not require a GPU (although some packages may require a GPU environment, if you encounter any issues, please raise an issue). Currently, the solver only supports the formulaic solving method using simulation_train_formulaic.py, which requires a config file and profiling data file as follows:
48+
49+
```bash
50+
51+
python simulation_train_formulaic.py --pre_profiling_data_path ./prof_data/data.pt --config configs/7B_internlm2.py --run_all_solu --model_size 7 --world_size 128 --global_batch_size 4194304
52+
53+
# explanation:
54+
python simulation_train_formulaic.py
55+
--pre_profiling_data_path ./prof_data/data.pt # profiling data file
56+
--config configs/7B_internlm2.py # model configuration file
57+
--run_all_solu # whether to iterate and solve all possible solutions
58+
--model_size 7 # means 7B model, if you want to run 70B model, you can set model_size to 70
59+
--world_size 128 # solving range is 128 cards
60+
--global_batch_size 4194304 # global batch size, 4M
61+
```

0 commit comments

Comments
 (0)