Skip to content

Commit d3eb82e

Browse files
committed
fix and update readme
1 parent b536cd3 commit d3eb82e

11 files changed

+125
-289
lines changed

gen_profiler_data.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
21
from internlm.simulator.profiler.perf_comm import gen_perf
32

43
if __name__ == "__main__":
5-
gen_perf()
4+
gen_perf()

internlm/core/context/parallel_context.py

Lines changed: 7 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,12 @@
1818
import torch.distributed as dist
1919

2020
from internlm.accelerator import get_accelerator
21-
from internlm.core.context.process_group_initializer_simplified import Initializer, ParallelMeta
2221
from internlm.utils.common import SingletonMeta
2322
from internlm.utils.logger import get_logger
2423
from internlm.utils.timeout import LLM_NCCL_TIMEOUT
2524

2625
from . import process_group_initializer as pgroup_initializer
27-
from .process_group_initializer_simplified import ParallelMode
26+
from .process_group_initializer import ParallelMode
2827
from .random import add_seed, get_seeds, set_mode
2928
from internlm.utils.common import get_args
3029

@@ -422,20 +421,6 @@ def init_global_dist(
422421
use_cpu (bool): whether to set up cpu process group.
423422
"""
424423

425-
# find cluster info
426-
if "clusters" not in self.config:
427-
nv_info = {
428-
"rank_range": [0, 8],
429-
"peak_tflops": 320,
430-
"capacity": 80 * 1024**3,
431-
"intra_bw": 150,
432-
"inter_bw": 100,
433-
}
434-
self.set_cluster_info("nv_cluster", nv_info)
435-
else:
436-
for cluster in self.config.clusters:
437-
self.clusters.append(ClusterInfo(**cluster))
438-
439424
# initialize the default process group
440425
if not fake_mode:
441426
init_method = f"tcp://[{host}]:{port}"
@@ -667,7 +652,7 @@ def _init_pg(self, rank, world_size, parallel_config):
667652
initializers.append(pgroup_initializer.Initializer_ISP_Data(*initializer_args))
668653
if (
669654
isinstance(parallel_config["tensor"], dict)
670-
and parallel_config["tensor"]["mode"] == TensorParallelMode.isp.name
655+
and parallel_config["tensor"]["mode"] == "isp"
671656
):
672657
initializers.append(pgroup_initializer.Initializer_Zero1_ISP(*initializer_args))
673658
else:
@@ -688,6 +673,8 @@ def _init_pg(self, rank, world_size, parallel_config):
688673
self._register_dist(*parallel_setting)
689674

690675
def _init_use_simplified_pg(self, rank, world_size, parallel_config):
676+
from internlm.core.context.process_group_initializer_simplified import InitializerParallelMeta
677+
691678
try:
692679
self.tensor_mode = parallel_config["tensor"]["mode"]
693680
except AttributeError:
@@ -861,14 +848,14 @@ def check_pg_is_intra(self, parallel_mode: ParallelMode):
861848
return (max_rank - min_rank) <= 7
862849

863850
def same_group_in_one_node(self, parallel_mode: ParallelMode):
864-
"""获得一个节点内有多少个相同类型的PG, 在跨节点通信时会存在带宽竞争
865-
这里返回的相同PG的数量会乘上每个rank的通信数据量大小
851+
"""Get the number of the same type of PG within a node. There will be bandwidth competition during cross-node communication.
852+
The number of the same PG returned here will be multiplied by the communication data size of each rank.
866853
867854
Args:
868855
parallel_mode (ParallelMode):
869856
870857
Returns:
871-
int: 一个节点内相同类型的PG的数量
858+
int: The number of the same type of PG within a node.
872859
"""
873860
pg_group_ranks = self.get_ranks_in_group(parallel_mode)
874861
pg_group_ranks = sorted(pg_group_ranks)
@@ -881,68 +868,7 @@ def same_group_in_one_node(self, parallel_mode: ParallelMode):
881868
else:
882869
return stride
883870

884-
# def set_cluster_info(self, name: str, info: dict):
885-
# self.clusters[name] = ClusterInfo(**info)
886-
887-
def get_cluster_info(self, name: str):
888-
return self.clusters[name]
889-
890-
def get_cluster_name_from_ip(self):
891-
"""
892-
node_ip_list = [
893-
'metax-c500-1',
894-
'metax-c500-2',
895-
'nvidia-node-1',
896-
'nvidia-node-2',
897-
]
898-
"""
899-
hostname = socket.gethostname()
900-
cluster_name = hostname.split("-")[0]
901-
return cluster_name
902-
903-
def sort_rank_based_on_ip_and_capacity(self):
904-
Capacity = []
905-
906-
def sort_rank(x, y):
907-
x_name = self.get_cluster_name_from_ip(x)
908-
y_name = self.get_cluster_name_from_ip(y)
909-
if x_name == y_name:
910-
return x_name > y_name
911-
else:
912-
x_c = self.clusters[x_name]["capacity"]
913-
y_c = self.clusters[y_name]["capacity"]
914-
return x_c > y_c
915871

916-
for cluster_name, cluster_info in self.clusters.items():
917-
peak_tflops.append(cluster_info["peak_tflops"])
918-
# Alpha.append(cluster_info.rank_range[-1] - cluster_info.rank_range[-1] + 1)
919-
Capacity.append(cluster_info["capacity"])
920-
921-
def switch_topology_aware_rank_scheduling():
922-
"""
923-
Switch topology-aware rank scheduling can optimize the performance of small-scale
924-
collective communications. Currently only supported in Alibaba Cloud.
925-
"""
926-
927-
local_rank = int(os.environ["LOCAL_RANK"])
928-
cluster_name = get_cluster_name_from_ip()
929-
930-
try:
931-
if cluster_name == "Ali":
932-
pass
933-
else:
934-
rank = int(os.environ["MLP_WORKER_RACK_RANK_INDEX"]) * 8 + local_rank
935-
except Exception as e:
936-
logger.error(
937-
f"The switch topology awareness error is reported, the reason is: {e}",
938-
"but don’t worry, this error will not affect normal training.",
939-
"If you train on Alibaba or Volcano Cloud, please contact wangguoteng or lijiaxing",
940-
)
941-
else:
942-
# If there is no any error, hack torch rank.
943-
os.environ["RANK"] = str(rank)
944-
if local_rank == 0:
945-
logger.info("Successfully bound node switch affinity!")
946872

947873

948874
global_context = ParallelContext()

internlm/core/context/process_group_initializer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ class ParallelMode(Enum):
6262

6363
# grouped query attention
6464
GQA = "gqa"
65+
66+
INTRA_DP_SZIE = "intra_dp"
67+
68+
INTER_DP_SZIE = "inter_dp"
6569

6670

6771
class ProcessGroupInitializer(ABC):

internlm/core/context/process_group_initializer_simplified.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22
# -*- encoding: utf-8 -*-
33

44
from copy import deepcopy
5-
from enum import Enum
65

76
import torch
87
import torch.distributed as dist
98

109
from internlm.utils.timeout import LLM_NCCL_TIMEOUT
11-
from internlm.core.context.process_group_initializer import ParallelMode
1210

1311
class ParallelMeta:
1412
def __init__(self, parallel_size, mode) -> None:

internlm/initialize/launch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from internlm.accelerator import AcceleratorType, get_accelerator
1313
from internlm.core.context import Config
1414
from internlm.core.context import global_context as gpc
15-
from internlm.core.context.process_group_initializer_simplified import ParallelMode
15+
from internlm.core.context.process_group_initializer import ParallelMode
1616
from internlm.utils.common import get_master_node
1717
from internlm.utils.gputest import warmup_process_group
1818
from internlm.utils.logger import get_logger
@@ -86,7 +86,9 @@ def add_simulator_arguments(parser):
8686
group.add_argument(
8787
"--pre_profiling_data_path", type=str, help="The path to pre-profiled performance data on the target cluster."
8888
)
89+
group.add_argument("--use_simplified_gp_init", action="store_true", default=True)
8990
group.add_argument("--use_simplified_gp_init", action="store_true", default=False)
91+
9092
return parser
9193

9294

internlm/model/ops/linear.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@
1414

1515
from internlm.accelerator import AcceleratorType, get_accelerator
1616
from internlm.core.context import global_context as gpc
17-
from internlm.simulator.ops.linear import (
18-
_fake_linear_bwdward_op,
19-
_fake_linear_forward_op,
20-
)
2117

2218
try:
2319
from fused_dense_lib import linear_bias_wgrad as _flash_linear_backward_op

internlm/simulator/README.md

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# InternLM Simulator
2+
3+
4+
## 1. Introduction
5+
The solver mainly consists of two components:
6+
1. `profiling`: Collects the time consumption of each stage during the model training process in advance and saves it as data files and image files.
7+
2. `simulation`: Simulates the model training process based on the collected data files and outputs the time consumption of each stage during the training process.
8+
9+
## 2. Usage
10+
11+
### 2.1 Generate profiling data
12+
13+
There are two types of profiling data:
14+
1. '`linear`' profiling data, include: [`LINEAR`]
15+
2. '`Communication`' profiling data, include: [`ALL2ALL`, `ALLREDUCE`, `REDUCESCATTER`, `ALLGATHER`, `BROADCAST`]
16+
17+
18+
Note:
19+
1. It is recommended to use more than 64 GPUs for data collection to ensure more accurate communication data.
20+
2. `Flash Attention` information is not collected in advance but is collected on the fly during the simulation and stored in the cache. This is because there are many variables that affect the performance of flash attention, and collecting in advance cannot cover all variables.
21+
22+
```python
23+
# generate profiling data
24+
torchrun --nproc-per-node=8 gen_profiler_data.py
25+
26+
# the profiling data will be saved in the following path
27+
./prof_data
28+
├── data.pt
29+
└── pics
30+
├── cal
31+
│ └── linear.jpg
32+
└── comm
33+
├── all2all_intra_2_inter_1.jpg
34+
├── all2all_intra_4_inter_1.jpg
35+
├── all_gather_intra_2_inter_1.jpg
36+
├── all_gather_intra_4_inter_1.jpg
37+
├── all_reduce_intra_2_inter_1.jpg
38+
├── all_reduce_intra_4_inter_1.jpg
39+
├── broadcast_intra_2_inter_1.jpg
40+
├── broadcast_intra_4_inter_1.jpg
41+
├── reduce_scatter_intra_2_inter_1.jpg
42+
└── reduce_scatter_intra_4_inter_1.jpg
43+
44+
```
45+
46+
### 2.2 Run simulation
47+
```python
48+
python simulation_train_formulaic.py --pre_profiling_data_path ./data/profiling_data.json --config configs/exp_simluator.py
49+
50+
```
51+
52+
53+
54+
## 4. 贡献

0 commit comments

Comments
 (0)