From fa23d1e9f13d8e8d167f6a53c26626f36a5cf9fa Mon Sep 17 00:00:00 2001 From: wangguoteng <877825076@qq.com> Date: Wed, 5 Jun 2024 10:24:23 +0800 Subject: [PATCH 1/3] feat(simulator): support parallel cost simulator for internevo --- internlm/simulator/common.py | 249 +++++++++++++ internlm/simulator/predict_cost_model.py | 325 +++++++++++++++++ internlm/simulator/profiler/__init__.py | 0 .../simulator/profiler/benchmark/__init__.py | 7 + .../simulator/profiler/benchmark/all2all.py | 137 ++++++++ .../profiler/benchmark/all_gather.py | 48 +++ .../profiler/benchmark/all_reduce.py | 46 +++ .../profiler/benchmark/base_benchmark.py | 6 + .../simulator/profiler/benchmark/broadcast.py | 46 +++ .../simulator/profiler/benchmark/linear.py | 56 +++ .../profiler/benchmark/multi_head_attn.py | 134 +++++++ .../profiler/benchmark/reduce_scatter.py | 48 +++ internlm/simulator/profiler/profiler.py | 131 +++++++ internlm/simulator/tracker/comm_tracker.py | 195 +++++++++++ internlm/simulator/tracker/comp_tracker.py | 12 + internlm/simulator/tracker/global_var.py | 0 internlm/simulator/tracker/mem_tracker.py | 46 +++ internlm/simulator/tracker/module_tracker.py | 73 ++++ internlm/simulator/utils.py | 207 +++++++++++ simulation_train.py | 328 ++++++++++++++++++ 20 files changed, 2094 insertions(+) create mode 100644 internlm/simulator/common.py create mode 100644 internlm/simulator/predict_cost_model.py create mode 100644 internlm/simulator/profiler/__init__.py create mode 100644 internlm/simulator/profiler/benchmark/__init__.py create mode 100644 internlm/simulator/profiler/benchmark/all2all.py create mode 100644 internlm/simulator/profiler/benchmark/all_gather.py create mode 100644 internlm/simulator/profiler/benchmark/all_reduce.py create mode 100644 internlm/simulator/profiler/benchmark/base_benchmark.py create mode 100644 internlm/simulator/profiler/benchmark/broadcast.py create mode 100644 internlm/simulator/profiler/benchmark/linear.py create mode 100644 internlm/simulator/profiler/benchmark/multi_head_attn.py create mode 100644 internlm/simulator/profiler/benchmark/reduce_scatter.py create mode 100644 internlm/simulator/profiler/profiler.py create mode 100644 internlm/simulator/tracker/comm_tracker.py create mode 100644 internlm/simulator/tracker/comp_tracker.py create mode 100644 internlm/simulator/tracker/global_var.py create mode 100644 internlm/simulator/tracker/mem_tracker.py create mode 100644 internlm/simulator/tracker/module_tracker.py create mode 100644 internlm/simulator/utils.py create mode 100644 simulation_train.py diff --git a/internlm/simulator/common.py b/internlm/simulator/common.py new file mode 100644 index 000000000..7c5e73003 --- /dev/null +++ b/internlm/simulator/common.py @@ -0,0 +1,249 @@ +import math +import os + +import torch +import torch.distributed as dist +from torch.distributed import GroupMember + + +# TODO: 这里需要增加一个broadcast +class CommOp: + ALL2ALL = "all2all" + ALLREDUCE = "all_reduce" + REDUCESCATTER = "reduce_scatter" + ALLGATHER = "all_gather" + LINEAR = "linear" + BROADCAST = "broadcast" + P2P = "p2p" + FLASH_ATTN = "flash_attn" + + +class AlgoType: + ISP = "isp" + MSP = "msp" + FSP = "fsp" + MTP = "mtp" + NONE = "none" + + +class BW: + IB = 100 * 1024**3 + A800_NVL = 150 * 1024**3 # 满速是 200 GB/s + A100_NVL = 250 * 1024**3 # 满速是 300 GB/s + + +BENCH_TYPE_LIST = [CommOp.ALL2ALL, CommOp.ALLREDUCE, CommOp.REDUCESCATTER, CommOp.ALLGATHER, CommOp.LINEAR] +# BENCH_TYPE_LIST = [CommOp.ALL2ALL, CommOp.ALLREDUCE, CommOp.REDUCESCATTER, CommOp.ALLGATHER, CommOp.LINEAR] + +K = 1024 + +KB = 1024 +MB = 1024 * KB +GB = 1024 * MB + +MS = 1000 +US = 1000 * MS + +_75GB = 75 * GB +_100GB = 100 * GB + +GLOBAL_BYTE_SIZES_LIST = [512 * KB, 1 * MB, 4 * MB, 64 * MB, 128 * MB, 256 * MB, 512 * MB, 1 * GB, 2 * GB, 4 * GB] +# GLOBAL_BYTE_SIZES_LIST = [512 * KB, 1 * MB, 4 * MB] # , 64 * MB, 128 * MB, 256 * MB] +GLOBAL_ELEM_SIZES_LIST = [dsize // 2 for dsize in GLOBAL_BYTE_SIZES_LIST] +WORLD_SIZE_LIST = [2, 4, 8, 16, 32, 64, 128] +TP_SIZE_RANGE = [1] + list(range(2, 80 + 1, 2)) + +OUT_OF_MEM_LATENCY = 10**9 + + +def cal_block_p_elem(h, multiple_of, mlp_ratio): + norm1_p_elem = h + norm2_p_elem = h + MHA = h * 3 * h + out_proj = h * h + mlp_hidden_features = multiple_of * ((int(h * mlp_ratio) + multiple_of - 1) // multiple_of) + mlp_p_elem = (h * mlp_hidden_features) * 3 + dropout1 = 0 + dropout2 = 0 + return norm1_p_elem + norm2_p_elem + MHA + out_proj + mlp_p_elem + dropout1 + dropout2 + + +def cal_model_p_elem(h, l, vocab_size, multiple_of, mlp_ratio): + embedding_p_elem = vocab_size * h + block_p_elem = l * cal_block_p_elem(h, multiple_of, mlp_ratio) + norm_p_elem = h + head_p_elem = vocab_size * h + return embedding_p_elem + block_p_elem + norm_p_elem + head_p_elem + + +def get_model_config(model_size): + if model_size == 7: + h = 4096 + a = 32 + l = 32 + elif model_size == 13: + h = 5120 + a = 40 + l = 40 + elif model_size == 20: + h = 5120 + a = 40 + l = 60 + elif model_size == 30: + h = 6144 + a = 48 + l = 60 + elif model_size == 65: + h = 8192 + a = 64 + l = 80 + elif model_size == 104: + h = 10240 + a = 80 + l = 82 + else: + raise ValueError(f"unsupport modesize: {model_size}") + + vocab_size = 103168 + mlp_ratio = 8 / 3 + multiple_of = 256 + + model_p_elem = cal_model_p_elem(h=h, l=l, vocab_size=vocab_size, multiple_of=multiple_of, mlp_ratio=mlp_ratio) + + return h, a, l, mlp_ratio, multiple_of, model_p_elem + + +def pretty_print_size(x): + if x < KB: + return f"{x} B" + elif x >= KB and x < MB: + return f"{x/KB:.3f} KB" + elif x >= MB and x < GB: + return f"{x/MB:.3f} MB" + else: + return f"{x/GB:.3f} GB" + + +def pretty_print_latency(x): + if x >= 1: + return f"{x:.3f} s" + elif x >= 1 / MS and x < 1: + return f"{x*MS:.3f} ms" + else: + return f"{x*US:.3f} us" + + +def get_local_rank(): + if "SLURM_PROCID" in os.environ: + return int(os.environ["SLURM_PROCID"]) % 8 + else: + return 0 + + +def get_world_size(): + if "SLURM_NPROCS" in os.environ: + return int(os.environ["SLURM_NPROCS"]) + else: + return 1 + + +def sync_all(): + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + +def get_bw(comm_op, size, duration, args): + n = dist.get_world_size() + tput = 0 + busbw = 0 + if comm_op == "all_to_all": + tput = size / duration + busbw = (size / duration) * ((n - 1) / n) + elif comm_op == "all_gather" or comm_op == "reduce_scatter": + size *= n + tput = size / duration + busbw = (size / duration) * ((n - 1) / n) + elif comm_op == "all_reduce": + tput = size * 2 / duration + busbw = (size / duration) * (2 * (n - 1) / n) + elif comm_op == "pt2pt" or comm_op == "broadcast": + tput = size / duration + busbw = tput + else: + print("wrong comm_op specified") + exit(0) + + if args.bw_unit == "Gbps": + tput *= 8 + busbw *= 8 + + return tput, busbw + + +sub_process_groups = {} +TORCH_DISTRIBUTED_DEFAULT_PORT = 12349 + + +def env2int(env_list, default=-1): + for e in env_list: + val = int(os.environ.get(e, -1)) + if val >= 0: + return val + return default + + +def init_torch_distributed(backend): + global dist + + # discover rank/size info from env + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = str(TORCH_DISTRIBUTED_DEFAULT_PORT) + if "MASTER_ADDR" not in os.environ: + import subprocess + + result = subprocess.check_output('scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1', shell=True) + master_addr = result.decode("utf8").strip() + if master_addr == "": + master_addr = "127.0.0.1" + os.environ["MASTER_ADDR"] = master_addr + local_rank = env2int( + ["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK", "SLURM_LOCALID"] + ) + if "LOCAL_RANK" not in os.environ: + os.environ["LOCAL_RANK"] = str(local_rank) + rank = env2int(["RANK", "MPI_RANKID", "OMPI_COMM_WORLD_RANK", "MV2_COMM_WORLD_RANK", "SLURM_PROCID"]) + if "RANK" not in os.environ: + os.environ["RANK"] = str(rank) + world_size = env2int(["WORLD_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "SLURM_NPROCS"]) + if "WORLD_SIZE" not in os.environ: + os.environ["WORLD_SIZE"] = str(world_size) + + torch.distributed.init_process_group(backend) + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + + +def build_process_gourp(max_world_size): + global sub_process_groups + if max_world_size > 1: + init_torch_distributed("nccl") + sub_process_groups[str(dist.get_world_size())] = GroupMember.WORLD + + if dist.is_initialized(): + world_size = dist.get_world_size() + node_nums = world_size // 8 + base_num = [2, 4, 6] + [8 * i for i in range(1, node_nums)] + + for gpu_nums in base_num: + ranks = [j for j in range(gpu_nums)] + print(ranks, flush=True) + sub_process_groups[f"{gpu_nums}"] = dist.new_group(ranks) + # dist.get_process_group_ranks() + + +def get_global_rank(): + if dist.is_initialized(): + return dist.get_rank() + else: + return 0 diff --git a/internlm/simulator/predict_cost_model.py b/internlm/simulator/predict_cost_model.py new file mode 100644 index 000000000..bee64c773 --- /dev/null +++ b/internlm/simulator/predict_cost_model.py @@ -0,0 +1,325 @@ +import functools +import os +import pickle +from collections import OrderedDict +from copy import deepcopy + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from scipy.interpolate import interp1d +from sklearn.linear_model import LinearRegression +from sklearn.metrics import r2_score +from sklearn.preprocessing import PolynomialFeatures + +from internlm.core.context import Config +from internlm.simulator.common import MB, OUT_OF_MEM_LATENCY, WORLD_SIZE_LIST, CommOp + +# import profiler.benchmark +# import scipy.interpolate +from internlm.simulator.profiler.benchmark.multi_head_attn import UnitMultiHeadAttn +from internlm.simulator.profiler.profiler import run_profile + + +class PolynomialModel: + def __init__(self, degree, data, name="unknown", segments=None) -> None: + """_summary_ + + Args: + degree (int): _description_ + data (dict): _description_ + segments (dict): _description_ + """ + self.name = name + self.degree = 3 # 多项式的度数 + self.poly_features = PolynomialFeatures(degree=degree, include_bias=False) # 准备多项式回归模型 + self.data = pd.DataFrame(data) # 转换为DataFrame + if segments is None: + segments = {"all": (0, float("inf"))} + print(segments, flush=True) + self.segments = OrderedDict(segments) + self.segment_scores = {seg: {} for seg in self.segments} # 用于存储拟合结果和评分 + self.model_fit = { + seg: {card: None for card in self.data["World_Size"].unique()} for seg in self.segments + } # 存储模型 + self.see_base_value() + self.build_model() + + def see_base_value(self): + # 可视化数据 + plt.figure(figsize=(12, 6)) + for card in self.data["World_Size"].unique(): + subset = self.data[self.data["World_Size"] == card] + plt.scatter(subset["Data_B"], subset["Latency_s"], label=f"{card} cards") + + plt.xlabel("Data Transferred (MB)") + plt.ylabel("Latency (ms)") + plt.title("Transferred Latency vs Data Transferred for Different Card Numbers") + plt.legend() + plt.xscale("log") + plt.grid(True) + plt.savefig(f"{self.name}.jpg") + plt.show() + print(self.data.head()) + + def build_model(self): + # 对每个分段和卡数的数据进行拟合 + plt.figure(figsize=(12, 6)) + for seg, (low, high) in self.segments.items(): + for card in self.data["World_Size"].unique(): + subset = self.data[ + (self.data["World_Size"] == card) & (self.data["Data_B"] >= low) & (self.data["Data_B"] < high) + ] + + # 如果该段中没有足够的数据点,则跳过 + if len(subset) < 2: + continue + + # 准备数据 + X = subset["Data_B"].values.reshape(-1, 1) + y = subset["Latency_s"].values + X_poly = self.poly_features.fit_transform(X) + + # 拟合模型 + model = LinearRegression() + model.fit(X_poly, y) + y_pred = model.predict(X_poly) + self.model_fit[seg][card] = model + + # 评估模型 + score = r2_score(y, y_pred) + self.segment_scores[seg][card] = score + + # 可视化拟合结果 + plt.scatter(X / MB, y, label=f"{card} cards") + plt.plot(X / MB, y_pred, label=f"{card} cards Fit") + + # 绘制图表 + plt.xlabel("Data Transferred (MB)") + plt.ylabel("Latency (ms)") + plt.title("Segmented Polynomial Regression Fit for Different Card Numbers") + plt.xscale("log") + plt.yscale("log") + plt.legend() + plt.grid(True) + plt.savefig(f"{self.name}_fit.jpg") + plt.show() + + def return_segments(self, x): + for key, value in self.segments.items(): + low, hight = value[0], value[1] + if x >= low and x < hight: + return key + assert ValueError, f"predict value:{x} out of range" + + def predict(self, world_size, complexity): + try: + model = self.model_fit[self.return_segments(complexity)][world_size] + X_pred = self.poly_features.fit_transform([[complexity]]) + Y_pred = model.predict(X_pred)[0] + return Y_pred + except Exception as e: + print(f"e: {e}", flush=True) + import pdb + + pdb.set_trace() + + +class SplineModel: + def __init__(self): + self._data_prefix = "data/cost_data" + self.spline_model_list = {} + self.data = {} + self.load_data() + self.build_model() + + def load_data(self): + for cost_data_file in os.listdir(self._data_prefix): + name, suffix = cost_data_file.split(".") + if suffix == "pickle": + with open(f"{self._data_prefix}/{cost_data_file}", "rb") as f: + self.data[name] = pickle.load(f) + + @staticmethod + def reformat_data_to_cost_model(total_results): + reformat_data = dict() + for world_size in total_results.keys(): + list_data = [] + for complexity in total_results[world_size].keys(): + for value in total_results[world_size][complexity]: + list_data.append([value["lat"], complexity]) # p data[2][524288][0]['lat'] + + # list_data.sort(key=functools.cmp_to_key(my_compare)) + data_list = list(map(list, zip(*list_data))) + reformat_data[world_size] = {"Data_B": data_list[1], "Latency_s": data_list[0]} + + return reformat_data + + def build_model(self): + # p data[2][524288][0]['lat'] + for cost_type, cost_data in self.data.items(): + if cost_type != CommOp.FLASH_ATTN: + try: + cost_data = SplineModel.reformat_data_to_cost_model(cost_data) + except TypeError as e: + print(f"e : {e}", flush=True) + import pdb + + pdb.set_trace() + + self.spline_model_list[cost_type] = {} + for world_size, data in cost_data.items(): + try: + x = data["Data_B"] + y = data["Latency_s"] + except KeyError as e: + print(f"e : {e}", flush=True) + import pdb + + pdb.set_trace() + self.spline_model_list[cost_type][world_size] = interp1d(x, y, kind="slinear") + # self.see_base_value(cost_type, world_size, x, y) + else: # fa我们直接查表,不预测 + self.spline_model_list[cost_type] = {} + self.spline_model_list[cost_type][1] = cost_data[1] + + def predict(self, cost_type, world_size, complexity): + return self.spline_model_list[cost_type][world_size](complexity) + + def predict_cost(self, cost_type: CommOp, complexity=0, world_size=1, **kwargs): + """predict computation cost + The cost of attention will use KV mapping, and the cost of linear will + use PolynomialModel. + + Args: + cost_type (CommOp): _description_ + complexity (int, optional): _description_. Defaults to 0. + + Returns: + float: op latency. + """ + if cost_type == CommOp.FLASH_ATTN: + try: + key = UnitMultiHeadAttn.gen_store_key(**kwargs) + return self.spline_model_list[cost_type][1][key][0]["lat"] + except KeyError as e: + raise KeyError(f"not found FA key: {key}") + else: + try: + if cost_type != CommOp.LINEAR and world_size == 1: + return 0 + else: + spline_model = self.spline_model_list[cost_type][world_size] + predict = spline_model(complexity) + except ValueError: + below_bounds, above_bounds = spline_model.x[0], spline_model.x[-1] + if complexity < below_bounds: + return spline_model(below_bounds) # 如果超过下界就返回下界 + if complexity > above_bounds: + lat = spline_model(above_bounds) + return lat * complexity / above_bounds # 如果超过上界就线性扩展 + raise ValueError(f"value error for cost_type:{cost_type}, complexity:{complexity}") + except KeyError as e: + print(f"e : {e}", flush=True) + import pdb + + pdb.set_trace() + else: + return predict + + +def my_compare(a, b): + world_size_a, complexity_a = a[0], a[2] + world_size_b, complexity_b = b[0], b[2] + # print(world_size_a, world_size_b, complexity_a, complexity_b) + + if world_size_a > world_size_b: + return True + elif world_size_a < world_size_b: + return False + else: + if complexity_a > complexity_b: + return True + elif complexity_a < complexity_b: + return False + else: + assert ValueError, f"a:{a}, b:{b}" + + +class GenCostModel: + def __init__(self, is_master=True, build_type_list=None) -> None: + self._master = is_master + self._profile_args = Config( + { + "trials": 10, + "warmups": 1, + } + ) + self.cost_data = None + self._data_prefix = "data/cost_data" + self.cost_kv_data = {} + self.build_type_list = build_type_list + + def _log(self, msg: str): + if self._master: + print(msg, flush=True) + + def build_cost_model_by_key_value(self): + if self.cost_data is None: + self.cost_data = OrderedDict() + for bench_type in self.build_type_list: + self._log(f"now test {bench_type}") + self.cost_kv_data[bench_type] = run_profile(self._profile_args, bench_type) + + def load_cost_model_by_key_value(self): + self.cost_data = OrderedDict() + for bench_type in self.build_type_list: + self._log(f"now load {bench_type}") + with open(f"./data/{bench_type}.pickle", "rb") as f: + self.cost_kv_data[bench_type] = pickle.load(f) + + def draw_pic(self, data, cost_type): + plt.figure(figsize=(12, 6)) + world_sizes = list(data.index) + for vol in list(data.columns): + plt.plot(world_sizes, data[vol].values, label=f"{vol/1024**2:.2f} MB") + + plt.xlabel("GPU nums") + plt.ylabel("Latency (s)") + plt.title(f"{cost_type}") + # plt.xscale("log") + # plt.yscale("log") + plt.legend() + plt.grid(True) + plt.savefig(f"./data/pics/{cost_type}.jpg") + plt.show() + + def dump_data(self): + # p data[2][524288][0]['lat'] + for bench_type, results in self.cost_kv_data.items(): + indexs, columns = [], None + tables = [] + if bench_type != CommOp.FLASH_ATTN: + for world_size, values in results.items(): + indexs.append(world_size) + one_col = [] + tmp_columns = [] + for vol, latency in values.items(): + tmp_columns.append(vol) + one_col.append(latency[0]["lat"]) + if columns is None: + columns = deepcopy(tmp_columns) + tables.append(one_col) + + # print(f"bench_type: {bench_type}", flush=True) + # print(f"index: {indexs}", flush=True) + # print(f"columns: {columns}", flush=True) + + df = pd.DataFrame(tables, columns=columns, index=indexs) + df.to_csv(f"./data/excel/{bench_type}.csv", index=False) + + if bench_type != CommOp.LINEAR: + self.draw_pic(df, bench_type) + + with open(f"{self._data_prefix}/{bench_type}.pickle", "wb") as f: + pickle.dump(results, f) diff --git a/internlm/simulator/profiler/__init__.py b/internlm/simulator/profiler/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/internlm/simulator/profiler/benchmark/__init__.py b/internlm/simulator/profiler/benchmark/__init__.py new file mode 100644 index 000000000..4f2f0256c --- /dev/null +++ b/internlm/simulator/profiler/benchmark/__init__.py @@ -0,0 +1,7 @@ +from .all2all import * +from .all_gather import * +from .all_reduce import * +from .linear import * +from .multi_head_attn import * +from .reduce_scatter import * +from .broadcast import * diff --git a/internlm/simulator/profiler/benchmark/all2all.py b/internlm/simulator/profiler/benchmark/all2all.py new file mode 100644 index 000000000..b56e40a2a --- /dev/null +++ b/internlm/simulator/profiler/benchmark/all2all.py @@ -0,0 +1,137 @@ +import torch +import torch.distributed as dist + +from internlm.model.registry import benchmark_initializer +from internlm.simulator.common import * + +from .base_benchmark import UnitBench + +BENCH_TYPE = "all2all" + + +# @benchmark_initializer.register_module(module_name=BENCH_TYPE) +class UnitBenchAll2ALL(UnitBench): + test_loop = { + "global_size": GLOBAL_ELEM_SIZES_LIST, + "world_size": WORLD_SIZE_LIST, # 7B, (13B, 20B), 30B, 65B, 123B + "async_op": [False], # it is not work!! False, + "dtype": [torch.bfloat16], + } + + def __init__(self, world_size, async_op, dtype, global_size=None, unit_size=None) -> None: + assert global_size is None or unit_size is None + + self.unit_size = unit_size if unit_size else global_size // world_size # elements_per_gpu + self.world_size = world_size + self.dtype = dtype + self.async_op = async_op + self.group = sub_process_groups[str(world_size)] + self.do_it = dist.get_rank() in set(dist.get_process_group_ranks(self.group)) + + if dist.get_world_size() < world_size: + self.input = None + self.output = None + else: + self.output = torch.ones(self.world_size, self.unit_size, dtype=self.dtype).to(f"cuda:{get_local_rank()}") + self.input = torch.ones(self.world_size, self.unit_size, dtype=self.dtype).to(f"cuda:{get_local_rank()}") + self.input_buffer_size = self.input.element_size() * self.input.numel() + + def run(self): + if self.output is None or not self.do_it: + return + + handler = dist.all_to_all_single(self.output, self.input, async_op=self.async_op, group=self.group) + if self.async_op: + handler.wait() + + def complexity(self): + return self.input_buffer_size + + +if __name__ == "__main__": + # data = { + # "Latency_ms": [41.746, 62.982, 65.596, 101.968, 138.671, 159.773, 177.197, 190.415, 193.555, 194.056, 194.097, + # 193.776, 193.419, 193.679, 194.425, 194.462, 36.732, 55.592, 80.364, 100.85, 116.875, 133.242, + # 160.23, 178.519, 189.055, 193.55, 193.752, 193.717, 193.417, 193.686, 194.365, 194.416, 33.096, + # 48.456, 72.221, 97.357, 113.762, 125.266, 134.315, 164.453, 178.744, 187.352, 192.915, 193.512, + # 192.669, 193.47, 194.342, 194.218], + # "Cards": [64] * 16 + [128] * 16 + [256] * 16, + # "Data_MB": [512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, + # 8388608, 16777216] * 3 + # } + cards_8_lat = [ + 0.035442, + 0.038785, + 0.041076, + 0.063415, + 0.092584, + 0.151337, + 0.259346, + 0.482307, + 0.896747, + 1.737, + 3.255, + 6.431, + ] + cards_16_lat = [ + 0.086889, + 0.113204, + 0.177494, + 0.271461, + 0.45525, + 0.84743, + 1.641, + 3.103, + 6.125, + 12.177, + 24.724, + 49.03, + ] + cards_32_lat = [ + 0.102149, + 0.14717, + 0.230115, + 0.382689, + 0.681639, + 1.432, + 2.499, + 4.812, + 9.554, + 18.706, + 37.845, + 73.225, + ] + cards_64_lat = [ + 0.115658, + 0.16165, + 0.259298, + 0.43826, + 0.822096, + 1.591, + 2.967, + 5.703, + 11.148, + 22.108, + 41.188, + 98.423, + ] + assert len(cards_8_lat) == len(cards_16_lat) == len(cards_32_lat) == len(cards_64_lat) + samples = len(cards_8_lat) + data = { + "Latency_ms": cards_8_lat + cards_16_lat + cards_32_lat + cards_64_lat, + "Cards": [8] * samples + [16] * samples + [32] * samples + [64] * samples, + "Data_MB": [i * MB for i in [0.5, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]] * 4, + } + segments = { + "small": (64 * KB, 8 * MB), # 64KB - 8MB, degree =2 + "large": (8 * MB, 1 * GB), # 8MB - 1GB, degree=1 + } + + segments = { + "all": (64 * KB, 1 * GB), + } + + model = PolynomialModel(degree=2, data=data, segments=segments) + model.predict(35 * MB) + model.predict(1.2 * MB) + model.predict(678 * MB) diff --git a/internlm/simulator/profiler/benchmark/all_gather.py b/internlm/simulator/profiler/benchmark/all_gather.py new file mode 100644 index 000000000..c677f69b4 --- /dev/null +++ b/internlm/simulator/profiler/benchmark/all_gather.py @@ -0,0 +1,48 @@ +import torch +import torch.distributed as dist + +from internlm.model.registry import benchmark_initializer +from internlm.simulator.common import * + +from .base_benchmark import UnitBench + +BENCH_TYPE = "all_gather" + + +# @benchmark_initializer.register_module(module_name=BENCH_TYPE) +class UnitBenchAllGather(UnitBench): + test_loop = { + "global_size": GLOBAL_ELEM_SIZES_LIST, + "world_size": WORLD_SIZE_LIST, # 7B, (13B, 20B), 30B, 65B, 123B + "async_op": [False], # it is not work!! False, + "dtype": [torch.bfloat16], + } + + def __init__(self, world_size, async_op, dtype, global_size=None, unit_size=None) -> None: + assert global_size is None or unit_size is None + + self.unit_size = unit_size if unit_size else global_size // world_size # elements_per_gpu + self.world_size = world_size + self.dtype = dtype + self.async_op = async_op + self.group = sub_process_groups[str(world_size)] + self.do_it = dist.get_rank() in set(dist.get_process_group_ranks(self.group)) + + if dist.get_world_size() < world_size: + self.input = None + self.output = None + else: + self.output = torch.ones(self.world_size, self.unit_size, dtype=self.dtype).to(f"cuda:{get_local_rank()}") + self.input = torch.ones(self.unit_size, dtype=self.dtype).to(f"cuda:{get_local_rank()}") + self.output_buffer_size = self.output.element_size() * self.output.numel() + + def run(self): + if self.output is None or not self.do_it: + return + + handler = dist._all_gather_base(self.output, self.input, async_op=self.async_op, group=self.group) + if self.async_op: + handler.wait() + + def complexity(self): + return self.output_buffer_size diff --git a/internlm/simulator/profiler/benchmark/all_reduce.py b/internlm/simulator/profiler/benchmark/all_reduce.py new file mode 100644 index 000000000..972249b33 --- /dev/null +++ b/internlm/simulator/profiler/benchmark/all_reduce.py @@ -0,0 +1,46 @@ +import torch +import torch.distributed as dist + +from internlm.model.registry import benchmark_initializer +from internlm.simulator.common import * + +from .base_benchmark import UnitBench + +BENCH_TYPE = "all_reduce" + + +# @benchmark_initializer.register_module(module_name=BENCH_TYPE) +class UnitBenchAllReduce(UnitBench): + test_loop = { + "global_size": GLOBAL_ELEM_SIZES_LIST, + "world_size": WORLD_SIZE_LIST, # 7B, (13B, 20B), 30B, 65B, 123B + "async_op": [False], # it is not work!! False, + "dtype": [torch.bfloat16], + } + + def __init__(self, world_size, async_op, dtype, global_size=None, unit_size=None) -> None: + assert global_size is None or unit_size is None + + self.unit_size = global_size // world_size + self.world_size = world_size + self.dtype = dtype + self.async_op = async_op + self.group = sub_process_groups[str(world_size)] + self.do_it = dist.get_rank() in set(dist.get_process_group_ranks(self.group)) + + if dist.get_world_size() < world_size: + self.buffer = None + else: + self.buffer = torch.ones(self.world_size, self.unit_size, dtype=self.dtype).to(f"cuda:{get_local_rank()}") + self.input_buffer_size = self.buffer.element_size() * self.buffer.numel() + + def run(self): + if self.buffer is None or not self.do_it: + return + + handler = dist.all_reduce(self.buffer, async_op=self.async_op, group=self.group) + if self.async_op: + handler.wait() + + def complexity(self): + return self.input_buffer_size diff --git a/internlm/simulator/profiler/benchmark/base_benchmark.py b/internlm/simulator/profiler/benchmark/base_benchmark.py new file mode 100644 index 000000000..c4b46ff93 --- /dev/null +++ b/internlm/simulator/profiler/benchmark/base_benchmark.py @@ -0,0 +1,6 @@ +class UnitBench: + def run(self): + raise NotImplementedError + + def complexity(self): + raise NotImplementedError diff --git a/internlm/simulator/profiler/benchmark/broadcast.py b/internlm/simulator/profiler/benchmark/broadcast.py new file mode 100644 index 000000000..464fcde1e --- /dev/null +++ b/internlm/simulator/profiler/benchmark/broadcast.py @@ -0,0 +1,46 @@ +import torch +import torch.distributed as dist + +from internlm.model.registry import benchmark_initializer +from internlm.simulator.common import * + +from .base_benchmark import UnitBench + +BENCH_TYPE = "broadcast" + + +# @benchmark_initializer.register_module(module_name=BENCH_TYPE) +class UnitBenchBroadcast(UnitBench): + test_loop = { + "global_size": GLOBAL_ELEM_SIZES_LIST, + "world_size": WORLD_SIZE_LIST, # 7B, (13B, 20B), 30B, 65B, 123B + "async_op": [False], # it is not work!! False, + "dtype": [torch.bfloat16], + } + + def __init__(self, world_size, async_op, dtype, global_size=None, unit_size=None) -> None: + assert global_size is None or unit_size is None + + self.unit_size = unit_size if unit_size else global_size // world_size # elements_per_gpu + self.world_size = world_size + self.dtype = dtype + self.async_op = async_op + self.group = sub_process_groups[str(world_size)] + self.do_it = dist.get_rank() in set(dist.get_process_group_ranks(self.group)) + + if dist.get_world_size() < world_size: + self.output = None + else: + self.output = torch.ones(self.world_size, self.unit_size, dtype=self.dtype).to(f"cuda:{get_local_rank()}") + self.input_buffer_size = self.output.element_size() * self.output.numel() + + def run(self): + if self.output is None or not self.do_it: + return + + handler = dist.broadcast(self.output, src=0, async_op=self.async_op, group=self.group) + if self.async_op: + handler.wait() + + def complexity(self): + return self.input_buffer_size diff --git a/internlm/simulator/profiler/benchmark/linear.py b/internlm/simulator/profiler/benchmark/linear.py new file mode 100644 index 000000000..4814b0f30 --- /dev/null +++ b/internlm/simulator/profiler/benchmark/linear.py @@ -0,0 +1,56 @@ +import torch +from internlm.model.registry import benchmark_initializer +from internlm.simulator.common import * + +from .base_benchmark import UnitBench + +BENCH_TYPE = "linear" + + +# @benchmark_initializer.register_module(module_name=BENCH_TYPE) +class UnitBenchLinear(UnitBench): + test_loop = { + "seq_len": [int(0.5 * K), 1 * K, 2 * K, 4 * K, 8 * K, 16 * K, 32 * K], + "hidden_dim": [ + 512, + 1024, + 2048, + 4096, + 5120, + 6144, + 8192, + 9216, + 10240, + 11264, + 12288, + ], # 7B, (13B, 20B), 30B, 65B, 123B + "bias": [False], # it is not work!! False, + "dtype": [torch.bfloat16], + "world_size": [1], + } + + def __init__(self, seq_len, hidden_dim, bias, dtype) -> None: + self.seq_len = seq_len + self.hidden_dim = hidden_dim + self.q = torch.nn.Linear( + hidden_dim, hidden_dim, bias=bias, device=f"cuda:{get_local_rank()}", dtype=dtype + ) # (hidden_dim, hidden_dim) + self.dtype = self.q.weight.element_size() + self.x = torch.rand(1, seq_len, hidden_dim).to(self.q.weight) # (bsz, seq_len, hidden_dim) + + def run(self): + self.q(self.x) + + @staticmethod + def gen_store_key(seq_len, hidden_dim, bias, dtype): + if dtype in [torch.bfloat16, torch.float16]: + element_size = 2 + elif dtype is torch.float32: + element_size = 4 + else: + assert False + return element_size * seq_len * hidden_dim * hidden_dim + + def complexity(self): + return self.dtype * self.seq_len * self.hidden_dim * self.hidden_dim + # return f"{self.seq_len} * {self.hidden_dim} * {self.hidden_dim}" diff --git a/internlm/simulator/profiler/benchmark/multi_head_attn.py b/internlm/simulator/profiler/benchmark/multi_head_attn.py new file mode 100644 index 000000000..8ae3aa80c --- /dev/null +++ b/internlm/simulator/profiler/benchmark/multi_head_attn.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math + +import torch +from einops import rearrange +from torch import nn + +from internlm.model.registry import benchmark_initializer +from internlm.simulator.common import TP_SIZE_RANGE, K, get_local_rank + +try: + from flash_attn.flash_attn_interface import flash_attn_qkvpacked_func + from flash_attn.modules.mha import FlashSelfAttention, SelfAttention +except ModuleNotFoundError: + print("import fa failed!", flush=True) + try: + from deeplink_ext.internevo_ops import ( + FlashCrossAttention, + FlashSelfAttention, + ) + except ModuleNotFoundError: + flash_attn_qkvpacked_func = None + FlashSelfAttention = None + SelfAttention = None + print("import dipu fa failed!", flush=True) + + +from .base_benchmark import UnitBench + +BENCH_TYPE = "flash_attn" + + +# @benchmark_initializer.register_module(module_name=BENCH_TYPE) +class UnitMultiHeadAttn(UnitBench): + test_loop = { + "seq_len": [64 * K, int(0.25 * K), int(0.5 * K), 1 * K, 2 * K, 4 * K, 8 * K, 32 * K, 16 * K], # 256 * K, 128 * K, + "num_heads_and_hidden_dim": [(64, 8192), (48, 6144), (32, 4096), (40, 5120)], # (80, 10240), + "dtype": [torch.bfloat16], + "micro_bsz": [ 2, 1], # 4, + "tp_size": TP_SIZE_RANGE, + "is_fwd": [True, False], + } + + def __init__(self, seq_len, num_heads_and_hidden_dim, dtype, micro_bsz, tp_size, is_fwd) -> None: + num_heads, embed_dim = num_heads_and_hidden_dim + self.num_heads_and_hidden_dim = num_heads_and_hidden_dim + self.TP = tp_size + self.S = seq_len + self.N = num_heads + self.H = embed_dim // self.N + self.dtype = dtype + self.dtype_size = 2 if self.dtype == torch.bfloat16 else 4 + self.B = micro_bsz + self.oom = False + self.is_fwd = is_fwd + self.causal = True + + assert num_heads % self.TP == 0, "num_heads must be divisible by tp_size" + assert num_heads >= tp_size, f"head nums must bigger then tp_size: {tp_size}" + + self.num_atten_head_tp = num_heads // self.TP + self.head_dim = self.H // num_heads + self.tp_embedding_dim = self.H // self.TP + + self.packed_length = self.S * self.B + self.device = f"cuda:{get_local_rank()}" + cu_seqlens = [i * self.S for i in range(self.B + 1)] + + weights_mem_used = self.packed_length * 3 * self.H * self.dtype_size + attn_activation = 11 * self.packed_length * self.H + mem_used = attn_activation + weights_mem_used + + self.inner_attn = FlashSelfAttention(causal=True, softmax_scale=self.H ** (0.5), attention_dropout=0.0) + + oom = False + if mem_used > 75 * 1024**3: + oom = True + + # 约束1: seqlen最大不能超过256K(不含) + # 约束2: embed_dim在被tp切过之后若大于6144, 则packed_length不能大于256k + if self.packed_length >= 256 * K and (self.H / self.TP) >= 6144: + oom = True + if self.S >= 256 * K and self.B > 1: + oom = True + if self.packed_length >= 524288 and (self.H / self.TP) >= 3072: + oom = True + if self.packed_length >= 1048576 and (self.H / self.TP) >= 2048: + oom = True + + if oom: + assert ( + False + ), f"warning : mem_used: {mem_used/1024**3:.2f} GB, seq_len: {self.S}, embed_dim: {self.H}, tp_size: {self.TP}" + + self.qkv = torch.rand( + size=(self.B * self.S, 3, self.N // self.TP, self.H), + dtype=self.dtype, + device=self.device, + requires_grad=True, + ) + + self.dtype_size = self.qkv.element_size() + self.cu_seqlens = torch.tensor(data=cu_seqlens, dtype=torch.int32, device=self.device) + self.max_seqlen= self.S + if not self.is_fwd: + self.output = self.run_fwd() + self.grad = torch.randn_like(self.output) / 32 # avoid grad is too large. + + def run(self): + if self.is_fwd: + self.run_fwd() + else: + self.run_bwd(self.output, self.grad) + + def run_fwd(self): + context = self.inner_attn(self.qkv, cu_seqlens=self.cu_seqlens, max_seqlen=self.max_seqlen, causal=self.causal) + return context + + def run_bwd(self, output, grad): + output.backward(grad, retain_graph=True) + + @staticmethod + def gen_store_key(micro_bsz, seq_len, num_heads_and_hidden_dim, tp_size, is_fwd): + _, embed_dim = num_heads_and_hidden_dim + tp_embedding_dim = embed_dim // tp_size + return f"b_{micro_bsz}_s_{seq_len}_h_{tp_embedding_dim}_fwd_{is_fwd}" + + def complexity(self): + return UnitMultiHeadAttn.gen_store_key( + self.B, self.S, self.num_heads_and_hidden_dim, self.TP, self.is_fwd + ) + # return f"{self.S} * {self.hidden_dim} * {self.hidden_dim}" diff --git a/internlm/simulator/profiler/benchmark/reduce_scatter.py b/internlm/simulator/profiler/benchmark/reduce_scatter.py new file mode 100644 index 000000000..6b8a0509f --- /dev/null +++ b/internlm/simulator/profiler/benchmark/reduce_scatter.py @@ -0,0 +1,48 @@ +import torch +import torch.distributed as dist + +from internlm.model.registry import benchmark_initializer +from internlm.simulator.common import * + +from .base_benchmark import UnitBench + +BENCH_TYPE = "reduce_scatter" + + +# @benchmark_initializer.register_module(module_name=BENCH_TYPE) +class UnitBenchAllReduceScatter(UnitBench): + test_loop = { + "global_size": GLOBAL_ELEM_SIZES_LIST, + "world_size": WORLD_SIZE_LIST, # 7B, (13B, 20B), 30B, 65B, 123B + "async_op": [False], # it is not work!! False, + "dtype": [torch.bfloat16], + } + + def __init__(self, world_size, async_op, dtype, global_size=None, unit_size=None) -> None: + assert global_size is None or unit_size is None + + self.unit_size = unit_size if unit_size else global_size // world_size # elements_per_gpu + self.world_size = world_size + self.dtype = dtype + self.async_op = async_op + self.group = sub_process_groups[str(world_size)] + self.do_it = dist.get_rank() in set(dist.get_process_group_ranks(self.group)) + + if dist.get_world_size() < world_size: + self.input = None + self.output = None + else: + self.output = torch.ones(self.unit_size, dtype=self.dtype).to(f"cuda:{get_local_rank()}") + self.input = torch.ones(self.world_size, self.unit_size, dtype=self.dtype).to(f"cuda:{get_local_rank()}") + self.input_buffer_size = self.input.element_size() * self.input.numel() + + def run(self): + if self.output is None or not self.do_it: + return + + handler = dist.reduce_scatter_tensor(self.output, self.input, async_op=self.async_op, group=self.group) + if self.async_op: + handler.wait() + + def complexity(self): + return self.input_buffer_size \ No newline at end of file diff --git a/internlm/simulator/profiler/profiler.py b/internlm/simulator/profiler/profiler.py new file mode 100644 index 000000000..95c6e43f9 --- /dev/null +++ b/internlm/simulator/profiler/profiler.py @@ -0,0 +1,131 @@ +import functools +import inspect +import os +import sys +import time +from collections import OrderedDict +from copy import deepcopy +from typing import Dict, List + +import torch +import torch.distributed as dist + +# internlm/model/registry.py +from internlm.model.registry import benchmark_initializer +from internlm.simulator.common import ( + OUT_OF_MEM_LATENCY, + get_global_rank, + get_world_size, + sync_all, +) + + +def DFS(loop_config: OrderedDict, results: OrderedDict, total_results: List): + if len(loop_config) == 0: + total_results.append(deepcopy(results)) + return + + now_key = list(loop_config.keys())[0] + now_values = loop_config[now_key] + loop_config.pop(now_key) + + for value in now_values: + results.update({now_key: value}) + DFS(loop_config, results, total_results) + + loop_config[now_key] = now_values + + +def filter_kwargs(func, kwargs): + sig = inspect.signature(func) + return {k: v for k, v in kwargs.items() if k in sig.parameters} + + +def debug_profile(bench, test_case): + if "lat" not in test_case: + test_case["lat"] = int.Maximum + + # print(f"{bench.complexity()}: micro_bsz: {test_case['micro_bsz']}, seq_len: {test_case['seq_len']}, num_heads_and_hidden_dim: {test_case['num_heads_and_hidden_dim']}, tp_size {test_case['tp_size']}, lat: {test_case['lat']}", flush=True) + + +def run_profile(args, test_type): + re_results = {} + + BENCH = benchmark_initializer.get_module(module_name=test_type) + + def run_benchmark(test_case, args): + sync_all() + # Warmups, establish connections, etc. + for _ in range(args.warmups): + try: + test_case.run() + except torch.cuda.OutOfMemoryError: + torch.cuda.empty_cache() + return OUT_OF_MEM_LATENCY + try: + sync_all() + except RuntimeError: + # self.packed_length * 3 * self.embed_dim * self.dtype_size + print( + f"packed_length: {test_case.packed_length}, embed_dim: {test_case.embed_dim}, micro_bsz: {test_case.micro_bsz}, seq_len: {test_case.seq_len}, tp:{test_case.tp_size}", + flush=True, + ) + torch.cuda.empty_cache() + return OUT_OF_MEM_LATENCY + + # time the actual comm op trials times and average it + pre = time.perf_counter() + for _ in range(args.trials): + try: + test_case.run() + except torch.cuda.OutOfMemoryError: + torch.cuda.empty_cache() + return OUT_OF_MEM_LATENCY + sync_all() + duration = time.perf_counter() - pre + + # maintain and clean performance data + avg_duration = duration / args.trials + return avg_duration + + sync_all() + # loop over various tensor sizes + test_args = OrderedDict(BENCH.test_loop) + total_cases = [] + + DFS(test_args, OrderedDict(), total_cases) + if get_global_rank() == 0: + print(f"all test case nums: {len(total_cases)}", flush=True) + + for test_case in total_cases: + world_size = test_case["world_size"] if "world_size" in test_case else 1 + + if world_size not in re_results: + re_results[world_size] = {} + + complex_tag = BENCH.gen_store_key(**filter_kwargs(BENCH.gen_store_key, test_case)) + + if complex_tag not in re_results[world_size]: + try: + bench = BENCH(**filter_kwargs(BENCH.__init__, test_case)) + except torch.cuda.OutOfMemoryError: + torch.cuda.empty_cache() + continue + except AssertionError: + # torch.cuda.empty_cache() + continue + else: + sync_all() + avg_duration = run_benchmark(bench, args) + test_case["lat"] = avg_duration + print(f"test_case: {test_case}, avg_duration: {avg_duration} ", flush=True) + + debug_profile(bench=bench, test_case=test_case) + re_results[world_size][complex_tag] = [test_case] + else: + if get_global_rank() == 0: + print( + f"Warning test_case: {test_case}, same complexity: {complex_tag}, lat:{re_results[world_size][complex_tag][0]['lat']}" + ) + + return re_results diff --git a/internlm/simulator/tracker/comm_tracker.py b/internlm/simulator/tracker/comm_tracker.py new file mode 100644 index 000000000..d24063a1d --- /dev/null +++ b/internlm/simulator/tracker/comm_tracker.py @@ -0,0 +1,195 @@ +from typing import Dict + +import torch + +from internlm.simulator.common import BW, CommOp +from internlm.simulator.predict_cost_model import SplineModel + +cost_model = None +scale_ratio = [1.415134488, 1.208864145, 1.1, 1] + + +def coll_comm_lat(comm_op, size, n): + if comm_op == CommOp.ALL2ALL: + if n <= 8: + return size * (n - 1) / n + else: + # intra_parts = 8 + one_part = size / n + return 8 * one_part * (n - 8 / n) + elif comm_op == CommOp.ALLREDUCE: + return size * 2 * (n - 1) / n + elif comm_op == CommOp.REDUCESCATTER: + return size * (n - 1) / n + elif comm_op == CommOp.ALLGATHER: + return size * (n - 1) / n + elif comm_op == CommOp.BROADCAST: + return size * (n - 1) / n + elif comm_op == CommOp.P2P: + return size + + raise ValueError(f"unknown comm_op: {comm_op}") + + +def coll_bus_bw(comm_op, size): + if comm_op == CommOp.ALL2ALL: + return size + elif comm_op == CommOp.ALLREDUCE: + return size * 2 + elif comm_op == CommOp.REDUCESCATTER: + return size + elif comm_op == CommOp.ALLGATHER: + return size + elif comm_op == CommOp.BROADCAST: + return size + elif comm_op == CommOp.P2P: + return size + + raise ValueError(f"unknown comm_op: {comm_op}") + + +# 需要判断是否打满带宽 +def get_scale_ratio(scale): + # 通信扩展惩罚系数 + if scale <= 16: + return 1 + elif 16 < scale <= 32: + return 1.1 + elif 32 < scale <= 64: + return 1.2 + elif 64 < scale <= 256: + return 1.3 + elif 256 < scale <= 512: + return 1.4 + else: + return 1.5 + + +class SingleCommMetric: + def __init__(self) -> None: + self.dur = 0 + self.volume = 0 + self.count = 0 + self.bw = 0 + + def add_new_comm(self, dur, volume, bw): + self.dur += dur + self.volume += volume + self.bw += bw + self.count += 1 + + def __repr__(self) -> str: + return f"dur: {self.dur}, volume: {self.volume}, avg bw: {self.bw/self.count:.3f} GB/s" + + def __str__(self) -> str: + return self.__repr__() + + +class CommType: + WP_PREFETCH_ALLGATHER = "wp_preftch_allgaher" + WP_WDP = "wp_wdp" + DP_ALLREDUCE = "dp_allreduce" + MSP_REDUCE_SCATTER = "msp_reduce_scatter" + MSP_ALLGATHER = "msp_allgahter" + MTP_ALLREDUCE = "mtp_allreduce" + + SP_NORM_ALLREDUCE = "sp_norm_allreduce" + + +coom_type_list = [ + CommType.WP_PREFETCH_ALLGATHER, + CommType.WP_WDP, + CommType.DP_ALLREDUCE, + CommType.MSP_ALLGATHER, + CommType.MSP_REDUCE_SCATTER, + CommType.MSP_ALLGATHER, + CommType.MSP_ALLGATHER, +] + + +class WPCommCost: + """ + WP的通信开销包括: + 1. pre-fetch allgahter + 2. wdp + """ + + def __init__(self) -> None: + pass + + +class CommTracker: + def __init__(self) -> None: + self.next_comm_type = None + self.next_parallel_mode = None + + self.comm_cost_dict: Dict[CommType, SingleCommMetric] = {} + for comm_type in coom_type_list: + self.comm_cost_dict[comm_type] = SingleCommMetric() + + def add_comm_meta(self, comm_type: CommType, parallel_mode, can_overlap): + self.next_comm_type = comm_type + self.next_parallel_mode = parallel_mode + self.can_overlap = can_overlap + + def cal_comm_cost(self, comm_op, comm_volume=1, dtype=torch.bfloat16): + """根据通信量获得近似的通信延迟,这个函数考虑了跨节点带宽content的情景 + 所以为了正确计算延迟,传入的 comm_volume 必须是以单个rank视角下的通信量 + (即代码中实际传入的通信量) + + Args: + comm_volume (int): 通信量, 单位B + parallel_mode (ParallelMode): gpc并行模式 + comm_op (CommOp, optional): 通信算子 + + Returns: + int: 通信延迟,是乘以10**4后并取整后的数值 + """ + + from internlm.core.context import ParallelMode + from internlm.core.context import global_context as gpc + + comm_type = self.next_comm_type + parallel_mode = self.next_parallel_mode + + if comm_type is None: + return + + scale = gpc.get_world_size(parallel_mode) + + if parallel_mode == ParallelMode.PIPELINE: + scale = 2 + + if scale <= 1: + return 0 + + is_intra = gpc.check_pg_is_intra(parallel_mode) + if not is_intra: + num_partner = gpc.same_group_in_one_node(parallel_mode) + assert num_partner <= 8, f"num_partner: {num_partner}" + if parallel_mode == ParallelMode.WEIGHT: + assert num_partner == 1 + if parallel_mode == ParallelMode.TENSOR: + assert num_partner == 1 + comm_volume *= num_partner + + global cost_model + try: + if cost_model is None: + cost_model = SplineModel() + + lat = cost_model.predict(comm_type, scale, comm_volume) + except FileNotFoundError: + # if comm_op == CommOp.P2P: + bw = BW.A800_NVL if is_intra else (BW.IB / get_scale_ratio(scale)) + + lat = coll_comm_lat(comm_op, comm_volume, scale) / bw # 转换成ms小数点保留两位 + + self.comm_cost_dict[comm_type].add_new_comm(lat, comm_volume, bw) + + +comm_tracker = CommTracker() + + +def get_gloabl_comm_tracker() -> CommTracker: + return comm_tracker diff --git a/internlm/simulator/tracker/comp_tracker.py b/internlm/simulator/tracker/comp_tracker.py new file mode 100644 index 000000000..04d9f7a80 --- /dev/null +++ b/internlm/simulator/tracker/comp_tracker.py @@ -0,0 +1,12 @@ +class CompTracker: + def __init__(self) -> None: + self.next_comm_type = None + self.next_parallel_mode = None + + # def add_comm_meta(self, comm_type: CommType, parallel_mode, can_overlap): + # self.next_comm_type = comm_type + # self.next_parallel_mode = parallel_mode + # self.can_overlap = can_overlap + + # def cal_comm_cost(self, comm_op, comm_volume=1, dtype=torch.bfloat16): + # pass diff --git a/internlm/simulator/tracker/global_var.py b/internlm/simulator/tracker/global_var.py new file mode 100644 index 000000000..e69de29bb diff --git a/internlm/simulator/tracker/mem_tracker.py b/internlm/simulator/tracker/mem_tracker.py new file mode 100644 index 000000000..ee27020c4 --- /dev/null +++ b/internlm/simulator/tracker/mem_tracker.py @@ -0,0 +1,46 @@ +import torch + +# from internlm.simulator.elements.tensor import FakeTensor + + +class FakeAllocator: + def __init__(self, capcity=0) -> None: + self.init_capcity = capcity + self.capcity = capcity + + def alloc(self, size): + if self.capcity - size >= 0: + self.capcity -= size + else: + raise RuntimeError(f"Out of Memory request: {size}, left: {self.capcity}") + + def free(self, size): + self.capcity += size + assert self.capcity <= self.init_capcity + + +global_allocator = FakeAllocator() + + +def get_global_allocator() -> FakeAllocator: + return global_allocator + + +class TensorTracker: + def __init__(self) -> None: + self.tensor_map = {} + + def save_tensor(self, tensor: torch.Tensor): + tid = id(tensor) + assert tid not in self.tensor_map + self.tensor_map[tid] = tensor + + def del_tensor(self, tid): + self.tensor_map.pop(tid).free_self() + + +global_tensor_manager = TensorTracker() + + +def get_global_tensor_manager(): + return global_tensor_manager diff --git a/internlm/simulator/tracker/module_tracker.py b/internlm/simulator/tracker/module_tracker.py new file mode 100644 index 000000000..b16f35edd --- /dev/null +++ b/internlm/simulator/tracker/module_tracker.py @@ -0,0 +1,73 @@ +# from typing import Self + + +# from internlm.simulator.tracker.global_var import get_pre_module_tracker +from typing import TypeVar + +_ModuleTracker = TypeVar("_ModuleTracker", bound="ModuleTracker") + +pre_module_tracker = None + +now_comm_tracker = None +now_comp_tracker = None +now_mem_tracker = None + + +def set_now_comm_tracker(tracker): + global now_comm_tracker + now_comm_tracker = tracker + + +def set_now_comp_tracker(tracker): + global now_comp_tracker + now_comp_tracker = tracker + + +def set_now_mem_tracker(tracker): + global now_mem_tracker + now_mem_tracker = tracker + + +def get_now_comm_tracker(): + return now_comm_tracker + + +def get_now_comp_tracker(): + return now_comp_tracker + + +def get_pre_module_tracker() -> _ModuleTracker: + return pre_module_tracker + + +class ModuleTracker: + def __init__(self, name: str) -> None: + + from internlm.simulator.tracker.comm_tracker import CommTracker + from internlm.simulator.tracker.comp_tracker import CompTracker + from internlm.simulator.tracker.mem_tracker import TensorTracker + + self.name = name + self.father_module = get_pre_module_tracker() + if self.father_module is not None: + self.father_module.register_submodule_tracker(self) + + self.comm_tracker = CommTracker() + self.comp_tracker = CompTracker() + self.mem_tracker = TensorTracker() + self.sub_tracker = [] + + def fwd_pre_hook(self, module, args, kwargs): + set_now_comm_tracker(self.comm_tracker) + set_now_comp_tracker(self.comp_tracker) + set_now_mem_tracker(self.mem_tracker) + print(f"[DEBUG]: call {self.name} fwd_pre_hook !", flush=True) + + def bwd_pre_hook(self, module, grad_input, grad_output): + set_now_comm_tracker(self.comm_tracker) + set_now_comp_tracker(self.comp_tracker) + set_now_mem_tracker(self.mem_tracker) + print(f"[DEBUG]: call {self.name} bwd_pre_hook !", flush=True) + + def register_submodule_tracker(self, module_tracker: _ModuleTracker): + self.sub_tracker.append(module_tracker) diff --git a/internlm/simulator/utils.py b/internlm/simulator/utils.py new file mode 100644 index 000000000..306d7716f --- /dev/null +++ b/internlm/simulator/utils.py @@ -0,0 +1,207 @@ +import math + +from internlm.simulator.common import GB, AlgoType, cal_block_p_elem, get_model_config + + +class LinsSolutionNoZ3: + def __init__( + self, + pp, + sp, + wp, + zp, + seq_len, + micro_bsz, + micro_num, + algo_type, + pp_comm_cost, + activation, + zp_comm_cost, + wp_comm_cost, + sp_comm_cost, + os_mm_cost, + p_g_mm_cost, + fwd_bwd_cost, + mem_cost, + comp_wp, + comp_attn, + world_size, + activation_ckpt, + tgs, + mem_pool_mm, + norm_activation, + head_input_activation, + head_output_activation, + block_output_activation, + wdp_comm_cost, + all_fwd_bwd_cost, + g_bsz, + pp_p2p_buffer, + rotary_emb_sincos_cache_mm, + modelsize, + backward_mem_peak, + blocks_activation, + overlap_latency, + total_latency, + ): + self.pp = pp + self.sp = sp + self.seq_len = seq_len + self.micro_bsz = micro_bsz + self.micro_num = micro_num + self.algo_type = algo_type + self.pp_comm_cost = pp_comm_cost + self.activation = activation + self.activation_ckpt = activation_ckpt + + self.wp_size = wp + self.zp_size = zp + self.zp_comm_cost = zp_comm_cost + self.wp_comm_cost = wp_comm_cost + self.os_mm_cost = os_mm_cost + self.p_g_mm_cost = p_g_mm_cost + self.sp_comm_cost = sp_comm_cost + self.total_mm_cost = mem_cost + self.fwd_bwd_cost = fwd_bwd_cost + self.comp_wp = comp_wp + self.comp_attn = comp_attn + self.world_size = world_size + self.tgs = tgs + + self.mem_pool_mm = mem_pool_mm + self.norm_activation = norm_activation + self.head_input_activation = head_input_activation + self.head_output_activation = head_output_activation + self.block_output_activation = block_output_activation + + self.wdp_comm_cost = wdp_comm_cost + self.all_fwd_bwd_cost = all_fwd_bwd_cost + self.g_bsz = g_bsz + self.pp_p2p_buffer = pp_p2p_buffer + self.rotary_emb_sincos_cache_mm = rotary_emb_sincos_cache_mm + self.modelsize = modelsize + self.backward_mem_peak = backward_mem_peak + self.blocks_activation = blocks_activation + self.overlap_latency = overlap_latency + self.total_latency = total_latency + + def __str__(self): + return self.__repr__() + + # Begin: world_size: 128, pp:1, sp:16, micro_bsz:1, micro_num:2, algo_type:isp, wp:16, zp:4 ckpt:1 + def __repr__(self): + return ( + f" world_size: {self.world_size}" + f" tgs: {self.tgs}, total_latency:{self.total_latency*10**3:.3f} ms" + f" global bsz: {self.g_bsz} \n" + f" activation ckpt: {self.activation_ckpt}" + f" seq_len: {self.seq_len}" + f" micro_bsz: {self.micro_bsz}" + f" micro_num: {self.micro_num}, \n" + f" modelsize: {self.modelsize}, algo_type: {self.algo_type}, pp_size: {self.pp}, sp_size: {self.sp}, wp_size: {self.wp_size}, zp_size: {self.zp_size}, \n" + f" one micro step fwd_bwd_cost: {self.fwd_bwd_cost*10**3:.2f} ms, all_fwd_bwd_cost: {self.all_fwd_bwd_cost*10**3:.2f} ms, overlap_latency: {self.overlap_latency*10**3:.2f} ms\n" + f" COMP: comp_wp: {self.comp_wp*10**3:.2f} ms, comp_attn: {self.comp_attn*10**3:.2f} ms, \n" + f" COMM: pp_comm_cost: {self.pp_comm_cost*10**3:.2f} ms, zp_comm_cost: {self.zp_comm_cost*10**3:.2f} ms, one layer wp_comm_cost: {self.wp_comm_cost*10**3:.2f} ms, one layer sp_comm_cost: {self.sp_comm_cost*10**3:.2f} ms, wdp_comm_cost: {self.wdp_comm_cost*10**3:.2f} ms \n" + f" total mem_cost: {self.total_mm_cost /GB:.2f} GB \n" + f" Not evictable MEM: os_mm_cost: {self.os_mm_cost/GB:.2f} GB, p_g_mm_cost: {self.p_g_mm_cost/GB:.2f} GB, isp_mem_pool: {self.mem_pool_mm/GB:.2f} GB, sincos_cache_mm: {self.rotary_emb_sincos_cache_mm/GB:.2f} GB,pp_p2p_buffer: {self.pp_p2p_buffer/GB:.2f} GB\n" + f" Activation MEM: total activation: {self.activation/GB:.2f} GB, blocks_activation: {self.blocks_activation/GB:.2f} GB, norm_activation: {self.norm_activation/GB:.2f} GB,backward_mem_peak: {self.backward_mem_peak/GB:.2f} GB \n" + f" head_input_activation: {self.head_input_activation/GB:.2f} GB, head_output_activation: {self.head_output_activation/GB:.2f} GB, block_output_activation(enable ckpt): {self.block_output_activation/GB:.2f} GB \n" + ) + + +class SPIter: + def __init__(self, gpu_nums, head_nums): + assert head_nums % 2 == 0 + stop = min(gpu_nums, head_nums) + if gpu_nums <= 8: + self.num_list = [1] + list(range(2, stop + 1, 2)) + else: + self.num_list = [1] + list(range(2, 8, 2)) + list(range(8, stop + 1, 8)) + + def __iter__(self): + return iter(self.num_list) + + def __len__(self): + return len(self.num_list) + + +class PPIter: + def __init__(self, gpu_nums, layer_nums): + # assert layer_nums % 2 == 0 + stop = int(math.log2(min(gpu_nums, layer_nums))) + self.num_list = [2**i for i in range(stop + 1)] + + def __iter__(self): + return iter(self.num_list) + + def __len__(self): + return len(self.num_list) + + +def get_bsz_strict(global_bsz: int, world_size: int, pp_size: int, sp_size: int, seq_len: int): + """ + 严格的按照 global_bsz 限制返回满足要求的 micro_bsz 和 micro_num + Args: + pp_size (int) + sp_size (int) + seq_len (int) + + Returns: + List[(int, int)]: micro_bsz, micro_num + """ + if pp_size * sp_size > world_size: + return None + + dp_world_size = world_size // pp_size // sp_size + if world_size % pp_size != 0 or world_size % sp_size != 0 or world_size % (pp_size * sp_size) != 0: + return None + + if global_bsz % dp_world_size != 0: + return None + if global_bsz % seq_len != 0: + return None + if global_bsz % (dp_world_size * seq_len) != 0: + return None + + bsz = global_bsz // dp_world_size // seq_len + + micro_bsz_num = [] + for micro_bsz in range(1, bsz + 1): + if bsz % micro_bsz == 0: + micro_num = bsz // micro_bsz + if micro_num >= pp_size: # 我们暂时不考虑 micro_num < pp_size 的情况 + micro_bsz_num.append((micro_bsz, micro_num)) + return micro_bsz_num + + +def get_bsz_approximate( + global_bsz_max: int, global_bsz_min: int, world_size: int, pp_size: int, sp_size: int, seq_len: int +): + """ + 允许global bsz在 min_bsz 和 max_bsz 之间松弛 + Args: + pp_size (int) + sp_size (int) + seq_len (int) + + Returns: + List[(int, int)]: micro_bsz, micro_num + """ + if pp_size * sp_size > world_size: + return None + + dp_world_size = world_size // pp_size // sp_size + if world_size % pp_size != 0 or world_size % sp_size != 0 or world_size % (pp_size * sp_size) != 0: + return None + + bsz_max = global_bsz_max // dp_world_size // seq_len + bsz_min = global_bsz_min // dp_world_size // seq_len + + micro_bsz_num = [] + for micro_bsz in range(1, int(bsz_max**0.5) + 1): + for micro_num in range(1, int(bsz_max**0.5) + 1): + if micro_bsz * micro_num >= bsz_min: + if micro_num >= pp_size: # 我们暂时不考虑 micro_num < pp_size 的情况 + assert micro_bsz * micro_num >= bsz_min and micro_bsz * micro_num <= bsz_max + micro_bsz_num.append((micro_bsz, micro_num)) + return micro_bsz_num diff --git a/simulation_train.py b/simulation_train.py new file mode 100644 index 000000000..fe4b66ac4 --- /dev/null +++ b/simulation_train.py @@ -0,0 +1,328 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import logging +import os +import socket +import time + +import torch +import torch.distributed as dist +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode + +import internlm +from internlm.core.context import Config, ParallelMode +from internlm.core.context import global_context as gpc +from internlm.core.context.random import reset_seed +from internlm.core.trainer import TrainState +from internlm.initialize.launch import launch +from internlm.model.losses import FlashGPTLMLoss +from internlm.simulator.common import AlgoType, CommOp +from internlm.simulator.tracker.comm_tracker import get_gloabl_comm_tracker +from internlm.simulator.tracker.mem_tracker import get_global_allocator + +# from internlm.simulator.elements.tensor import FakeTensor +from internlm.simulator.utils import PPIter, SPIter, get_bsz_approximate, get_bsz_strict +from internlm.train import ( + get_scheduler_hooks, + initialize_llm_profile, + initialize_model, + initialize_optimizer, + initialize_parallel_communicator, +) +from internlm.utils.common import ( + enable_pytorch_expandable_segments, + launch_time, + parse_args, +) + +# global llm logger +logger = logging.getLogger(__file__) + + +gloab_allocator = get_global_allocator() +global_comm_tracker = get_gloabl_comm_tracker() +from internlm.initialize.launch import args_sanity_check + + +class WaitHandler: + def wait(self): + return + + +def dummy_broadcast(tensor, src, group=None, async_op=False): + global_comm_tracker.cal_comm_cost( + comm_op=CommOp.BROADCAST, comm_volume=tensor.numel() * tensor.element_size(), dtype=tensor.dtype + ) + if async_op is True: + return WaitHandler() + + +def dummy_allreduce(tensor, op, group=None, async_op=False): + global_comm_tracker.cal_comm_cost( + comm_op=CommOp.ALLREDUCE, comm_volume=tensor.numel() * tensor.element_size(), dtype=tensor.dtype + ) + if async_op is True: + return WaitHandler() + + +def dummy_allgahter(tensor_list, tensor, group=None, async_op=False): + if async_op is True: + return WaitHandler() + + +def dummy_reduce_scatter(output, input_list, op, group=None, async_op=False): + if async_op is True: + return WaitHandler() + + +def dummy_reduce_scatter(output, input_list, op, group=None, async_op=False): + if async_op is True: + return WaitHandler() + + +def dummy_all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False): + if async_op is True: + return WaitHandler() + + +def dummy_batch_isend_irecv(p2p_op_list): + return [WaitHandler() for _ in range(len(p2p_op_list))] + + +def dummy_barrier(group=None, async_op=False, device_ids=None): + if async_op is True: + return WaitHandler() + + +old_bcast = dist.broadcast +old_all_reduce = dist.all_reduce +old_all_gahter = dist.all_gather +old_reduce_scatter = dist.reduce_scatter +old_all_to_all = dist.all_to_all +old_batch_isend_irecv = dist.batch_isend_irecv +old_barrier = dist.barrier + +dist.broadcast = dummy_broadcast +dist.all_reduce = dummy_allreduce +dist.all_gather = dummy_allgahter +dist.reduce_scatter = dummy_reduce_scatter +dist.all_to_all = dummy_all_to_all +dist.batch_isend_irecv = dummy_batch_isend_irecv +dist.barrier = dummy_barrier + + +def main(args): + very_begining_time = time.time() + enable_pytorch_expandable_segments() + + # init setting + skip_batches = gpc.config.data.skip_batches + total_steps = gpc.config.data.total_steps + valid_every = gpc.config.data.valid_every + label_smoothing = gpc.config.loss.label_smoothing + + # initialize model + model = initialize_model() + model = model.to("cuda") + # print(model) + # for prefix, module in model.named_modules(): + # print(f"prefix: {prefix}, module: {module}", flush=True) + # for prefix, param in model.named_parameters(): + # print(f"prefix: {prefix}, param: {param}", flush=True) + + # initialize isp communicator + isp_communicator = initialize_parallel_communicator(model) + + # initialize loss function + criterion = FlashGPTLMLoss(parallel_output=gpc.config.model.parallel_output, label_smoothing=label_smoothing) + + # initialize the train and validation data loader + # initialize and resume train state + train_state = TrainState(gpc.config, None) + + optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model, isp_communicator) + + # initialize trainer + trainer, train_dl, _, _ = internlm.initialize_trainer( + model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=None, + lr_scheduler=lr_scheduler, + beta2_scheduler=beta2_scheduler, + scheduler_hooks=get_scheduler_hooks(None, optimizer, isp_communicator), + ) + + trainer.train() + total_steps = 10 + + S = gpc.config.data["seq_len"] + micro_num = gpc.config.data["micro_num"] + micro_bsz = gpc.config.data["micro_bsz"] + + batch = [ + { + "input_ids": torch.tensor(micro_num * [list(range(micro_bsz * S))], dtype=torch.int64), + "cu_seqlens": torch.tensor(micro_num * [[0, S]], dtype=torch.int64), + "indexes": torch.tensor(micro_num * [list(range(S))], dtype=torch.int64), + # 'type_ids': torch.tensor(micro_num* [list(range(S))], dtype=torch.int32 ), + }, + torch.tensor(micro_num * [list(range(micro_bsz * S))], dtype=torch.int64), + ] + print(batch) + with initialize_llm_profile(profiling=True, start_time=launch_time()) as prof: + for batch_count in range(train_state.batch_count, total_steps): + s = time.time() + # record the consumed samples in training + train_state.batch_count = batch_count + train_state.num_consumed_samples_in_epoch += len(batch[1]) + + # zero the grads of parameters + trainer.zero_grad() + + if hasattr(gpc.config.model, "num_experts"): + trainer.execute_schedule( + batch, + forward_only=False, + return_loss=True, + return_output_label=False, + ) + else: + trainer.execute_schedule( + batch, + forward_only=False, + return_loss=True, + return_output_label=False, + ) + + if isp_communicator and isp_communicator.enable_memory_pool: + isp_communicator.memory_pool.reset_lazy_pools() + + trainer_result = trainer.step() + print(f"ont step use time: {time.time() -s :.3f} s", flush=True) + prof.step() + import pdb + + pdb.set_trace() + + +def run_loop( + global_bsz, + world_size, + args, + use_fixed_micro_bsz=False, + use_strict_bsz=True, + global_bsz_max=1, + global_bsz_min=1, + debug=True, +): + gpc.load_config(config=Config.from_file(args.config)) + gpc.set_fake_mode(True) + + L = gpc.config.model["num_layers"] + KV_H = gpc.config.model["num_kv_attention_heads"] + S = gpc.config.data["seq_len"] + H = gpc.config.model["hidden_size"] + MICRO_BSZ = gpc.config.data["micro_bsz"] + MICRO_NUM = gpc.config.data["micro_num"] + + pp_search_range = PPIter(world_size, L) + sp_search_range = SPIter(world_size, KV_H) + wp_search_ranges = SPIter(world_size, world_size) + # zp_search_ranges_max = SPIter(world_size, world_size) + solutions_list = [] + algo_list = [AlgoType.ISP, AlgoType.MSP, AlgoType.FSP] + + for pp_i, pp in enumerate(pp_search_range): + for sp_i, sp in enumerate(sp_search_range): + if not use_fixed_micro_bsz: + if use_strict_bsz: + bs_bns = get_bsz_strict(global_bsz, world_size, pp, sp, S) + else: + bs_bns = get_bsz_approximate(global_bsz_max, global_bsz_min, world_size, pp, sp, S) + if bs_bns is None or len(bs_bns) == 0: + if debug: + print( + f"NO solu: pp:{pp} , sp:{sp} can't find micro_bsz/micro_num for" + f"world_size:{world_size}, seq_len:{S}, global bsz range: [{global_bsz_min}-{global_bsz_max}]!", + flush=True, + ) + continue + else: + bs_bns = [(MICRO_BSZ, MICRO_NUM)] + + for micro_bsz, micro_num in bs_bns: + for algo_type in algo_list: + for activation_ckpt in [0, 1]: + for wp_i, wp in enumerate(wp_search_ranges): + if algo_type in [AlgoType.MSP, AlgoType.FSP]: + if wp > 1: + if debug: + print("NO solu: msp, fsp not support wp>1 !", flush=True) + continue # msp, fsp禁掉fsdp,我们目前还不支持 + # zp的搜索空间是被wp限制的,同时他不是按照8的倍数变化的,是,1,2,3, ...这样递增的 + zp_search_range = world_size // pp // sp // wp # 这里的sp对于msp和fsp来说是tp + else: + zp_search_range = world_size // pp // wp # internlm实现的zp和deepspeed不一样,zp是在切wp的基础上再切的 + + try: + assert H % sp == 0, f"embed_dim:{H} must be divisible by sp: {sp}" + assert KV_H % sp == 0, f"num_heads: {KV_H} must be divisible by sp: {sp}" + assert KV_H >= sp, f"num_heads: {KV_H} must bigger then sp: {sp}" + except AssertionError as e: + if debug: + print(f"NO solu: head assert {e}", flush=True) + continue + + for zp_i, zp in enumerate(range(1, zp_search_range + 1)): + # set config + print( + f"activation_ckpt: {activation_ckpt}, micro_num: {micro_num}, micro_bsz: {micro_bsz}, pp: {pp}, wp: {wp}, zp: {zp}, sp: {sp}, {str(algo_type)}", + flush=True, + ) + gpc.config.model["checkpoint"] = activation_ckpt + gpc.config.parallel["zero1"]["size"] = zp + gpc.config.parallel["tensor"]["size"] = sp + gpc.config.parallel["tensor"]["mode"] = str(algo_type) + gpc.config.parallel["pipeline"]["size"] = pp + gpc.config.parallel["weight"]["size"] = wp + + gpc.config.data["micro_num"] = micro_num + gpc.config.data["micro_bsz"] = micro_bsz + + gpc.destroy() + reset_seed() + + launch( + config=gpc.config, + local_rank=0, + rank=0, + world_size=world_size, + host="127.0.0.1", + port=12345, + backend="nccl", + seed=0, + fake_mode=fake_mode, + ) + args_sanity_check() + assert hasattr(gpc, "config") and gpc.config is not None + + with FakeTensorMode(): + main(args) + + +if __name__ == "__main__": + args = parse_args() + hostname = socket.gethostname() + world_size = args.world_size + + fake_mode = "fake_mode" in os.environ + + # initialize distributed environment + print(f"fake_mode: {fake_mode}", flush=True) + + gloab_allocator.init_capcity = 80 * 1024**3 + gloab_allocator.capcity = 80 * 1024**3 + + run_loop(global_bsz=4096 * 1024, world_size=world_size, args=args) From b536cd31325bcade988f757c272d6fb35949ea8c Mon Sep 17 00:00:00 2001 From: wangguoteng <877825076@qq.com> Date: Mon, 5 Aug 2024 19:30:24 +0800 Subject: [PATCH 2/3] add simulation_train_formulaic --- gen_profiler_data.py | 5 + internlm/core/context/parallel_context.py | 349 ++++++++- .../process_group_initializer_simplified.py | 229 ++++++ internlm/core/context/random.py | 15 + internlm/core/parallel/comm/isp.py | 12 +- internlm/core/parallel/comm/tensor.py | 20 +- internlm/core/parallel/comm/utils.py | 4 +- internlm/core/parallel/shard.py | 269 ++++++- .../core/scheduler/no_pipeline_scheduler.py | 1 + internlm/data/build_dataloader.py | 2 +- internlm/data/tokenized/dummy_dataset.py | 2 +- internlm/data/utils.py | 7 +- internlm/initialize/launch.py | 90 ++- internlm/model/metrics.py | 3 +- internlm/model/modeling_internlm2.py | 5 + internlm/model/modules/embedding.py | 5 +- internlm/model/modules/linear.py | 6 + internlm/model/modules/mha.py | 8 + internlm/model/modules/norm.py | 7 + internlm/model/modules/utils.py | 6 +- internlm/model/ops/attention.py | 20 +- internlm/model/ops/cross_entropy.py | 8 + internlm/model/ops/linear.py | 17 +- internlm/model/ops/norm.py | 23 +- internlm/model/registry.py | 8 +- internlm/simulator/common.py | 41 +- internlm/simulator/formulas/__init__.py | 0 internlm/simulator/formulas/algo.py | 386 ++++++++++ internlm/simulator/formulas/comm.py | 271 +++++++ internlm/simulator/formulas/comp.py | 200 +++++ internlm/simulator/formulas/mem.py | 227 ++++++ internlm/simulator/formulas/overlap.py | 68 ++ internlm/simulator/predict_cost_model.py | 325 -------- .../simulator/profiler/benchmark/__init__.py | 34 +- .../simulator/profiler/benchmark/all2all.py | 114 +-- .../profiler/benchmark/all_gather.py | 28 +- .../profiler/benchmark/all_reduce.py | 22 +- .../simulator/profiler/benchmark/broadcast.py | 26 +- .../simulator/profiler/benchmark/linear.py | 7 +- .../profiler/benchmark/multi_head_attn.py | 319 ++++---- .../profiler/benchmark/reduce_scatter.py | 29 +- internlm/simulator/profiler/perf_comm.py | 454 ++++++++++++ internlm/simulator/profiler/profiler.py | 243 ++++-- internlm/simulator/tracker/comm_tracker.py | 85 +-- internlm/solver/optimizer/compatible_adamw.py | 4 +- .../solver/optimizer/hybrid_zero_optim.py | 38 +- internlm/solver/optimizer/utils.py | 14 +- internlm/train/pipeline.py | 3 + internlm/utils/common.py | 9 + simulation_train.py | 154 +++- simulation_train_formulaic.py | 691 ++++++++++++++++++ 51 files changed, 4041 insertions(+), 872 deletions(-) create mode 100644 gen_profiler_data.py create mode 100644 internlm/core/context/process_group_initializer_simplified.py create mode 100644 internlm/simulator/formulas/__init__.py create mode 100644 internlm/simulator/formulas/algo.py create mode 100644 internlm/simulator/formulas/comm.py create mode 100644 internlm/simulator/formulas/comp.py create mode 100644 internlm/simulator/formulas/mem.py create mode 100644 internlm/simulator/formulas/overlap.py delete mode 100644 internlm/simulator/predict_cost_model.py create mode 100644 internlm/simulator/profiler/perf_comm.py create mode 100644 simulation_train_formulaic.py diff --git a/gen_profiler_data.py b/gen_profiler_data.py new file mode 100644 index 000000000..b22ed5b51 --- /dev/null +++ b/gen_profiler_data.py @@ -0,0 +1,5 @@ + +from internlm.simulator.profiler.perf_comm import gen_perf + +if __name__ == "__main__": + gen_perf() \ No newline at end of file diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 6b23fdae6..8fa4b2fe2 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -4,25 +4,29 @@ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context import inspect +import os import random import socket import sys +from dataclasses import dataclass from importlib.machinery import SourceFileLoader from pathlib import Path -from typing import Union +from typing import List, Union import numpy as np import torch import torch.distributed as dist from internlm.accelerator import get_accelerator +from internlm.core.context.process_group_initializer_simplified import Initializer, ParallelMeta from internlm.utils.common import SingletonMeta from internlm.utils.logger import get_logger from internlm.utils.timeout import LLM_NCCL_TIMEOUT from . import process_group_initializer as pgroup_initializer -from .process_group_initializer import ParallelMode +from .process_group_initializer_simplified import ParallelMode from .random import add_seed, get_seeds, set_mode +from internlm.utils.common import get_args IS_REPLICA_ZERO_PARALLEL = "is_replica_zero_parallel" # for isp, with optimizer split in dp group @@ -127,6 +131,18 @@ def from_file(filename: str): return config +@dataclass +class ClusterInfo: + # name: str + name: str + peak_tflops: float + capacity: float + intra_bw: float + inter_bw: float + gpu_per_node: int + node_num: int + + class ParallelContext(metaclass=SingletonMeta): """This class provides interface functions for users to get the parallel context, such as the global rank, the local rank, the world size, etc. of each device. @@ -134,6 +150,14 @@ class ParallelContext(metaclass=SingletonMeta): """ def __init__(self): + # load config from file + self._config = None + self._group_map = {} + self.clusters = [] + self.micro_num_list = None + self._init_attr() + + def _init_attr(self): # distributed settings self._global_ranks = dict() self._local_ranks = dict() @@ -141,9 +165,7 @@ def __init__(self): self._groups = dict() self._cpu_groups = dict() self._ranks_in_group = dict() - - # load config from file - self._config = None + self._all_ranks = dict() # default parallel args, will be overwritten during process group intialization self.world_size = 1 @@ -341,6 +363,13 @@ def get_world_size(self, parallel_mode: ParallelMode): self._check_parallel_mode(parallel_mode) return self._world_sizes.get(parallel_mode, 1) + def get_group_size(self, process_group): + if self.fake_mode: + mode = self._group_map[id(process_group)] + return self.get_world_size(mode) + else: + return dist.get_world_size(process_group) + def get_group(self, parallel_mode: ParallelMode): """Returns the group of the current device for `parallel_mode`. @@ -351,7 +380,10 @@ def get_group(self, parallel_mode: ParallelMode): torch.distributed.ProcessGroup: The group of the current device for `parallel_mode`. """ self._check_parallel_mode(parallel_mode) - return self._groups[parallel_mode] + if parallel_mode not in self._groups: + return None + else: + return self._groups[parallel_mode] def get_ranks_in_group(self, parallel_mode: ParallelMode): """Returns the rank of the current device for `parallel_mode` in the group. @@ -369,7 +401,16 @@ def get_cpu_group(self, parallel_mode: ParallelMode): self._check_parallel_mode(parallel_mode) return self._cpu_groups[parallel_mode] - def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, port: int, use_cpu: bool = False): + def init_global_dist( + self, + rank: int, + world_size: int, + backend: str, + host: str, + port: int, + use_cpu: bool = False, + fake_mode: bool = False, + ): """Initializes the global distributed environment Args: @@ -380,36 +421,60 @@ def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, port (str): the master port for distributed training. use_cpu (bool): whether to set up cpu process group. """ - # initialize the default process group - init_method = f"tcp://[{host}]:{port}" - dist.init_process_group( - rank=rank, - world_size=world_size, - backend=backend, - init_method=init_method, - timeout=LLM_NCCL_TIMEOUT, - ) - # None will give the default global process group for pytorch dist operations - ranks = list(range(world_size)) - if use_cpu: - cpu_group = ( - dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) - if dist.get_backend() != "gloo" - else None + # find cluster info + if "clusters" not in self.config: + nv_info = { + "rank_range": [0, 8], + "peak_tflops": 320, + "capacity": 80 * 1024**3, + "intra_bw": 150, + "inter_bw": 100, + } + self.set_cluster_info("nv_cluster", nv_info) + else: + for cluster in self.config.clusters: + self.clusters.append(ClusterInfo(**cluster)) + + # initialize the default process group + if not fake_mode: + init_method = f"tcp://[{host}]:{port}" + dist.init_process_group( + rank=rank, + world_size=world_size, + backend=backend, + init_method=init_method, + timeout=LLM_NCCL_TIMEOUT, ) + + # None will give the default global process group for pytorch dist operations + ranks = list(range(world_size)) + if use_cpu: + cpu_group = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else None + ) + else: + cpu_group = None + + group = dist.GroupMember.WORLD else: - cpu_group = None - self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL) + ranks = list(range(world_size)) + group, cpu_group = None, None + + self._register_dist(rank, world_size, group, cpu_group, ranks, [list(range(world_size))], ParallelMode.GLOBAL) self._global_ranks[ParallelMode.GLOBAL] = rank - def _register_dist(self, local_rank, world_size, process_group, cpu_group, ranks_in_group, mode): + def _register_dist(self, local_rank, world_size, process_group, cpu_group, ranks_in_group, all_ranks, mode): self._check_parallel_mode(mode) self._local_ranks[mode] = local_rank self._world_sizes[mode] = world_size self._groups[mode] = process_group self._cpu_groups[mode] = cpu_group self._ranks_in_group[mode] = ranks_in_group + self._group_map[id(process_group)] = mode + self._all_ranks[mode] = all_ranks def check_sanity(self): """Checks sanity of the parallel context. @@ -450,6 +515,28 @@ def check_sanity(self): "will introduce redundancy when saving fsdp model ckpts, recommend setting them to same value" ) + if self.tensor_mode == "isp": + assert ( + self.zero1_parallel_size <= self.weight_data_parallel_size + ), f"zero1_size:{self.zero1_parallel_size} should be less than wdp_size:{self.weight_data_parallel_size}" + assert self.weight_data_parallel_size % self.zero1_parallel_size == 0, ( + f"weight_data_parallel_size:{self.weight_data_parallel_size} % " + f"zero1_parallel_size: {self.zero1_parallel_size} != 0" + ) + else: + assert ( + self.zero1_parallel_size <= self.data_parallel_size + ), f"zero1_size:{self.zero1_parallel_size} should be less than dp_size:{self.data_parallel_size}" + assert ( + self.data_parallel_size % self.zero1_parallel_size == 0 + ), f"data_parallel_size:{self.data_parallel_size} % zero1_parallel_size: {self.zero1_parallel_size} != 0" + + assert self.zero1_parallel_size >= 1 and self.zero1_parallel_size <= self.world_size + + assert ( + self.data_parallel_size % self.num_experts == 0 or self.num_experts % self.data_parallel_size == 0 + ), "can not place the experts evenly" + def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str): if key in config: ele = config[key] @@ -462,10 +549,11 @@ def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str) f'{"Parallel configuration does not support this kind of argument, please use int or dict"}' ) - def init_parallel_groups(self): + def init_parallel_groups(self, fake_mode: bool = False): """Initializes the parallel groups.""" # get rank and world size + self.fake_mode = fake_mode rank = self.get_global_rank() world_size = self.get_world_size(ParallelMode.GLOBAL) self.world_size = world_size @@ -488,7 +576,14 @@ def init_parallel_groups(self): self._set_parallel_size_from_config(parallel_config, "tensor", "tensor_parallel_size") self._set_parallel_size_from_config(parallel_config, "pipeline", "pipeline_parallel_size") self._set_parallel_size_from_config(parallel_config, "zero1", "zero1_parallel_size") + + + if get_args().use_simplified_gp_init: + self._init_use_simplified_pg(rank, world_size, parallel_config) + else: + self._init_pg(rank, world_size, parallel_config) + def _init_pg(self, rank, world_size, parallel_config): # the user should not set the data parallel size manually # instead, it should be calculated based on other parallel config self.sequence_parallel_size = self.tensor_parallel_size @@ -496,7 +591,11 @@ def init_parallel_groups(self): self.weight_data_parallel_size = max( 1, self.world_size // self.pipeline_parallel_size // self.weight_parallel_size ) - if isinstance(parallel_config["tensor"], dict) and parallel_config["tensor"]["mode"] == "isp": + + if ( + isinstance(parallel_config["tensor"], dict) + and parallel_config["tensor"]["mode"] == "isp" + ): if self.zero1_parallel_size == -1: self.zero1_parallel_size = self.weight_data_parallel_size self.zero1_parallel_size = max(1, self.zero1_parallel_size) @@ -523,7 +622,8 @@ def init_parallel_groups(self): if "sequence_parallel" not in parallel_config: parallel_config._add_item("sequence_parallel", True) if isinstance(parallel_config["tensor"], int) or ( - isinstance(parallel_config["tensor"], dict) and parallel_config["tensor"]["mode"] == "mtp" + isinstance(parallel_config["tensor"], dict) + and parallel_config["tensor"]["mode"] == "mtp" ): parallel_config["sequence_parallel"] = False @@ -564,7 +664,11 @@ def init_parallel_groups(self): initializers.append(pgroup_initializer.Initializer_Weight_Data(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Data(*initializer_args)) - if isinstance(parallel_config["tensor"], dict) and parallel_config["tensor"]["mode"] == "isp": + initializers.append(pgroup_initializer.Initializer_ISP_Data(*initializer_args)) + if ( + isinstance(parallel_config["tensor"], dict) + and parallel_config["tensor"]["mode"] == TensorParallelMode.isp.name + ): initializers.append(pgroup_initializer.Initializer_Zero1_ISP(*initializer_args)) else: initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args)) @@ -582,6 +686,66 @@ def init_parallel_groups(self): self._register_dist(*args) else: self._register_dist(*parallel_setting) + + def _init_use_simplified_pg(self, rank, world_size, parallel_config): + try: + self.tensor_mode = parallel_config["tensor"]["mode"] + except AttributeError: + self.tensor_mode = "mtp" + + self.num_experts = self.config.model.get("num_experts", 1) + + # the user should not set the data parallel size manually + # instead, it should be calculated based on other parallel config + self.sequence_parallel_size = self.tensor_parallel_size + self.data_parallel_size = max(1, self.world_size // self.pipeline_parallel_size // self.sequence_parallel_size) + self.weight_data_parallel_size = max( + 1, self.world_size // self.pipeline_parallel_size // self.weight_parallel_size + ) + + if self.tensor_mode == "isp": + if self.zero1_parallel_size == -1: + self.zero1_parallel_size = self.weight_data_parallel_size + else: + if self.zero1_parallel_size == -1: + self.zero1_parallel_size = self.data_parallel_size + + # set sequence parallel value + if self.tensor_mode == "mtp": + parallel_config["sequence_parallel"] = False + else: + parallel_config._add_item("sequence_parallel", True) + + # by default, expert_parallel_size equals to data_parallel_size, but if the number of experts is smaller + # than data_parallel_size, set expert_parallel_size to be the number of experts to make sure each device + # has one expert. + self.expert_parallel_size = min(self.data_parallel_size, self.num_experts) + + self.check_sanity() + + parallel_info = { + "tp": ParallelMeta(self.tensor_parallel_size, ParallelMode.TENSOR), + "wp": ParallelMeta(self.weight_parallel_size, ParallelMode.WEIGHT), + "pp": ParallelMeta(self.pipeline_parallel_size, ParallelMode.PIPELINE), + "dp": ParallelMeta(self.data_parallel_size, ParallelMode.DATA), + "zero1": ParallelMeta(self.zero1_parallel_size, ParallelMode.ZERO1), + "wdp": ParallelMeta(self.weight_data_parallel_size, ParallelMode.WEIGHT_DATA), + "ep": ParallelMeta(self.expert_parallel_size, ParallelMode.EXPERT), + "edp": ParallelMeta(self.data_parallel_size // self.expert_parallel_size, ParallelMode.EXPERT_DATA), + "intra_dp": ParallelMeta(-1, ParallelMode.INTRA_DP_SZIE), + "inter_dp": ParallelMeta(-1, ParallelMode.INTER_DP_SZIE), + } + + initializer = Initializer(rank, world_size, self.fake_mode, self.tensor_mode, parallel_info) + parallel_settings = initializer.init_dist_group() + + for name, parallel_setting in parallel_settings.items(): + # print(f"name: {name}, parallel_setting: {parallel_setting}") + if isinstance(parallel_setting, list): + for args in parallel_setting: + self._register_dist(*args) + else: + self._register_dist(*parallel_setting) def is_initialized(self, parallel_mode: ParallelMode): """Returns a boolean value indicating whether `parallel_mode` is initialized @@ -589,14 +753,19 @@ def is_initialized(self, parallel_mode: ParallelMode): """ return parallel_mode in self._groups + def set_fake_mode(self, fake_mode: bool = False): + self.fake_mode = fake_mode + def destroy(self): """Destroys the current distributed parallel environment.""" - for mode, group in self._groups.items(): - if mode is not ParallelMode.GLOBAL: - dist.destroy_process_group(group) - # destroy global process group - dist.destroy_process_group() - self._groups.clear() + if not self.fake_mode: + for mode, group in self._groups.items(): + if mode is not ParallelMode.GLOBAL: + dist.destroy_process_group(group) + # destroy global process group + dist.destroy_process_group() + + self._init_attr() def set_device(self, device_ordinal: int = None): """Sets distributed processes to be bound to devices. @@ -624,7 +793,6 @@ def set_seed(self, seed: int, dpseed_with_tpoffset: bool = False): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) - assert internlm_accelerator.is_available() # data parallel seed are kept the same in the same pipeline stage dp_seed = seed @@ -669,5 +837,112 @@ def set_virtual_pipeline_parallel_size(self, size): def set_virtual_pipeline_parallel_rank(self, rank): self.virtual_pipeline_parallel_rank = rank + def get_parallel_all_ranks(self, mode): + return self._all_ranks[mode] + + def get_cluster_local_rank(self): + devices_offset = 0 + for i, cluster in enumerate(self.clusters): + devices_offset += cluster.gpu_per_node * cluster.node_num + if self.get_global_rank() < devices_offset: + return i + raise ValueError + + def get_model_parallel_size(self): + return self.get_world_size(ParallelMode.PIPELINE) * self.get_world_size(ParallelMode.TENSOR) + + def check_pg_is_intra(self, parallel_mode: ParallelMode): + pg_group_ranks = self.get_ranks_in_group(parallel_mode) + if len(pg_group_ranks) > 8: + return False + else: + min_rank = min(pg_group_ranks) + max_rank = max(pg_group_ranks) + return (max_rank - min_rank) <= 7 + + def same_group_in_one_node(self, parallel_mode: ParallelMode): + """获得一个节点内有多少个相同类型的PG, 在跨节点通信时会存在带宽竞争 + 这里返回的相同PG的数量会乘上每个rank的通信数据量大小 + + Args: + parallel_mode (ParallelMode): + + Returns: + int: 一个节点内相同类型的PG的数量 + """ + pg_group_ranks = self.get_ranks_in_group(parallel_mode) + pg_group_ranks = sorted(pg_group_ranks) + if len(pg_group_ranks) == 1: + return 1 + else: + stride = pg_group_ranks[1] - pg_group_ranks[0] + if stride >= 8: + return 8 + else: + return stride + + # def set_cluster_info(self, name: str, info: dict): + # self.clusters[name] = ClusterInfo(**info) + + def get_cluster_info(self, name: str): + return self.clusters[name] + + def get_cluster_name_from_ip(self): + """ + node_ip_list = [ + 'metax-c500-1', + 'metax-c500-2', + 'nvidia-node-1', + 'nvidia-node-2', + ] + """ + hostname = socket.gethostname() + cluster_name = hostname.split("-")[0] + return cluster_name + + def sort_rank_based_on_ip_and_capacity(self): + Capacity = [] + + def sort_rank(x, y): + x_name = self.get_cluster_name_from_ip(x) + y_name = self.get_cluster_name_from_ip(y) + if x_name == y_name: + return x_name > y_name + else: + x_c = self.clusters[x_name]["capacity"] + y_c = self.clusters[y_name]["capacity"] + return x_c > y_c + + for cluster_name, cluster_info in self.clusters.items(): + peak_tflops.append(cluster_info["peak_tflops"]) + # Alpha.append(cluster_info.rank_range[-1] - cluster_info.rank_range[-1] + 1) + Capacity.append(cluster_info["capacity"]) + + def switch_topology_aware_rank_scheduling(): + """ + Switch topology-aware rank scheduling can optimize the performance of small-scale + collective communications. Currently only supported in Alibaba Cloud. + """ + + local_rank = int(os.environ["LOCAL_RANK"]) + cluster_name = get_cluster_name_from_ip() + + try: + if cluster_name == "Ali": + pass + else: + rank = int(os.environ["MLP_WORKER_RACK_RANK_INDEX"]) * 8 + local_rank + except Exception as e: + logger.error( + f"The switch topology awareness error is reported, the reason is: {e}", + "but don’t worry, this error will not affect normal training.", + "If you train on Alibaba or Volcano Cloud, please contact wangguoteng or lijiaxing", + ) + else: + # If there is no any error, hack torch rank. + os.environ["RANK"] = str(rank) + if local_rank == 0: + logger.info("Successfully bound node switch affinity!") + global_context = ParallelContext() diff --git a/internlm/core/context/process_group_initializer_simplified.py b/internlm/core/context/process_group_initializer_simplified.py new file mode 100644 index 000000000..c1423a5ae --- /dev/null +++ b/internlm/core/context/process_group_initializer_simplified.py @@ -0,0 +1,229 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from copy import deepcopy +from enum import Enum + +import torch +import torch.distributed as dist + +from internlm.utils.timeout import LLM_NCCL_TIMEOUT +from internlm.core.context.process_group_initializer import ParallelMode + +class ParallelMeta: + def __init__(self, parallel_size, mode) -> None: + self.parallel_size = parallel_size + self.mode = mode + + def __str__(self) -> str: + return self.__repr__() + + def __repr__(self) -> str: + return f"{self.mode}, {self.parallel_size}" + + +def determine_intra_inter_size_of_group(one_group_indexs, intra_range=8): + "Determine the inter size and intra size of a rank group." + gourp_size = len(one_group_indexs) + if gourp_size == 1: + return 1, 1 + else: + group_stride = one_group_indexs[1] - one_group_indexs[0] + if group_stride >= intra_range: + return 1, gourp_size + else: + intra_size = intra_range // group_stride + inter_size = gourp_size // intra_size + return max(1, intra_size), max(1, inter_size) + + +class Initializer: + def __init__( + self, + rank: int, + world_size: int, + fake_mode: bool = False, + tensor_mode: str = "fsp", + parallel_info: dict = None, + ): + """Initialize communication groups + + Args: + rank (int): global rank + world_size (int): world size + fake_mode (bool, optional): Whether to create actual NCCL communication + groups.Defaults to False. + tensor_mode (str, optional): ISP/FSP/MSP. Defaults to "fsp". + parallel_info (dict, optional): parallel_info. Defaults to None. + """ + self.rank = rank + self.world_size = world_size + self.fake_mode = fake_mode + self.tensor_mode = tensor_mode + self.parallel_info = parallel_info + + # assert sequence_parallel_size == tensor_parallel_size + super().__init__() + + def init_dist_group(self, use_cpu: bool = False): + parallel_info, world_size = self.parallel_info, self.world_size + + wp_size = parallel_info["wp"].parallel_size + # tp_size = parallel_info["tp"].parallel_size + # pp_size = parallel_info["pp"].parallel_size + wdp_size = parallel_info["wdp"].parallel_size + zero1_size = parallel_info["zero1"].parallel_size + ep_size = parallel_info["ep"].parallel_size + edp_size = parallel_info["edp"].parallel_size + + re_group_args = {} + + # stride_order means the placement priority of PG groups. + stride_order = ["tp", "dp", "pp"] + strides = {} + + def assemble_group(all_ranks, dim_name): + for ranks in all_ranks: + if self.fake_mode or len(all_ranks) == 1: + group, group_cpu = None, None + else: + group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) + if use_cpu: + group_cpu = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else group + ) + else: + group_cpu = None + + if self.rank in ranks: + local_rank = ranks.tolist().index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks.tolist() + + new_all_ranks = [] + for ranks in all_ranks: + new_all_ranks.append(ranks.tolist()) + + return ( + local_rank, + group_world_size, + process_group, + cpu_group, + ranks_in_group, + new_all_ranks, + parallel_info[dim_name].mode, + ) + + def split_orthogonal_sub_group(dim_name, indexs, size, stride): + assert size <= world_size, f"{dim_name} stride: {size} should less then worldsize: {world_size} !" + + indexs = indexs.reshape(-1, stride).T.reshape(-1) + all_ranks = torch.split(indexs, size) + + return indexs, assemble_group(all_ranks, dim_name) + + def split_horizontal_sub_group(dim_name, indexs, size, stride): + assert size <= world_size, f"{dim_name} stride: {size} should less then worldsize: {world_size} !" + + indexs = indexs.reshape(stride, -1).reshape(-1) + all_ranks = torch.split(indexs, size) + + return indexs, assemble_group(all_ranks, dim_name) + + count = 0 + for dim_name in stride_order: + parallel_size = parallel_info[dim_name].parallel_size + if parallel_size == 1: + continue + + if count == 0: + strides[dim_name] = 1 + else: + strides[dim_name] = strides[old_dim_name] * parallel_info[old_dim_name].parallel_size + + father_indexs, group_args = split_orthogonal_sub_group( + dim_name, torch.arange(start=0, end=world_size), size=parallel_size, stride=strides[dim_name] + ) + re_group_args[dim_name] = group_args + + if dim_name == "dp": + """ + "EP, EDP, and ZeRO are auxiliary parallel modes within DP." + """ + if wp_size == 1 and self.tensor_mode != "isp": + re_group_args["zero1"] = split_horizontal_sub_group("zero1", father_indexs, zero1_size, zero1_size)[ + 1 + ] + print(f"re_group_args['zero1']: {re_group_args['zero1']}") + + # MoE expert group is subgroup of data parallel group + if ep_size > 1: + ep_indexs, group_ep_args = split_horizontal_sub_group( + "ep", father_indexs, size=ep_size, stride=ep_size + ) + re_group_args["ep"] = group_ep_args + re_group_args["edp"] = split_orthogonal_sub_group("edp", ep_indexs, edp_size, ep_size)[1] + + one_group_indexs = group_args[4] # one group ranks + intra_dp_size, inter_dp_size = determine_intra_inter_size_of_group(one_group_indexs) + + # It will be used in drawing heatmap. + parallel_info["intra_dp"].parallel_size = intra_dp_size + parallel_info["inter_dp"].parallel_size = inter_dp_size + + # The only parallel group with a higher priority than DP is TP. + # see: stride_order = ["tp", "dp", "pp"] + high_priority_group = parallel_info["tp"].parallel_size + + re_group_args["intra_dp"] = split_horizontal_sub_group( + "intra_dp", father_indexs, size=intra_dp_size, stride=high_priority_group + )[1] + + re_group_args["inter_dp"] = split_orthogonal_sub_group( + "inter_dp", father_indexs, size=inter_dp_size, stride=intra_dp_size + )[1] + + elif dim_name == "tp": + """ + The situation with isp is somewhat complex. When using isp, the head/embedding is partitioned + according to the Megatron-TP method and uses the TP communication group, while other modules + are partitioned according to the WP communication group and reuse the TP communication group + (but perform DeepSpeed-Ulysses instead of Megatron-TP). Therefore, + for head/embedding, their Zero1 communication group is orthogonal to the TP group, + for other modules, their Zero1 communication group is the Wdp communication group + (orthogonal to the WP/TP communication groups). + FIXME: Can this be further simplified? + """ + if self.tensor_mode == "isp": + if wp_size == 1: + re_group_args["zero1"] = split_horizontal_sub_group( + "zero1", father_indexs, zero1_size, zero1_size + )[1] + else: + wp_index, re_group_args["wp"] = split_horizontal_sub_group( + "wp", torch.arange(start=0, end=world_size), wp_size, wp_size + ) + re_group_args["wdp"] = split_orthogonal_sub_group("wdp", wp_index, wdp_size, wp_size)[1] + re_group_args["zero1"] = split_orthogonal_sub_group( + "zero1", father_indexs, zero1_size, wp_size + )[1] + + count += 1 + old_dim_name = dim_name + + for name, info in parallel_info.items(): + if info.parallel_size == 1: + # If the degree of parallelism is 1, for logical consistency, + # we still need to create a logical communication group + re_group_args[name] = assemble_group([torch.tensor([self.rank])], name) + + # If two groups are orthogonal to each other and one group has a parallelism degree of 1, + # then the parallelism degree of the other group is world_size. + if parallel_info["wp"].parallel_size == 1: + re_group_args["wdp"] = tuple(list(deepcopy(re_group_args["dp"]))[0:-1] + [parallel_info["wdp"].mode]) + + return re_group_args diff --git a/internlm/core/context/random.py b/internlm/core/context/random.py index 7a0b138ad..80661319f 100644 --- a/internlm/core/context/random.py +++ b/internlm/core/context/random.py @@ -2,6 +2,7 @@ # -*- encoding: utf-8 -*- # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context +import os from contextlib import contextmanager from torch import Tensor @@ -10,6 +11,8 @@ from .process_group_initializer import ParallelMode +fake_mode = "fake_mode" in os.environ + internlm_accelerator = get_accelerator() @@ -35,11 +38,15 @@ def seed_states(self): def set_state(self, parallel_mode: ParallelMode, state: Tensor): """Sets the state of the seed manager for `parallel_mode`.""" + if fake_mode: + return assert parallel_mode in self._seed_states, f"{parallel_mode} not found in seed manager" self._seed_states[parallel_mode] = state def set_mode(self, parallel_mode: ParallelMode, update_rng_current_mode: bool = True): """Sets the current mode of the seed manager.""" + if fake_mode: + return if update_rng_current_mode and self.current_mode: # save state for current mode self._seed_states[self._current_mode] = internlm_accelerator.get_rng_state() @@ -50,6 +57,8 @@ def set_mode(self, parallel_mode: ParallelMode, update_rng_current_mode: bool = def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrite: bool = False): """Adds a seed to the seed manager for `parallel_mode`.""" + if fake_mode: + return assert isinstance(parallel_mode, ParallelMode), "Invalid ParallelMode" if not overwrite: assert parallel_mode not in self._seed_states, f"Seed for {parallel_mode} exists" @@ -63,6 +72,8 @@ def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrite: bool = Fal internlm_accelerator.set_rng_state(current_state) def reset(self): + if fake_mode: + return self._current_mode = None self._seeds = {} self._seed_states = {} @@ -131,3 +142,7 @@ def seed(parallel_mode: ParallelMode): yield _SEED_MANAGER.set_mode(parallel_mode) finally: _SEED_MANAGER.set_mode(current_mode) + + +def reset_seed(): + _SEED_MANAGER.reset() diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index 14637912b..c3aed5cb4 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -523,7 +523,7 @@ def register_prerequisite_for_forward_prefetch_hooks(self, prerequisite_func: Ca def weight_hook( self, tensor: torch.Tensor, async_op: bool = False, module: nn.Module = None, is_bias: bool = False ) -> torch.Tensor: - if dist.get_world_size(self.process_group) <= 1: + if gpc.get_group_size(self.process_group) <= 1: return tensor if not self.overlap: @@ -545,7 +545,7 @@ def grad_hook( reduce_op: dist.ReduceOp = dist.ReduceOp.AVG, is_bias: bool = False, ) -> Tuple[torch.Tensor, AsyncCommHandle]: - if dist.get_world_size(self.process_group) <= 1: + if gpc.get_group_size(self.process_group) <= 1: return tensor, DUMMY_HANDLE_CONST if not self.overlap: @@ -573,7 +573,7 @@ def grad_hook( result, handle = ( self._get_constant_zero( ( - tensor.shape[0] // dist.get_world_size(self.process_group), + tensor.shape[0] // gpc.get_group_size(self.process_group), *tensor.shape[1:], ) ), @@ -634,10 +634,10 @@ def forward(ctx, group: dist.ProcessGroup, input_: torch.Tensor, scatter_idx: in ctx.scatter_idx = scatter_idx ctx.gather_idx = gather_idx - if dist.get_world_size(group) <= 1: + if gpc.get_group_size(group) <= 1: return input_ - seq_world_size = dist.get_world_size(group) + seq_world_size = gpc.get_group_size(group) input_list = [t.contiguous() for t in torch.tensor_split(input_, seq_world_size, scatter_idx)] output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] @@ -647,7 +647,7 @@ def forward(ctx, group: dist.ProcessGroup, input_: torch.Tensor, scatter_idx: in @staticmethod def backward(ctx, *grad_output: torch.Tensor) -> Tuple[None, torch.Tensor, None, None]: - if dist.get_world_size(ctx.group) <= 1: + if gpc.get_group_size(ctx.group) <= 1: return (None, *grad_output, None, None) return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None) diff --git a/internlm/core/parallel/comm/tensor.py b/internlm/core/parallel/comm/tensor.py index 47086ad96..f557ddac6 100644 --- a/internlm/core/parallel/comm/tensor.py +++ b/internlm/core/parallel/comm/tensor.py @@ -127,7 +127,7 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T """ all reduce grad_input only for column parallel linear when backward. """ - if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.ROW: + if gpc.get_group_size(self._process_group) <= 1 or self._role == LinearRole.ROW: return grad_input, DUMMY_HANDLE_CONST return all_reduce_raw(grad_input, process_group=self._process_group, async_op=async_op) @@ -136,7 +136,7 @@ def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[tor """ all reduce output only for row parallel linear when forward. """ - if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: + if gpc.get_group_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: return output, DUMMY_HANDLE_CONST return all_reduce_raw(output, process_group=self._process_group, async_op=async_op) @@ -173,7 +173,7 @@ def input_hook( # 2. row parallel linear should not allgather input. # 3. column parallel linear should not allgather input if save_total_input_as_activation and backward is True. if ( - dist.get_world_size(self._process_group) <= 1 + gpc.get_group_size(self._process_group) <= 1 or self._role == LinearRole.ROW or (is_forward is False and self._save_total_input) ): @@ -187,7 +187,7 @@ def grad_output_hook( """ all gather grad_output only for row parallel linear when backward. """ - if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: + if gpc.get_group_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: return grad_output, DUMMY_HANDLE_CONST return all_gather_raw(grad_output, process_group=self._process_group, async_op=async_op, gather_dim=_GATHER_DIM) @@ -196,7 +196,7 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T """ reduce scatter grad_input only for column parallel linear when backward. """ - if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.ROW: + if gpc.get_group_size(self._process_group) <= 1 or self._role == LinearRole.ROW: return grad_input, DUMMY_HANDLE_CONST return reduce_scatter_raw( @@ -207,7 +207,7 @@ def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[tor """ reduce scatter output only for row parallel linear when forward. """ - if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: + if gpc.get_group_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: return output, DUMMY_HANDLE_CONST return reduce_scatter_raw(output, process_group=self._process_group, async_op=async_op, reduce_dim=_REDUCE_DIM) @@ -230,7 +230,7 @@ def grad_output_hook( """ split grad_output if retain_out_sharded is False. """ - if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1: + if self._retain_out_sharded or gpc.get_group_size(self._process_group) <= 1: return grad_output, DUMMY_HANDLE_CONST return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1) @@ -241,7 +241,7 @@ def output_hook( """ all gather output for head layer if retain_out_sharded is False. """ - if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1: + if self._retain_out_sharded or gpc.get_group_size(self._process_group) <= 1: return output, DUMMY_HANDLE_CONST return _gather(output, parallel_mode=self._parallel_mode, dim=-1) @@ -271,7 +271,7 @@ def grad_output_hook( """ split grad_output if retain_out_sharded is False. """ - if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1: + if self._retain_out_sharded or gpc.get_group_size(self._process_group) <= 1: return grad_output, DUMMY_HANDLE_CONST return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1) @@ -283,7 +283,7 @@ def output_hook( """ all gather output for head layer if retain_out_sharded is False. """ - if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1: + if self._retain_out_sharded or gpc.get_group_size(self._process_group) <= 1: return output, DUMMY_HANDLE_CONST return _gather(output, parallel_mode=self._parallel_mode, dim=-1) diff --git a/internlm/core/parallel/comm/utils.py b/internlm/core/parallel/comm/utils.py index dbfeb3fda..8f3a39c97 100644 --- a/internlm/core/parallel/comm/utils.py +++ b/internlm/core/parallel/comm/utils.py @@ -181,7 +181,7 @@ def all_gather_raw( gather_dim: int = 0, memory_pool_allocator: Callable = None, ): - world_size = dist.get_world_size(process_group) + world_size = gpc.get_group_size(process_group) if world_size <= 1: return input_, None @@ -204,7 +204,7 @@ def reduce_scatter_raw( reduce_dim: int = 0, memory_pool_allocator: Callable = None, ): - world_size = dist.get_world_size(process_group) + world_size = gpc.get_group_size(process_group) assert input_.shape[reduce_dim] % world_size == 0 if world_size <= 1: diff --git a/internlm/core/parallel/shard.py b/internlm/core/parallel/shard.py index 33c187ec5..7cbcdc9e4 100644 --- a/internlm/core/parallel/shard.py +++ b/internlm/core/parallel/shard.py @@ -2,9 +2,11 @@ shard strategies for parallel """ -from typing import Callable +from typing import Callable, List +import numpy as np import torch +import torch.distributed as dist from torch import nn from internlm.core.context import ParallelMode @@ -72,6 +74,271 @@ def partition_uniform(num_items: int, pipeline_parallel_size: int, num_chunks: i return parts +def revise_load_balance_v1(balance_list, min_value): + for i in range(len(balance_list)): + # 如果layer数量不够PP切分,先尝试从他后面一个cluster的PP中借一个layer + count, sentinel = 1, balance_list[i] + while balance_list[i] < min_value: + if count == len(balance_list): + if sentinel == balance_list[i]: + raise RuntimeError(f"Unable to continue splitting, balance_list: {balance_list}") + count, sentinel = 1, balance_list[i] + + next_cluster = (i + count) % len(balance_list) + if balance_list[next_cluster] - 1 >= min_value: + balance_list[next_cluster] -= 1 + balance_list[i] += 1 + count += 1 + + +def find_most_max_deviation(distributed, adjust_value): + """返回和分布最大偏离的idx(注意不一定是绝对值的最大和最小) + + Args: + distributed (_type_): _description_ + adjust_value (_type_): _description_ + + Returns: + _type_: _description_ + """ + # import pdb; pdb.set_trace(); + sums_act = sum(adjust_value) + sums_dist = sum(distributed) + + if abs(sums_dist - 1) >= 1e-6: + distributed = list(map(lambda x: x / sums_dist, distributed)) + if abs(sums_act - 1) >= 1e-6: + adjust_value = list(map(lambda x: x / sums_act, adjust_value)) + + max_positive_diff = 0 + max_negative_diff = 0 + + max_pos_idx = -1 + max_neg_idx = -1 + + for i, zips in enumerate(zip(distributed, adjust_value)): + a, b = zips + diff = abs(a - b) + if b >= a: + if diff > max_positive_diff: + max_positive_diff = diff + max_pos_idx = i + else: + if diff > max_negative_diff: + max_negative_diff = diff + max_neg_idx = i + + return max_pos_idx, max_neg_idx + + +def find_min_max_value(adjust_value): + min_value, max_value = float("inf"), -float("inf") + min_idx, max_idx = -1, -1 + for i in range(len(adjust_value)): + if adjust_value[i] > max_value: + max_value = adjust_value[i] + max_idx = i + if adjust_value[i] < min_value: + min_value = adjust_value[i] + min_idx = i + return min_idx, max_idx + + +def greedy_filling(min_value, total_diff, balance_list, max_pos_idx, max_neg_idx): + if total_diff > 0: # 超出的部分需要减去,比如layer数量 + # 但是每个clusater有自身的下界要求 + if balance_list[max_pos_idx] - 1 >= min_value: + balance_list[max_pos_idx] -= 1 + total_diff -= 1 + return False + else: + _, maxidx = find_min_max_value(balance_list) + if balance_list[maxidx] - 1 < min_value: + raise ValueError(f"Unable to continue splitting, balance_list: {balance_list}, min_value: {min_value}") + balance_list[maxidx] -= 1 + total_diff -= 1 + return False + + if total_diff < 0: # 不足的部分我们直接补齐,但是一般来说我们没有上界的要求 + balance_list[max_neg_idx] += 1 + total_diff += 1 + return False + + return True + + +def PP_mem_balance_filling(min_value, total_diff, balance_list): + # 如果是PP,如果需要 + if total_diff > 0: + for i in range(len(balance_list)): + if total_diff > 0: + if balance_list[i] - 1 < min_value: + raise ValueError( + f"Unable to continue splitting, balance_list: {balance_list}, min_value: {min_value}" + ) + balance_list[i] -= 1 + total_diff -= 1 + else: + return True + + if total_diff < 0: + for i in range(len(balance_list) - 1, 0, -1): + if total_diff < 0: + balance_list[i] += 1 + total_diff += 1 + else: + return True + + return False + + +def revise_load_balance_v2( + base_value: int, min_value: int, distributed: List[float], relax_boundary: bool = False +) -> List[int]: + """_summary_ + + Args: + base_value (int): 基准值,各个cluster的具体值根据该值上下浮动 + min_value (int): 每一项取值的下界 + distributed (List[float]): 负载均衡的分布 + total_sums (List[int]): 原始输入的总和 + relax_boundary (bool, optional): 是否可以放松总和的上界. Defaults to False. + 如果对Layer进行负载均衡则必须为False,如果对micro_num则为True + + Raises: + ValueError: 某些情况下负载均衡是不可行的,比如 micro_num = 1 等情况 + 这个时候我们会放弃负载均衡,沿用用户原始的配置 + + Returns: + List[int]: 负载均衡后的结果 + """ + # 检查每一项目 + all_nums = len(distributed) * base_value + sums_dist = sum(distributed) + distributed = list(map(lambda x: x / sums_dist, distributed)) + balance_list = list(map(lambda ratio: round(all_nums * ratio), distributed)) + + while True: + total_diff = sum(balance_list) - all_nums + + max_pos_idx, max_neg_idx = find_most_max_deviation(distributed, balance_list) + + # 检查总和是否等于初始值(layer数量和global bsz),尝试进行靠拢 + if not greedy_filling(min_value, total_diff, balance_list, max_pos_idx, max_neg_idx): + continue + + # 检查每一项是否满足下界,在尽量不改变sum的情况下继续微调 + if balance_list[max_neg_idx] < min_value: + # 从最大正偏移处借一个值 + if balance_list[max_pos_idx] - 1 >= min_value: + balance_list[max_pos_idx] -= 1 + balance_list[max_neg_idx] += 1 + else: + # 如果不能再借 + if not relax_boundary: + raise ValueError(f"Unable to continue splitting, balance_list: {balance_list}") + else: + balance_list[max_neg_idx] += 1 + relax_boundary = False + else: + break + + return balance_list + + +def weighted_sum(weight, value): + w_sums = sum(weight) + if abs(w_sums - 1) >= 1e-6: + weight = list(map(lambda x: x / w_sums, weight)) + + sums = 0 + for w, v in zip(weight, value): + sums += w * v + return sums + + +def cluster_load_balance(): + + peak_tflops = [] + capacities = [] + gpus_per_cluster = [] + + for cluster_info in gpc.clusters: + peak_tflops.append(cluster_info.peak_tflops) + capacities.append(cluster_info.capacity) + gpus = cluster_info.node_num * cluster_info.gpu_per_node + gpus_per_cluster.append(gpus) + + # capacity_first_cluster = sorted(cluster_list, key=lambda x: x.capacity) + # tflops_first_cluster = sorted(cluster_list, key=lambda x: x.peak_tflops) + + global_bsz = gpc.config.data.global_bsz + micro_bsz = gpc.config.data.micro_bsz + seq_len = gpc.config.data.seq_len + dp_size = gpc.get_world_size(ParallelMode.DATA) + cluster_name = gpc.clusters[gpc.get_cluster_local_rank()].name + rank = gpc.get_global_rank() + + # 根据单卡的峰值tflops来确定micro_num比例 + tflops = [] + for cluster in gpc.clusters: + tflops.append(cluster.peak_tflops) + + # 负载均衡 + if gpc.get_world_size(ParallelMode.PIPELINE) == 1: + # import pdb; pdb.set_trace() + micro_num_all = global_bsz // (micro_bsz * seq_len) + micro_num = micro_num_all // dp_size + + min_value = 1 + base_value = micro_num + total_sums = micro_num_all # TODO: 需不要考虑dp的大小 + relax_boundary = True + + else: + pp_size = gpc.get_world_size(ParallelMode.PIPELINE) + assert len(gpc.clusters) % 2 == 0 + layer_per_cluster = gpc.config.model.layer_num // (pp_size // len(gpc.clusters)) + + min_value = pp_size // len(gpc.clusters) # 每个pp stage至少分到一层 + base_value = layer_per_cluster + total_sums = gpc.config.model.layer_num + relax_boundary = False + + balance_results = revise_load_balance_v2( + base_value=base_value, + min_value=min_value, + distributed=tflops, + relax_boundary=relax_boundary, + ) + + new_sum = sum(balance_results) + old_sum = base_value * len(gpc.clusters) + if new_sum != old_sum: + if relax_boundary: + print(f"Warrning: allow relax constraints, now/old: {new_sum}/{old_sum}") + else: + raise ValueError(f"Unexcepted no relax_boundary but new_sum != base_value: {new_sum}/{old_sum}") + + if gpc.get_world_size(ParallelMode.PIPELINE) == 1: + gpc.config.data.micro_num = balance_results[gpc.get_cluster_local_rank()] + gpc.micro_num_list = np.array(balance_results) + + print( + f"Rank: {rank}, cluster_name: {cluster_name}, balance_results: {balance_results}, \ +balance micro_num: {gpc.config.data.micro_num}" + ) + else: + new_layer_num = balance_results[gpc.get_cluster_local_rank()] + + print( + f"Rank: {rank}, cluster_name: {cluster_name},balance_results: {balance_results}, \ +balance PP layer: {new_layer_num}" + ) + + return balance_results + + def pipeline_parallel_sharding_wrapper( num_layers: int, num_chunks: int, model_builder: Callable, device: torch.device, **kwargs ): diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 339a404e3..ab20d38db 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -207,6 +207,7 @@ def forward_backward_step( for _current_accum_step in range(self._grad_accum_size): if engine.optimizer is not None: + engine.optimizer.current_accum_step = _current_accum_step if _current_accum_step == self._grad_accum_size - 1: engine.optimizer.skip_grad_reduce = False else: diff --git a/internlm/data/build_dataloader.py b/internlm/data/build_dataloader.py index c2c0ea690..dd6dde8a4 100644 --- a/internlm/data/build_dataloader.py +++ b/internlm/data/build_dataloader.py @@ -59,7 +59,7 @@ def get_tokenized_train_loader_items(data_cfg): folder=data_cfg.train_folder, packed_length=data_cfg.packed_length, max_length_per_sample=data_cfg.seq_len, - show_progress=dist.get_rank() == 0, + show_progress=gpc.get_global_rank() == 0, min_length=data_cfg.get("min_length", 0), min_length_dict=data_cfg.get("min_length_dict", None), pack_sample_into_one=data_cfg.get("pack_sample_into_one", False), diff --git a/internlm/data/tokenized/dummy_dataset.py b/internlm/data/tokenized/dummy_dataset.py index ab5012f48..06d741844 100644 --- a/internlm/data/tokenized/dummy_dataset.py +++ b/internlm/data/tokenized/dummy_dataset.py @@ -15,7 +15,7 @@ class RandomDataset(Dataset): """ - def __init__(self, num_samples=10000, max_len=1024, fixed_seqlen: bool = False) -> None: + def __init__(self, num_samples=500, max_len=1024, fixed_seqlen: bool = False) -> None: super().__init__() rng = np.random.RandomState(1999) max_num = rng.randint(1, 30, size=(num_samples,)) diff --git a/internlm/data/utils.py b/internlm/data/utils.py index 91585a707..59ca20134 100644 --- a/internlm/data/utils.py +++ b/internlm/data/utils.py @@ -9,6 +9,8 @@ from internlm.core.context import global_context as gpc from internlm.core.parallel.comm.utils import _split +fake_mode = "fake_mode" in os.environ + def get_dataset_type_ids_map(path): dirlist = list(os.listdir(path)) @@ -67,7 +69,10 @@ def packed_data_normalizer(data, label): data["indexes"] = data["indexes"][0] data["cu_seqlens"] = data["cu_seqlens"][0].squeeze(0) - data["max_seqlen"] = (data["cu_seqlens"][1:] - data["cu_seqlens"][:-1]).max().item() + if fake_mode: + data["max_seqlen"] = gpc.config.data["seq_len"] + else: + data["max_seqlen"] = (data["cu_seqlens"][1:] - data["cu_seqlens"][:-1]).max().item() # Move to parallel package for standardization if gpc.config.parallel.sequence_parallel and gpc.config.parallel["tensor"].get("mode", "mtp") == "isp": diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index df8356818..9a719cf8f 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -12,7 +12,7 @@ from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import Config from internlm.core.context import global_context as gpc -from internlm.core.context.process_group_initializer import ParallelMode +from internlm.core.context.process_group_initializer_simplified import ParallelMode from internlm.utils.common import get_master_node from internlm.utils.gputest import warmup_process_group from internlm.utils.logger import get_logger @@ -32,7 +32,10 @@ internlm_accelerator = get_accelerator() -def get_default_parser(): +_INTERNEVO_PARSER = None + + +def add_default_arguments(): """Reads user command line and uses an argument parser to parse the input arguments. Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed. @@ -40,7 +43,7 @@ def get_default_parser(): Parser: Returns the parser with the default arguments, the user may add customized arguments into this parser. """ parser = argparse.ArgumentParser() - parser.add_argument("--config", type=str, help="path to the config file") + parser.add_argument("--config", type=str, default="", help="path to the config file") parser.add_argument( "--launcher", type=str, @@ -60,10 +63,44 @@ def get_default_parser(): parser.add_argument( "--disable_volc_topology", default=False, action="store_true", help="disable volc switch topology." ) + parser.add_argument("--fake_mode", default=False, action="store_true", help="Simulation run.") + return parser -def args_sanity_check(): +def add_simulator_arguments(parser): + group = parser.add_argument_group(title="simulator") + group.add_argument("--model_size", type=int, default=7, help="model parameters.") + group.add_argument( + "--draw_heatmap", default=False, action="store_true", help="wheater to draw model communication heatmap." + ) + group.add_argument("--draw_heatmap_path", type=str, default="./comm_matrix") + group.add_argument( + "--run_all_solu", + action="store_true", + default=False, + help="Whether to perform a full solution. \ +If not, it will only calculate the estimated TGS based on the parallel configuration provided in the config.", + ) + group.add_argument("--global_batch_size", type=int, default=4 * 1024**2, help="Global batch size limitation.") + group.add_argument( + "--pre_profiling_data_path", type=str, help="The path to pre-profiled performance data on the target cluster." + ) + group.add_argument("--use_simplified_gp_init", action="store_true", default=False) + return parser + + +def get_default_parser(): + global _INTERNEVO_PARSER + if _INTERNEVO_PARSER is None: + _INTERNEVO_PARSER = add_default_arguments() + + _INTERNEVO_PARSER = add_simulator_arguments(_INTERNEVO_PARSER) + + return _INTERNEVO_PARSER + + +def args_sanity_check(verbose=False): assert gpc.config is not None, "config is not load!" if "JOB_NAME" not in gpc.config: @@ -114,7 +151,7 @@ def args_sanity_check(): assert data.seq_len is not None, "'seq_len' must be given a value" assert data.micro_bsz is not None, "'micro_bsz' must be given a value" - if "packed_length" in data and gpc.is_rank_for_log(): + if "packed_length" in data and gpc.is_rank_for_log() and verbose: logger.warning("packed_length would be ignored and will be setted as seq_len * micro_bsz.") data._add_item("packed_length", data.seq_len * data.micro_bsz) @@ -127,7 +164,7 @@ def args_sanity_check(): if "gradient_accumulation" not in data: data._add_item("gradient_accumulation", data.micro_num) - if gpc.is_rank_for_log(): + if gpc.is_rank_for_log() and verbose: logger.info(f"gradient_accumulation size will be setted to {data.micro_num}.") else: if pp == 1: @@ -167,7 +204,7 @@ def args_sanity_check(): if "fixed_random_dataset_seqlen" not in data: data._add_item("fixed_random_dataset_seqlen", True) - if gpc.is_rank_for_log(): + if gpc.is_rank_for_log() and verbose: logger.info("+" * 15 + " Data Info " + "+" * 15) # pylint: disable=W1201 logger.info(f"seq_len: {data.seq_len}") logger.info(f"micro_num: {data.micro_num}") @@ -197,7 +234,7 @@ def args_sanity_check(): assert "save_ckpt_folder" in ckpt prefix_list = ["boto3:", "volc:", "oss2:"] if not any(ckpt.save_ckpt_folder.startswith(prefix) for prefix in prefix_list): - if gpc.is_rank_for_log(): + if gpc.is_rank_for_log() and verbose: logger.warning( "Storing ckpt on file system does not support asynchronous storage, will use sync save!" ) @@ -230,7 +267,7 @@ def args_sanity_check(): # to auto-load latest checkpoint. ckpt._add_item("auto_resume", True) - if gpc.is_rank_for_log(): + if gpc.is_rank_for_log() and verbose: logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201 logger.info(f"is enable save ckpt: {ckpt.enable_save_ckpt}") logger.info(f"save_ckpt_folder: {ckpt.save_ckpt_folder}") @@ -248,7 +285,7 @@ def args_sanity_check(): "resume_tb_folder", os.environ["resume_tb_folder"] if "resume_tb_folder" in os.environ else None ) - if gpc.is_rank_for_log(): + if gpc.is_rank_for_log() and verbose: logger.info(f"tensorboard_folder: {gpc.config.tensorboard_folder}") logger.info(f"resume_tb_folder: {gpc.config.resume_tb_folder}") @@ -257,7 +294,7 @@ def args_sanity_check(): torch.backends.cudnn.deterministic = gpc.config.get("cudnn_deterministic", False) clip_grad_norm = gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0) - if gpc.is_rank_for_log(): + if gpc.is_rank_for_log() and verbose: logger.info("+" * 15 + " Other Info " + "+" * 15) # pylint: disable=W1201 logger.info(f"cudnn.benchmark: {torch.backends.cudnn.benchmark }") logger.info(f"cudnn.deterministic: {torch.backends.cudnn.deterministic }") @@ -279,13 +316,14 @@ def args_sanity_check(): torch.backends.cuda.matmul.allow_tf32 = True gpc.config.model.dtype = torch.float32 else: - assert gpc.config.model.dtype in [ - "torch.float16", - "torch.half", - "torch.bfloat16", - "torch.float32", - "torch.tf32", - ] + if isinstance(gpc.config.model.dtype, str): + assert gpc.config.model.dtype in [ + "torch.float16", + "torch.half", + "torch.bfloat16", + "torch.float32", + "torch.tf32", + ] if "checkpoint" in model: if model.checkpoint is True: @@ -297,7 +335,7 @@ def args_sanity_check(): model.checkpoint >= 0 and model.checkpoint <= 1 ), f'model.checkpoint: "{model.checkpoint}" should >=0 and <=1' - if gpc.is_rank_for_log(): + if gpc.is_rank_for_log() and verbose: logger.info("+" * 15 + " Model Info " + "+" * 15) # pylint: disable=W1201 logger.info(f"Model: {gpc.config.model}") @@ -321,7 +359,7 @@ def args_sanity_check(): # Try to change user setting if internlm_accelerator.get_accelerator_backend() is not AcceleratorType.GPU: gpc.config.model.update({"parallel_output": False}) - if old_parallel_output is True and gpc.is_rank_for_log(): + if old_parallel_output is True and gpc.is_rank_for_log() and verbose: logger.warning( "'parallel_output' is converted from 'True' to 'False'." "Because 'parallel_output' only support by FlashCrossEntropyLoss." @@ -426,7 +464,7 @@ def args_sanity_check(): alert = gpc.config.monitor.alert - if alert.enable_feishu_alert and not alert.feishu_alert_address and gpc.is_rank_for_log(): + if alert.enable_feishu_alert and not alert.feishu_alert_address and gpc.is_rank_for_log() and verbose: logger.warning("alert is enable but alert_address is not set") optim_ckpt = gpc.config.hybrid_zero_optimizer @@ -437,7 +475,7 @@ def args_sanity_check(): optim_ckpt._add_item("overlap_sync_grad", False) if "overlap_sync_param" not in optim_ckpt: optim_ckpt._add_item("overlap_sync_param", False) - if gpc.is_rank_for_log(): + if gpc.is_rank_for_log() and verbose: logger.info( f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}" ) @@ -469,6 +507,7 @@ def launch( backend: str = "nccl", local_rank: int = None, seed: int = 1024, + fake_mode: bool = False, ): """This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input arguments are not given. Then initialize and set distributed environment by calling global_context's functions. @@ -500,16 +539,19 @@ def launch( gpc.load_config(config) # init default process group - gpc.init_global_dist(rank, world_size, backend, host, port) + gpc.init_global_dist(rank, world_size, backend, host, port, fake_mode=fake_mode) # init process groups for different parallel modes from config - gpc.init_parallel_groups() + gpc.init_parallel_groups(fake_mode) # set cuda device if internlm_accelerator.is_available(): # if local rank is not given, calculate automatically gpc.set_device(local_rank) + if fake_mode: + return + gpc.set_seed(seed) warmup_process_group() diff --git a/internlm/model/metrics.py b/internlm/model/metrics.py index 54cc41ba2..f6ffc4ed9 100644 --- a/internlm/model/metrics.py +++ b/internlm/model/metrics.py @@ -3,6 +3,7 @@ import torch from internlm.accelerator import AcceleratorType, get_accelerator +from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.model.ops.cross_entropy import new_cross_entropy from internlm.utils.common import SchedulerHook, get_current_device @@ -98,7 +99,7 @@ def __init__(self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str self.total_log_probs = torch.Tensor([0]).to(device=device) self.tp_pg = tp_pg self.dp_pg = dp_pg - self.tp_local_rank = torch.distributed.get_rank(self.tp_pg) + self.tp_local_rank = gpc.get_local_rank(ParallelMode.TENSOR) self.tokenizer = tokenizer self.total_bytes = torch.Tensor([0]).to(device=device).view(1) self.batch_shift = 0 diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/modeling_internlm2.py index c3b894120..c54b41044 100644 --- a/internlm/model/modeling_internlm2.py +++ b/internlm/model/modeling_internlm2.py @@ -22,6 +22,7 @@ convert_attn_args_to_kwargs, convert_attn_kwargs_to_args, ) +from internlm.simulator.tracker.module_tracker import ModuleTracker from internlm.solver.activation_checkpoint import activation_checkpoint from internlm.utils.logger import get_logger @@ -118,6 +119,10 @@ def __init__( self.use_dynamic_ntk_rope = use_dynamic_ntk_rope head_dim = hidden_size // num_attention_heads + self.module_tracker = ModuleTracker(self._get_name()) + self.register_forward_pre_hook(self.module_tracker.fwd_pre_hook, with_kwargs=True) + self.register_full_backward_hook(self.module_tracker.bwd_pre_hook) + self.attention = GQA( embed_dim=hidden_size, num_heads=num_attention_heads, diff --git a/internlm/model/modules/embedding.py b/internlm/model/modules/embedding.py index fa922daaa..76ba5d556 100644 --- a/internlm/model/modules/embedding.py +++ b/internlm/model/modules/embedding.py @@ -10,6 +10,7 @@ from internlm.core.context import global_context as gpc from internlm.model.ops.rotary_emb import apply_rotary_emb +from internlm.utils.common import get_current_device class Embedding1D(nn.Module): @@ -43,7 +44,9 @@ def __init__( self.embed_kwargs = kwargs embed_dim_per_partition = embedding_dim // gpc.tensor_parallel_size - self.weight = nn.Parameter(torch.empty((num_embeddings, embed_dim_per_partition), dtype=dtype)) + self.weight = nn.Parameter( + torch.empty((num_embeddings, embed_dim_per_partition), dtype=dtype, device=get_current_device()) + ) def forward(self, input_: Tensor) -> Tensor: return F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 443539703..2af623f7e 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -343,6 +343,12 @@ def __init__( else: super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) + from internlm.simulator.tracker.module_tracker import ModuleTracker + + self.module_tracker = ModuleTracker(self._get_name()) + self.register_forward_pre_hook(self.module_tracker.fwd_pre_hook, with_kwargs=True) + self.register_full_backward_hook(self.module_tracker.bwd_pre_hook) + def forward(self, input: torch.Tensor) -> torch.Tensor: # pylint: disable=W0622 _class_name = self.__class__.__name__ assert self._communicator is not None, f"{_class_name} should register with a communicator first." diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index e06697260..806406590 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -12,6 +12,7 @@ from internlm.model.modules.linear import new_linear from internlm.model.modules.utils import update_kv_cache from internlm.model.ops.attention import CrossAttention, SelfAttention +from internlm.simulator.tracker.module_tracker import ModuleTracker from internlm.utils.logger import get_logger logger = get_logger(__file__) @@ -91,6 +92,10 @@ def __init__( assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" + self.module_tracker = ModuleTracker(self._get_name()) + self.register_forward_pre_hook(self.module_tracker.fwd_pre_hook, with_kwargs=True) + self.register_full_backward_hook(self.module_tracker.bwd_pre_hook) + if self.rotary_emb_dim > 0: self.rotary_emb = new_rotary_embedding( self.rotary_emb_dim, @@ -422,9 +427,12 @@ def _training(self, x, **kwargs): Arguments: x: (batch, seqlen, hidden_dim) """ + print(f"x shape: {x.shape}", flush=True) + # wqkv if self.enable_qkv_fusion: qkv = self.wqkv(x) + print(f"qkv shape: {qkv.shape}", flush=True) qkv = rearrange(qkv, "b s (h gs d) -> b s h gs d", gs=self.q_per_kv + 2, d=self.head_dim) q, k, v = (qkv[..., : self.q_per_kv, :], qkv[..., -2, :], qkv[..., -1, :]) q = rearrange(q, "b s h gs d -> b s (h gs) d") diff --git a/internlm/model/modules/norm.py b/internlm/model/modules/norm.py index b94cdd435..d478a4172 100644 --- a/internlm/model/modules/norm.py +++ b/internlm/model/modules/norm.py @@ -2,6 +2,7 @@ layer norm modules """ +import os from typing import List, Union import torch @@ -9,10 +10,16 @@ from internlm.model.ops.norm import RMSNorm +# from internlm.simulator.fake_ops import FakeLayerNorm Shape = Union[int, List[int], torch.Size] +fake_mode = "fake_mode" in os.environ + + def new_layer_norm(norm_type: str, normalized_shape: Shape, eps: float = 1e-5): + # if fake_mode: + # return FakeLayerNorm(normalized_shape, eps) if norm_type == "rmsnorm": return RMSNorm(normalized_shape, eps) else: # default: layernorm diff --git a/internlm/model/modules/utils.py b/internlm/model/modules/utils.py index dd86cb1cb..e1a49bf7f 100644 --- a/internlm/model/modules/utils.py +++ b/internlm/model/modules/utils.py @@ -1,6 +1,8 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import os + import torch import torch.nn.functional as F from einops import rearrange @@ -20,7 +22,9 @@ def Silu(w1_o, w2_o): return F.silu(w1_o) * w2_o -Silu = torch.jit.script(Silu) +fake_mode = "fake_mode" in os.environ +if not fake_mode: + Silu = torch.jit.script(Silu) def update_kv_cache(kv, inference_params, layer_idx): diff --git a/internlm/model/ops/attention.py b/internlm/model/ops/attention.py index 9205652aa..06bf89eb2 100644 --- a/internlm/model/ops/attention.py +++ b/internlm/model/ops/attention.py @@ -7,6 +7,7 @@ """ import math +import os from typing import Callable, Tuple import torch @@ -66,6 +67,11 @@ except (ModuleNotFoundError, ImportError): gpu_flash_attn_impl = False +# from internlm.simulator.ops.attention import FakeFlashAttention + +fake_mode = "fake_mode" in os.environ + + internlm_accelerator = get_accelerator() device_backend = internlm_accelerator.get_accelerator_backend() @@ -127,10 +133,15 @@ def _flash_varlen_kvpacked_attn( # compatible data format: [1, packelen, 3, n_head, headim] q, kv = q.squeeze(dim=0), kv.squeeze(dim=0) + # if fake_mode: + # fa_func = FakeFlashAttention.apply + # else: + fa_func = _flash_varlen_kvpacked_func + # input_idxs: 0: q, 1: kv output = _flash_float32_compatibility_wrapper( (0, 1), - _flash_varlen_kvpacked_func, + fa_func, q, kv, cu_seqlens_q, @@ -167,10 +178,15 @@ def _flash_varlen_qkvsplited_attn( # compatible data format: [1, packelen, 3, n_head, headim] q, k, v = q.squeeze(dim=0), k.squeeze(dim=0), v.squeeze(dim=0) + # if fake_mode: + # fa_func = FakeFlashAttention.apply + # else: + fa_func = _flash_varlen_qkvsplited_func + # input_idxs: 0: q, 1: k, 2: v output = _flash_float32_compatibility_wrapper( (0, 1, 2), - _flash_varlen_qkvsplited_func, + fa_func, q, k, v, diff --git a/internlm/model/ops/cross_entropy.py b/internlm/model/ops/cross_entropy.py index f3fdccf96..153f77f74 100644 --- a/internlm/model/ops/cross_entropy.py +++ b/internlm/model/ops/cross_entropy.py @@ -25,6 +25,9 @@ logger = get_logger(__file__) internlm_accelerator = get_accelerator() +import os + +fake_mode = "fake_mode" in os.environ # TODO: ops是否需要实现更加统一的形式 def new_cross_entropy( @@ -34,6 +37,11 @@ def new_cross_entropy( parallel_output: bool = False, **kwargs, ): + if fake_mode: + kwargs.pop("inplace_backward", None) + return nn.CrossEntropyLoss( + ignore_index=ignore_index, reduction=reduction, label_smoothing=label_smoothing, **kwargs + ) if parallel_output: assert ( gpc.config.model.get("use_flash_attn", False) and flash_cross_entropy_impl diff --git a/internlm/model/ops/linear.py b/internlm/model/ops/linear.py index eeffddc03..da3eda5b4 100644 --- a/internlm/model/ops/linear.py +++ b/internlm/model/ops/linear.py @@ -6,6 +6,7 @@ This file implements support for the linear layer operators. """ +import os from typing import Optional, Tuple import torch @@ -13,6 +14,10 @@ from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import global_context as gpc +from internlm.simulator.ops.linear import ( + _fake_linear_bwdward_op, + _fake_linear_forward_op, +) try: from fused_dense_lib import linear_bias_wgrad as _flash_linear_backward_op @@ -23,6 +28,8 @@ internlm_accelerator = get_accelerator() +fake_mode = "fake_mode" in os.environ + def _select_ops_binding(dtype: torch.dtype, is_cuda: bool = True) -> None: dtype_eligible = dtype in (torch.float16, torch.bfloat16) or ( @@ -32,10 +39,14 @@ def _select_ops_binding(dtype: torch.dtype, is_cuda: bool = True) -> None: is_gpu_backend = internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU flash_attn_eligible = flash_attn_impl and dtype_eligible and is_cuda - if use_flash_attn and is_gpu_backend and flash_attn_eligible: - return _torch_linear_forward_op, _flash_linear_backward_op - else: + if fake_mode: + # return _fake_linear_forward_op, _fake_linear_bwdward_op return _torch_linear_forward_op, _linear_bias_wgrad_torch + else: + if use_flash_attn and is_gpu_backend and flash_attn_eligible: + return _torch_linear_forward_op, _flash_linear_backward_op + else: + return _torch_linear_forward_op, _linear_bias_wgrad_torch def _linear_bias_wgrad_torch(_input: torch.Tensor, grad_output: torch.Tensor, has_d_bias: bool): diff --git a/internlm/model/ops/norm.py b/internlm/model/ops/norm.py index 8ade10caa..2230436f1 100644 --- a/internlm/model/ops/norm.py +++ b/internlm/model/ops/norm.py @@ -1,14 +1,19 @@ # adopted from https://github.com/NVIDIA/apex/blob/master/apex/normalization/fused_layer_norm import numbers +import os import torch from torch.nn import init from torch.nn.parameter import Parameter from internlm.accelerator import AcceleratorType, get_accelerator +from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger +fake_mode = "fake_mode" in os.environ + + logger = get_logger(__file__) internlm_accelerator = get_accelerator() @@ -54,11 +59,11 @@ def __init__(self, normalized_shape, eps=1e-5): normalized_shape = (normalized_shape,) self.normalized_shape = torch.Size(normalized_shape) self.eps = eps - self.weight = Parameter(torch.empty(*normalized_shape)) + self.weight = Parameter(torch.empty(*normalized_shape, device=get_current_device())) self.reset_parameters() def forward(self, _input: torch.Tensor): - if apex_rmsnorm_impl: + if apex_rmsnorm_impl and not fake_mode: _norm_func = mixed_dtype_fused_rms_norm_affine else: _norm_func = manual_rms_norm @@ -73,8 +78,12 @@ def extra_repr(self): # TODO: Support deeplink in a more unified manner -RMSNorm = ( - MixedFusedRMSNorm - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.DIPU and deeplink_rmsnorm_impl - else _RMSNorm -) + +if fake_mode: + RMSNorm = _RMSNorm +else: + RMSNorm = ( + MixedFusedRMSNorm + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.DIPU and deeplink_rmsnorm_impl + else _RMSNorm + ) diff --git a/internlm/model/registry.py b/internlm/model/registry.py index e91a22551..9700c7224 100644 --- a/internlm/model/registry.py +++ b/internlm/model/registry.py @@ -6,7 +6,8 @@ from internlm.model.modeling_internlm import InternLM1 from internlm.model.modeling_internlm2 import InternLM2 from internlm.model.modeling_llama import Llama2 -from internlm.model.modeling_llava import Llava + +# from internlm.model.modeling_llava import Llava from internlm.model.modeling_moe import Internlm1MoE @@ -37,7 +38,7 @@ def register_module(self, module_name: str, func: Callable): AssertionError: Raises an AssertionError if the module has already been registered before. """ - assert module_name not in self._registry, f"{module_name} already registered in {self.name}" + # assert module_name not in self._registry, f"{module_name} already registered in {self.name}" self._registry[module_name] = func @@ -73,6 +74,7 @@ def has(self, module_name: str): model_initializer = Registry("model_initializer") +benchmark_initializer = Registry("benchmark_initializer") def register_model_initializer() -> None: @@ -80,4 +82,4 @@ def register_model_initializer() -> None: model_initializer.register_module("INTERNLM2_PUBLIC", InternLM2) model_initializer.register_module("LLAMA2", Llama2) model_initializer.register_module("INTERNLM_MoE", Internlm1MoE) - model_initializer.register_module("LLAVA", Llava) + # model_initializer.register_module("LLAVA", Llava) diff --git a/internlm/simulator/common.py b/internlm/simulator/common.py index 7c5e73003..eb7f9aee0 100644 --- a/internlm/simulator/common.py +++ b/internlm/simulator/common.py @@ -6,8 +6,7 @@ from torch.distributed import GroupMember -# TODO: 这里需要增加一个broadcast -class CommOp: +class CostType: ALL2ALL = "all2all" ALLREDUCE = "all_reduce" REDUCESCATTER = "reduce_scatter" @@ -32,8 +31,11 @@ class BW: A100_NVL = 250 * 1024**3 # 满速是 300 GB/s -BENCH_TYPE_LIST = [CommOp.ALL2ALL, CommOp.ALLREDUCE, CommOp.REDUCESCATTER, CommOp.ALLGATHER, CommOp.LINEAR] -# BENCH_TYPE_LIST = [CommOp.ALL2ALL, CommOp.ALLREDUCE, CommOp.REDUCESCATTER, CommOp.ALLGATHER, CommOp.LINEAR] +BENCH_TYPE_LIST = [CostType.ALL2ALL, CostType.ALLREDUCE, CostType.REDUCESCATTER, CostType.ALLGATHER, CostType.LINEAR] +# BENCH_TYPE_LIST = [CostType.ALL2ALL, CostType.ALLREDUCE, CostType.REDUCESCATTER, CostType.ALLGATHER, CostType.LINEAR] + + +POSITIVE_INFINITY = 1e12 K = 1024 @@ -47,7 +49,8 @@ class BW: _75GB = 75 * GB _100GB = 100 * GB -GLOBAL_BYTE_SIZES_LIST = [512 * KB, 1 * MB, 4 * MB, 64 * MB, 128 * MB, 256 * MB, 512 * MB, 1 * GB, 2 * GB, 4 * GB] +GLOBAL_BYTE_SIZES_LIST = [1 * KB, 512 * KB, 1 * MB, 4 * MB, 32 * MB, 64 * MB, 128 * MB, 256 * MB, 512 * MB, 1 * GB] +# GLOBAL_BYTE_SIZES_LIST = [64 * MB, 128 * MB] # GLOBAL_BYTE_SIZES_LIST = [512 * KB, 1 * MB, 4 * MB] # , 64 * MB, 128 * MB, 256 * MB] GLOBAL_ELEM_SIZES_LIST = [dsize // 2 for dsize in GLOBAL_BYTE_SIZES_LIST] WORLD_SIZE_LIST = [2, 4, 8, 16, 32, 64, 128] @@ -56,21 +59,29 @@ class BW: OUT_OF_MEM_LATENCY = 10**9 -def cal_block_p_elem(h, multiple_of, mlp_ratio): +def cal_block_p_elem(h, q_head, kv_head, multiple_of, mlp_ratio): norm1_p_elem = h norm2_p_elem = h - MHA = h * 3 * h + + Wq = h * h + Wkv = 2 * h * (h * kv_head // q_head) + out_proj = h * h mlp_hidden_features = multiple_of * ((int(h * mlp_ratio) + multiple_of - 1) // multiple_of) mlp_p_elem = (h * mlp_hidden_features) * 3 dropout1 = 0 dropout2 = 0 - return norm1_p_elem + norm2_p_elem + MHA + out_proj + mlp_p_elem + dropout1 + dropout2 + + # 简化公式: + # 2 * h + 2* h * h + 2 * h * (h * kv_head // q_head) + (h * H_MLP) * 3 + # H * (2 + 2 * H + 2 * H * KV_head // Q_head + H_MLP * 3) + + return norm1_p_elem + norm2_p_elem + Wq + Wkv + out_proj + mlp_p_elem + dropout1 + dropout2 -def cal_model_p_elem(h, l, vocab_size, multiple_of, mlp_ratio): +def cal_model_p_elem(h, q_head, kv_head, l, vocab_size, mlp_ratio, multiple_of=256): embedding_p_elem = vocab_size * h - block_p_elem = l * cal_block_p_elem(h, multiple_of, mlp_ratio) + block_p_elem = l * cal_block_p_elem(h, q_head, kv_head, multiple_of, mlp_ratio) norm_p_elem = h head_p_elem = vocab_size * h return embedding_p_elem + block_p_elem + norm_p_elem + head_p_elem @@ -147,10 +158,14 @@ def get_world_size(): return 1 -def sync_all(): - torch.cuda.synchronize() +def sync_all(group=None): if dist.is_initialized(): - dist.barrier() + dist.barrier(group) + torch.cuda.synchronize() + + +def sync_local(): + torch.cuda.synchronize() def get_bw(comm_op, size, duration, args): diff --git a/internlm/simulator/formulas/__init__.py b/internlm/simulator/formulas/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/internlm/simulator/formulas/algo.py b/internlm/simulator/formulas/algo.py new file mode 100644 index 000000000..06e2b61de --- /dev/null +++ b/internlm/simulator/formulas/algo.py @@ -0,0 +1,386 @@ +from internlm.simulator.formulas.overlap import TransformerOverlap +from internlm.simulator.common import AlgoType, CostType + + +class BaseAlgo: + def __init__( + self, + config: dict, + cost_data: object, + model_config: dict, + X: list = [], + C: list = [], + A: list = [], + num_strategies: int = 0, + ) -> None: + + self._world_size = config["world_size"] + self._global_batch_size = config["global_batch_size"] + self._sequence_length = config["sequence_length"] + self._model_size = config["model_size"] + self._grad_acc = config["grad_acc"] + self._SP = config["SP"] + self._micro_batch_size = config["micro_bs"] + self._vocab_size = config["vocab_size"] + self._dtype_size = 2 # the sizeof(model.dtype) + self._os_size_ratio = 2 if self._dtype_size == 2 else 1 # the sizeof(OS.P) + self._p_size = self._model_size * 10**9 # model size + self._l = model_config["l"] + self._h = model_config["h"] + self._a = model_config["a"] + self._mlp_ratio = model_config["mlp_ratio"] + self._multiple_of = model_config["multiple_of"] + self._cost_data = cost_data + self._num_strategies = num_strategies + + self.overlap_res = TransformerOverlap( + b=self._dtype_size, + s=self._sequence_length, + h=self._h, + # a=self._a, + num_layers=self._l, + dtype_size=self._dtype_size, + mlp_ratio=self._mlp_ratio, + multiple_of=self._multiple_of, + vocab_size=self._vocab_size, + cost_data=self._cost_data, + ) + + # the combination of parallel strategies + # X[i][j]: i->P,G,OS; j->2,4,6,... + self.X = X + # the communication cost + # C[i][j] means the communication cost of stratege X[i][j] + self.C = C + # the memory cost + # A[i][j] means the memory cost of stratege X[i][j] + self.A = A + + def _lookup_comm_cost(self, type: CostType, world_size, complexity): + return self._cost_data[type].predict(world_size, complexity) + + def get_XCA(self): + return self.X, self.C, self.A + + def set_memory_threshold(self): + """set the memory threshold""" + pass + + def get_comm_cost(self): + """get the communication cost""" + pass + + def get_mem_cost(self): + """get the memory cost""" + pass + + +class ISP(BaseAlgo): + def __init__( + self, config: dict, cost_data: object, model_config: dict, X: dict, C: dict, A: dict, num_strategies: int + ) -> None: + super().__init__(config, cost_data, model_config, X, C, A, num_strategies) + self.algo_type = AlgoType.ISP + + def set_memory_threshold(self): + self._activation = ( + self._dtype_size + * self._micro_batch_size + * self._sequence_length + * self._h + * (34 + (5 * self._a * self._sequence_length / self._h)) + / self._SP + ) + self._memory_threshold = 80 * (1024**3) - self._activation + if self._memory_threshold < 0: + print(f"!!!warning!!!: self._memory_threshold: {self._memory_threshold} < 0") + print(f"activation: {self._activation:.4f} GB") + return self._memory_threshold + + def _get_os_comm_cost(self, comm_range): + if comm_range <= 1: + return 0 + comm_cost = self._dtype_size * self._p_size + return self._lookup_comm_cost(CostType.ALLGATHER, comm_range, comm_cost) # TODO: Should be broadcast + + def _comm_cost(self, i: int, j: int): + """ + Get communication cost. + + Args: + i (int): p (i==0), g (i==1), os (i==2) + j (int): node count + + Returns: + float: communication cost + + commu cost = fwd + bwd + optimizer + + fwd = sp + wp + bwd = sp + wp + optimizer = zp + + 其中 wp的通信可以overlap + """ + # self._SP_comm = self._get_sp_comm_cost(self._SP) + + if j != 0: + comm_range = j * 8 + else: + comm_range = 1 # no comm cost + + # 算overlap的通信开销 + overlap_cost = self.overlap_res._get_overlap(comm_range, self._SP, self.algo_type) + + # 算os的通信开销 + if comm_range == 1: + os_comm_cost = 0 + else: + os_comm_cost = self._get_os_comm_cost(comm_range) + + # 总的通信开销 + comm_cost = os_comm_cost + overlap_cost + + return comm_cost + + def get_comm_cost(self): + for i in range(3): + for j in range(self._num_strategies): + # TODO:这里需要支持更多的切分策略 + if j != 1 and j % 2 != 0: # 节点数为奇数的时候 + self.C[i][j] = self.C[i][j - 1] * 1.2 + else: # 节点数为偶数 + self.C[i][j] = self._comm_cost(i, j) + + def _mem_cost(self, i: int, j: int): + if i == 0: + if j == 0: + # 不切P + return self._dtype_size * self._model_size + # 对P切j*8份 + return self._dtype_size * self._model_size / (j * 8) + elif i == 1: + if j == 0: + # 不切G + return self._dtype_size * self._model_size + # 对G切j*8份 + return self._dtype_size * self._model_size / (j * 8) + else: + if j == 0: + # 不切OS + return self._dtype_size * self._os_size_ratio * 3 * self._model_size + # 对OS切j*8份 + return self._dtype_size * self._os_size_ratio * 3 * self._model_size / (j * 8) + + def get_mem_cost(self): + for i in range(3): + for j in range(self._num_strategies): + if j != 1 and j % 2 != 0: + self.A[i][j] = self.A[i][j - 1] * 0.8 + else: + self.A[i][j] = self._mem_cost(i, j) + + +class MSP(BaseAlgo): + def __init__( + self, config: dict, cost_data: object, model_config: dict, X: dict, C: dict, A: dict, num_strategies: int + ) -> None: + super().__init__(config, cost_data, model_config, X, C, A, num_strategies) + self.algo_type = AlgoType.MSP + + def set_memory_threshold(self): + self._activation = ( + self._dtype_size + * self._micro_batch_size + * self._sequence_length + * self._h + * (4 + 30 / self._SP + (5 * self._a * self._sequence_length / self._h / self._SP)) + ) + self._memory_threshold = 80 * (1024**3) - self._activation + if self._memory_threshold < 0: + print(f"!!!warning!!!: self._memory_threshold: {self._memory_threshold} < 0") + print(f"activation: {self._activation:.4f} GB") + return self._memory_threshold + + def _get_os_comm_cost(self, comm_range): + if comm_range <= 1: + return 0 + comm_cost = self._dtype_size * self._p_size + return self._lookup_comm_cost(CostType.ALLGATHER, comm_range, comm_cost) # TODO: Should be broadcast + + def _comm_cost(self, i: int, j: int): + """ + Get communication cost. + + Args: + i (int): p (i==0), g (i==1), os (i==2) + j (int): node count + + Returns: + float: communication cost + + commu cost = fwd + bwd + optimizer + + fwd = sp + wp + bwd = sp + wp + optimizer = zp + + 其中 wp的通信可以overlap + """ + # self._SP_comm = self._get_sp_comm_cost(self._SP) + + if j != 0: + comm_range = j * 8 + else: + comm_range = 1 # no comm cost + + # 算overlap的通信开销 + overlap_cost = self.overlap_res._get_overlap(comm_range, self._SP, self.algo_type) + + # 算os的通信开销 + if comm_range == 1: + os_comm_cost = 0 + else: + os_comm_cost = self._get_os_comm_cost(comm_range) + + # 总的通信开销 + comm_cost = os_comm_cost + overlap_cost + + return comm_cost + + def get_comm_cost(self): + for i in range(3): + for j in range(self._num_strategies): + # TODO:这里需要支持更多的切分策略 + if j != 1 and j % 2 != 0: # 节点数为奇数的时候 + self.C[i][j] = self.C[i][j - 1] * 1.2 + else: # 节点数为偶数 + self.C[i][j] = self._comm_cost(i, j) + + def _mem_cost(self, i: int, j: int): + if i == 0: + if j == 0: + # 不切P + return self._dtype_size * self._model_size + # 对P切j*8份 + return self._dtype_size * self._model_size / (j * 8) + elif i == 1: + if j == 0: + # 不切G + return self._dtype_size * self._model_size + # 对G切j*8份 + return self._dtype_size * self._model_size / (j * 8) + else: + if j == 0: + # 不切OS + return self._dtype_size * self._os_size_ratio * 3 * self._model_size + # 对OS切j*8份 + return self._dtype_size * self._os_size_ratio * 3 * self._model_size / (j * 8) + + def get_mem_cost(self): + for i in range(3): + for j in range(self._num_strategies): + if j != 1 and j % 2 != 0: + self.A[i][j] = self.A[i][j - 1] * 0.8 + else: + self.A[i][j] = self._mem_cost(i, j) + + +class FSP(BaseAlgo): + def __init__( + self, config: dict, cost_data: object, model_config: dict, X: dict, C: dict, A: dict, num_strategies: int + ) -> None: + super().__init__(config, cost_data, model_config, X, C, A, num_strategies) + self.algo_type = AlgoType.FSP + + def set_memory_threshold(self): + self._activation = ( + self._dtype_size + * self._micro_batch_size + * self._sequence_length + * self._h + * (34 + (5 * self._a * self._sequence_length / self._h)) + / self._SP + ) + self._memory_threshold = 80 * (1024**3) - self._activation + if self._memory_threshold < 0: + print(f"!!!warning!!!: self._memory_threshold: {self._memory_threshold} < 0") + print(f"activation: {self._activation:.4f} GB") + return self._memory_threshold + + def _comm_cost(self, i: int, j: int): + """ + Get communication cost. + + Args: + i (int): p (i==0), g (i==1), os (i==2) + j (int): node count + + Returns: + float: communication cost + + commu cost = fwd + bwd + optimizer + + fwd = sp + wp + bwd = sp + wp + optimizer = zp + + 其中 wp的通信可以overlap + """ + # self._SP_comm = self._get_sp_comm_cost(self._SP) + + if j != 0: + comm_range = j * 8 + else: + comm_range = 1 # no comm cost + + # 算overlap的通信开销 + overlap_cost = self.overlap_res._get_overlap(comm_range, self._SP, self.algo_type) + + # 算os的通信开销 + if comm_range == 1: + os_comm_cost = 0 + else: + os_comm_cost = self._get_os_comm_cost(comm_range) + + # 总的通信开销 + comm_cost = os_comm_cost + overlap_cost + + return comm_cost + + def get_comm_cost(self): + for i in range(3): + for j in range(self._num_strategies): + # TODO:这里需要支持更多的切分策略 + if j != 1 and j % 2 != 0: # 节点数为奇数的时候 + self.C[i][j] = self.C[i][j - 1] * 1.2 + else: # 节点数为偶数 + self.C[i][j] = self._comm_cost(i, j) + + def _mem_cost(self, i: int, j: int): + if i == 0: + if j == 0: + # 不切P + return self._dtype_size * self._model_size + # 对P切j*8份 + return self._dtype_size * self._model_size / (j * 8) + elif i == 1: + if j == 0: + # 不切G + return self._dtype_size * self._model_size + # 对G切j*8份 + return self._dtype_size * self._model_size / (j * 8) + else: + if j == 0: + # 不切OS + return self._dtype_size * self._os_size_ratio * 3 * self._model_size + # 对OS切j*8份 + return self._dtype_size * self._os_size_ratio * 3 * self._model_size / (j * 8) + + def get_mem_cost(self): + for i in range(3): + for j in range(self._num_strategies): + if j != 1 and j % 2 != 0: + self.A[i][j] = self.A[i][j - 1] * 0.8 + else: + self.A[i][j] = self._mem_cost(i, j) diff --git a/internlm/simulator/formulas/comm.py b/internlm/simulator/formulas/comm.py new file mode 100644 index 000000000..a5435d3af --- /dev/null +++ b/internlm/simulator/formulas/comm.py @@ -0,0 +1,271 @@ +from internlm.simulator.common import AlgoType, CostType +from internlm.simulator.profiler.perf_comm import ( + allgather, + allreduce, + alltoall, + get_comm_cost, + reducescatter, +) +# from internlm.simulator.utils import CommPredict + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc + + +class TransformerCommunication: + def __init__( + self, + b, + s, + h, + vocab_size, + mlp_ratio, + multiple_of, + dtype_size, + ckpt=0, + wdp_size=1, + ): + self.b = b # Batch size + self.s = s # Sequence length + self.h = h # Hidden size + + self.qkv_communication_latency = 0 + self.post_attention_communication_latency = 0 + self.first_linear_communication_latency = 0 + self.second_linear_communication_latency = 0 + self.attention_all_to_all_communication_latency = 0 + + self.mlp_ratio = mlp_ratio + self.multiple_of = multiple_of + self.dtype_size = dtype_size + self.mlp_hidden_size = self.multiple_of * ( + (int(self.h * self.mlp_ratio) + self.multiple_of - 1) // self.multiple_of + ) + + self.ckpt = ckpt # activation checkpoint + + # self.toal_comm = self.communication_isp() + + def communication_isp(self): + """ + ckpt: means the activation checkpoint, {0 or 1} + + sp communication: + + comm(sp) = comm(forward, sp) + comm(backward, sp) + + comm(forward, sp) = 4 * comm(all2all, s/sp, b, h) * (ckpt + 1) + + comm(backward, sp) = 4 * comm(all2all, s/sp, b, h) + + wp communication: (In our implementation, the wp communication of ckpt==1 is the same as ckpt==0) + + comm(wp) = comm(forwad, wp) + comm(backward, wp) + + comm(forward, wp) = comm(all_gather, (wqkv, wo, mlp)) + + comm(backward, wp) = comm(all_gather, (wqkv, wo, mlp)) + comm(reduceScatter, (wqkv, wo, mlp)) + + wdp communication: (actually wdp communication should be included in the optimizer communication) + """ + + self.wp_scale = gpc.get_world_size(ParallelMode.WEIGHT) + self.sp_scale = gpc.get_world_size(ParallelMode.TENSOR) + + # wp communication + qkv_wp_volume = 3 * self.dtype_size * self.h**2 + wo_wp_volume = self.dtype_size * self.h**2 + mlp_w1_volume = self.dtype_size * self.h * self.mlp_hidden_size + + qkv_latency = 2 * allgather(qkv_wp_volume, ParallelMode.WEIGHT, comm_nums=2) + reducescatter( + qkv_wp_volume, ParallelMode.WEIGHT + ) + wo_latency = 2 * allgather(wo_wp_volume, ParallelMode.WEIGHT, comm_nums=2) + reducescatter(wo_wp_volume, ParallelMode.WEIGHT) + mlp_w1_latency = 2 * allgather(mlp_w1_volume, ParallelMode.WEIGHT, comm_nums=2) + reducescatter( + mlp_w1_volume, ParallelMode.WEIGHT + ) + mlp_w2_latency = mlp_w1_latency + mlp_w3_latency = mlp_w1_latency + + # sp communication + all2all_volume = self.s / self.sp_scale * self.b * self.h * self.dtype_size + + sp_all2all_comm_time = 4 * (self.ckpt + 1) + 4 + all2all_latency = alltoall(all2all_volume, ParallelMode.TENSOR, comm_nums=sp_all2all_comm_time) + + wp_comm_latency = qkv_latency + wo_latency + mlp_w1_latency + mlp_w2_latency + mlp_w3_latency + sp_comm_latency = sp_all2all_comm_time * all2all_latency # forward + backward + + # wdp communication + # wdp_volume = self.model_para / gpc.get_world_size(ParallelMode.WEIGHT_DATA) # TODO: 这个通信量是否合理? + # wdp_latency = allreduce(wdp_volume, ParallelMode.WEIGHT_DATA) + + return wp_comm_latency, sp_comm_latency + + def communication_msp(self): + """ + ckpt: means the activation checkpoint, {0 or 1} + + sp communication: + + comm(sp) = comm(forward, sp) + comm(backward, sp) + + comm(forward, sp) = (2 * comm(all_gather, s, b, h) + 2 * comm(reduceScatter, s, b, h)) * (ckpt + 1) + + comm(backward, sp) = 2 * comm(reduceScatter, s, b, h) + 2 * comm(all_gather, s, b, h) + + wp communication: + + comm(wp) = comm(forwad, wp) + comm(backward, wp) + + comm(forward, wp) = comm(all_gather, (wqkv, wo, mlp)) + + comm(backward, wp) = comm(all_gather, (wqkv, wo, mlp)) + comm(reduceScatter, (wqkv, wo, mlp)) + + wdp communication: (actually wdp communication should be included in the optimizer communication) + """ + self.wp_scale = gpc.get_world_size(ParallelMode.WEIGHT) + self.sp_scale = gpc.get_world_size(ParallelMode.TENSOR) + + # compute sp communication + # all_gather and reduceScatter have the same commu volume + # the communication volume in backward is equal to the forward + qkv_sp_volume = self.s * self.b * self.h * self.dtype_size # the forward all-gather + wo_sp_volume = self.s * self.b * self.h * self.dtype_size # the forward reduceScatter + mlp_w1_sp_volume = qkv_sp_volume # the forward all-gather + mlp_w2_sp_volume = self.s * self.b * self.h * self.dtype_size # the forward reduceScatter + + # compute the sp latency (forward + backward) + sp_forward = ( + allgather(qkv_sp_volume, ParallelMode.TENSOR) + + reducescatter(wo_sp_volume, ParallelMode.TENSOR) + + allgather(mlp_w1_sp_volume, ParallelMode.TENSOR) + + reducescatter(mlp_w2_sp_volume, ParallelMode.TENSOR) + ) + + sp_backward = ( + reducescatter(qkv_sp_volume, ParallelMode.TENSOR) + + allgather(wo_sp_volume, ParallelMode.TENSOR) + + reducescatter(mlp_w1_sp_volume, ParallelMode.TENSOR) + + allgather(mlp_w2_sp_volume, ParallelMode.TENSOR) + ) + + sp_forward = sp_forward * (self.ckpt + 1) + + sp_comm_latency = sp_forward + sp_backward + + # commpute wp communication + qkv_wp_volume = 3 * self.h * self.h / self.sp_scale * self.dtype_size + wo_wp_volume = self.h * self.h / self.sp_scale * self.dtype_size + + # w2 and w3 have the same volume as w1 + mlp_w1_wp_volume = self.h * self.mlp_hidden_size / self.sp_scale * self.dtype_size + + qkv_wp_latency = 2 * allgather(qkv_wp_volume, ParallelMode.WEIGHT, comm_nums=2) + reducescatter( + qkv_wp_volume, ParallelMode.WEIGHT + ) + wo_wp_latency = 2 * allgather(wo_wp_volume, ParallelMode.WEIGHT, comm_nums=2) + reducescatter( + wo_wp_volume, ParallelMode.WEIGHT + ) + mlp_w1_wp_latency = 2 * allgather(mlp_w1_wp_volume, ParallelMode.WEIGHT, comm_nums=2) + reducescatter( + mlp_w1_wp_volume, ParallelMode.WEIGHT + ) + mlp_w2_wp_latency = mlp_w1_wp_latency + + wp_comm_latency = qkv_wp_latency + wo_wp_latency + mlp_w1_wp_latency + mlp_w2_wp_latency + + # wdp communication + # wdp_volume = self.model_para // self.sp_scale // self.wp_scale # TODO: 这个通信量是否合理? + # wdp_latency = allreduce(wdp_volume, ParallelMode.WEIGHT_DATA) + + return wp_comm_latency, sp_comm_latency + + def communication_fsp(self): + """ + ckpt: means the activation checkpoint, {0 or 1} + + sp communication: + + comm(sp) = comm(forward, sp) + comm(backward, sp) + + comm(forward, sp) = (2 * comm(all_gather, s, b, h) + 2 * comm(reduceScatter, s, b, h)) * (ckpt + 1) + + comm(backward, sp) = 2 * comm(reduceScatter, s, b, h) + 4 * comm(all_gather, s, b, h) + + wp communication: + + comm(wp) = comm(forwad, wp) + comm(backward, wp) + + comm(forward, wp) = comm(all_gather, (wqkv, wo, mlp)) + + comm(backward, wp) = comm(all_gather, (wqkv, wo, mlp)) + comm(reduceScatter, (wqkv, wo, mlp)) + + wdp communication: (actually wdp communication should be included in the optimizer communication) + """ + + self.wp_scale = gpc.get_world_size(ParallelMode.WEIGHT) + self.sp_scale = gpc.get_world_size(ParallelMode.TENSOR) + + # compute sp communication + # all_gather and reduceScatter have the same commu volume + # the communication volume in backward is equal to the forward + qkv_sp_volume = self.s * self.b * self.h * self.dtype_size # the forward all-gather + wo_sp_volume = self.s * self.b * self.h * self.dtype_size # the forward reduceScatter + mlp_w1_sp_volume = qkv_sp_volume # the forward all-gather + mlp_w2_sp_volume = self.s * self.b * self.h * self.dtype_size # the forward reduceScatter + + # compute the sp latency (forward + backward) + sp_forward = ( + allgather(qkv_sp_volume, ParallelMode.TENSOR) + + reducescatter(wo_sp_volume, ParallelMode.TENSOR) + + allgather(mlp_w1_sp_volume, ParallelMode.TENSOR) + + reducescatter(mlp_w2_sp_volume, ParallelMode.TENSOR) + ) + + sp_backward = ( + allgather(qkv_sp_volume, ParallelMode.TENSOR) + + reducescatter(qkv_sp_volume, ParallelMode.TENSOR) + + allgather(wo_sp_volume, ParallelMode.TENSOR) + + allgather(mlp_w1_sp_volume, ParallelMode.TENSOR) + + reducescatter(mlp_w1_sp_volume, ParallelMode.TENSOR) + + allgather(mlp_w2_sp_volume, ParallelMode.TENSOR) + ) + + sp_forward = sp_forward * (self.ckpt + 1) + + sp_comm_latency = sp_forward + sp_backward + + # commpute wp communication + qkv_wp_volume = 3 * self.h * self.h / self.sp_scale * self.dtype_size + wo_wp_volume = self.h * self.h / self.sp_scale * self.dtype_size + + # w2 and w3 have the same volume as w1 + mlp_w1_wp_volume = self.h * self.mlp_hidden_size / self.sp_scale * self.dtype_size + + qkv_wp_latency = 2 * allgather(qkv_wp_volume, ParallelMode.WEIGHT, comm_nums=2) + reducescatter( + qkv_wp_volume, ParallelMode.WEIGHT + ) + wo_wp_latency = 2 * allgather(wo_wp_volume, ParallelMode.WEIGHT, comm_nums=2) + reducescatter( + wo_wp_volume, ParallelMode.WEIGHT + ) + mlp_w1_wp_latency = 2 * allgather(mlp_w1_wp_volume, ParallelMode.WEIGHT, comm_nums=2) + reducescatter( + mlp_w1_wp_volume, ParallelMode.WEIGHT + ) + mlp_w2_wp_latency = mlp_w1_wp_latency + + wp_comm_latency = qkv_wp_latency + wo_wp_latency + mlp_w1_wp_latency + mlp_w2_wp_latency + + # wdp communication + # wdp_volume = self.model_para // self.sp_scale // self.wp_scale # TODO: 这个通信量是否合理? + # wdp_latency = allreduce(wdp_volume, ParallelMode.WEIGHT_DATA) + + return wp_comm_latency, sp_comm_latency + + def communication(self, algo_type): + if algo_type == AlgoType.ISP: + return self.communication_isp() + elif algo_type == AlgoType.MSP: + return self.communication_msp() + elif algo_type == AlgoType.FSP: + return self.communication_fsp() + raise ValueError(f"Unkown algo_type: {algo_type}") diff --git a/internlm/simulator/formulas/comp.py b/internlm/simulator/formulas/comp.py new file mode 100644 index 000000000..04f26a226 --- /dev/null +++ b/internlm/simulator/formulas/comp.py @@ -0,0 +1,200 @@ +from internlm.simulator.common import AlgoType, CostType +from internlm.simulator.profiler.perf_comm import get_cal_cost, get_fa_cost + + +def get_linear_cost(complexity): + return get_cal_cost(CostType.LINEAR, complexity) # 转换成ms小数点保留两位 + + +def get_atten_cost_polynomial(complexity): + return get_cal_cost(CostType.LINEAR, complexity) + + +def get_atten_cost_predict(micro_bsz, seq_len, hidden_dim, q_head, kv_head, is_fwd): + """_summary_ + + Args: + micro_bsz (int): b + seq_len (int): seqlen, 注意这里是完整的seqlen + hidden_dim (int): 原始的head_dim + num_heads (int): 原始的num_heads + sp_tp (int): sp for isp, tp for msp/fsp + + Returns: + int: latency of fa, unit is second. + """ + predict = get_fa_cost( + micro_bsz=micro_bsz, + seqlen=seq_len, + hidden_size=hidden_dim, + q_head=q_head, + kv_head=kv_head, + dtype=2, + is_fwd=is_fwd, + ) + return predict + + +class TransformerComputation: + def __init__( + self, + a, + a_kv, + b, + s, + h, + vocab_size, + sp_scale, + dtype_size, + mlp_ratio, + multiple_of, + use_fa=True, + cost_data=None, + ckpt=0, + ): + self.a = a + self.a_kv = a_kv + self.b = b # Batch size + self.s = s # Sequence length + self.h = h # Hidden size + self.sp_scale = sp_scale + self.qkv_computation = 0 + self.qkt_computation = 0 + # self.score_v_computation = 0 + # self.post_attention_linear = 0 + # self.first_linear = 0 + # self.second_linear = 0 + # self.logits_computation = 0 + # self.attention_computation = 0 + # self.flash_attention_computation = 0 + # self.mlp_computation = 0 + self.vocab_size = vocab_size + self.dtype_size = dtype_size + self.mlp_ratio = mlp_ratio + self.multiple_of = multiple_of + self.mlp_hidden_size = self.multiple_of * ( + (int(self.h * self.mlp_ratio) + self.multiple_of - 1) // self.multiple_of + ) + self.ckpt = ckpt + self.use_fa = use_fa + + def _compute_embedding(self, scale): + """ + the head computation is the same as embedding computation. + msp and fsp share the same computation. + + scale: the scale factor. when the algo is isp, the scale is one; else the scale is self.sp_scale. + """ + volumn = self.dtype_size * self.b * self.s * self.vocab_size * self.h / scale + try: + latency = get_linear_cost(volumn) + except Exception: + #import pdb; pdb.set_trace() + return get_linear_cost(9895604649984) + return latency + + def _compute_linears(self): + """ + compute the latency for linears in one transformer layer, such as wqkv, wo, mlp + """ + + # wqkv + # ISP: (b, s/sp, h) * (h, 3h) + # MSP or FSP: (b, s, h) * (h, 3h/sp) + qkv_volumn = 3 * self.dtype_size * self.b * self.s * self.h * self.h / self.sp_scale + qkv_latency = get_linear_cost(qkv_volumn) + + # wo + # ISP: (b, s/sp, h) * (h, h) + # MSP or FSP: (b, s, h/sp) * (h/sp, h) + wo_volumn = self.dtype_size * self.b * self.s * self.h * self.h / self.sp_scale + wo_latency = get_linear_cost(wo_volumn) + + # mlp w1 + # ISP: (b, s/sp, h) * (h, mlp_h) + # MSP or FSP: (b, s, h) * (h, mlp_h/sp) + w1_volumn = self.dtype_size * self.b * self.s * self.h * self.mlp_hidden_size / self.sp_scale + w1_latency = get_linear_cost(w1_volumn) + + # mlp w2 + # ISP: (b, s/sp, h) * (h, mlp_h) + # MSP or FSP: (b, s, h/sp) * (h/sp, mlp_h) + w2_volumn = self.dtype_size * self.b * self.s * self.h * self.mlp_hidden_size / self.sp_scale + w2_latency = get_linear_cost(w2_volumn) + + # mlp w3 + # ISP: (b, s/sp, mlp_h) * (mlp_h, h) + # MSP or FSP: (b, s, mlp_h/sp) * (mlp_h/sp, h) + w3_volumn = self.dtype_size * self.b * self.s * self.h * self.mlp_hidden_size / self.sp_scale + w3_latency = get_linear_cost(w3_volumn) + + total_latency = qkv_latency + wo_latency + w1_latency + w2_latency + w3_latency + + return total_latency + + def _compute_attn(self, is_fwd): + """ + compute the latency for attention in one transformer layer + """ + if self.use_fa: + # 由于我们目前不支持搜索ring attn,所以这里我们sp/tp只切head数量 + a = self.a // self.sp_scale + a_kv = self.a_kv // self.sp_scale + + total_latency = get_atten_cost_predict(self.b, self.s, self.h, a, a_kv, is_fwd) + else: + # QK^T matrix multiplication + # (b, s, h/sp) * (b, s, h/sp)^T + qkt_volume = self.dtype_size * self.b * self.s * self.s * self.h / self.sp_scale + qkt_latency = get_atten_cost_polynomial(qkt_volume) + + # Score dot V + # (b, s, s) * (b, s, h/sp) + score_v_volume = self.dtype_size * self.b * self.s * self.s * self.h / self.sp_scale + score_v_latency = get_atten_cost_polynomial(score_v_volume) + + total_latency = qkt_latency + score_v_latency + + return total_latency + + def _computation(self, embedding_scale): + # TODO: the following computation exclude norm computation + """ + ckpt: activation checkpoint {0 or 1} + + the computation latency for each transformer layer + + compu(msp) = compu(forward) + compu(backward) + + compu(backward) = 2 * compu(forward) + + compu(forward) = (compu(linear, (wqkv, wo, mlp)) + compu(attn)) * (ckpt + 1) + """ + + # compute the latency for embedding and head + embedding_latency = self._compute_embedding(embedding_scale) + head_latency = embedding_latency + + # compute the latency for linears + linears_latency = self._compute_linears() * (self.ckpt + 1) + self._compute_linears() * 2 + + # compute the latency for attention + attn_latency = self._compute_attn(is_fwd=True) * (self.ckpt + 1) + self._compute_attn(is_fwd=False) + + # the computation for each transformer layer + # transformer_latency = linears_latency + attn_latency + + return linears_latency, attn_latency + + def total_computation(self, algo_type): + if algo_type == AlgoType.ISP: + # return self.total_computation_isp() + return self._computation(1.0) + else: + return self._computation(self.sp_scale) + + +# Example usage +# Assuming values for b (batch size), s (sequence length), h (hidden size), num_layers, and vocab_size +# b, s, h, num_layers, vocab_size = 1, 16384, 4096, 32, 10000 +# transformer_comp = TransformerComputation(b, s, h,num_layers,vocab_size) diff --git a/internlm/simulator/formulas/mem.py b/internlm/simulator/formulas/mem.py new file mode 100644 index 000000000..bc7859a9d --- /dev/null +++ b/internlm/simulator/formulas/mem.py @@ -0,0 +1,227 @@ +from internlm.simulator.common import AlgoType + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc + + +# 所有公式计算都变成无状态的,计算结果完全由外部传入的参数决定,内部不进行诸如切pp这样的操作 +def get_isp_memory_threshold( + dtype_size: int, + micro_batch_size: int, + sequence_length: int, + hidden_dim: int, + use_fa: int, + head_num: int, + layer_num: int, + activation_ckpt: int, + sp_size: int, +): + """ + Args: + dtype_size (int): bf16=2, fp32=4 + micro_batch_size (int): + sequence_length (int): + hidden_dim (int): + use_fa (int): 0 or 1 + head_num (int): + layer_num (int): + activation_ckpt (int): 0 or 1 + + Returns: + int: activation memory usage. + """ + # TODO: ht mark pp情况下,rank0的激活值会累积pp个micro_bsz,所以这里是不是还得再乘一个pp_size? + # TODO: wgt mark 应该不需要,每个pp拿到L/pp个layer,又最多保存pp个micro_num的激活, + # rank0相当于还是L份layer的激活 + + if activation_ckpt: + layer_num = 0 + + """ (0) dropout input: 2bsh + (1) attention: 2bsh (qkv input) + 3*2bsh(attn input) + 2bsh(attn_out_padded) + 2bsh(out project input)-> 12bsh + (2) dropout input: 2bsh + (3) MLP: 2 * (1 + 8/3 + 8/3 + 8/3 +8/3)* bsh = 70 * bsh / 3 + w1_o = self.w1(x) # 8/3 + w2_o = self.w2(x) # 8/3 + w3_in = Silu(w1_o, w2_o) # 8/3 + 8/3 + out = self.w3(w3_in) # 8/3 + total: 16bsh + 70 * bsh / 3 = 118 * bsh /3 + """ + activation = ( + dtype_size + * micro_batch_size + * sequence_length + * hidden_dim + * (118 / 3 + (1 - use_fa) * (5 * head_num * sequence_length / hidden_dim)) + / sp_size + ) * layer_num + return activation + + +def get_msp_memory_threshold( + dtype_size: int, + micro_batch_size: int, + sequence_length: int, + hidden_dim: int, + use_fa: int, + head_num: int, + layer_num: int, + activation_ckpt: int, + sp_size: int, +): + if activation_ckpt: + layer_num = 0 + + activation = ( + dtype_size + * micro_batch_size + * sequence_length + * hidden_dim + * ( + 12 / 3 + ((118 - 12) / 3) / sp_size + (1 - use_fa) * (5 * head_num * sequence_length / hidden_dim / sp_size) + ) # TODO: check + ) * layer_num + return activation + + +def get_fsp_memory_threshold( + dtype_size: int, + micro_batch_size: int, + sequence_length: int, + hidden_dim: int, + use_fa: int, + head_num: int, + layer_num: int, + activation_ckpt: int, + sp_size: int, +): + if activation_ckpt: + layer_num = 0 + + activation = ( + dtype_size + * micro_batch_size + * sequence_length + * hidden_dim + * (118 / 3 + (1 - use_fa) * (5 * head_num * sequence_length / hidden_dim)) + / sp_size + ) * layer_num # 显存阈值根据pp0来计算,需要micro_num >= pp,stage_0需要保存 pp 份才成立 + return activation + + +# tp=1,sp=1 +# seql_len=512, hidden_dim 4096, no tp,sp +# embed shape: torch.Size([1, 4096, 512]) 1 +# block shape: torch.Size([4096, 512]) +# head shape: torch.Size([4096, 103168]) + +# tp=4,sp=1 +# seql_len=512, hidden_dim 4096 +# embed shape: torch.Size([1, 4096, 512]) +# block shape: torch.Size([4096, 512]) +# head shape: torch.Size([4096, 25792]) + +# tp=4,sp=4 +# embed shape: torch.Size([1, 1024, 512]) +# block shape: torch.Size([1024, 512]) +# head shape: torch.Size([4096, 25792]) + +# WP不省激活,因此不受wp影响 +# 这里只计算一层的激活,不受pp影响 + + +# embedding output +def get_embedding_output_mm(micro_bsz, seq_len, hidden_dim, sp, algo, dtype_size): + # [b, hidden_dim, seql_len] + # sp的world_size是从tp的pg中获得的 + sp_worldsize = gpc.get_world_size(ParallelMode.TENSOR) + # assert sp == sp_worldsize, f"sp={sp}, sp_world_size:{sp_worldsize}" + assert sp_worldsize == sp, f"sp={sp}, sp_world_size:{sp_worldsize}, algo: {algo}" + return dtype_size * micro_bsz * seq_len * hidden_dim // sp + + +# block output +def get_block_output_mm(micro_bsz, seq_len, hidden_dim, sp, dtype_size): + # [hidden_dim, packed_length] + sp_worldsize = gpc.get_world_size(ParallelMode.TENSOR) + assert sp == sp_worldsize, f"sp={sp}, sp_world_size:{sp_worldsize}" + return dtype_size * micro_bsz * seq_len * hidden_dim // sp + + +# norm output +def get_norm_output_mm(micro_bsz, seq_len, hidden_dim, sp, dtype_size): + # [hidden_dim, packed_length] + sp_worldsize = gpc.get_world_size(ParallelMode.TENSOR) + assert sp == sp_worldsize, f"sp={sp}, sp_world_size:{sp_worldsize}" + return 4 * micro_bsz * seq_len * hidden_dim // sp # norm的输出是fp32的 + + +# head output +def get_head_output_mm(micro_bsz, seq_len, vocab_size, dtype_size): + # [seq_len, vocab_size] + return micro_bsz * dtype_size * seq_len * vocab_size // gpc.get_world_size(ParallelMode.TENSOR) + + +# head input +def get_head_input_mm(micro_bsz, seq_len, hidden_dim, dtype_size, tp_size, algo): + # [seq_len, vocab_size] + if algo in [AlgoType.ISP, AlgoType.FSP]: + return micro_bsz * dtype_size * seq_len * hidden_dim // tp_size + else: + return 0 + + +# rotary embedding sin/cos cache +def get_rotary_emb_sincos_cache_mm(seq_len, pp_size, hidden_dim, head_nums, layer_nums, dtype_size): + # [sin,cos] * dtype_size * pp切后的layer_nums * 不切的seq_len * head_dim // 2 + return 2 * dtype_size * (layer_nums // pp_size) * seq_len * (hidden_dim // head_nums) // 2 + + +def get_backward_mem_peak(seq_len, micro_bsz, dtype_size, vocab_size, tp_size, hidden_size): + # 这个函数是峰值位置 + head_input_grad = 2 * dtype_size * seq_len * micro_bsz * hidden_size # 512 MB (1份激活1份激活的梯度) + reduce_scatter_grad = head_input_grad / tp_size # 512 MB / 8 + head_weight_grad = dtype_size * hidden_size * vocab_size / tp_size # 100.b MB + return head_input_grad + reduce_scatter_grad + head_weight_grad + + +def get_memory_pool_mm(mlp_ratio, hidden_size, dtype_size): + mlp_hidden_size = int(hidden_size * mlp_ratio) + mlp_hidden_size = 256 * ((mlp_hidden_size + 256 - 1) // 256) + module_Wqkv = 3 * hidden_size * hidden_size * dtype_size + module_out_proj = hidden_size * hidden_size * dtype_size + module_w1 = mlp_hidden_size * hidden_size * dtype_size + module_w2 = mlp_hidden_size * hidden_size * dtype_size + module_w3 = mlp_hidden_size * hidden_size * dtype_size + prefetch_two_layers_weight = 2 * (module_Wqkv + module_out_proj + module_w1 + module_w2 + module_w3) + + return prefetch_two_layers_weight * 2 # all_gather + reduce_scatter approximately + + +def get_p2p_buffer_size(dtype_size, seq_len, sp_size, micro_bsz, hidden_dim): + return dtype_size * (seq_len // sp_size) * micro_bsz * hidden_dim + + +def get_block_threshold( + algo: AlgoType, + **kwargs, +): + """get_block_threshold 获得一层激活的显存占用 + 注意: + (1) seqlen一定是没有被sp切过的 + (2) 公式是基于fp16计算的, 所以传入的 dtype_size 要除以2 + Args: + dtype_size (int): 数据元素大小, 单位B + seq_len (int): 没有被切过的seq_len + + Returns: + float : 一个layer的显存占用, 单位B + """ + if algo == AlgoType.ISP: + return get_isp_memory_threshold(**kwargs) + elif algo == AlgoType.MSP: + return get_msp_memory_threshold(**kwargs) + elif algo == AlgoType.FSP: + return get_fsp_memory_threshold(**kwargs) + + assert ValueError(f"unknow algo: {algo}") diff --git a/internlm/simulator/formulas/overlap.py b/internlm/simulator/formulas/overlap.py new file mode 100644 index 000000000..dbd77c63a --- /dev/null +++ b/internlm/simulator/formulas/overlap.py @@ -0,0 +1,68 @@ +from internlm.simulator.formulas.comm import TransformerCommunication +from internlm.simulator.formulas.comp import TransformerComputation +from internlm.simulator.common import get_model_config + + +# 1. dtype 加入复杂度 +# 2. comm 没有乘以 laynum +# 3. atten 计算还没加 +# 4. mmeory check +# 5. 集成simulator +class TransformerOverlapOneLayer: + def __init__( + self, + micro_bsz, + seq_len, + vocab_size, + dtype_size, + sp_size, + pp_size, + world_size, + ckpt, + hidden_dim, + num_head, + num_kv_head, + mlp_ratio, + multiple_of, + ): + self.b = micro_bsz # Batch size + self.s = seq_len # Sequence length + self.vocab_size = vocab_size + self.sp_scale = sp_size + self.dtype_size = dtype_size + self.world_size = world_size + self.pp_size = pp_size + + self.h, self.a, self.a_kv, self.mlp_ratio, self.multiple_of = hidden_dim, num_head, num_kv_head, mlp_ratio, multiple_of + + self.ckpt = ckpt # the activation checkpoint + + def _get_overlap(self, algo_type): + # 一个transformer layer的通信时延 (forward + backward) + comm_wp, comm_sp = TransformerCommunication( + self.b, + self.s, + self.h, + self.vocab_size, + dtype_size=self.dtype_size, + mlp_ratio=self.mlp_ratio, + multiple_of=self.multiple_of, + ckpt=self.ckpt, + ).communication(algo_type) + + # 一个transformer layer的计算时延 (forward + backward) + comp_wp, comp_attn = TransformerComputation( + self.a, + self.a_kv, + self.b, + self.s, + self.h, + self.vocab_size, + dtype_size=self.dtype_size, + mlp_ratio=self.mlp_ratio, + multiple_of=self.multiple_of, + sp_scale=self.sp_scale, + ckpt=self.ckpt, + ).total_computation(algo_type) + + return comm_wp, comm_sp, comp_wp, comp_attn diff --git a/internlm/simulator/predict_cost_model.py b/internlm/simulator/predict_cost_model.py deleted file mode 100644 index bee64c773..000000000 --- a/internlm/simulator/predict_cost_model.py +++ /dev/null @@ -1,325 +0,0 @@ -import functools -import os -import pickle -from collections import OrderedDict -from copy import deepcopy - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -from scipy.interpolate import interp1d -from sklearn.linear_model import LinearRegression -from sklearn.metrics import r2_score -from sklearn.preprocessing import PolynomialFeatures - -from internlm.core.context import Config -from internlm.simulator.common import MB, OUT_OF_MEM_LATENCY, WORLD_SIZE_LIST, CommOp - -# import profiler.benchmark -# import scipy.interpolate -from internlm.simulator.profiler.benchmark.multi_head_attn import UnitMultiHeadAttn -from internlm.simulator.profiler.profiler import run_profile - - -class PolynomialModel: - def __init__(self, degree, data, name="unknown", segments=None) -> None: - """_summary_ - - Args: - degree (int): _description_ - data (dict): _description_ - segments (dict): _description_ - """ - self.name = name - self.degree = 3 # 多项式的度数 - self.poly_features = PolynomialFeatures(degree=degree, include_bias=False) # 准备多项式回归模型 - self.data = pd.DataFrame(data) # 转换为DataFrame - if segments is None: - segments = {"all": (0, float("inf"))} - print(segments, flush=True) - self.segments = OrderedDict(segments) - self.segment_scores = {seg: {} for seg in self.segments} # 用于存储拟合结果和评分 - self.model_fit = { - seg: {card: None for card in self.data["World_Size"].unique()} for seg in self.segments - } # 存储模型 - self.see_base_value() - self.build_model() - - def see_base_value(self): - # 可视化数据 - plt.figure(figsize=(12, 6)) - for card in self.data["World_Size"].unique(): - subset = self.data[self.data["World_Size"] == card] - plt.scatter(subset["Data_B"], subset["Latency_s"], label=f"{card} cards") - - plt.xlabel("Data Transferred (MB)") - plt.ylabel("Latency (ms)") - plt.title("Transferred Latency vs Data Transferred for Different Card Numbers") - plt.legend() - plt.xscale("log") - plt.grid(True) - plt.savefig(f"{self.name}.jpg") - plt.show() - print(self.data.head()) - - def build_model(self): - # 对每个分段和卡数的数据进行拟合 - plt.figure(figsize=(12, 6)) - for seg, (low, high) in self.segments.items(): - for card in self.data["World_Size"].unique(): - subset = self.data[ - (self.data["World_Size"] == card) & (self.data["Data_B"] >= low) & (self.data["Data_B"] < high) - ] - - # 如果该段中没有足够的数据点,则跳过 - if len(subset) < 2: - continue - - # 准备数据 - X = subset["Data_B"].values.reshape(-1, 1) - y = subset["Latency_s"].values - X_poly = self.poly_features.fit_transform(X) - - # 拟合模型 - model = LinearRegression() - model.fit(X_poly, y) - y_pred = model.predict(X_poly) - self.model_fit[seg][card] = model - - # 评估模型 - score = r2_score(y, y_pred) - self.segment_scores[seg][card] = score - - # 可视化拟合结果 - plt.scatter(X / MB, y, label=f"{card} cards") - plt.plot(X / MB, y_pred, label=f"{card} cards Fit") - - # 绘制图表 - plt.xlabel("Data Transferred (MB)") - plt.ylabel("Latency (ms)") - plt.title("Segmented Polynomial Regression Fit for Different Card Numbers") - plt.xscale("log") - plt.yscale("log") - plt.legend() - plt.grid(True) - plt.savefig(f"{self.name}_fit.jpg") - plt.show() - - def return_segments(self, x): - for key, value in self.segments.items(): - low, hight = value[0], value[1] - if x >= low and x < hight: - return key - assert ValueError, f"predict value:{x} out of range" - - def predict(self, world_size, complexity): - try: - model = self.model_fit[self.return_segments(complexity)][world_size] - X_pred = self.poly_features.fit_transform([[complexity]]) - Y_pred = model.predict(X_pred)[0] - return Y_pred - except Exception as e: - print(f"e: {e}", flush=True) - import pdb - - pdb.set_trace() - - -class SplineModel: - def __init__(self): - self._data_prefix = "data/cost_data" - self.spline_model_list = {} - self.data = {} - self.load_data() - self.build_model() - - def load_data(self): - for cost_data_file in os.listdir(self._data_prefix): - name, suffix = cost_data_file.split(".") - if suffix == "pickle": - with open(f"{self._data_prefix}/{cost_data_file}", "rb") as f: - self.data[name] = pickle.load(f) - - @staticmethod - def reformat_data_to_cost_model(total_results): - reformat_data = dict() - for world_size in total_results.keys(): - list_data = [] - for complexity in total_results[world_size].keys(): - for value in total_results[world_size][complexity]: - list_data.append([value["lat"], complexity]) # p data[2][524288][0]['lat'] - - # list_data.sort(key=functools.cmp_to_key(my_compare)) - data_list = list(map(list, zip(*list_data))) - reformat_data[world_size] = {"Data_B": data_list[1], "Latency_s": data_list[0]} - - return reformat_data - - def build_model(self): - # p data[2][524288][0]['lat'] - for cost_type, cost_data in self.data.items(): - if cost_type != CommOp.FLASH_ATTN: - try: - cost_data = SplineModel.reformat_data_to_cost_model(cost_data) - except TypeError as e: - print(f"e : {e}", flush=True) - import pdb - - pdb.set_trace() - - self.spline_model_list[cost_type] = {} - for world_size, data in cost_data.items(): - try: - x = data["Data_B"] - y = data["Latency_s"] - except KeyError as e: - print(f"e : {e}", flush=True) - import pdb - - pdb.set_trace() - self.spline_model_list[cost_type][world_size] = interp1d(x, y, kind="slinear") - # self.see_base_value(cost_type, world_size, x, y) - else: # fa我们直接查表,不预测 - self.spline_model_list[cost_type] = {} - self.spline_model_list[cost_type][1] = cost_data[1] - - def predict(self, cost_type, world_size, complexity): - return self.spline_model_list[cost_type][world_size](complexity) - - def predict_cost(self, cost_type: CommOp, complexity=0, world_size=1, **kwargs): - """predict computation cost - The cost of attention will use KV mapping, and the cost of linear will - use PolynomialModel. - - Args: - cost_type (CommOp): _description_ - complexity (int, optional): _description_. Defaults to 0. - - Returns: - float: op latency. - """ - if cost_type == CommOp.FLASH_ATTN: - try: - key = UnitMultiHeadAttn.gen_store_key(**kwargs) - return self.spline_model_list[cost_type][1][key][0]["lat"] - except KeyError as e: - raise KeyError(f"not found FA key: {key}") - else: - try: - if cost_type != CommOp.LINEAR and world_size == 1: - return 0 - else: - spline_model = self.spline_model_list[cost_type][world_size] - predict = spline_model(complexity) - except ValueError: - below_bounds, above_bounds = spline_model.x[0], spline_model.x[-1] - if complexity < below_bounds: - return spline_model(below_bounds) # 如果超过下界就返回下界 - if complexity > above_bounds: - lat = spline_model(above_bounds) - return lat * complexity / above_bounds # 如果超过上界就线性扩展 - raise ValueError(f"value error for cost_type:{cost_type}, complexity:{complexity}") - except KeyError as e: - print(f"e : {e}", flush=True) - import pdb - - pdb.set_trace() - else: - return predict - - -def my_compare(a, b): - world_size_a, complexity_a = a[0], a[2] - world_size_b, complexity_b = b[0], b[2] - # print(world_size_a, world_size_b, complexity_a, complexity_b) - - if world_size_a > world_size_b: - return True - elif world_size_a < world_size_b: - return False - else: - if complexity_a > complexity_b: - return True - elif complexity_a < complexity_b: - return False - else: - assert ValueError, f"a:{a}, b:{b}" - - -class GenCostModel: - def __init__(self, is_master=True, build_type_list=None) -> None: - self._master = is_master - self._profile_args = Config( - { - "trials": 10, - "warmups": 1, - } - ) - self.cost_data = None - self._data_prefix = "data/cost_data" - self.cost_kv_data = {} - self.build_type_list = build_type_list - - def _log(self, msg: str): - if self._master: - print(msg, flush=True) - - def build_cost_model_by_key_value(self): - if self.cost_data is None: - self.cost_data = OrderedDict() - for bench_type in self.build_type_list: - self._log(f"now test {bench_type}") - self.cost_kv_data[bench_type] = run_profile(self._profile_args, bench_type) - - def load_cost_model_by_key_value(self): - self.cost_data = OrderedDict() - for bench_type in self.build_type_list: - self._log(f"now load {bench_type}") - with open(f"./data/{bench_type}.pickle", "rb") as f: - self.cost_kv_data[bench_type] = pickle.load(f) - - def draw_pic(self, data, cost_type): - plt.figure(figsize=(12, 6)) - world_sizes = list(data.index) - for vol in list(data.columns): - plt.plot(world_sizes, data[vol].values, label=f"{vol/1024**2:.2f} MB") - - plt.xlabel("GPU nums") - plt.ylabel("Latency (s)") - plt.title(f"{cost_type}") - # plt.xscale("log") - # plt.yscale("log") - plt.legend() - plt.grid(True) - plt.savefig(f"./data/pics/{cost_type}.jpg") - plt.show() - - def dump_data(self): - # p data[2][524288][0]['lat'] - for bench_type, results in self.cost_kv_data.items(): - indexs, columns = [], None - tables = [] - if bench_type != CommOp.FLASH_ATTN: - for world_size, values in results.items(): - indexs.append(world_size) - one_col = [] - tmp_columns = [] - for vol, latency in values.items(): - tmp_columns.append(vol) - one_col.append(latency[0]["lat"]) - if columns is None: - columns = deepcopy(tmp_columns) - tables.append(one_col) - - # print(f"bench_type: {bench_type}", flush=True) - # print(f"index: {indexs}", flush=True) - # print(f"columns: {columns}", flush=True) - - df = pd.DataFrame(tables, columns=columns, index=indexs) - df.to_csv(f"./data/excel/{bench_type}.csv", index=False) - - if bench_type != CommOp.LINEAR: - self.draw_pic(df, bench_type) - - with open(f"{self._data_prefix}/{bench_type}.pickle", "wb") as f: - pickle.dump(results, f) diff --git a/internlm/simulator/profiler/benchmark/__init__.py b/internlm/simulator/profiler/benchmark/__init__.py index 4f2f0256c..0c9f1c06d 100644 --- a/internlm/simulator/profiler/benchmark/__init__.py +++ b/internlm/simulator/profiler/benchmark/__init__.py @@ -1,7 +1,27 @@ -from .all2all import * -from .all_gather import * -from .all_reduce import * -from .linear import * -from .multi_head_attn import * -from .reduce_scatter import * -from .broadcast import * +from internlm.model.registry import Registry, benchmark_initializer +from internlm.simulator.profiler.benchmark import ( + all2all, + all_gather, + all_reduce, + broadcast, + linear, + reduce_scatter, +) + +# from .all_gather import * +# from .all_reduce import * +# from .broadcast import * +# from .linear import * +# from .multi_head_attn import * +# from .reduce_scatter import * + + +def register_comm_pref_initializer() -> None: + benchmark_initializer.register_module(all2all.BENCH_TYPE, all2all.UnitBenchAll2ALL) + benchmark_initializer.register_module(all_gather.BENCH_TYPE, all_gather.UnitBenchAllGather) + benchmark_initializer.register_module(all_reduce.BENCH_TYPE, all_reduce.UnitBenchAllReduce) + benchmark_initializer.register_module(broadcast.BENCH_TYPE, broadcast.UnitBenchBroadcast) + benchmark_initializer.register_module(reduce_scatter.BENCH_TYPE, reduce_scatter.UnitBenchAllReduceScatter) + benchmark_initializer.register_module(linear.BENCH_TYPE, linear.UnitBenchLinear) + + # model_initializer.register_module("LLAVA", Llava) diff --git a/internlm/simulator/profiler/benchmark/all2all.py b/internlm/simulator/profiler/benchmark/all2all.py index b56e40a2a..c070586c5 100644 --- a/internlm/simulator/profiler/benchmark/all2all.py +++ b/internlm/simulator/profiler/benchmark/all2all.py @@ -3,6 +3,7 @@ from internlm.model.registry import benchmark_initializer from internlm.simulator.common import * +from internlm.utils.common import get_current_device from .base_benchmark import UnitBench @@ -13,125 +14,38 @@ class UnitBenchAll2ALL(UnitBench): test_loop = { "global_size": GLOBAL_ELEM_SIZES_LIST, - "world_size": WORLD_SIZE_LIST, # 7B, (13B, 20B), 30B, 65B, 123B "async_op": [False], # it is not work!! False, "dtype": [torch.bfloat16], } - def __init__(self, world_size, async_op, dtype, global_size=None, unit_size=None) -> None: + def __init__(self, async_op, dtype, group, global_size=None, unit_size=None) -> None: assert global_size is None or unit_size is None + world_size = dist.get_world_size(group) + assert world_size > 0, f"group: {group}" self.unit_size = unit_size if unit_size else global_size // world_size # elements_per_gpu self.world_size = world_size self.dtype = dtype self.async_op = async_op - self.group = sub_process_groups[str(world_size)] - self.do_it = dist.get_rank() in set(dist.get_process_group_ranks(self.group)) + self.group = group - if dist.get_world_size() < world_size: - self.input = None - self.output = None - else: - self.output = torch.ones(self.world_size, self.unit_size, dtype=self.dtype).to(f"cuda:{get_local_rank()}") - self.input = torch.ones(self.world_size, self.unit_size, dtype=self.dtype).to(f"cuda:{get_local_rank()}") + device = get_current_device() + + if self.group is not None: + self.output = torch.ones(self.world_size, self.unit_size, dtype=self.dtype, device=device) + self.input = torch.ones(self.world_size, self.unit_size, dtype=self.dtype, device=device) self.input_buffer_size = self.input.element_size() * self.input.numel() def run(self): - if self.output is None or not self.do_it: + if self.group is None: return handler = dist.all_to_all_single(self.output, self.input, async_op=self.async_op, group=self.group) if self.async_op: handler.wait() - def complexity(self): + def bw_complexity(self): return self.input_buffer_size - -if __name__ == "__main__": - # data = { - # "Latency_ms": [41.746, 62.982, 65.596, 101.968, 138.671, 159.773, 177.197, 190.415, 193.555, 194.056, 194.097, - # 193.776, 193.419, 193.679, 194.425, 194.462, 36.732, 55.592, 80.364, 100.85, 116.875, 133.242, - # 160.23, 178.519, 189.055, 193.55, 193.752, 193.717, 193.417, 193.686, 194.365, 194.416, 33.096, - # 48.456, 72.221, 97.357, 113.762, 125.266, 134.315, 164.453, 178.744, 187.352, 192.915, 193.512, - # 192.669, 193.47, 194.342, 194.218], - # "Cards": [64] * 16 + [128] * 16 + [256] * 16, - # "Data_MB": [512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, - # 8388608, 16777216] * 3 - # } - cards_8_lat = [ - 0.035442, - 0.038785, - 0.041076, - 0.063415, - 0.092584, - 0.151337, - 0.259346, - 0.482307, - 0.896747, - 1.737, - 3.255, - 6.431, - ] - cards_16_lat = [ - 0.086889, - 0.113204, - 0.177494, - 0.271461, - 0.45525, - 0.84743, - 1.641, - 3.103, - 6.125, - 12.177, - 24.724, - 49.03, - ] - cards_32_lat = [ - 0.102149, - 0.14717, - 0.230115, - 0.382689, - 0.681639, - 1.432, - 2.499, - 4.812, - 9.554, - 18.706, - 37.845, - 73.225, - ] - cards_64_lat = [ - 0.115658, - 0.16165, - 0.259298, - 0.43826, - 0.822096, - 1.591, - 2.967, - 5.703, - 11.148, - 22.108, - 41.188, - 98.423, - ] - assert len(cards_8_lat) == len(cards_16_lat) == len(cards_32_lat) == len(cards_64_lat) - samples = len(cards_8_lat) - data = { - "Latency_ms": cards_8_lat + cards_16_lat + cards_32_lat + cards_64_lat, - "Cards": [8] * samples + [16] * samples + [32] * samples + [64] * samples, - "Data_MB": [i * MB for i in [0.5, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]] * 4, - } - segments = { - "small": (64 * KB, 8 * MB), # 64KB - 8MB, degree =2 - "large": (8 * MB, 1 * GB), # 8MB - 1GB, degree=1 - } - - segments = { - "all": (64 * KB, 1 * GB), - } - - model = PolynomialModel(degree=2, data=data, segments=segments) - model.predict(35 * MB) - model.predict(1.2 * MB) - model.predict(678 * MB) + def algo_complexity(self): + return self.input_buffer_size diff --git a/internlm/simulator/profiler/benchmark/all_gather.py b/internlm/simulator/profiler/benchmark/all_gather.py index c677f69b4..60eb94ff1 100644 --- a/internlm/simulator/profiler/benchmark/all_gather.py +++ b/internlm/simulator/profiler/benchmark/all_gather.py @@ -3,6 +3,7 @@ from internlm.model.registry import benchmark_initializer from internlm.simulator.common import * +from internlm.utils.common import get_current_device from .base_benchmark import UnitBench @@ -13,36 +14,37 @@ class UnitBenchAllGather(UnitBench): test_loop = { "global_size": GLOBAL_ELEM_SIZES_LIST, - "world_size": WORLD_SIZE_LIST, # 7B, (13B, 20B), 30B, 65B, 123B "async_op": [False], # it is not work!! False, "dtype": [torch.bfloat16], } - def __init__(self, world_size, async_op, dtype, global_size=None, unit_size=None) -> None: + def __init__(self, async_op, dtype, group, global_size=None, unit_size=None) -> None: assert global_size is None or unit_size is None + world_size = dist.get_world_size(group) self.unit_size = unit_size if unit_size else global_size // world_size # elements_per_gpu self.world_size = world_size self.dtype = dtype self.async_op = async_op - self.group = sub_process_groups[str(world_size)] - self.do_it = dist.get_rank() in set(dist.get_process_group_ranks(self.group)) - - if dist.get_world_size() < world_size: - self.input = None - self.output = None - else: - self.output = torch.ones(self.world_size, self.unit_size, dtype=self.dtype).to(f"cuda:{get_local_rank()}") - self.input = torch.ones(self.unit_size, dtype=self.dtype).to(f"cuda:{get_local_rank()}") + self.group = group + + device = get_current_device() + + if self.group is not None: + self.output = torch.ones(self.world_size, self.unit_size, dtype=self.dtype, device=device) + self.input = torch.ones(self.unit_size, dtype=self.dtype, device=device) self.output_buffer_size = self.output.element_size() * self.output.numel() def run(self): - if self.output is None or not self.do_it: + if self.group is None: return handler = dist._all_gather_base(self.output, self.input, async_op=self.async_op, group=self.group) if self.async_op: handler.wait() - def complexity(self): + def bw_complexity(self): + return self.output_buffer_size + + def algo_complexity(self): return self.output_buffer_size diff --git a/internlm/simulator/profiler/benchmark/all_reduce.py b/internlm/simulator/profiler/benchmark/all_reduce.py index 972249b33..c4c48b264 100644 --- a/internlm/simulator/profiler/benchmark/all_reduce.py +++ b/internlm/simulator/profiler/benchmark/all_reduce.py @@ -3,6 +3,7 @@ from internlm.model.registry import benchmark_initializer from internlm.simulator.common import * +from internlm.utils.common import get_current_device from .base_benchmark import UnitBench @@ -13,34 +14,35 @@ class UnitBenchAllReduce(UnitBench): test_loop = { "global_size": GLOBAL_ELEM_SIZES_LIST, - "world_size": WORLD_SIZE_LIST, # 7B, (13B, 20B), 30B, 65B, 123B "async_op": [False], # it is not work!! False, "dtype": [torch.bfloat16], } - def __init__(self, world_size, async_op, dtype, global_size=None, unit_size=None) -> None: + def __init__(self, async_op, dtype, group, global_size=None, unit_size=None) -> None: assert global_size is None or unit_size is None + world_size = dist.get_world_size(group) self.unit_size = global_size // world_size self.world_size = world_size self.dtype = dtype self.async_op = async_op - self.group = sub_process_groups[str(world_size)] - self.do_it = dist.get_rank() in set(dist.get_process_group_ranks(self.group)) + self.group = group + device = get_current_device() - if dist.get_world_size() < world_size: - self.buffer = None - else: - self.buffer = torch.ones(self.world_size, self.unit_size, dtype=self.dtype).to(f"cuda:{get_local_rank()}") + if self.group is not None: + self.buffer = torch.ones(self.world_size, self.unit_size, dtype=self.dtype, device=device) self.input_buffer_size = self.buffer.element_size() * self.buffer.numel() def run(self): - if self.buffer is None or not self.do_it: + if self.group is None: return handler = dist.all_reduce(self.buffer, async_op=self.async_op, group=self.group) if self.async_op: handler.wait() - def complexity(self): + def bw_complexity(self): + return 2 * self.input_buffer_size + + def algo_complexity(self): return self.input_buffer_size diff --git a/internlm/simulator/profiler/benchmark/broadcast.py b/internlm/simulator/profiler/benchmark/broadcast.py index 464fcde1e..46ad4c0fa 100644 --- a/internlm/simulator/profiler/benchmark/broadcast.py +++ b/internlm/simulator/profiler/benchmark/broadcast.py @@ -3,6 +3,7 @@ from internlm.model.registry import benchmark_initializer from internlm.simulator.common import * +from internlm.utils.common import get_current_device from .base_benchmark import UnitBench @@ -13,34 +14,37 @@ class UnitBenchBroadcast(UnitBench): test_loop = { "global_size": GLOBAL_ELEM_SIZES_LIST, - "world_size": WORLD_SIZE_LIST, # 7B, (13B, 20B), 30B, 65B, 123B "async_op": [False], # it is not work!! False, "dtype": [torch.bfloat16], } - def __init__(self, world_size, async_op, dtype, global_size=None, unit_size=None) -> None: + def __init__(self, async_op, dtype, group, global_size=None, unit_size=None) -> None: assert global_size is None or unit_size is None + world_size = dist.get_world_size(group) self.unit_size = unit_size if unit_size else global_size // world_size # elements_per_gpu self.world_size = world_size self.dtype = dtype self.async_op = async_op - self.group = sub_process_groups[str(world_size)] - self.do_it = dist.get_rank() in set(dist.get_process_group_ranks(self.group)) + self.group = group + device = get_current_device() - if dist.get_world_size() < world_size: - self.output = None - else: - self.output = torch.ones(self.world_size, self.unit_size, dtype=self.dtype).to(f"cuda:{get_local_rank()}") + if self.group is not None: + self.output = torch.ones(self.world_size, self.unit_size, dtype=self.dtype, device=device) self.input_buffer_size = self.output.element_size() * self.output.numel() def run(self): - if self.output is None or not self.do_it: + if self.group is None: return - handler = dist.broadcast(self.output, src=0, async_op=self.async_op, group=self.group) + # src是整个group中rank最小的 + src = min(dist.get_process_group_ranks(self.group)) + handler = dist.broadcast(self.output, src=src, async_op=self.async_op, group=self.group) if self.async_op: handler.wait() - def complexity(self): + def bw_complexity(self): + return self.input_buffer_size + + def algo_complexity(self): return self.input_buffer_size diff --git a/internlm/simulator/profiler/benchmark/linear.py b/internlm/simulator/profiler/benchmark/linear.py index 4814b0f30..ac522e3dc 100644 --- a/internlm/simulator/profiler/benchmark/linear.py +++ b/internlm/simulator/profiler/benchmark/linear.py @@ -1,6 +1,8 @@ import torch + from internlm.model.registry import benchmark_initializer from internlm.simulator.common import * +from internlm.utils.common import get_current_device from .base_benchmark import UnitBench @@ -26,15 +28,12 @@ class UnitBenchLinear(UnitBench): ], # 7B, (13B, 20B), 30B, 65B, 123B "bias": [False], # it is not work!! False, "dtype": [torch.bfloat16], - "world_size": [1], } def __init__(self, seq_len, hidden_dim, bias, dtype) -> None: self.seq_len = seq_len self.hidden_dim = hidden_dim - self.q = torch.nn.Linear( - hidden_dim, hidden_dim, bias=bias, device=f"cuda:{get_local_rank()}", dtype=dtype - ) # (hidden_dim, hidden_dim) + self.q = torch.nn.Linear(hidden_dim, hidden_dim, bias=bias, device=get_current_device(), dtype=dtype) self.dtype = self.q.weight.element_size() self.x = torch.rand(1, seq_len, hidden_dim).to(self.q.weight) # (bsz, seq_len, hidden_dim) diff --git a/internlm/simulator/profiler/benchmark/multi_head_attn.py b/internlm/simulator/profiler/benchmark/multi_head_attn.py index 8ae3aa80c..f4cf5d737 100644 --- a/internlm/simulator/profiler/benchmark/multi_head_attn.py +++ b/internlm/simulator/profiler/benchmark/multi_head_attn.py @@ -1,134 +1,209 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- - -import math - import torch -from einops import rearrange -from torch import nn - -from internlm.model.registry import benchmark_initializer -from internlm.simulator.common import TP_SIZE_RANGE, K, get_local_rank - -try: - from flash_attn.flash_attn_interface import flash_attn_qkvpacked_func - from flash_attn.modules.mha import FlashSelfAttention, SelfAttention -except ModuleNotFoundError: - print("import fa failed!", flush=True) - try: - from deeplink_ext.internevo_ops import ( - FlashCrossAttention, - FlashSelfAttention, - ) - except ModuleNotFoundError: - flash_attn_qkvpacked_func = None - FlashSelfAttention = None - SelfAttention = None - print("import dipu fa failed!", flush=True) +from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func - -from .base_benchmark import UnitBench +from internlm.simulator.profiler.profiler import Timer +from internlm.utils.common import get_current_device BENCH_TYPE = "flash_attn" -# @benchmark_initializer.register_module(module_name=BENCH_TYPE) -class UnitMultiHeadAttn(UnitBench): - test_loop = { - "seq_len": [64 * K, int(0.25 * K), int(0.5 * K), 1 * K, 2 * K, 4 * K, 8 * K, 32 * K, 16 * K], # 256 * K, 128 * K, - "num_heads_and_hidden_dim": [(64, 8192), (48, 6144), (32, 4096), (40, 5120)], # (80, 10240), - "dtype": [torch.bfloat16], - "micro_bsz": [ 2, 1], # 4, - "tp_size": TP_SIZE_RANGE, - "is_fwd": [True, False], - } - - def __init__(self, seq_len, num_heads_and_hidden_dim, dtype, micro_bsz, tp_size, is_fwd) -> None: - num_heads, embed_dim = num_heads_and_hidden_dim - self.num_heads_and_hidden_dim = num_heads_and_hidden_dim - self.TP = tp_size - self.S = seq_len - self.N = num_heads - self.H = embed_dim // self.N - self.dtype = dtype - self.dtype_size = 2 if self.dtype == torch.bfloat16 else 4 - self.B = micro_bsz - self.oom = False - self.is_fwd = is_fwd - self.causal = True - - assert num_heads % self.TP == 0, "num_heads must be divisible by tp_size" - assert num_heads >= tp_size, f"head nums must bigger then tp_size: {tp_size}" - - self.num_atten_head_tp = num_heads // self.TP - self.head_dim = self.H // num_heads - self.tp_embedding_dim = self.H // self.TP - - self.packed_length = self.S * self.B - self.device = f"cuda:{get_local_rank()}" - cu_seqlens = [i * self.S for i in range(self.B + 1)] - - weights_mem_used = self.packed_length * 3 * self.H * self.dtype_size - attn_activation = 11 * self.packed_length * self.H - mem_used = attn_activation + weights_mem_used - - self.inner_attn = FlashSelfAttention(causal=True, softmax_scale=self.H ** (0.5), attention_dropout=0.0) - - oom = False - if mem_used > 75 * 1024**3: - oom = True - - # 约束1: seqlen最大不能超过256K(不含) - # 约束2: embed_dim在被tp切过之后若大于6144, 则packed_length不能大于256k - if self.packed_length >= 256 * K and (self.H / self.TP) >= 6144: - oom = True - if self.S >= 256 * K and self.B > 1: - oom = True - if self.packed_length >= 524288 and (self.H / self.TP) >= 3072: - oom = True - if self.packed_length >= 1048576 and (self.H / self.TP) >= 2048: - oom = True - - if oom: - assert ( - False - ), f"warning : mem_used: {mem_used/1024**3:.2f} GB, seq_len: {self.S}, embed_dim: {self.H}, tp_size: {self.TP}" - - self.qkv = torch.rand( - size=(self.B * self.S, 3, self.N // self.TP, self.H), - dtype=self.dtype, - device=self.device, +def run_fa_lat_test(micro_bsz, seqlen, hidden_size, q_head, kv_head, dtype, warmups=2, trials=5): + # 1, S, N, H + def run(): + device = get_current_device() + cu_seqlens = torch.tensor([i * seqlen for i in range(micro_bsz + 1)], dtype=torch.int32, device=device) + + tfwd, tbwd = Timer(True), Timer(True) + q = torch.rand( + [micro_bsz * seqlen, q_head, hidden_size // q_head], + dtype=dtype, + device=device, + requires_grad=True, + ) + kv = torch.rand( + [micro_bsz * seqlen, 2, kv_head, hidden_size // q_head], + dtype=dtype, + device=device, requires_grad=True, ) - self.dtype_size = self.qkv.element_size() - self.cu_seqlens = torch.tensor(data=cu_seqlens, dtype=torch.int32, device=self.device) - self.max_seqlen= self.S - if not self.is_fwd: - self.output = self.run_fwd() - self.grad = torch.randn_like(self.output) / 32 # avoid grad is too large. - - def run(self): - if self.is_fwd: - self.run_fwd() - else: - self.run_bwd(self.output, self.grad) - - def run_fwd(self): - context = self.inner_attn(self.qkv, cu_seqlens=self.cu_seqlens, max_seqlen=self.max_seqlen, causal=self.causal) - return context - - def run_bwd(self, output, grad): - output.backward(grad, retain_graph=True) - - @staticmethod - def gen_store_key(micro_bsz, seq_len, num_heads_and_hidden_dim, tp_size, is_fwd): - _, embed_dim = num_heads_and_hidden_dim - tp_embedding_dim = embed_dim // tp_size - return f"b_{micro_bsz}_s_{seq_len}_h_{tp_embedding_dim}_fwd_{is_fwd}" - - def complexity(self): - return UnitMultiHeadAttn.gen_store_key( - self.B, self.S, self.num_heads_and_hidden_dim, self.TP, self.is_fwd + torch.cuda.synchronize() + tfwd.start() + context = flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens_k=cu_seqlens, + cu_seqlens_q=cu_seqlens, + max_seqlen_k=seqlen, + max_seqlen_q=seqlen, + causal=True, ) - # return f"{self.S} * {self.hidden_dim} * {self.hidden_dim}" + t_fwd = tfwd.end() + + grad = torch.randn_like(context) / 32 # avoid grad is too large. + torch.cuda.synchronize() + + tbwd.start() + context.backward(grad, retain_graph=True) + t_bwd = tbwd.end() + return t_fwd, t_bwd + + for i in range(warmups): + run() + + t_fwds, t_bwds = 0, 0 + for i in range(trials): + t_fwd, t_bwd = run() + t_fwds += t_fwd + t_bwds += t_bwd + + return t_fwds / trials, t_bwds / trials + + +# from .base_benchmark import UnitBench +# import math + +# import torch +# from einops import rearrange +# from torch import nn + +# from internlm.model.registry import benchmark_initializer +# from internlm.simulator.common import TP_SIZE_RANGE, K, get_local_rank +# from internlm.utils.common import get_current_device + +# try: +# from flash_attn.flash_attn_interface import ( +# flash_attn_qkvpacked_func, +# flash_attn_varlen_func, +# ) +# from flash_attn.modules.mha import FlashSelfAttention, SelfAttention +# except ModuleNotFoundError: +# print("import fa failed!", flush=True) +# try: +# from deeplink_ext.internevo_ops import FlashCrossAttention, FlashSelfAttention +# except ModuleNotFoundError: +# flash_attn_qkvpacked_func = None +# FlashSelfAttention = None +# SelfAttention = None +# print("import dipu fa failed!", flush=True) + + +# @benchmark_initializer.register_module(module_name=BENCH_TYPE) + +# 对于FA,我们还是用on the fly的方式 profiling,并用cache缓存中间结果 +# class UnitMultiHeadAttn(UnitBench): +# # test_loop = { +# # "seq_len": [ +# # 64 * K, +# # int(0.25 * K), +# # int(0.5 * K), +# # 1 * K, +# # 2 * K, +# # 4 * K, +# # 8 * K, +# # 32 * K, +# # 16 * K, +# # ], # 256 * K, 128 * K, +# # "head_H": [(64, 8192), (48, 6144), (32, 4096), (40, 5120)], # (80, 10240), +# # "dtype": [torch.bfloat16], +# # "micro_bsz": [2, 1], # 4, +# # "tp_size": TP_SIZE_RANGE, +# # "is_fwd": [True, False], +# # } + +# def __init__(self, seq_len, num_heads_and_hidden_dim, dtype, micro_bsz, tp_size, is_fwd) -> None: +# q_head, kv_head, embed_dim = num_heads_and_hidden_dim +# self.num_heads_and_hidden_dim = num_heads_and_hidden_dim +# self.TP = tp_size +# self.S = seq_len +# self.N = num_heads +# self.H = embed_dim // self.N +# self.dtype = dtype +# self.dtype_size = 2 if self.dtype == torch.bfloat16 else 4 +# self.B = micro_bsz +# self.oom = False +# self.is_fwd = is_fwd +# self.causal = True + +# assert num_heads % self.TP == 0, "num_heads must be divisible by tp_size" +# assert num_heads >= tp_size, f"head nums must bigger then tp_size: {tp_size}" + +# self.num_atten_head_tp = num_heads // self.TP +# self.head_dim = self.H // num_heads +# self.tp_embedding_dim = self.H // self.TP + +# self.packed_length = self.S * self.B +# self.device = f"cuda:{get_local_rank()}" +# cu_seqlens = [i * self.S for i in range(self.B + 1)] + +# weights_mem_used = self.packed_length * 3 * self.H * self.dtype_size +# attn_activation = 11 * self.packed_length * self.H +# mem_used = attn_activation + weights_mem_used + +# self.inner_attn = FlashSelfAttention(causal=True, softmax_scale=self.H ** (0.5), attention_dropout=0.0) + +# oom = False +# if mem_used > 75 * 1024**3: +# oom = True + +# # 约束1: seqlen最大不能超过256K(不含) +# # 约束2: embed_dim在被tp切过之后若大于6144, 则packed_length不能大于256k +# if self.packed_length >= 256 * K and (self.H / self.TP) >= 6144: +# oom = True +# if self.S >= 256 * K and self.B > 1: +# oom = True +# if self.packed_length >= 524288 and (self.H / self.TP) >= 3072: +# oom = True +# if self.packed_length >= 1048576 and (self.H / self.TP) >= 2048: +# oom = True + +# if oom: +# assert ( +# False +# ), f"warning : mem_used: {mem_used/1024**3:.2f} GB, seq_len: {self.S}, embed_dim: {self.H}, tp_size: {self.TP}" + +# self.qkv = torch.rand( +# size=(self.B * self.S, 3, self.N // self.TP, self.H), +# dtype=self.dtype, +# device=self.device, +# requires_grad=True, +# ) + +# self.dtype_size = self.qkv.element_size() +# self.cu_seqlens = torch.tensor(data=cu_seqlens, dtype=torch.int32, device=self.device) +# self.max_seqlen = self.S +# if not self.is_fwd: +# self.output = self.run_fwd() +# self.grad = torch.randn_like(self.output) / 32 # avoid grad is too large. + +# def run(self): +# if self.is_fwd: +# self.run_fwd() +# else: +# self.run_bwd(self.output, self.grad) + +# def run_fwd(self): +# context = self.inner_attn(self.qkv, cu_seqlens=self.cu_seqlens, max_seqlen=self.max_seqlen, causal=self.causal) +# return context + +# def run_bwd(self, output, grad): +# output.backward(grad, retain_graph=True) + +# @staticmethod +# def gen_store_key(micro_bsz, seq_len, num_heads_and_hidden_dim, tp_size, is_fwd): +# _, embed_dim = num_heads_and_hidden_dim +# tp_embedding_dim = embed_dim // tp_size +# return f"b_{micro_bsz}_s_{seq_len}_h_{tp_embedding_dim}_fwd_{is_fwd}" + +# def complexity(self): +# return UnitMultiHeadAttn.gen_store_key(self.B, self.S, self.num_heads_and_hidden_dim, self.TP, self.is_fwd) +# # return f"{self.S} * {self.hidden_dim} * {self.hidden_dim}" + + +if __name__ == "__main__": + + micro_bsz, seqlen, hidden_size, q_head, kv_head, dtype = 1, 4096, 4096, 32, 8, torch.bfloat16 + t_fwd, t_bwd = run_fwd(micro_bsz, seqlen, hidden_size, q_head, kv_head, dtype) + print(f"t_fwd: {t_fwd}, t_bwd: {t_bwd}", flush=True) diff --git a/internlm/simulator/profiler/benchmark/reduce_scatter.py b/internlm/simulator/profiler/benchmark/reduce_scatter.py index 6b8a0509f..39da2275e 100644 --- a/internlm/simulator/profiler/benchmark/reduce_scatter.py +++ b/internlm/simulator/profiler/benchmark/reduce_scatter.py @@ -3,6 +3,7 @@ from internlm.model.registry import benchmark_initializer from internlm.simulator.common import * +from internlm.utils.common import get_current_device from .base_benchmark import UnitBench @@ -13,36 +14,36 @@ class UnitBenchAllReduceScatter(UnitBench): test_loop = { "global_size": GLOBAL_ELEM_SIZES_LIST, - "world_size": WORLD_SIZE_LIST, # 7B, (13B, 20B), 30B, 65B, 123B "async_op": [False], # it is not work!! False, "dtype": [torch.bfloat16], } - def __init__(self, world_size, async_op, dtype, global_size=None, unit_size=None) -> None: + def __init__(self, async_op, dtype, group, global_size=None, unit_size=None) -> None: assert global_size is None or unit_size is None + device = get_current_device() + world_size = dist.get_world_size(group) self.unit_size = unit_size if unit_size else global_size // world_size # elements_per_gpu self.world_size = world_size self.dtype = dtype self.async_op = async_op - self.group = sub_process_groups[str(world_size)] - self.do_it = dist.get_rank() in set(dist.get_process_group_ranks(self.group)) - - if dist.get_world_size() < world_size: - self.input = None - self.output = None - else: - self.output = torch.ones(self.unit_size, dtype=self.dtype).to(f"cuda:{get_local_rank()}") - self.input = torch.ones(self.world_size, self.unit_size, dtype=self.dtype).to(f"cuda:{get_local_rank()}") + self.group = group + + if self.group is not None: + self.output = torch.ones(self.unit_size, dtype=self.dtype, device=device) + self.input = torch.ones(self.world_size, self.unit_size, dtype=self.dtype, device=device) self.input_buffer_size = self.input.element_size() * self.input.numel() def run(self): - if self.output is None or not self.do_it: + if self.group is None: return handler = dist.reduce_scatter_tensor(self.output, self.input, async_op=self.async_op, group=self.group) if self.async_op: handler.wait() - def complexity(self): - return self.input_buffer_size \ No newline at end of file + def bw_complexity(self): + return self.input_buffer_size + + def algo_complexity(self): + return self.input_buffer_size diff --git a/internlm/simulator/profiler/perf_comm.py b/internlm/simulator/profiler/perf_comm.py new file mode 100644 index 000000000..58402bfb5 --- /dev/null +++ b/internlm/simulator/profiler/perf_comm.py @@ -0,0 +1,454 @@ +import functools +import math +import os +import socket + +import torch +import torch.distributed as dist + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.simulator.common import BW, POSITIVE_INFINITY, CostType +from internlm.simulator.profiler.benchmark import register_comm_pref_initializer +from internlm.simulator.profiler.benchmark.multi_head_attn import run_fa_lat_test +from internlm.simulator.profiler.profiler import ( + draw_cal_pics, + draw_pics, + run_cal_profile, + run_comm_profile, + sync_all, +) +from internlm.utils.common import get_args, get_master_node, parse_args + +cost_model = None +scale_ratio = [1.415134488, 1.208864145, 1.1, 1] + +fa_cost_cache = {} + + +def get_group_id(rank, gpus_per_node, intra_size, inter_size): + intra_x = rank % gpus_per_node + inter_y = rank // gpus_per_node + x_idx = intra_x // intra_size + y_idx = inter_y // inter_size + # y_stride = gpus_per_node // intra_size + return x_idx, y_idx + + +def gen_cal_key(op_type: CostType): + return f"{op_type}" + + +def gen_comm_key(op_name, intra_size, inter_size): + return f"{op_name}_intra_{intra_size}_inter_{inter_size}" + + +def new_process_group(world_size, gpus_per_node, intra_size, inter_size): + node_nums = world_size // gpus_per_node + intra_group_stride = gpus_per_node // intra_size + inter_group_stride = node_nums // inter_size + + gid_2_group = [[None for _ in range(intra_group_stride)] for _ in range(inter_group_stride)] + + for j_outer in range(inter_group_stride): + for i_outer in range(intra_group_stride): + base_idx = i_outer * intra_size + j_outer * inter_size * gpus_per_node + ranks = [] + for j in range(inter_size): + idx = base_idx + j * gpus_per_node + ranks.extend(list(range(idx, idx + intra_size, 1))) + # if dist.get_rank() == 0: + # print(f"base_idx: {base_idx}, intra_size: {intra_size}, inter_size: {inter_size}, key: {key}, ranks: {ranks}", flush=True) + group = dist.new_group(ranks, backend="nccl") + gid_2_group[j_outer][i_outer] = (group, ranks) + + return gid_2_group + + +def gen_perf(): + if "RANK" not in os.environ: + os.environ["RANK"] = os.environ["SLURM_PROCID"] + if "WORLD_SIZE" not in os.environ: + os.environ["WORLD_SIZE"] = os.environ["SLURM_NPROCS"] + if "LOCAL_RANK" not in os.environ: + os.environ["LOCAL_RANK"] = str(int(os.environ["RANK"]) % 8) + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = get_master_node() + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = str(12345) + + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + host = os.environ["MASTER_ADDR"] + port = int(os.environ["MASTER_PORT"]) + + gpus_per_node = 8 + node_num = world_size // gpus_per_node + config = dict( + parallel=dict( + zero1=dict(size=1), + tensor=dict(size=gpus_per_node, mode="mtp"), + pipeline=dict(size=world_size // gpus_per_node, interleaved_overlap=True), + weight=dict(size=1, overlap=True, memory_pool=True), + ), + clusters=[ + { + "name": "nv_cluster", + "peak_tflops": 320, + "capacity": 80 * 1024**3, + "intra_bw": 150, + "inter_bw": 100, + "gpu_per_node": 8, + "node_num": 1, + }, + { + "name": "mx_cluster", + "peak_tflops": 240, + "capacity": 64 * 1024**3, + "intra_bw": 150, + "inter_bw": 100, + "gpu_per_node": 8, + "node_num": 1, + }, + ], + ) + + gpc.load_config(config) + + init_method = f"tcp://[{host}]:{port}" + dist.init_process_group( + rank=rank, + world_size=world_size, + backend="nccl", + init_method=init_method, + ) + group = dist.GroupMember.WORLD + + gpc._register_dist(rank, world_size, group, None, list(range(world_size)), ParallelMode.GLOBAL) + gpc._global_ranks[ParallelMode.GLOBAL] = rank + gpc.set_device(local_rank) + + comm_test_list = [ + CostType.ALL2ALL, + CostType.ALLREDUCE, + CostType.REDUCESCATTER, + CostType.ALLGATHER, + CostType.BROADCAST, + ] + + register_comm_pref_initializer() + + intra_comm_nums = int(math.log(gpus_per_node)) + 1 # 0,1,2,3 + inter_comm_nums = int(math.log(node_num)) + 1 + + data_path = f"./prof_data" + cal_pic_path = f"{data_path}/pics/cal" + comm_pic_path = f"{data_path}/pics/comm" + + if dist.get_rank() == 0: + os.makedirs(comm_pic_path, exist_ok=True) + if dist.get_rank() == 0: + os.makedirs(cal_pic_path, exist_ok=True) + + spline_model_dict = {} + + if dist.get_rank() == 0: + comp_test_list = [CostType.LINEAR] + for test_op in comp_test_list: + tflop, tflops = run_cal_profile(test_op) + spline_model = draw_cal_pics(cal_pic_path, f"{test_op}", tflop, tflops) + spline_model_dict[gen_cal_key(test_op)] = spline_model + + sync_all() + + for i in range(inter_comm_nums): + for j in range(intra_comm_nums): + inter_size, intra_size = 2**i, 2**j + if inter_size * intra_size != 1: + + x_idx, y_idx = get_group_id(rank, gpus_per_node, intra_size, inter_size) + groups = new_process_group(world_size, gpus_per_node, intra_size, inter_size) + + for test_type in comm_test_list: + key = gen_comm_key(test_op, intra_size, inter_size) + if dist.get_rank() == 0: + print( + f"key: {key}, inter_size: {inter_size}, intra_size: {intra_size}, ranks: {groups[y_idx][x_idx][1]}", + flush=True, + ) + pg = groups[y_idx][x_idx][0] + assert ( + pg != -100 + ), f"key: {key}, x_idx: {x_idx}, y_idx: {y_idx}, rank: {gpc.get_global_rank()}, ranks: {groups[y_idx][x_idx][1]}" + comm_vols, bws = run_comm_profile(test_type, pg, key) + sync_all() + if dist.get_rank() == 0: + spline_model_dict[key] = draw_pics(comm_pic_path, key, comm_vols, bws) + + print(f"rank: {gpc.get_global_rank()}, all done!", flush=True) + + if dist.get_rank() == 0: + pt = os.path.join(data_path, "data.pt") + with open(pt, "wb") as f: + torch.save(spline_model_dict, f) + + +def init_cost_model(cost_model_path): + global cost_model + with open(cost_model_path, "rb") as f: + cost_model = torch.load(f) + + +def coll_algo_bw(comm_op, size, n): + if comm_op == CostType.ALL2ALL: + if n <= 8: + return size * (n - 1) / n + else: + # intra_parts = 8 + one_part = size / n + return 8 * one_part * (n - 8 / n) + elif comm_op == CostType.ALLREDUCE: + return size * 2 * (n - 1) / n + elif comm_op == CostType.REDUCESCATTER: + return size * (n - 1) / n + elif comm_op == CostType.ALLGATHER: + return size * (n - 1) / n + elif comm_op == CostType.BROADCAST: + return size * (n - 1) / n + elif comm_op == CostType.P2P: + return size + + raise ValueError(f"unknown comm_op: {comm_op}") + + +def coll_bus_bw(comm_op, size, n): + if comm_op == CostType.ALL2ALL: + return size + elif comm_op == CostType.ALLREDUCE: + return size * 2 + elif comm_op == CostType.REDUCESCATTER: + return size + elif comm_op == CostType.ALLGATHER: + return size + elif comm_op == CostType.BROADCAST: + return size + elif comm_op == CostType.P2P: + return size + + raise ValueError(f"unknown comm_op: {comm_op}") + + +# 需要判断是否打满带宽 +def get_scale_ratio(scale): + # 通信扩展惩罚系数 + if scale <= 16: + return 1 + elif 16 < scale <= 32: + return 1.1 + elif 32 < scale <= 64: + return 1.2 + elif 64 < scale <= 256: + return 1.3 + elif 256 < scale <= 512: + return 1.4 + else: + return 1.5 + + +comm_matrix_dict = {} + + +def draw_heatmap(comm_nums: int, comm_volume: int, parallel_mode, use_rail_optim=False): + """ "Draw a heatmap for communication volume of different parallel mode." + + Args: + comm_nums (int): "Communication volume." + comm_volume (int): "Communication volume." + parallel_mode (_type_): + use_rail_optim (bool, optional): Whether to consider multi-track optimization. + Defaults to False. + """ + + global comm_matrix_dict + + if parallel_mode not in comm_matrix_dict: + comm_matrix_dict[parallel_mode] = [ + [0 for _ in range(gpc.get_world_size(ParallelMode.GLOBAL))] + for _ in range(gpc.get_world_size(ParallelMode.GLOBAL)) + ] + + comm_mat = comm_matrix_dict[parallel_mode] + + all_ranks = gpc.get_parallel_all_ranks(parallel_mode) + + for sub_id in range(len(all_ranks)): + ranks = all_ranks[sub_id] + for i in range(len(ranks)): + world_size = len(ranks) + + if parallel_mode in [ParallelMode.TENSOR, ParallelMode.WEIGHT]: + _comm_volume = comm_volume * gpc.config.model["num_layers"] + elif parallel_mode == ParallelMode.PIPELINE: + _comm_volume = 8 * 2 * comm_volume * comm_nums + # elif parallel_mode == ParallelMode.ZERO1: + # _comm_volume = comm_volume * world_size + else: + _comm_volume = comm_volume + + if _comm_volume < 0: + _comm_volume = -1 * _comm_volume + + chunk_size = _comm_volume / world_size + is_intra = gpc.check_pg_is_intra(parallel_mode) + + print(f"sub_id: {sub_id}, parallel_mode: {parallel_mode}, is_intra: {is_intra}", flush=True) + if is_intra: # hard code, nvswitch + for j in range(len(ranks)): + if j != i: + # print(f"len: {len(ranks)}, i: {i}, j: {j}", flush=True) + comm_mat[ranks[i]][ranks[j]] += chunk_size + comm_mat[ranks[j]][ranks[i]] += chunk_size + else: + if use_rail_optim: + pass + else: + if parallel_mode in [ParallelMode.DATA, ParallelMode.ZERO1]: + inter_all_ranks = gpc.get_parallel_all_ranks(ParallelMode.INTER_DP_SZIE) + intra_all_ranks = gpc.get_parallel_all_ranks(ParallelMode.INTRA_DP_SZIE) + + if parallel_mode == ParallelMode.DATA: + chunk_size /= 2 + + for k in range(len(inter_all_ranks)): + t_ranks = inter_all_ranks[k] + for p in range(len(t_ranks)): + for q in range(len(t_ranks)): + if p != q: + comm_mat[t_ranks[p]][ + t_ranks[q] + ] += chunk_size # / 4 # += (chunk_size // len(t_ranks)) + comm_mat[t_ranks[q]][ + t_ranks[p] + ] += chunk_size # / 4 # += (chunk_size // len(t_ranks)) + + for k in range(len(intra_all_ranks)): + t_ranks = intra_all_ranks[k] + for p in range(len(t_ranks)): + for q in range(len(t_ranks)): + if p != q: + comm_mat[t_ranks[p]][ + t_ranks[q] + ] += chunk_size # / 4 # += (chunk_size // len(t_ranks)) + comm_mat[t_ranks[q]][ + t_ranks[p] + ] += chunk_size # / 4 # += (chunk_size // len(t_ranks)) + + return + elif parallel_mode == ParallelMode.ZERO1: + inter_all_ranks = gpc.get_parallel_all_ranks(ParallelMode.INTER_DP_SZIE) + for k in range(len(inter_all_ranks)): + t_ranks = inter_all_ranks[k] + for p in range(len(t_ranks)): + for q in range(len(t_ranks)): + if p != q: + comm_mat[t_ranks[p]][t_ranks[q]] += chunk_size // len(t_ranks) + comm_mat[t_ranks[q]][t_ranks[p]] += chunk_size // len(t_ranks) + return + elif parallel_mode == ParallelMode.PIPELINE: + if i < len(ranks) - 1: + comm_mat[ranks[i]][ranks[(i + 1) % world_size]] += _comm_volume + comm_mat[ranks[(i + 1) % world_size]][ranks[i]] += _comm_volume + else: + assert False + + +def get_comm_cost_from_logic(comm_volume: int, parallel_mode: ParallelMode, comm_op: CostType = None, comm_nums=1): + """根据通信量获得近似的通信延迟,这个函数考虑了跨节点带宽content的情景 + 所以为了正确计算延迟,传入的 comm_volume 必须是以单个rank视角下的通信量 + (即代码中实际传入的通信量) + + Args: + comm_volume (int): 通信量, 单位B + parallel_mode (ParallelMode): gpc并行模式 + comm_op (CostType, optional): 通信算子 + + Returns: + int: 通信延迟,是乘以10**4后并取整后的数值 + """ + scale = gpc.get_world_size(parallel_mode) + + if scale > 1 and get_args().draw_heatmap: + draw_heatmap(comm_nums, comm_volume, parallel_mode) + + if parallel_mode == ParallelMode.PIPELINE: + scale = 2 + + if scale <= 1: + return 0 + + is_intra = gpc.check_pg_is_intra(parallel_mode) + if not is_intra: + num_partner = gpc.same_group_in_one_node(parallel_mode) + assert num_partner <= 8, f"num_partner: {num_partner}" + if parallel_mode == ParallelMode.WEIGHT: + assert num_partner == 1 + if parallel_mode == ParallelMode.TENSOR: + assert num_partner == 1 + comm_volume *= num_partner + + bw = BW.A800_NVL if is_intra else (BW.IB / get_scale_ratio(scale)) + return coll_algo_bw(comm_op, comm_volume, scale) / bw # 转换成ms小数点保留两位 + + +def get_comm_cost_from_cost_data(comm_volume: int, parallel_mode: ParallelMode, comm_op: CostType = None): + """这里最佳的实现感觉是仿照NCCL的写法,建立起完整的通信代价矩阵,难点是如何确定一次集合通信包含了几个环 + (难道把nccl建图和搜索最优路径的代码用python重写一遍?) + + Args: + comm_volume (int): _description_ + parallel_mode (ParallelMode): _description_ + comm_op (CostType, optional): _description_. Defaults to None. + """ + pass + + +def get_cal_cost(cal_op, flop): + global cost_model + assert cost_model is not None + try: + flops = cost_model[gen_cal_key(cal_op)](flop) + except Exception as e: + print(f"error: {e}", flush=True) + return POSITIVE_INFINITY + else: + return flop / flops # latency in second. + + +def get_fa_cost(micro_bsz, seqlen, hidden_size, q_head, kv_head, dtype, is_fwd): + fa_key = f"{micro_bsz}_{seqlen}_{hidden_size}_{q_head}_{kv_head}" + + if fa_key not in fa_cost_cache: + print(f"not found FA key : {fa_key}, do profiling...") + try: + t_fwd, t_bwd = run_fa_lat_test(micro_bsz, seqlen, hidden_size, q_head, kv_head, dtype=torch.bfloat16) + except RuntimeError as e: + print(f"{e}, fa run fail", flush=True) + t_fwd, t_bwd = float("inf"), float("inf") + + fa_cost_cache[fa_key] = t_fwd, t_bwd + + if is_fwd: + return fa_cost_cache[fa_key][0] + else: + return fa_cost_cache[fa_key][1] + + +get_comm_cost = get_comm_cost_from_logic + +allgather = functools.partial(get_comm_cost, comm_op=CostType.ALLGATHER) +reducescatter = functools.partial(get_comm_cost, comm_op=CostType.REDUCESCATTER) +broadcast = functools.partial(get_comm_cost, comm_op=CostType.BROADCAST) +p2p = functools.partial(get_comm_cost, comm_op=CostType.P2P) +alltoall = functools.partial(get_comm_cost, comm_op=CostType.ALL2ALL) +allreduce = functools.partial(get_comm_cost, comm_op=CostType.ALLREDUCE) diff --git a/internlm/simulator/profiler/profiler.py b/internlm/simulator/profiler/profiler.py index 95c6e43f9..9d28bbfd4 100644 --- a/internlm/simulator/profiler/profiler.py +++ b/internlm/simulator/profiler/profiler.py @@ -1,25 +1,51 @@ -import functools import inspect import os -import sys import time from collections import OrderedDict from copy import deepcopy from typing import Dict, List +import matplotlib.pyplot as plt import torch import torch.distributed as dist +from scipy.interpolate import interp1d # internlm/model/registry.py +# from internlm.model.registry import benchmark_initializer from internlm.model.registry import benchmark_initializer from internlm.simulator.common import ( + GLOBAL_BYTE_SIZES_LIST, OUT_OF_MEM_LATENCY, - get_global_rank, - get_world_size, sync_all, + sync_local, ) +class Timer: + def __init__(self, use_event) -> None: + self.use_event = use_event + if use_event: + self.start_t = torch.cuda.Event(enable_timing=True) + self.end_t = torch.cuda.Event(enable_timing=True) + + def start(self): + if self.use_event: + self.start_t.record() + else: + self.start_t = time.time() + + def end(self, group=None): + if self.use_event: + self.end_t.record() + if group != None: + dist.barrier(group) + torch.cuda.synchronize() + return self.start_t.elapsed_time(self.end_t) / 1000 + else: + torch.cuda.synchronize() + return time.time() - self.start_t + + def DFS(loop_config: OrderedDict, results: OrderedDict, total_results: List): if len(loop_config) == 0: total_results.append(deepcopy(results)) @@ -41,31 +67,22 @@ def filter_kwargs(func, kwargs): return {k: v for k, v in kwargs.items() if k in sig.parameters} -def debug_profile(bench, test_case): - if "lat" not in test_case: - test_case["lat"] = int.Maximum - - # print(f"{bench.complexity()}: micro_bsz: {test_case['micro_bsz']}, seq_len: {test_case['seq_len']}, num_heads_and_hidden_dim: {test_case['num_heads_and_hidden_dim']}, tp_size {test_case['tp_size']}, lat: {test_case['lat']}", flush=True) - - -def run_profile(args, test_type): - re_results = {} - +def run_cal_profile(test_type, warmups=2, trials=5): BENCH = benchmark_initializer.get_module(module_name=test_type) - def run_benchmark(test_case, args): - sync_all() + def run_benchmark(test_case): # Warmups, establish connections, etc. - for _ in range(args.warmups): + timer = Timer(use_event=True) + for _ in range(warmups): try: test_case.run() except torch.cuda.OutOfMemoryError: torch.cuda.empty_cache() return OUT_OF_MEM_LATENCY + try: - sync_all() + sync_local() except RuntimeError: - # self.packed_length * 3 * self.embed_dim * self.dtype_size print( f"packed_length: {test_case.packed_length}, embed_dim: {test_case.embed_dim}, micro_bsz: {test_case.micro_bsz}, seq_len: {test_case.seq_len}, tp:{test_case.tp_size}", flush=True, @@ -74,58 +91,180 @@ def run_benchmark(test_case, args): return OUT_OF_MEM_LATENCY # time the actual comm op trials times and average it - pre = time.perf_counter() - for _ in range(args.trials): + duration = 0 + for _ in range(trials): + timer.start() try: test_case.run() except torch.cuda.OutOfMemoryError: torch.cuda.empty_cache() return OUT_OF_MEM_LATENCY - sync_all() - duration = time.perf_counter() - pre + + duration += timer.end() + + # maintain and clean performance data + avg_duration = duration / trials + return avg_duration + + sync_local() + # loop over various tensor sizes + test_args = OrderedDict(BENCH.test_loop) + total_cases = [] + + DFS(test_args, OrderedDict(), total_cases) + + tflop = [] + tflops_list = [] + for _, test_case in enumerate(total_cases): + + try: + bench = BENCH(**filter_kwargs(BENCH.__init__, test_case)) + except torch.cuda.OutOfMemoryError: + torch.cuda.empty_cache() + break + except AssertionError: + torch.cuda.empty_cache() + break + else: + sync_local() + complexity = bench.complexity() + if complexity in tflop: + continue + + avg_lat = run_benchmark(bench) + tflops = complexity / avg_lat + + tflop.append(complexity) + tflops_list.append(tflops) + + print( + f"complexity: {complexity/ 10**12:.3f}, tflops:{tflops/ 10**12:.3f}, avg_duration: {avg_lat*1000:.3f} ms", + flush=True, + ) + + return tflop, tflops_list + + +def run_comm_profile(test_type, group, plot_name, warmups=5, trials=20): + + BENCH = benchmark_initializer.get_module(module_name=test_type) + + def run_benchmark(test_case, group): + # Warmups, establish connections, etc. + timer = Timer(use_event=True) + for _ in range(warmups): + test_case.run() + + sync_all(group) + + # time the actual comm op trials times and average it + duration = 0 + for _ in range(trials): + timer.start() + test_case.run() + duration += timer.end(group) # maintain and clean performance data - avg_duration = duration / args.trials + avg_duration = duration / trials return avg_duration - sync_all() # loop over various tensor sizes test_args = OrderedDict(BENCH.test_loop) total_cases = [] DFS(test_args, OrderedDict(), total_cases) - if get_global_rank() == 0: - print(f"all test case nums: {len(total_cases)}", flush=True) + comm_vols, bws = [], [] for test_case in total_cases: - world_size = test_case["world_size"] if "world_size" in test_case else 1 + test_case["group"] = group + bench = BENCH(**filter_kwargs(BENCH.__init__, test_case)) - if world_size not in re_results: - re_results[world_size] = {} + avg_duration = run_benchmark(bench, group) - complex_tag = BENCH.gen_store_key(**filter_kwargs(BENCH.gen_store_key, test_case)) + comm_vol = bench.bw_complexity() + bw = comm_vol / avg_duration - if complex_tag not in re_results[world_size]: - try: - bench = BENCH(**filter_kwargs(BENCH.__init__, test_case)) - except torch.cuda.OutOfMemoryError: - torch.cuda.empty_cache() - continue - except AssertionError: - # torch.cuda.empty_cache() - continue + comm_vols.append(comm_vol) + bws.append(bw) + if dist.get_rank() == 0: + print( + f" plot_name: {plot_name}, Buff: {test_case['global_size']/1024**3:.3f} GB, avg bw: {bw/ 1024**3:.3f} GB/s", + flush=True, + ) + + return comm_vols, bws + + +def draw_pics(base_path, plot_name, comm_vols, bws): + x, y = [], [] + + spline_model = interp1d(comm_vols, bws, kind="slinear") + + end = GLOBAL_BYTE_SIZES_LIST[-1] // 1024**2 + for i in range(1, end + 1): + vol = i * 1024**2 + try: + predice_bw = spline_model(vol) + except ValueError: + if vol < GLOBAL_BYTE_SIZES_LIST[0]: + predice_bw = spline_model(GLOBAL_BYTE_SIZES_LIST[0]) else: - sync_all() - avg_duration = run_benchmark(bench, args) - test_case["lat"] = avg_duration - print(f"test_case: {test_case}, avg_duration: {avg_duration} ", flush=True) + predice_bw = spline_model(GLOBAL_BYTE_SIZES_LIST[-1]) - debug_profile(bench=bench, test_case=test_case) - re_results[world_size][complex_tag] = [test_case] - else: - if get_global_rank() == 0: - print( - f"Warning test_case: {test_case}, same complexity: {complex_tag}, lat:{re_results[world_size][complex_tag][0]['lat']}" - ) + x.append(vol / 1024**2) + y.append(predice_bw / 1024**3) + + bws = list(map(lambda x: x / 1024**3, bws)) + comm_vols = list(map(lambda x: x / 1024**2, comm_vols)) + + pic_path = os.path.join(base_path, plot_name + ".jpg") + + plt.figure(figsize=(12, 6)) + plt.scatter(comm_vols, bws, label="True value") + plt.plot(x, y, label="Fit value") + plt.xlabel("Data Transferred (MB)") + plt.ylabel("Bandwidth (GB/s)") + plt.title(f"Bandwidth Spline Fit for {plot_name} at different data volume") + plt.legend() + plt.grid(True) + plt.savefig(pic_path) + plt.show() + + return spline_model + + +def draw_cal_pics(base_path, plot_name, tflop, tflops): + # x, y = [], [] + + spline_model = interp1d(tflop, tflops, kind="slinear") + + # start = tflop[0] + # end = tflop[-1] + # for complexity in range(start, end+1): + # try: + # predice_tflops = spline_model(complexity) + # except ValueError: + # if complexity < tflop[0]: + # predice_tflops = spline_model(tflop[0]) + # elif complexity > tflop[-1]: + # predice_tflops = spline_model(tflop[-1]) + + # x.append(complexity) + # y.append(predice_tflops) + + pic_path = os.path.join(base_path, plot_name + ".jpg") + tflop = list(map(lambda x: x / 10**12, tflop)) + tflops = list(map(lambda x: x / 10**12, tflops)) + + plt.figure(figsize=(12, 6)) + plt.scatter(tflop, tflops, label=f"True value") + # plt.plot(x, y, label=f"Fit value") + plt.xlabel("tflop") + plt.ylabel("Tflops") + plt.title(f"Tflops Spline Fit for {plot_name} at different tflop") + plt.legend() + plt.grid(True) + plt.savefig(pic_path) + plt.show() - return re_results + return spline_model diff --git a/internlm/simulator/tracker/comm_tracker.py b/internlm/simulator/tracker/comm_tracker.py index d24063a1d..19fbcc523 100644 --- a/internlm/simulator/tracker/comm_tracker.py +++ b/internlm/simulator/tracker/comm_tracker.py @@ -2,47 +2,48 @@ import torch -from internlm.simulator.common import BW, CommOp -from internlm.simulator.predict_cost_model import SplineModel +from internlm.simulator.common import BW, CostType + +# from internlm.simulator.predict_cost_model import SplineModel cost_model = None scale_ratio = [1.415134488, 1.208864145, 1.1, 1] def coll_comm_lat(comm_op, size, n): - if comm_op == CommOp.ALL2ALL: + if comm_op == CostType.ALL2ALL: if n <= 8: return size * (n - 1) / n else: # intra_parts = 8 one_part = size / n return 8 * one_part * (n - 8 / n) - elif comm_op == CommOp.ALLREDUCE: + elif comm_op == CostType.ALLREDUCE: return size * 2 * (n - 1) / n - elif comm_op == CommOp.REDUCESCATTER: + elif comm_op == CostType.REDUCESCATTER: return size * (n - 1) / n - elif comm_op == CommOp.ALLGATHER: + elif comm_op == CostType.ALLGATHER: return size * (n - 1) / n - elif comm_op == CommOp.BROADCAST: + elif comm_op == CostType.BROADCAST: return size * (n - 1) / n - elif comm_op == CommOp.P2P: + elif comm_op == CostType.P2P: return size raise ValueError(f"unknown comm_op: {comm_op}") def coll_bus_bw(comm_op, size): - if comm_op == CommOp.ALL2ALL: + if comm_op == CostType.ALL2ALL: return size - elif comm_op == CommOp.ALLREDUCE: + elif comm_op == CostType.ALLREDUCE: return size * 2 - elif comm_op == CommOp.REDUCESCATTER: + elif comm_op == CostType.REDUCESCATTER: return size - elif comm_op == CommOp.ALLGATHER: + elif comm_op == CostType.ALLGATHER: return size - elif comm_op == CommOp.BROADCAST: + elif comm_op == CostType.BROADCAST: return size - elif comm_op == CommOp.P2P: + elif comm_op == CostType.P2P: return size raise ValueError(f"unknown comm_op: {comm_op}") @@ -132,60 +133,8 @@ def add_comm_meta(self, comm_type: CommType, parallel_mode, can_overlap): self.next_parallel_mode = parallel_mode self.can_overlap = can_overlap - def cal_comm_cost(self, comm_op, comm_volume=1, dtype=torch.bfloat16): - """根据通信量获得近似的通信延迟,这个函数考虑了跨节点带宽content的情景 - 所以为了正确计算延迟,传入的 comm_volume 必须是以单个rank视角下的通信量 - (即代码中实际传入的通信量) - - Args: - comm_volume (int): 通信量, 单位B - parallel_mode (ParallelMode): gpc并行模式 - comm_op (CommOp, optional): 通信算子 - - Returns: - int: 通信延迟,是乘以10**4后并取整后的数值 - """ - - from internlm.core.context import ParallelMode - from internlm.core.context import global_context as gpc - - comm_type = self.next_comm_type - parallel_mode = self.next_parallel_mode - - if comm_type is None: - return - - scale = gpc.get_world_size(parallel_mode) - - if parallel_mode == ParallelMode.PIPELINE: - scale = 2 - - if scale <= 1: - return 0 - - is_intra = gpc.check_pg_is_intra(parallel_mode) - if not is_intra: - num_partner = gpc.same_group_in_one_node(parallel_mode) - assert num_partner <= 8, f"num_partner: {num_partner}" - if parallel_mode == ParallelMode.WEIGHT: - assert num_partner == 1 - if parallel_mode == ParallelMode.TENSOR: - assert num_partner == 1 - comm_volume *= num_partner - - global cost_model - try: - if cost_model is None: - cost_model = SplineModel() - - lat = cost_model.predict(comm_type, scale, comm_volume) - except FileNotFoundError: - # if comm_op == CommOp.P2P: - bw = BW.A800_NVL if is_intra else (BW.IB / get_scale_ratio(scale)) - - lat = coll_comm_lat(comm_op, comm_volume, scale) / bw # 转换成ms小数点保留两位 - - self.comm_cost_dict[comm_type].add_new_comm(lat, comm_volume, bw) + def cal_comm_cost(self, comm_op, group, src=None, comm_volume=1, dtype=torch.bfloat16): + pass comm_tracker = CommTracker() diff --git a/internlm/solver/optimizer/compatible_adamw.py b/internlm/solver/optimizer/compatible_adamw.py index bca8c2746..8d97ddd31 100644 --- a/internlm/solver/optimizer/compatible_adamw.py +++ b/internlm/solver/optimizer/compatible_adamw.py @@ -33,7 +33,9 @@ def new_compatible_adamw(params, lr: float = 0.001, betas: Tuple[float, float] = "Use fused AdamaW to avoid nan grad norm when " "model size is larger and use_fp32_norm=True, Please note this!" ) - adam_extra_kwargs["fused"] = True + + if fake_mode = "fake_mode" not in os.environ: + adam_extra_kwargs["fused"] = True elif backend is AcceleratorType.NPU: if gpc.is_rank_for_log(): logger.warning( diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 1d26bd871..e9ddda8f7 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -6,6 +6,7 @@ from itertools import product from typing import List, Optional +import numpy as np import torch import torch.distributed as dist from torch.optim import Optimizer @@ -51,6 +52,14 @@ logger = get_logger(__file__) internlm_accelerator = get_accelerator() +import os + +from internlm.simulator.tracker.comm_tracker import CommType, get_gloabl_comm_tracker + +fake_mode = "fake_mode" in os.environ + +comm_tracker = get_gloabl_comm_tracker() + class HybridZeroOptimizer(BaseOptimizer): """ @@ -234,6 +243,9 @@ def __init__( self.has_params = sum(self.param_group_has_params) != 0 # flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled. self.skip_grad_reduce = False + self.current_accum_step = 0 + + self._unbalance_micro_num = True self._attach_reduction_hook() @@ -318,11 +330,10 @@ def _define_and_attach(param, reduce_rank=None): ) def reduction_layernorm_func(): + parallel_mode = ParallelMode.WEIGHT if self.use_isp else ParallelMode.TENSOR + comm_tracker.add_comm_meta(CommType.SP_NORM_ALLREDUCE, parallel_mode, can_overlap=True) handle = reduce_tensor( - param.grad, - dtype=None, - dst_rank=reduce_rank, - parallel_mode=ParallelMode.WEIGHT if self.use_isp else ParallelMode.TENSOR, + param.grad, dtype=None, dst_rank=reduce_rank, parallel_mode=parallel_mode ) handle.wait() @@ -342,6 +353,16 @@ def extra_layernorm_reduce_grad_hook(*args): # pylint: disable=W0613 if self.skip_grad_reduce is False: reduction_layernorm_func() + # define hook for sequence_parallel + def unbalance_micro_num_loss_scale_hook(grad): # pylint: disable=W0613 + if self.skip_grad_reduce is True: # 只在梯度累加的时候生效 + left_step = np.max(gpc.micro_num_list) - self.current_accum_step + scale_denominator = np.sum(left_step >= gpc.micro_num_list) + scale = gpc.get_world_size(ParallelMode.DATA) / scale_denominator + + return grad * scale + return grad + # get the AccumulateGrad object of the param itself # If these objects are not kept, reduction hooks may not be attached successfully. accum_grad_obj = get_grad_accumulate_object(param) @@ -366,6 +387,10 @@ def extra_layernorm_reduce_grad_hook(*args): # pylint: disable=W0613 ): accum_grad_obj.register_hook(accum_grad_hook) + if self._unbalance_micro_num: + # 注意这个hook必须要在梯度累加之前被调用,所以不能采用在 accum_grad_obj 上注册hook,而是需要直接注册在param上 + param.register_hook(unbalance_micro_num_loss_scale_hook) + if self._overlap_sync_grad: accum_grad_obj.register_hook(reduce_grad_hook) @@ -439,6 +464,7 @@ def _store_and_try_reduce_grads_by_bucket(self, param, reduce_rank=None): current_bucket = self._bucket_store[group_id] if current_bucket.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size: + comm_tracker.add_comm_meta(CommType.DP_ALLREDUCE, current_bucket.get_dp_parallel_mode(), can_overlap=True) self._reduce_grads_stored_in_bucket(current_bucket, reduce_rank) # the param must not be reduced to ensure correctness @@ -672,7 +698,9 @@ def step(self, closure=None): # we need to reduce the gradients left in the communication bucket for group_id in range(self.num_param_groups): - self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None) + current_bucket = self._bucket_store[group_id] + comm_tracker.add_comm_meta(CommType.DP_ALLREDUCE, current_bucket.get_dp_parallel_mode(), can_overlap=False) + self._reduce_grads_stored_in_bucket(current_bucket, reduce_rank=None) # wait grads reduced and clear reduced grads for bucket in self._bucket_in_progress: diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 9279c1138..4a27d9700 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -35,6 +35,9 @@ APEX_AVAILABLE = False inf = math.inf +import os + +fake_mode = "fake_mode" in os.environ def flatten(input_): @@ -189,7 +192,7 @@ def multi_tensor_l2norm_torch(tensor_list, per_tensor): def calc_l2_norm(grads): norm = 0.0 if len(grads) > 0: - if APEX_AVAILABLE: + if APEX_AVAILABLE and not fake_mode: dummy_overflow_buf = torch.tensor([0], device=get_current_device(), dtype=torch.int32) norm, _ = multi_tensor_applier( amp_C.multi_tensor_l2norm, @@ -228,13 +231,13 @@ def reduce_grads(gradients, parameters, weight_parallel_mode): if ( gpc.is_initialized(ParallelMode.PIPELINE) and hasattr(p, "pipeline_shared_module_pg") - and dist.get_rank(p.pipeline_shared_module_pg) == 0 + # and dist.get_rank(p.pipeline_shared_module_pg) == 0 ): # if shared between different pipe, only count o parallel_grads.append(g.data.float()) elif ( gpc.is_initialized(ParallelMode.PIPELINE) and hasattr(p, "pipeline_shared_module_pg") - and dist.get_rank(p.pipeline_shared_module_pg) != 0 + # and dist.get_rank(p.pipeline_shared_module_pg) != 0 ): continue elif ( @@ -355,7 +358,10 @@ def compute_norm(gradients, parameters, norm_type=2, zero_mode=ParallelMode.ZERO dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(zero_mode)) if torch.is_tensor(total_norm): - total_norm = total_norm.item() + if fake_mode: + total_norm = 0.1 + else: + total_norm = total_norm.item() # Need to allreduce(avg) the norms across different ranks because moe params will not be synced during allreduce # model and zero have been reduced!!! diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index a9c4a5d0c..9e947a78e 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -2,6 +2,7 @@ # -*- encoding: utf-8 -*- import math +import os import time from typing import Callable, Iterable, List, Optional, Tuple, TypeVar, Union @@ -461,6 +462,8 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None): f"wp{gpc.get_local_rank(ParallelMode.WEIGHT)}_" f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}" ) + os.makedirs(trace_path, exist_ok=True) + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU: experimental_config = torch_npu.profiler._ExperimentalConfig( aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, diff --git a/internlm/utils/common.py b/internlm/utils/common.py index df4583d43..9a02be1aa 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -21,6 +21,15 @@ logger = get_logger(__file__) internlm_accelerator = get_accelerator() +_INTERNEVO_ARGS = None + + +def get_args(): + global _INTERNEVO_ARGS + if _INTERNEVO_ARGS is None: + _INTERNEVO_ARGS = parse_args() + return _INTERNEVO_ARGS + def parse_args(): parser = internlm.get_default_parser() diff --git a/simulation_train.py b/simulation_train.py index fe4b66ac4..dd145a82c 100644 --- a/simulation_train.py +++ b/simulation_train.py @@ -14,10 +14,11 @@ from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc from internlm.core.context.random import reset_seed +from internlm.core.parallel.shard import cluster_load_balance from internlm.core.trainer import TrainState from internlm.initialize.launch import launch from internlm.model.losses import FlashGPTLMLoss -from internlm.simulator.common import AlgoType, CommOp +from internlm.simulator.common import AlgoType, CostType, cal_model_p_elem from internlm.simulator.tracker.comm_tracker import get_gloabl_comm_tracker from internlm.simulator.tracker.mem_tracker import get_global_allocator @@ -52,7 +53,11 @@ def wait(self): def dummy_broadcast(tensor, src, group=None, async_op=False): global_comm_tracker.cal_comm_cost( - comm_op=CommOp.BROADCAST, comm_volume=tensor.numel() * tensor.element_size(), dtype=tensor.dtype + comm_op=CostType.BROADCAST, + group=group, + src=src, + comm_volume=tensor.numel() * tensor.element_size(), + dtype=tensor.dtype, ) if async_op is True: return WaitHandler() @@ -60,28 +65,43 @@ def dummy_broadcast(tensor, src, group=None, async_op=False): def dummy_allreduce(tensor, op, group=None, async_op=False): global_comm_tracker.cal_comm_cost( - comm_op=CommOp.ALLREDUCE, comm_volume=tensor.numel() * tensor.element_size(), dtype=tensor.dtype + comm_op=CostType.ALLREDUCE, group=group, comm_volume=tensor.numel() * tensor.element_size(), dtype=tensor.dtype ) if async_op is True: return WaitHandler() def dummy_allgahter(tensor_list, tensor, group=None, async_op=False): + tensor = torch.concat(tensor_list).view(-1) + + global_comm_tracker.cal_comm_cost( + comm_op=CostType.ALLGATHER, group=group, comm_volume=tensor.numel() * tensor.element_size(), dtype=tensor.dtype + ) if async_op is True: return WaitHandler() def dummy_reduce_scatter(output, input_list, op, group=None, async_op=False): - if async_op is True: - return WaitHandler() + tensor = torch.concat(input_list).view(-1) + global_comm_tracker.cal_comm_cost( + comm_op=CostType.REDUCESCATTER, + group=group, + comm_volume=tensor.numel() * tensor.element_size(), + dtype=tensor.dtype, + ) -def dummy_reduce_scatter(output, input_list, op, group=None, async_op=False): if async_op is True: return WaitHandler() def dummy_all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False): + tensor = torch.concat(input_tensor_list).view(-1) + + global_comm_tracker.cal_comm_cost( + comm_op=CostType.ALL2ALL, group=group, comm_volume=tensor.numel() * tensor.element_size(), dtype=tensor.dtype + ) + if async_op is True: return WaitHandler() @@ -112,6 +132,55 @@ def dummy_barrier(group=None, async_op=False, device_ids=None): dist.barrier = dummy_barrier +def cal_C(layer_nums=1, has_input_embeding=False, has_output_embeding=False, use_fp16=True): + """ + Do the calculation of total flops required for a single one token under this parameter size, + without considering model parallelism other than PP. (TODO: support MoE?) + """ + if use_fp16: + element_size = 2 + + fp32_element_size = 4 + + S, V, H = gpc.config.data.seq_len, gpc.config.model.vocab_size, gpc.config.model.hidden_size + mlp_ratio = gpc.config.MLP_RATIO + + q_head_nums = gpc.config.model.num_attention_heads + kv_head_nums = gpc.config.model.num_kv_attention_heads + + q_dim = H + kv_dim = (H // q_head_nums) * kv_head_nums + + hidden_features = int(H * mlp_ratio) + + EMBEDING, HEAD, LOSS = 0, 0, 0 + if has_input_embeding: # vocab, H + EMBEDING = element_size * V * H + + if has_output_embeding: # cross entropy + HEAD = element_size * V * H + LOSS = fp32_element_size * V * H # fp32 + + QKV = element_size * H * (q_dim + 2 * kv_dim) + + ATTN = 2 * 1 * S * kv_dim # S * S * H -> 1 * S * H + + ATTN_OUT = 2 * (q_dim + 2 * kv_dim) * H + + w1 = H * hidden_features + w3 = H * hidden_features + w2 = hidden_features * H + + MLP = w1 + w2 + w3 + + # 2 * H * (H + 2 * H * V_head // Q_head) + 2 * S * kv_dim + 3 * H * H_MLP + LAYER = QKV + ATTN + ATTN_OUT + MLP + + LAYER_ALL = LAYER * layer_nums + + return LAYER_ALL + EMBEDING + HEAD + LOSS + + def main(args): very_begining_time = time.time() enable_pytorch_expandable_segments() @@ -170,7 +239,6 @@ def main(args): }, torch.tensor(micro_num * [list(range(micro_bsz * S))], dtype=torch.int64), ] - print(batch) with initialize_llm_profile(profiling=True, start_time=launch_time()) as prof: for batch_count in range(train_state.batch_count, total_steps): s = time.time() @@ -202,9 +270,6 @@ def main(args): trainer_result = trainer.step() print(f"ont step use time: {time.time() -s :.3f} s", flush=True) prof.step() - import pdb - - pdb.set_trace() def run_loop( @@ -245,7 +310,8 @@ def run_loop( if debug: print( f"NO solu: pp:{pp} , sp:{sp} can't find micro_bsz/micro_num for" - f"world_size:{world_size}, seq_len:{S}, global bsz range: [{global_bsz_min}-{global_bsz_max}]!", + f"world_size:{world_size}, seq_len:{S}, global bsz range: \ +[{global_bsz_min}-{global_bsz_max}]!", flush=True, ) continue @@ -260,11 +326,14 @@ def run_loop( if wp > 1: if debug: print("NO solu: msp, fsp not support wp>1 !", flush=True) - continue # msp, fsp禁掉fsdp,我们目前还不支持 - # zp的搜索空间是被wp限制的,同时他不是按照8的倍数变化的,是,1,2,3, ...这样递增的 - zp_search_range = world_size // pp // sp // wp # 这里的sp对于msp和fsp来说是tp + continue + # Zp's search space is constrained by Wp, and it does not change in multiples of 8; + # instead, it increments as 1, 2, 3, .. + zp_search_range = world_size // pp // sp // wp # Here, sp for msp and fsp is tp. else: - zp_search_range = world_size // pp // wp # internlm实现的zp和deepspeed不一样,zp是在切wp的基础上再切的 + # The implementation of zp in InternEvo is different from DeepSpeed. + # Zp is further partitioned on the basis of wp + zp_search_range = world_size // pp // wp try: assert H % sp == 0, f"embed_dim:{H} must be divisible by sp: {sp}" @@ -278,9 +347,12 @@ def run_loop( for zp_i, zp in enumerate(range(1, zp_search_range + 1)): # set config print( - f"activation_ckpt: {activation_ckpt}, micro_num: {micro_num}, micro_bsz: {micro_bsz}, pp: {pp}, wp: {wp}, zp: {zp}, sp: {sp}, {str(algo_type)}", + f"activation_ckpt: {activation_ckpt}, micro_num: {micro_num}, \ +micro_bsz: {micro_bsz}, pp: {pp}, wp: {wp}, zp: {zp}, sp: {sp}, {str(algo_type)}", flush=True, ) + + gpc.destroy() gpc.config.model["checkpoint"] = activation_ckpt gpc.config.parallel["zero1"]["size"] = zp gpc.config.parallel["tensor"]["size"] = sp @@ -291,7 +363,6 @@ def run_loop( gpc.config.data["micro_num"] = micro_num gpc.config.data["micro_bsz"] = micro_bsz - gpc.destroy() reset_seed() launch( @@ -303,7 +374,7 @@ def run_loop( port=12345, backend="nccl", seed=0, - fake_mode=fake_mode, + fake_mode=True, ) args_sanity_check() assert hasattr(gpc, "config") and gpc.config is not None @@ -312,17 +383,56 @@ def run_loop( main(args) +def run_single(activation_ckpt=False, zp=1, sp=1, algo_type="fsp", pp=1, wp=1, global_bsz=4 * 1024 * 1024): + gpc.load_config(config=Config.from_file(args.config)) + gpc.set_fake_mode(True) + + gpc.config.model["checkpoint"] = activation_ckpt + gpc.config.parallel["zero1"]["size"] = zp + gpc.config.parallel["tensor"]["size"] = sp + gpc.config.parallel["tensor"]["mode"] = str(algo_type) + gpc.config.parallel["pipeline"]["size"] = pp + gpc.config.parallel["weight"]["size"] = wp + + gpc.config.data["global_bsz"] = global_bsz + # gpc.config.data["micro_num"] = micro_num + gpc.config.data["micro_bsz"] = 1 + + gpc.destroy() + reset_seed() + + launch( + config=gpc.config, + local_rank=0, + rank=0, + world_size=world_size, + host="127.0.0.1", + port=12345, + backend="nccl", + seed=0, + fake_mode=True, + ) + args_sanity_check() + assert hasattr(gpc, "config") and gpc.config is not None + + cluster_load_balance() + + # with FakeTensorMode(): + # main(args) + + if __name__ == "__main__": args = parse_args() hostname = socket.gethostname() world_size = args.world_size - fake_mode = "fake_mode" in os.environ - + # fake_mode = True # "fake_mode" in os.environ + os.environ["fake_mode"] = "1" # initialize distributed environment - print(f"fake_mode: {fake_mode}", flush=True) + print(f"fake_mode !!", flush=True) gloab_allocator.init_capcity = 80 * 1024**3 gloab_allocator.capcity = 80 * 1024**3 - run_loop(global_bsz=4096 * 1024, world_size=world_size, args=args) + # run_loop(global_bsz=4096 * 1024, world_size=world_size, args=args) + run_single() diff --git a/simulation_train_formulaic.py b/simulation_train_formulaic.py new file mode 100644 index 000000000..858bd05d5 --- /dev/null +++ b/simulation_train_formulaic.py @@ -0,0 +1,691 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import logging +import os +import socket + +import torch + +# from internlm.core.context.parallel_context import reset_global_context +from internlm.core.context import Config, ParallelMode +from internlm.core.context import global_context as gpc +from internlm.core.context.random import reset_seed +from internlm.core.parallel.shard import cluster_load_balance, partition_uniform +from internlm.initialize.launch import args_sanity_check, launch +from internlm.simulator.common import AlgoType, cal_block_p_elem, cal_model_p_elem + +# from internlm.simulator.formulas.context import ParallelMode, check_and_modify_parallel_config +# from internlm.simulator.formulas.context import global_context as gpc +from internlm.simulator.formulas.mem import ( + get_backward_mem_peak, + get_block_output_mm, + get_block_threshold, + get_head_input_mm, + get_head_output_mm, + get_memory_pool_mm, + get_norm_output_mm, + get_p2p_buffer_size, + get_rotary_emb_sincos_cache_mm, +) +from internlm.simulator.formulas.overlap import TransformerOverlapOneLayer +from internlm.simulator.profiler.perf_comm import ( + allreduce, + broadcast, + comm_matrix_dict, + init_cost_model, + p2p, +) +from internlm.simulator.tracker.comm_tracker import get_gloabl_comm_tracker +from internlm.simulator.tracker.mem_tracker import get_global_allocator + +# from internlm.simulator.elements.tensor import FakeTensor +from internlm.simulator.utils import ( + LinsSolutionNoZ3, + PPIter, + SPIter, + get_bsz_approximate, + get_bsz_strict, +) +from internlm.utils.common import get_args, parse_args + +# global llm logger +logger = logging.getLogger(__file__) + + +gloab_allocator = get_global_allocator() +global_comm_tracker = get_gloabl_comm_tracker() + + +def comm_dp_cost(dtype_size, algo, pp_blocks_elem, embedding_elem, zp) -> float: + """The communication overhead introduced by partitioning OS for parameter synchronization""" + # The communication overhead introduced by Wdp, where the input is the communication + # volume from a DP rank perspective. + # The parameters of MSP and FSP are partitioned by TP and need to be divided by sp_size. + # The parameters of ISP are partitioned by WP and need to be divided by wp_size. + if algo in [AlgoType.MSP, AlgoType.FSP]: + # gradient sync + wdp_latency = allreduce(dtype_size * (pp_blocks_elem + embedding_elem), ParallelMode.DATA) + + # parameter sync + # zp_latency = zp * broadcast(dtype_size * (pp_blocks_elem + embedding_elem) / zp, + # ParallelMode.ZERO1, comm_nums=zp) + zp_latency = zp * broadcast(dtype_size * pp_blocks_elem / zp, ParallelMode.ZERO1, comm_nums=zp) + + elif algo == AlgoType.ISP: + # gradient sync + wdp_block_latency = allreduce(dtype_size * pp_blocks_elem, ParallelMode.WEIGHT_DATA) + wdp_embedding_latency = allreduce(dtype_size * embedding_elem, ParallelMode.DATA) + wdp_latency = wdp_block_latency + wdp_embedding_latency + + # parameter sync + block_zp_latency = zp * broadcast(dtype_size * pp_blocks_elem / zp, ParallelMode.ZERO1, comm_nums=zp) + embedding_zp_latency = broadcast(dtype_size * embedding_elem, ParallelMode.DATA) + zp_latency = max(block_zp_latency, embedding_zp_latency) + + return zp_latency, wdp_latency + + +def pp_comm_overhead(dtype_size, seq_len, hidden_size, pp_size, sp_size, micro_bsz, micro_num): + """Calculate the latency of P2P communication in PP.""" + if pp_size == 1: + return 0 + + p2p_buffer_size = get_p2p_buffer_size(dtype_size, seq_len, sp_size, micro_bsz, hidden_size) + + warmup_p2p_num = min(pp_size, micro_num) + one_f_one_b_p2p_num = micro_num - 1 + cooldown_p2p_num = min(pp_size, micro_num) + + p2p_count = warmup_p2p_num + one_f_one_b_p2p_num + cooldown_p2p_num + p2p_latency = p2p_count * p2p(p2p_buffer_size, ParallelMode.PIPELINE, comm_nums=p2p_count) + return p2p_latency + + +def cal_cost( + pp, + sp, + wp, + zp, + micro_bsz, + micro_num, + algo_type, + world_size, + activation_ckpt, + pp_num_layers=None, + max_pp_num_layers=None, + debug=True, + overlap_wdp=True, +) -> LinsSolutionNoZ3: + if pp_num_layers is None or max_pp_num_layers is None: + max_pp_num_layers = 0 + parts = partition_uniform(gpc.config.model.num_layers, pipeline_parallel_size=pp, num_chunks=1) + for part in parts: + start, end = part[0] + num_layer = end - start + if num_layer > max_pp_num_layers: + max_pp_num_layers = num_layer + # max_pp_rank = pp_rank + + pp_num_layers = max_pp_num_layers + + assert max_pp_num_layers > 0 + + # Anti-fragmentation penalty + # if algo_type in [AlgoType.MSP, AlgoType.FSP]: + # if sp * zp * wp * pp < (gpc.config.model_size / 1.5): + # if debug: + # print(f"NO solu: skip sp*zp*wp*pp< 4 solu!\n", flush=True) + # return None + # else: + # if zp * wp * pp < (gpc.config.model_size / 1.5): + # if debug: + # print(f"NO solu: skip zp*wp*pp< 4 solu!\n", flush=True) + # return None + + now_global_bsz = micro_bsz * micro_num * gpc.config.data.seq_len * gpc.get_world_size(ParallelMode.DATA) + + dp = gpc.get_world_size(ParallelMode.DATA) + one_layer_elem = cal_block_p_elem( + gpc.config.model.hidden_size, + q_head=gpc.config.model.num_attention_heads, + kv_head=gpc.config.model.num_kv_attention_heads, + multiple_of=gpc.config.model.multiple_of, + mlp_ratio=gpc.config.model.mlp_ratio, + ) + + print(f"pp_num_layers: {pp_num_layers}, one_layer_elem: {one_layer_elem}", flush=True) + pp_blocks_elem = pp_num_layers * one_layer_elem + embedding_dp_shared_range = 1 if dp <= 1 else 2 + head_num = 1 if pp > 1 else 2 + embedding_elem = gpc.config.model.vocab_size * gpc.config.model.hidden_size + + if algo_type in [AlgoType.MSP, AlgoType.FSP]: + embedding_elem_parallel = head_num * embedding_elem / wp / sp + block_elem_parallel = pp_blocks_elem / wp / sp + total_p_element = block_elem_parallel + embedding_elem_parallel + total_os_element = total_p_element / zp + os_mm_cost = gpc.config.dtype_size * gpc.config.fp32_ratio * 3 * total_os_element # zp显存消耗 + p_g_mm_cost = 2 * gpc.config.dtype_size * total_p_element # wp显存消耗 + else: + embedding_elem_parallel = head_num * embedding_elem / sp + block_elem_parallel = pp_blocks_elem / wp + total_p_element = block_elem_parallel + embedding_elem_parallel + total_os_element = ( + block_elem_parallel / zp + embedding_elem_parallel / embedding_dp_shared_range + ) # embeding不会被zp切 + os_mm_cost = gpc.config.dtype_size * gpc.config.fp32_ratio * 3 * total_os_element # zp显存消耗 + p_g_mm_cost = 2 * gpc.config.dtype_size * total_p_element # wp显存消耗 + + zp_comm_cost, wdp_comm_cost = comm_dp_cost( + dtype_size=gpc.config.dtype_size, + algo=algo_type, + pp_blocks_elem=block_elem_parallel, + embedding_elem=embedding_elem_parallel, + zp=zp, + ) # 计算dp相关的通信开销 + + # zp_comm_cost=0 + if overlap_wdp: + wdp_comm_cost = 0 + + blocks_activation = get_block_threshold( + algo=algo_type, + micro_batch_size=micro_bsz, + layer_num=gpc.config.model.num_layers, # 显存阈值根据pp0来计算 + sp_size=sp, + activation_ckpt=activation_ckpt, + hidden_dim=gpc.config.model.hidden_size, + sequence_length=gpc.config.data.seq_len, # 这里一定要传入没切过的seqlen + use_fa=gpc.config.model.use_flash_attn, + head_num=gpc.config.model.num_attention_heads, + dtype_size=gpc.config.dtype_size // 2, # dtype_size要除以2,因为激活值计算公式是默认按照fp16类型来的 + ) # isp激活的话,不需要除以wp,因为需要allgather + + if algo_type == AlgoType.ISP: + isp_mem_pool = get_memory_pool_mm( + gpc.config.model.mlp_ratio, gpc.config.model.hidden_size, gpc.config.dtype_size + ) + else: + isp_mem_pool = 0 + + pp_p2p_buffer = ( + get_p2p_buffer_size(gpc.config.dtype_size, gpc.config.data.seq_len, sp, micro_bsz, gpc.config.model.hidden_size) + if pp > 1 + else 0 + ) + + # 下面这些激活的计算不受到重计算的影响 + norm_activation = get_norm_output_mm( + micro_bsz, gpc.config.data.seq_len, gpc.config.model.hidden_size, sp=sp, dtype_size=gpc.config.dtype_size + ) + + head_input_activation = get_head_input_mm( + micro_bsz, + gpc.config.data.seq_len, + gpc.config.model.hidden_size, + dtype_size=gpc.config.dtype_size, + tp_size=sp, + algo=algo_type, + ) + head_output_activation = get_head_output_mm( + micro_bsz, gpc.config.data.seq_len, gpc.config.model.vocab_size, dtype_size=gpc.config.dtype_size + ) + rotary_emb_sincos_cache_mm = get_rotary_emb_sincos_cache_mm( + seq_len=gpc.config.data.seq_len, + pp_size=pp, + hidden_dim=gpc.config.model.hidden_size, + head_nums=gpc.config.model.num_attention_heads, + layer_nums=gpc.config.model.num_layers, + dtype_size=gpc.config.dtype_size, + ) + # 对于pp0,占用的激活仍然是 layer_num 份 + block_output_activation = ( + gpc.config.model.num_layers + * get_block_output_mm( + micro_bsz, gpc.config.data.seq_len, gpc.config.model.hidden_size, sp=sp, dtype_size=gpc.config.dtype_size + ) + ) * activation_ckpt # 只有开启重计算才需要额外加上这部分block激活的输出 + backward_mem_peak = get_backward_mem_peak( + seq_len=gpc.config.data.seq_len, + micro_bsz=micro_bsz, + dtype_size=gpc.config.dtype_size, + vocab_size=gpc.config.model.vocab_size, + tp_size=sp, + hidden_size=gpc.config.model.hidden_size, + ) + activation = ( + blocks_activation + + norm_activation + + head_input_activation + + head_output_activation + + block_output_activation + + backward_mem_peak + ) + + # 总显存开销 + mem_cost1 = ( + p_g_mm_cost + os_mm_cost + activation + isp_mem_pool + rotary_emb_sincos_cache_mm + pp_p2p_buffer + ) # fwd_bwd显存峰值(需要加上Grad吗?) + mem_cost2 = p_g_mm_cost + os_mm_cost / 3 * 5 # adamw的显存峰值 + mem_cost = max(mem_cost1, mem_cost2) + if mem_cost > gpc.config.mem_threshold: + # A[pp_i][sp_i][wp_i][zp_i] = _100GB + # C[pp_i][sp_i][wp_i][zp_i] = 0 + if debug: + print( + f"NO solu: mem_cost: {mem_cost/1024**3:.2f} GB > mem_threshold: \ +{gpc.config.mem_threshold/1024**3:.2f} GB ---- p_g_mm_cost: {p_g_mm_cost/1024**3:.2f} GB, \ +os_mm_cost: {os_mm_cost/1024**3:.2f} GB, activation: {activation/1024**3:.2f} GB\n", + flush=True, + ) + return None + # else: + # A[pp_i][sp_i][wp_i][zp_i] = mem_cost + + try: + (wp_comm_cost, sp_comm_cost, comp_wp, comp_attn,) = TransformerOverlapOneLayer( + micro_bsz=micro_bsz, + sp_size=sp, + pp_size=pp, + world_size=world_size, + ckpt=activation_ckpt, + seq_len=gpc.config.data.seq_len, # 这里需要传原始的seqlen,因为这个类里面还会切sp + vocab_size=gpc.config.model.vocab_size, + dtype_size=gpc.config.dtype_size, + hidden_dim=gpc.config.model.hidden_size, + num_head=gpc.config.model.num_attention_heads, + num_kv_head=gpc.config.model.num_kv_attention_heads, + mlp_ratio=gpc.config.model.mlp_ratio, + multiple_of=gpc.config.model.multiple_of, + )._get_overlap(algo_type) + except KeyError as e: + print(f"not found FA key: {e}", flush=True) + return None + + if wp > 1: + overlap_latency = min(comp_wp, wp_comm_cost) * gpc.config.wp_penalty_coefficient + max(comp_wp, wp_comm_cost) + else: + overlap_latency = comp_wp + + def overlaped_fwd_bwd_cost(): + return overlap_latency + sp_comm_cost + comp_attn + + if pp == 1: + fwd_bwd_cost = gpc.config.model.num_layers * overlaped_fwd_bwd_cost() + grad_acc = micro_num + all_fwd_bwd_cost = grad_acc * fwd_bwd_cost # 算上梯度累积的fwdbwd开销 + pp_comm_cost = 0 + else: + # 注意这里要使用 max_pp_num_layers 来计算pp的延迟,而不是pp0的 num layer + fwd_bwd_cost = max_pp_num_layers * overlaped_fwd_bwd_cost() # 1个pp micro step的fwd_bwd开销 + all_fwd_bwd_cost = micro_num * fwd_bwd_cost # pp的idea开销(不含bubble) + pp_p2p_cost = pp_comm_overhead( + dtype_size=gpc.config.dtype_size, + seq_len=gpc.config.data.seq_len, + hidden_size=gpc.config.model.hidden_size, + pp_size=pp, + sp_size=sp, + micro_bsz=micro_bsz, + micro_num=micro_num, + ) # pp的p2p延迟 + pp_bubble_cost = (pp - 1) * fwd_bwd_cost # pp的bubble开销 + pp_comm_cost = pp_p2p_cost + pp_bubble_cost # pp总的额外开销 + + total_latency = all_fwd_bwd_cost + pp_comm_cost + wdp_comm_cost + zp_comm_cost # fwd_bwd_cost 乘上梯度累加 + + # 计算tgs,为了方便取max这里乘了一个-1 + tgs = (-1 * now_global_bsz) / (world_size * total_latency) + + solu = LinsSolutionNoZ3( + pp=pp, + sp=sp, + wp=wp, + zp=zp, + seq_len=gpc.config.data.seq_len, + micro_bsz=micro_bsz, + micro_num=micro_num, + algo_type=algo_type, + pp_comm_cost=pp_comm_cost, + activation=activation, + zp_comm_cost=zp_comm_cost, + wp_comm_cost=wp_comm_cost, + sp_comm_cost=sp_comm_cost, + os_mm_cost=os_mm_cost, + p_g_mm_cost=p_g_mm_cost, + fwd_bwd_cost=fwd_bwd_cost, + mem_cost=mem_cost, + comp_wp=comp_wp, + comp_attn=comp_attn, + world_size=world_size, + activation_ckpt=activation_ckpt, + tgs=-1 * tgs, + mem_pool_mm=isp_mem_pool, + norm_activation=norm_activation, + head_input_activation=head_input_activation, + head_output_activation=head_output_activation, + block_output_activation=block_output_activation, + wdp_comm_cost=wdp_comm_cost, + all_fwd_bwd_cost=all_fwd_bwd_cost, + g_bsz=now_global_bsz, + pp_p2p_buffer=pp_p2p_buffer, + rotary_emb_sincos_cache_mm=rotary_emb_sincos_cache_mm, + modelsize=gpc.config.param_elements / 10**9, + backward_mem_peak=backward_mem_peak, + blocks_activation=blocks_activation, + overlap_latency=overlap_latency, + total_latency=total_latency, + ) + + gpc.destroy() # 销毁device mesh + return solu + + +def run_loop( + global_bsz, + world_size, + args, + use_fixed_micro_bsz=False, + use_strict_bsz=True, + global_bsz_max=1, + global_bsz_min=1, + debug=True, +): + gpc.load_config(config=Config.from_file(args.config)) + gpc.set_fake_mode(True) + + min_comm_cost, msp_min_cost, fsp_min_cost, isp_min_cost = ( + float("inf"), + float("inf"), + float("inf"), + float("inf"), + ) + min_cost_solution, msp_min_solu, fsp_min_solu, isp_min_solu = None, None, None, None + + L = gpc.config.model["num_layers"] + KV_H = gpc.config.model["num_kv_attention_heads"] + S = gpc.config.data["seq_len"] + H = gpc.config.model["hidden_size"] + MICRO_BSZ = gpc.config.data["micro_bsz"] + MICRO_NUM = gpc.config.data["micro_num"] + + pp_search_range = PPIter(world_size, L) + sp_search_range = SPIter(world_size, KV_H) + wp_search_ranges = SPIter(world_size, world_size) + # zp_search_ranges_max = SPIter(world_size, world_size) + solutions_list = [] + algo_list = [AlgoType.ISP, AlgoType.MSP, AlgoType.FSP] + + gpc.config["param_elements"] = cal_model_p_elem( + h=gpc.config.model.hidden_size, + q_head=gpc.config.model.num_attention_heads, + kv_head=gpc.config.model.num_kv_attention_heads, + l=gpc.config.model.num_layers, + vocab_size=gpc.config.model.vocab_size, + mlp_ratio=gpc.config.model.mlp_ratio, + multiple_of=gpc.config.model.multiple_of, + ) + print(f"param_elements: {gpc.config['param_elements']}", flush=TimeoutError) + + for _, pp in enumerate(pp_search_range): + for _, sp in enumerate(sp_search_range): + if not use_fixed_micro_bsz: + if use_strict_bsz: + bs_bns = get_bsz_strict(global_bsz, world_size, pp, sp, S) + else: + bs_bns = get_bsz_approximate(global_bsz_max, global_bsz_min, world_size, pp, sp, S) + + if bs_bns is None or len(bs_bns) == 0: + if debug: + print( + f"NO solu: pp:{pp} , sp:{sp} can't find micro_bsz/micro_num for" + f"world_size:{world_size}, seq_len:{S}, \ +global bsz range: [{global_bsz_min}-{global_bsz_max}]!", + flush=True, + ) + continue + else: + bs_bns = [(MICRO_BSZ, MICRO_NUM)] + + for micro_bsz, micro_num in bs_bns: + for algo_type in algo_list: + for activation_ckpt in [0, 1]: + for _, wp in enumerate(wp_search_ranges): + if algo_type in [AlgoType.MSP, AlgoType.FSP]: + if wp > 1: + if debug: + print("NO solu: msp, fsp not support wp>1 !", flush=True) + continue + # Zp's search space is constrained by Wp, and it does not change in + # multiples of 8; instead, it increments as 1, 2, 3, .. + # Here, sp for msp and fsp is tp. + zp_search_range = world_size // pp // sp // wp + else: + # The implementation of zp in InternEvo is different from DeepSpeed. + # Zp is further partitioned on the basis of wp + zp_search_range = world_size // pp // wp + try: + assert H % sp == 0, f"embed_dim:{H} must be divisible by sp: {sp}" + assert KV_H % sp == 0, f"num_heads: {KV_H} must be divisible by sp: {sp}" + assert KV_H >= sp, f"num_heads: {KV_H} must bigger then sp: {sp}" + if algo_type != AlgoType.ISP: + assert ( + wp * sp * pp * zp_search_range <= world_size + ), f"{algo_type} not support wp and sp share same pg group." + except AssertionError as e: + if debug: + print(f"NO solu: head assert {e}", flush=True) + continue + + for _, zp in enumerate(range(1, zp_search_range + 1)): + # set config + print( + f"activation_ckpt: {activation_ckpt}, micro_num: {micro_num}, \ +micro_bsz: {micro_bsz}, pp: {pp}, wp: {wp}, zp: {zp}, sp: {sp}, {str(algo_type)}", + flush=True, + ) + + # reset_global_context() + + gpc.destroy() + gpc.config.model["checkpoint"] = activation_ckpt + gpc.config.parallel["zero1"]["size"] = zp + gpc.config.parallel["tensor"]["size"] = sp + gpc.config.parallel["tensor"]["mode"] = str(algo_type) + gpc.config.parallel["pipeline"]["size"] = pp + gpc.config.parallel["weight"]["size"] = wp + gpc.config.model_size = 7 + + gpc.config.data["micro_num"] = micro_num + gpc.config.data["micro_bsz"] = micro_bsz + + gpc.config["mem_threshold"] = 80 * 1024**3 + gpc.config["wp_penalty_coefficient"] = 0.1 + gpc.config["dtype_size"] = 2 + gpc.config["fp32_ratio"] = 2 + + reset_seed() + + try: + launch( + config=gpc.config, + local_rank=0, + rank=0, + world_size=world_size, + host="127.0.0.1", + port=12345, + backend="nccl", + seed=0, + fake_mode=True, + ) + args_sanity_check() + assert hasattr(gpc, "config") and gpc.config is not None + except AssertionError as e: + if debug: + print(f"NO solu: build gpc failed: {e}\n", flush=True) + continue + except ZeroDivisionError as e: + if debug: + print(f"NO solu: build gpc failed: {e}\n", flush=True) + continue + + solu = cal_cost( + pp=pp, + sp=sp, + wp=wp, + zp=zp, + micro_bsz=micro_bsz, + micro_num=micro_num, + algo_type=algo_type, + world_size=world_size, + activation_ckpt=activation_ckpt, + ) + if solu is None: + continue + cost = solu.tgs + solutions_list.append(solu) + if cost < min_comm_cost: + min_comm_cost = cost + min_cost_solution = solu + + print(f"solu: {solu}", flush=True) + + if algo_type == AlgoType.MSP: + if cost < msp_min_cost: + msp_min_cost = cost + msp_min_solu = solu + elif algo_type == AlgoType.FSP: + if cost < fsp_min_cost: + fsp_min_cost = cost + fsp_min_solu = solu + elif algo_type == AlgoType.ISP: + if cost < isp_min_cost: + isp_min_cost = cost + isp_min_solu = solu + + return solutions_list, min_comm_cost, min_cost_solution, msp_min_solu, fsp_min_solu, isp_min_solu + + +def run_warrper(global_bsz, world_size, args): + solutions_list, min_comm_cost, min_cost_solution, msp_min_solu, fsp_min_solu, isp_min_solu = run_loop( + global_bsz=global_bsz, world_size=world_size, args=args + ) + + if min_cost_solution is not None: + solutions_list = sorted(solutions_list, key=lambda solu: solu.tgs, reverse=True) + print("--------------------- END -----------------------", flush=True) + # print("Max TGS:", min_comm_cost * -1) + for i, solu in enumerate(solutions_list): + if i > 5: + break + print(f"Top{i} Solution:", solu, flush=True) + + print("--------------------- MSP best solution -----------------------", flush=True) + if msp_min_solu is not None: + print(f"self.msp_min_solu : {msp_min_solu}") + print("--------------------- FSP best solution -----------------------", flush=True) + if fsp_min_solu is not None: + print(f"self.fsp_min_solu : {fsp_min_solu}") + print("--------------------- ISP best solution -----------------------", flush=True) + if isp_min_solu is not None: + print(f"self.isp_min_solu : {isp_min_solu}") + + final_res = { + "algo_type": min_cost_solution.algo_type, + "seq_len": min_cost_solution.seq_len, + "micro_num": min_cost_solution.micro_num, + "micro_bsz": min_cost_solution.micro_bsz, + "pp_size": min_cost_solution.pp, + "tp_size": min_cost_solution.sp, + "wp_size": min_cost_solution.wp_size, + "zp_size": min_cost_solution.zp_size, + "activation_ckpt": bool(min_cost_solution.activation_ckpt), + } + print(final_res) + else: + print("No solution found") + + +def run_single(global_bsz=4 * 1024 * 1024): + gpc.load_config(config=Config.from_file(args.config)) + gpc.set_fake_mode(True) + print(f"gpc.config.parallel: {gpc.config.parallel}") + + gpc.config.data["global_bsz"] = global_bsz + gpc.config.model_size = args.model_size + gpc.config["mem_threshold"] = 80 * 1024**3 + gpc.config["wp_penalty_coefficient"] = 0.1 + gpc.config["dtype_size"] = 2 + gpc.config["fp32_ratio"] = 2 + gpc.config["param_elements"] = cal_model_p_elem( + h=gpc.config.model.hidden_size, + q_head=gpc.config.model.num_attention_heads, + kv_head=gpc.config.model.num_kv_attention_heads, + l=gpc.config.model.num_layers, + vocab_size=gpc.config.model.vocab_size, + mlp_ratio=gpc.config.model.mlp_ratio, + multiple_of=gpc.config.model.multiple_of, + ) + + reset_seed() + + launch( + config=gpc.config, + local_rank=0, + rank=0, + world_size=world_size, + host="127.0.0.1", + port=12345, + backend="nccl", + seed=0, + fake_mode=True, + ) + args_sanity_check() + assert hasattr(gpc, "config") and gpc.config is not None + + # cluster_load_balance() + + solu = cal_cost( + pp=gpc.config.parallel["pipeline"]["size"], + sp=gpc.config.parallel["tensor"]["size"], + wp=gpc.config.parallel["weight"]["size"], + zp=gpc.config.parallel["zero1"]["size"], + micro_bsz=gpc.config.data["micro_bsz"], + micro_num=gpc.config.data["micro_num"], + algo_type=gpc.config.parallel["tensor"]["mode"], + world_size=world_size, + activation_ckpt=gpc.config.model["checkpoint"], + ) + + assert solu is not None + + print(f"solu: {solu}") + + # /mnt/inspurfs/wangguoteng.p/comm_matrix + name = f"internlm2_{gpc.config.model_size}B.pt" + pt = os.path.join(get_args().draw_heatmap_path, name) + + new_dict = {} + for name, mat in comm_matrix_dict.items(): + print(f"name: {name}, mat: {mat}", flush=True) + new_dict[str(name)] = mat + + with open(pt, "wb") as f: + torch.save(new_dict, f=f) + + +if __name__ == "__main__": + args = parse_args() + hostname = socket.gethostname() + world_size = args.world_size + + init_cost_model(get_args().pre_profiling_data_path) + + os.environ["fake_mode"] = "1" + gloab_allocator.init_capcity = 80 * 1024**3 + gloab_allocator.capcity = 80 * 1024**3 + + if get_args().run_all_solu: + run_warrper(4096 * 1024, world_size, args) + else: + run_single(get_args().global_batch_size) From bf8b40e60ac4adf698498c1ace617dd0b93561a9 Mon Sep 17 00:00:00 2001 From: wangguoteng <877825076@qq.com> Date: Thu, 12 Sep 2024 21:06:19 +0800 Subject: [PATCH 3/3] fix and update readme --- gen_profiler_data.py | 3 +- internlm/core/context/parallel_context.py | 112 ++----------- .../core/context/process_group_initializer.py | 4 + .../process_group_initializer_simplified.py | 2 - internlm/initialize/launch.py | 5 +- internlm/model/ops/linear.py | 4 - internlm/simulator/README.md | 61 ++++++++ .../profiler/benchmark/multi_head_attn.py | 147 ------------------ internlm/simulator/profiler/perf_comm.py | 50 +++--- internlm/simulator/profiler/profiler.py | 17 -- simulation_train_formulaic.py | 41 +++-- 11 files changed, 140 insertions(+), 306 deletions(-) create mode 100644 internlm/simulator/README.md diff --git a/gen_profiler_data.py b/gen_profiler_data.py index b22ed5b51..37e3b78a3 100644 --- a/gen_profiler_data.py +++ b/gen_profiler_data.py @@ -1,5 +1,4 @@ - from internlm.simulator.profiler.perf_comm import gen_perf if __name__ == "__main__": - gen_perf() \ No newline at end of file + gen_perf() diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 8fa4b2fe2..f44449dcc 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -18,15 +18,13 @@ import torch.distributed as dist from internlm.accelerator import get_accelerator -from internlm.core.context.process_group_initializer_simplified import Initializer, ParallelMeta -from internlm.utils.common import SingletonMeta +from internlm.utils.common import SingletonMeta, get_args from internlm.utils.logger import get_logger from internlm.utils.timeout import LLM_NCCL_TIMEOUT from . import process_group_initializer as pgroup_initializer -from .process_group_initializer_simplified import ParallelMode +from .process_group_initializer import ParallelMode from .random import add_seed, get_seeds, set_mode -from internlm.utils.common import get_args IS_REPLICA_ZERO_PARALLEL = "is_replica_zero_parallel" # for isp, with optimizer split in dp group @@ -422,20 +420,6 @@ def init_global_dist( use_cpu (bool): whether to set up cpu process group. """ - # find cluster info - if "clusters" not in self.config: - nv_info = { - "rank_range": [0, 8], - "peak_tflops": 320, - "capacity": 80 * 1024**3, - "intra_bw": 150, - "inter_bw": 100, - } - self.set_cluster_info("nv_cluster", nv_info) - else: - for cluster in self.config.clusters: - self.clusters.append(ClusterInfo(**cluster)) - # initialize the default process group if not fake_mode: init_method = f"tcp://[{host}]:{port}" @@ -576,8 +560,7 @@ def init_parallel_groups(self, fake_mode: bool = False): self._set_parallel_size_from_config(parallel_config, "tensor", "tensor_parallel_size") self._set_parallel_size_from_config(parallel_config, "pipeline", "pipeline_parallel_size") self._set_parallel_size_from_config(parallel_config, "zero1", "zero1_parallel_size") - - + if get_args().use_simplified_gp_init: self._init_use_simplified_pg(rank, world_size, parallel_config) else: @@ -592,10 +575,7 @@ def _init_pg(self, rank, world_size, parallel_config): 1, self.world_size // self.pipeline_parallel_size // self.weight_parallel_size ) - if ( - isinstance(parallel_config["tensor"], dict) - and parallel_config["tensor"]["mode"] == "isp" - ): + if isinstance(parallel_config["tensor"], dict) and parallel_config["tensor"]["mode"] == "isp": if self.zero1_parallel_size == -1: self.zero1_parallel_size = self.weight_data_parallel_size self.zero1_parallel_size = max(1, self.zero1_parallel_size) @@ -622,8 +602,7 @@ def _init_pg(self, rank, world_size, parallel_config): if "sequence_parallel" not in parallel_config: parallel_config._add_item("sequence_parallel", True) if isinstance(parallel_config["tensor"], int) or ( - isinstance(parallel_config["tensor"], dict) - and parallel_config["tensor"]["mode"] == "mtp" + isinstance(parallel_config["tensor"], dict) and parallel_config["tensor"]["mode"] == "mtp" ): parallel_config["sequence_parallel"] = False @@ -665,10 +644,7 @@ def _init_pg(self, rank, world_size, parallel_config): initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args)) initializers.append(pgroup_initializer.Initializer_Data(*initializer_args)) initializers.append(pgroup_initializer.Initializer_ISP_Data(*initializer_args)) - if ( - isinstance(parallel_config["tensor"], dict) - and parallel_config["tensor"]["mode"] == TensorParallelMode.isp.name - ): + if isinstance(parallel_config["tensor"], dict) and parallel_config["tensor"]["mode"] == "isp": initializers.append(pgroup_initializer.Initializer_Zero1_ISP(*initializer_args)) else: initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args)) @@ -686,7 +662,7 @@ def _init_pg(self, rank, world_size, parallel_config): self._register_dist(*args) else: self._register_dist(*parallel_setting) - + def _init_use_simplified_pg(self, rank, world_size, parallel_config): try: self.tensor_mode = parallel_config["tensor"]["mode"] @@ -723,6 +699,11 @@ def _init_use_simplified_pg(self, rank, world_size, parallel_config): self.check_sanity() + from internlm.core.context.process_group_initializer_simplified import ( + Initializer, + ParallelMeta, + ) + parallel_info = { "tp": ParallelMeta(self.tensor_parallel_size, ParallelMode.TENSOR), "wp": ParallelMeta(self.weight_parallel_size, ParallelMode.WEIGHT), @@ -861,14 +842,14 @@ def check_pg_is_intra(self, parallel_mode: ParallelMode): return (max_rank - min_rank) <= 7 def same_group_in_one_node(self, parallel_mode: ParallelMode): - """获得一个节点内有多少个相同类型的PG, 在跨节点通信时会存在带宽竞争 - 这里返回的相同PG的数量会乘上每个rank的通信数据量大小 + """Get the number of the same type of PG within a node. There will be bandwidth competition during cross-node communication. + The number of the same PG returned here will be multiplied by the communication data size of each rank. Args: parallel_mode (ParallelMode): Returns: - int: 一个节点内相同类型的PG的数量 + int: The number of the same type of PG within a node. """ pg_group_ranks = self.get_ranks_in_group(parallel_mode) pg_group_ranks = sorted(pg_group_ranks) @@ -881,68 +862,5 @@ def same_group_in_one_node(self, parallel_mode: ParallelMode): else: return stride - # def set_cluster_info(self, name: str, info: dict): - # self.clusters[name] = ClusterInfo(**info) - - def get_cluster_info(self, name: str): - return self.clusters[name] - - def get_cluster_name_from_ip(self): - """ - node_ip_list = [ - 'metax-c500-1', - 'metax-c500-2', - 'nvidia-node-1', - 'nvidia-node-2', - ] - """ - hostname = socket.gethostname() - cluster_name = hostname.split("-")[0] - return cluster_name - - def sort_rank_based_on_ip_and_capacity(self): - Capacity = [] - - def sort_rank(x, y): - x_name = self.get_cluster_name_from_ip(x) - y_name = self.get_cluster_name_from_ip(y) - if x_name == y_name: - return x_name > y_name - else: - x_c = self.clusters[x_name]["capacity"] - y_c = self.clusters[y_name]["capacity"] - return x_c > y_c - - for cluster_name, cluster_info in self.clusters.items(): - peak_tflops.append(cluster_info["peak_tflops"]) - # Alpha.append(cluster_info.rank_range[-1] - cluster_info.rank_range[-1] + 1) - Capacity.append(cluster_info["capacity"]) - - def switch_topology_aware_rank_scheduling(): - """ - Switch topology-aware rank scheduling can optimize the performance of small-scale - collective communications. Currently only supported in Alibaba Cloud. - """ - - local_rank = int(os.environ["LOCAL_RANK"]) - cluster_name = get_cluster_name_from_ip() - - try: - if cluster_name == "Ali": - pass - else: - rank = int(os.environ["MLP_WORKER_RACK_RANK_INDEX"]) * 8 + local_rank - except Exception as e: - logger.error( - f"The switch topology awareness error is reported, the reason is: {e}", - "but don’t worry, this error will not affect normal training.", - "If you train on Alibaba or Volcano Cloud, please contact wangguoteng or lijiaxing", - ) - else: - # If there is no any error, hack torch rank. - os.environ["RANK"] = str(rank) - if local_rank == 0: - logger.info("Successfully bound node switch affinity!") - global_context = ParallelContext() diff --git a/internlm/core/context/process_group_initializer.py b/internlm/core/context/process_group_initializer.py index 5519c7d84..1d27bf93d 100644 --- a/internlm/core/context/process_group_initializer.py +++ b/internlm/core/context/process_group_initializer.py @@ -62,6 +62,10 @@ class ParallelMode(Enum): # grouped query attention GQA = "gqa" + + INTRA_DP_SZIE = "intra_dp" + + INTER_DP_SZIE = "inter_dp" class ProcessGroupInitializer(ABC): diff --git a/internlm/core/context/process_group_initializer_simplified.py b/internlm/core/context/process_group_initializer_simplified.py index c1423a5ae..257b4f663 100644 --- a/internlm/core/context/process_group_initializer_simplified.py +++ b/internlm/core/context/process_group_initializer_simplified.py @@ -2,13 +2,11 @@ # -*- encoding: utf-8 -*- from copy import deepcopy -from enum import Enum import torch import torch.distributed as dist from internlm.utils.timeout import LLM_NCCL_TIMEOUT -from internlm.core.context.process_group_initializer import ParallelMode class ParallelMeta: def __init__(self, parallel_size, mode) -> None: diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 9a719cf8f..e7a35dc33 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -12,7 +12,7 @@ from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import Config from internlm.core.context import global_context as gpc -from internlm.core.context.process_group_initializer_simplified import ParallelMode +from internlm.core.context.process_group_initializer import ParallelMode from internlm.utils.common import get_master_node from internlm.utils.gputest import warmup_process_group from internlm.utils.logger import get_logger @@ -86,7 +86,8 @@ def add_simulator_arguments(parser): group.add_argument( "--pre_profiling_data_path", type=str, help="The path to pre-profiled performance data on the target cluster." ) - group.add_argument("--use_simplified_gp_init", action="store_true", default=False) + group.add_argument("--use_simplified_gp_init", action="store_true", default=True) + return parser diff --git a/internlm/model/ops/linear.py b/internlm/model/ops/linear.py index da3eda5b4..6ea9ac62f 100644 --- a/internlm/model/ops/linear.py +++ b/internlm/model/ops/linear.py @@ -14,10 +14,6 @@ from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import global_context as gpc -from internlm.simulator.ops.linear import ( - _fake_linear_bwdward_op, - _fake_linear_forward_op, -) try: from fused_dense_lib import linear_bias_wgrad as _flash_linear_backward_op diff --git a/internlm/simulator/README.md b/internlm/simulator/README.md new file mode 100644 index 000000000..f9f0f964f --- /dev/null +++ b/internlm/simulator/README.md @@ -0,0 +1,61 @@ +# InternLM Simulator + + +## 1. Introduction +The solver mainly consists of two components: +1. `profiling`: Collects the time consumption of each stage during the model training process in advance and saves it as data files and image files. +2. `simulation`: Simulates the model training process based on the collected data files and outputs the time consumption of each stage during the training process. + +## 2. Usage + +### 2.1 Generate profiling data + +There are two types of profiling data: +1. '`linear`' profiling data, include: [`LINEAR`] +2. '`Communication`' profiling data, include: [`ALL2ALL`, `ALLREDUCE`, `REDUCESCATTER`, `ALLGATHER`, `BROADCAST`] + + +Note: +1. It is recommended to use more than 64 GPUs for data collection to ensure more accurate communication data. +2. `Flash Attention` information is not collected in advance but is collected on the fly during the simulation and stored in the cache. This is because there are many variables that affect the performance of flash attention, and collecting in advance cannot cover all variables. + +```python +# generate profiling data +torchrun --nproc-per-node=8 gen_profiler_data.py + +# the profiling data will be saved in the following path +./prof_data +├── data.pt +└── pics + ├── cal + │ └── linear.jpg + └── comm + ├── all2all_intra_2_inter_1.jpg + ├── all2all_intra_4_inter_1.jpg + ├── all_gather_intra_2_inter_1.jpg + ├── all_gather_intra_4_inter_1.jpg + ├── all_reduce_intra_2_inter_1.jpg + ├── all_reduce_intra_4_inter_1.jpg + ├── broadcast_intra_2_inter_1.jpg + ├── broadcast_intra_4_inter_1.jpg + ├── reduce_scatter_intra_2_inter_1.jpg + └── reduce_scatter_intra_4_inter_1.jpg + +``` + +### 2.2 Run simulation +Running the solver does not require a GPU (although some packages may require a GPU environment, if you encounter any issues, please raise an issue). Currently, the solver only supports the formulaic solving method using simulation_train_formulaic.py, which requires a config file and profiling data file as follows: + +```bash + +python simulation_train_formulaic.py --pre_profiling_data_path ./prof_data/data.pt --config configs/7B_internlm2.py --run_all_solu --model_size 7 --world_size 128 --global_batch_size 4194304 + +# explanation: +python simulation_train_formulaic.py + --pre_profiling_data_path ./prof_data/data.pt # profiling data file + --config configs/7B_internlm2.py # model configuration file + --run_all_solu # whether to iterate and solve all possible solutions + --model_size 7 # means 7B model, if you want to run 70B model, you can set model_size to 70 + --world_size 128 # solving range is 128 cards + --global_batch_size 4194304 # global batch size, 4M +``` diff --git a/internlm/simulator/profiler/benchmark/multi_head_attn.py b/internlm/simulator/profiler/benchmark/multi_head_attn.py index f4cf5d737..bec18e1c4 100644 --- a/internlm/simulator/profiler/benchmark/multi_head_attn.py +++ b/internlm/simulator/profiler/benchmark/multi_head_attn.py @@ -60,150 +60,3 @@ def run(): t_bwds += t_bwd return t_fwds / trials, t_bwds / trials - - -# from .base_benchmark import UnitBench -# import math - -# import torch -# from einops import rearrange -# from torch import nn - -# from internlm.model.registry import benchmark_initializer -# from internlm.simulator.common import TP_SIZE_RANGE, K, get_local_rank -# from internlm.utils.common import get_current_device - -# try: -# from flash_attn.flash_attn_interface import ( -# flash_attn_qkvpacked_func, -# flash_attn_varlen_func, -# ) -# from flash_attn.modules.mha import FlashSelfAttention, SelfAttention -# except ModuleNotFoundError: -# print("import fa failed!", flush=True) -# try: -# from deeplink_ext.internevo_ops import FlashCrossAttention, FlashSelfAttention -# except ModuleNotFoundError: -# flash_attn_qkvpacked_func = None -# FlashSelfAttention = None -# SelfAttention = None -# print("import dipu fa failed!", flush=True) - - -# @benchmark_initializer.register_module(module_name=BENCH_TYPE) - -# 对于FA,我们还是用on the fly的方式 profiling,并用cache缓存中间结果 -# class UnitMultiHeadAttn(UnitBench): -# # test_loop = { -# # "seq_len": [ -# # 64 * K, -# # int(0.25 * K), -# # int(0.5 * K), -# # 1 * K, -# # 2 * K, -# # 4 * K, -# # 8 * K, -# # 32 * K, -# # 16 * K, -# # ], # 256 * K, 128 * K, -# # "head_H": [(64, 8192), (48, 6144), (32, 4096), (40, 5120)], # (80, 10240), -# # "dtype": [torch.bfloat16], -# # "micro_bsz": [2, 1], # 4, -# # "tp_size": TP_SIZE_RANGE, -# # "is_fwd": [True, False], -# # } - -# def __init__(self, seq_len, num_heads_and_hidden_dim, dtype, micro_bsz, tp_size, is_fwd) -> None: -# q_head, kv_head, embed_dim = num_heads_and_hidden_dim -# self.num_heads_and_hidden_dim = num_heads_and_hidden_dim -# self.TP = tp_size -# self.S = seq_len -# self.N = num_heads -# self.H = embed_dim // self.N -# self.dtype = dtype -# self.dtype_size = 2 if self.dtype == torch.bfloat16 else 4 -# self.B = micro_bsz -# self.oom = False -# self.is_fwd = is_fwd -# self.causal = True - -# assert num_heads % self.TP == 0, "num_heads must be divisible by tp_size" -# assert num_heads >= tp_size, f"head nums must bigger then tp_size: {tp_size}" - -# self.num_atten_head_tp = num_heads // self.TP -# self.head_dim = self.H // num_heads -# self.tp_embedding_dim = self.H // self.TP - -# self.packed_length = self.S * self.B -# self.device = f"cuda:{get_local_rank()}" -# cu_seqlens = [i * self.S for i in range(self.B + 1)] - -# weights_mem_used = self.packed_length * 3 * self.H * self.dtype_size -# attn_activation = 11 * self.packed_length * self.H -# mem_used = attn_activation + weights_mem_used - -# self.inner_attn = FlashSelfAttention(causal=True, softmax_scale=self.H ** (0.5), attention_dropout=0.0) - -# oom = False -# if mem_used > 75 * 1024**3: -# oom = True - -# # 约束1: seqlen最大不能超过256K(不含) -# # 约束2: embed_dim在被tp切过之后若大于6144, 则packed_length不能大于256k -# if self.packed_length >= 256 * K and (self.H / self.TP) >= 6144: -# oom = True -# if self.S >= 256 * K and self.B > 1: -# oom = True -# if self.packed_length >= 524288 and (self.H / self.TP) >= 3072: -# oom = True -# if self.packed_length >= 1048576 and (self.H / self.TP) >= 2048: -# oom = True - -# if oom: -# assert ( -# False -# ), f"warning : mem_used: {mem_used/1024**3:.2f} GB, seq_len: {self.S}, embed_dim: {self.H}, tp_size: {self.TP}" - -# self.qkv = torch.rand( -# size=(self.B * self.S, 3, self.N // self.TP, self.H), -# dtype=self.dtype, -# device=self.device, -# requires_grad=True, -# ) - -# self.dtype_size = self.qkv.element_size() -# self.cu_seqlens = torch.tensor(data=cu_seqlens, dtype=torch.int32, device=self.device) -# self.max_seqlen = self.S -# if not self.is_fwd: -# self.output = self.run_fwd() -# self.grad = torch.randn_like(self.output) / 32 # avoid grad is too large. - -# def run(self): -# if self.is_fwd: -# self.run_fwd() -# else: -# self.run_bwd(self.output, self.grad) - -# def run_fwd(self): -# context = self.inner_attn(self.qkv, cu_seqlens=self.cu_seqlens, max_seqlen=self.max_seqlen, causal=self.causal) -# return context - -# def run_bwd(self, output, grad): -# output.backward(grad, retain_graph=True) - -# @staticmethod -# def gen_store_key(micro_bsz, seq_len, num_heads_and_hidden_dim, tp_size, is_fwd): -# _, embed_dim = num_heads_and_hidden_dim -# tp_embedding_dim = embed_dim // tp_size -# return f"b_{micro_bsz}_s_{seq_len}_h_{tp_embedding_dim}_fwd_{is_fwd}" - -# def complexity(self): -# return UnitMultiHeadAttn.gen_store_key(self.B, self.S, self.num_heads_and_hidden_dim, self.TP, self.is_fwd) -# # return f"{self.S} * {self.hidden_dim} * {self.hidden_dim}" - - -if __name__ == "__main__": - - micro_bsz, seqlen, hidden_size, q_head, kv_head, dtype = 1, 4096, 4096, 32, 8, torch.bfloat16 - t_fwd, t_bwd = run_fwd(micro_bsz, seqlen, hidden_size, q_head, kv_head, dtype) - print(f"t_fwd: {t_fwd}, t_bwd: {t_bwd}", flush=True) diff --git a/internlm/simulator/profiler/perf_comm.py b/internlm/simulator/profiler/perf_comm.py index 58402bfb5..782609909 100644 --- a/internlm/simulator/profiler/perf_comm.py +++ b/internlm/simulator/profiler/perf_comm.py @@ -125,7 +125,8 @@ def gen_perf(): ) group = dist.GroupMember.WORLD - gpc._register_dist(rank, world_size, group, None, list(range(world_size)), ParallelMode.GLOBAL) + # local_rank, world_size, process_group, cpu_group, ranks_in_group, all_ranks, mode + gpc._register_dist(rank, world_size, group, None, list(range(world_size)), list(range(world_size)), ParallelMode.GLOBAL) gpc._global_ranks[ParallelMode.GLOBAL] = rank gpc.set_device(local_rank) @@ -162,29 +163,30 @@ def gen_perf(): sync_all() - for i in range(inter_comm_nums): - for j in range(intra_comm_nums): - inter_size, intra_size = 2**i, 2**j - if inter_size * intra_size != 1: - - x_idx, y_idx = get_group_id(rank, gpus_per_node, intra_size, inter_size) - groups = new_process_group(world_size, gpus_per_node, intra_size, inter_size) - - for test_type in comm_test_list: - key = gen_comm_key(test_op, intra_size, inter_size) - if dist.get_rank() == 0: - print( - f"key: {key}, inter_size: {inter_size}, intra_size: {intra_size}, ranks: {groups[y_idx][x_idx][1]}", - flush=True, - ) - pg = groups[y_idx][x_idx][0] - assert ( - pg != -100 - ), f"key: {key}, x_idx: {x_idx}, y_idx: {y_idx}, rank: {gpc.get_global_rank()}, ranks: {groups[y_idx][x_idx][1]}" - comm_vols, bws = run_comm_profile(test_type, pg, key) - sync_all() - if dist.get_rank() == 0: - spline_model_dict[key] = draw_pics(comm_pic_path, key, comm_vols, bws) + for test_op in comm_test_list: + for i in range(inter_comm_nums): + for j in range(intra_comm_nums): + inter_size, intra_size = 2**i, 2**j + if inter_size * intra_size != 1: + + x_idx, y_idx = get_group_id(rank, gpus_per_node, intra_size, inter_size) + groups = new_process_group(world_size, gpus_per_node, intra_size, inter_size) + + for test_type in comm_test_list: + key = gen_comm_key(test_op, intra_size, inter_size) + if dist.get_rank() == 0: + print( + f"key: {key}, inter_size: {inter_size}, intra_size: {intra_size}, ranks: {groups[y_idx][x_idx][1]}", + flush=True, + ) + pg = groups[y_idx][x_idx][0] + assert ( + pg != -100 + ), f"key: {key}, x_idx: {x_idx}, y_idx: {y_idx}, rank: {gpc.get_global_rank()}, ranks: {groups[y_idx][x_idx][1]}" + comm_vols, bws = run_comm_profile(test_type, pg, key) + sync_all() + if dist.get_rank() == 0: + spline_model_dict[key] = draw_pics(comm_pic_path, key, comm_vols, bws) print(f"rank: {gpc.get_global_rank()}, all done!", flush=True) diff --git a/internlm/simulator/profiler/profiler.py b/internlm/simulator/profiler/profiler.py index 9d28bbfd4..adea31fd6 100644 --- a/internlm/simulator/profiler/profiler.py +++ b/internlm/simulator/profiler/profiler.py @@ -234,24 +234,7 @@ def draw_pics(base_path, plot_name, comm_vols, bws): def draw_cal_pics(base_path, plot_name, tflop, tflops): - # x, y = [], [] - spline_model = interp1d(tflop, tflops, kind="slinear") - - # start = tflop[0] - # end = tflop[-1] - # for complexity in range(start, end+1): - # try: - # predice_tflops = spline_model(complexity) - # except ValueError: - # if complexity < tflop[0]: - # predice_tflops = spline_model(tflop[0]) - # elif complexity > tflop[-1]: - # predice_tflops = spline_model(tflop[-1]) - - # x.append(complexity) - # y.append(predice_tflops) - pic_path = os.path.join(base_path, plot_name + ".jpg") tflop = list(map(lambda x: x / 10**12, tflop)) tflops = list(map(lambda x: x / 10**12, tflops)) diff --git a/simulation_train_formulaic.py b/simulation_train_formulaic.py index 858bd05d5..8b3654ab6 100644 --- a/simulation_train_formulaic.py +++ b/simulation_train_formulaic.py @@ -12,6 +12,7 @@ from internlm.core.context import global_context as gpc from internlm.core.context.random import reset_seed from internlm.core.parallel.shard import cluster_load_balance, partition_uniform +from internlm.initialize import get_default_parser from internlm.initialize.launch import args_sanity_check, launch from internlm.simulator.common import AlgoType, cal_block_p_elem, cal_model_p_elem @@ -82,6 +83,8 @@ def comm_dp_cost(dtype_size, algo, pp_blocks_elem, embedding_elem, zp) -> float: block_zp_latency = zp * broadcast(dtype_size * pp_blocks_elem / zp, ParallelMode.ZERO1, comm_nums=zp) embedding_zp_latency = broadcast(dtype_size * embedding_elem, ParallelMode.DATA) zp_latency = max(block_zp_latency, embedding_zp_latency) + else: + raise ValueError(f"Invalid algo type: {algo}") return zp_latency, wdp_latency @@ -384,16 +387,20 @@ def overlaped_fwd_bwd_cost(): def run_loop( global_bsz, world_size, - args, + config_path, use_fixed_micro_bsz=False, use_strict_bsz=True, global_bsz_max=1, global_bsz_min=1, debug=True, ): - gpc.load_config(config=Config.from_file(args.config)) + gpc.load_config(config=Config.from_file(config_path)) gpc.set_fake_mode(True) + if "multiple_of" not in gpc.config.model: + gpc.config.model["multiple_of"] = 256 + print(f"multiple_of not in config, use default value: {gpc.config.model.multiple_of}") + min_comm_cost, msp_min_cost, fsp_min_cost, isp_min_cost = ( float("inf"), float("inf"), @@ -566,9 +573,9 @@ def run_loop( return solutions_list, min_comm_cost, min_cost_solution, msp_min_solu, fsp_min_solu, isp_min_solu -def run_warrper(global_bsz, world_size, args): +def run_warrper(world_size, global_bsz, config_path): solutions_list, min_comm_cost, min_cost_solution, msp_min_solu, fsp_min_solu, isp_min_solu = run_loop( - global_bsz=global_bsz, world_size=world_size, args=args + global_bsz=global_bsz, world_size=world_size, config_path=config_path ) if min_cost_solution is not None: @@ -606,7 +613,17 @@ def run_warrper(global_bsz, world_size, args): print("No solution found") -def run_single(global_bsz=4 * 1024 * 1024): +def get_world_size(): + if "WORLD_SIZE" in os.environ: + return int(os.environ["WORLD_SIZE"]) + else: + if "SLURM_NTASKS" in os.environ: + return int(os.environ["SLURM_NTASKS"]) + else: + return 1 + + +def run_single(world_size, global_bsz=4 * 1024 * 1024): gpc.load_config(config=Config.from_file(args.config)) gpc.set_fake_mode(True) print(f"gpc.config.parallel: {gpc.config.parallel}") @@ -675,17 +692,19 @@ def run_single(global_bsz=4 * 1024 * 1024): if __name__ == "__main__": - args = parse_args() + parser = get_default_parser() + args = parser.parse_args() + hostname = socket.gethostname() - world_size = args.world_size + global_batch_size = args.global_batch_size - init_cost_model(get_args().pre_profiling_data_path) + init_cost_model(args.pre_profiling_data_path) os.environ["fake_mode"] = "1" gloab_allocator.init_capcity = 80 * 1024**3 gloab_allocator.capcity = 80 * 1024**3 - if get_args().run_all_solu: - run_warrper(4096 * 1024, world_size, args) + if args.run_all_solu: + run_warrper(args.world_size, args.global_batch_size, args.config) else: - run_single(get_args().global_batch_size) + run_single(args.world_size, args.global_batch_size)