From 9fc11369b33d1c56f20946c9a0cd75b9819b218a Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Wed, 19 Feb 2025 13:27:15 +0800 Subject: [PATCH 1/9] temp version --- tools/convert_ckpt_parallel.py | 432 +++++++++++++++++++++++++++++++++ 1 file changed, 432 insertions(+) create mode 100644 tools/convert_ckpt_parallel.py diff --git a/tools/convert_ckpt_parallel.py b/tools/convert_ckpt_parallel.py new file mode 100644 index 000000000..d8345ac00 --- /dev/null +++ b/tools/convert_ckpt_parallel.py @@ -0,0 +1,432 @@ +import argparse +import os +import shutil +import sys +from collections import defaultdict + +import torch +from torch._utils import _flatten_dense_tensors +from itertools import cycle +from collections import OrderedDict + + +current_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.join(current_dir, "../")) + + +def parse_args(): + args = argparse.ArgumentParser() + args.add_argument("origin_model_path", type=str, default=None) + args.add_argument("target_model_path", type=str, default=None) + args.add_argument("--target_tp_size", type=int, default=0) + args.add_argument("--target_pp_size", type=int, default=0) + args.add_argument("--model_size", type=str, default="7B", choices=["7B", "20B", "70B"]) + return args.parse_args() + + +def get_mapping(n, m): + if m % n != 0: + raise ValueError("m must be a multiple of n") + + n_list = list(range(n)) + m_list = list(range(m)) + + mapping = {} + for i, n_val in enumerate(n_list): + mapping[n_val] = m_list[i * int(m / n) : (i + 1) * int(m / n)] + + return mapping + + +def map_pp_lists(old_pp, new_pp): + result = [] + old_ranks = list(range(old_pp)) + new_ranks = list(range(new_pp)) + + if old_pp > new_pp: + ratio = old_pp // new_pp + for i in old_ranks: + result.append([new_ranks[i // ratio]]) + elif old_pp < new_pp: + ratio = new_pp // old_pp + for i in old_ranks: + result.append(new_ranks[i * ratio:(i + 1) * ratio]) + else: + for i in old_ranks: + result.append([new_ranks[i]]) + + assert len(result) == old_pp + return result + + +def sorted_state_dict(unordered_dict): + sorted_keys = sorted(unordered_dict.keys()) # 排序键名 + sorted_dict = OrderedDict() + for key in sorted_keys: + sorted_dict[key] = unordered_dict[key] + return sorted_dict + + +def flatten(input_): + return _flatten_dense_tensors(input_) + + +def unflatten_tensor(flat_tensor, states): + """ + 根据目标形状,将扁平化的张量拆分为多个子张量。 + + :param flat_tensor: 扁平化的张量 + :param shapes: 每个子张量的目标形状(list of tuples) + :return: 切分后的多个子张量列表 + """ + start = 0 + unflat_tensors = [] + + for _, state in states.items(): + shape = state['shape'] + size = torch.prod(torch.tensor(shape)) # 计算每个子张量的大小 + tensor = flat_tensor[start:start + size].reshape(*shape) # 切分并恢复形状 + unflat_tensors.append(tensor) + start += size # 更新起始位置 + + return unflat_tensors + + +def preprocess_optimizer_state(old_tp_size, old_pp_size, old_zero1_size, old_meta, folder): + processed_ckpt_states = [[[{} for _ in range(old_zero1_size)] for _ in range(old_pp_size)] for _ in range(old_tp_size)] + for old_tp_rank in range(old_tp_size): + for old_pp_rank in range(old_pp_size): + for old_zero1_rank in range(old_zero1_size): + ckpt_states = torch.load(os.path.join(folder, f"optimizer_tp{old_tp_rank}_pp{old_pp_rank}_zo{old_zero1_rank}.pt"), map_location="cpu") + base_optim_states = ckpt_states['base_optim_states']['state'] + flat_fp32_weights = ckpt_states['flat_fp32_weights'] + processed_state = ckpt_states + for group_id in list(base_optim_states.keys()): + exp_avg = base_optim_states[group_id]['exp_avg'] + exp_avg_sq = base_optim_states[group_id]['exp_avg_sq'] + flat_tensor = flat_fp32_weights[group_id] + metaData = old_meta['metaData'][old_tp_rank][old_pp_rank][old_zero1_rank][group_id] + + unflat_exp_avg = unflatten_tensor(exp_avg, metaData) + unflat_exp_avg_sq = unflatten_tensor(exp_avg_sq, metaData) + unflat_tensor = unflatten_tensor(flat_tensor, metaData) + + processed_state['base_optim_states']['state'][group_id]['exp_avg'] = unflat_exp_avg + processed_state['base_optim_states']['state'][group_id]['exp_avg_sq'] = unflat_exp_avg_sq + processed_state['flat_fp32_weights'][group_id] = unflat_tensor + + processed_ckpt_states[old_tp_rank][old_pp_rank][old_zero1_rank] = processed_state + + return processed_ckpt_states + + +def check_optimizer_convert(target_meta, target_states, group_id): + fqn_list = list(target_meta.keys()) + index = len(target_states['global_fqn']) - 1 + meta_fqn = fqn_list[index] + meta_shape = target_meta[meta_fqn]['shape'] + meta_group_id = target_meta[meta_fqn]['group_id'] + states_fqn = target_states['global_fqn'][-1] + states_shape = target_states['fp32_weights'][-1].shape + + print(fqn_list) + print(target_states['global_fqn']) + assert meta_fqn == states_fqn, f"states_fqn {states_fqn} and meta_fqn {meta_fqn} are not the same." + assert meta_group_id == group_id, f"For {states_fqn}: group_id {states_shape} and meta_group_id {meta_shape} are not the same." + assert meta_shape == states_shape, f"For {states_fqn}: states_shape {states_shape} and meta_shape {meta_shape} are not the same." + + +def model_tp_split(split_maps, old_pp_rank, old_tp_size, new_states, old_meta_data, new_meta_data, ratio, tp_mode, old_map_local_to_global, new_meta, folder): + for old_tp_rank in range(old_tp_size): + ckpt_states = torch.load(os.path.join(folder, f"model_{tp_mode}{old_tp_rank}_pp{old_pp_rank}.pt"), map_location='cpu') + for fqn, tensor in ckpt_states.items(): + assert len(tensor.size()) < 3, "Only support 2D or 1D tensors." + global_fqn = old_map_local_to_global[old_pp_rank][fqn] + tp_dim = old_meta_data[global_fqn]['tp_dim'] + assert tp_dim == new_meta_data[global_fqn]['tp_dim'], f"{global_fqn} tp_dim in old and new meta are not equal: old={tp_dim}, new={new_meta_data[fqn]['tp_dim']}" + new_pp_rank = new_meta_data[global_fqn]['pp'] + new_zero1_rank = new_meta_data[global_fqn]['zero1'] + new_fqn = new_meta_data[global_fqn]['fqn'] + group_id = new_meta_data[global_fqn]['group_id'] + + if tp_dim == -1: + for _, new_tp_rank in enumerate(split_maps[old_tp_rank]): + new_states[new_tp_rank][new_pp_rank][new_fqn] = tensor.detach().clone() + splited_shape = new_states[new_tp_rank][new_pp_rank][new_fqn].shape + meta_shape = new_meta['metaData'][new_tp_rank][new_pp_rank][new_zero1_rank][group_id][global_fqn]['shape'] + assert splited_shape == meta_shape, f"{new_fqn}: splited shape {splited_shape} is not euqal to metaData {meta_shape}" + else: + split_size = tensor.size()[tp_dim] // ratio + new_tp_splits = torch.split(tensor, split_size, dim=tp_dim) + for i, new_tp_rank in enumerate(split_maps[old_tp_rank]): + new_states[new_tp_rank][new_pp_rank][new_fqn] = new_tp_splits[i].detach().clone() + splited_shape = new_states[new_tp_rank][new_pp_rank][new_fqn].shape + meta_shape = new_meta['metaData'][new_tp_rank][new_pp_rank][new_zero1_rank][group_id][global_fqn]['shape'] + assert splited_shape == meta_shape, f"{new_fqn}: splited shape {splited_shape} is not euqal to metaData {meta_shape}" + + +def optimizer_tp_split(split_maps, old_tp_size, old_pp_rank, old_zero1_rank, old_meta, new_meta, new_meta_data, processed_ckpt_states, new_states, ratio): + for old_tp_rank in range(old_tp_size): + ckpt_states = processed_ckpt_states[old_tp_rank][old_pp_rank][old_zero1_rank] + for group_id in ckpt_states['flat_fp32_weights'].keys(): + old_metaData = old_meta['metaData'][old_tp_rank][old_pp_rank][old_zero1_rank][group_id] + exp_avg_list = ckpt_states['base_optim_states']['state'][group_id]['exp_avg'] + exp_avg_sq_list = ckpt_states['base_optim_states']['state'][group_id]['exp_avg_sq'] + fp32_weights_list = ckpt_states['flat_fp32_weights'][group_id] + + for i, global_fqn in enumerate(list(old_metaData.keys())): + tp_dim = old_metaData[global_fqn]['tp_dim'] + new_pp_rank = new_meta_data[global_fqn]['pp'] + new_zero1_rank = new_meta_data[global_fqn]['zero1'] + + if tp_dim == -1: + for _, new_tp_rank in enumerate(split_maps[old_tp_rank]): + target_new_states = new_states[new_tp_rank][new_pp_rank][new_zero1_rank][group_id] + target_new_states['global_fqn'].append(global_fqn) + target_new_states['exp_avg'].append(exp_avg_list[i].detach().clone()) + target_new_states['exp_avg_sq'].append(exp_avg_sq_list[i].detach().clone()) + target_new_states['fp32_weights'].append(fp32_weights_list[i].detach().clone()) + + check_optimizer_convert(new_meta['metaData'][new_tp_rank][new_pp_rank][new_zero1_rank][group_id], target_new_states, group_id) + else: + split_size = old_metaData[global_fqn]['shape'][tp_dim] // ratio + new_exp_avg_splits = torch.split(exp_avg_list[i], split_size, dim=tp_dim) + new_exp_avg_sq_splits = torch.split(exp_avg_sq_list[i], split_size, dim=tp_dim) + new_fp32_weights_splits = torch.split(fp32_weights_list[i], split_size, dim=tp_dim) + for j, new_tp_rank in enumerate(split_maps[old_tp_rank]): + target_new_states = new_states[new_tp_rank][new_pp_rank][new_zero1_rank][group_id] + target_new_states['global_fqn'].append(global_fqn) + target_new_states['exp_avg'].append(new_exp_avg_splits[j].detach().clone()) + target_new_states['exp_avg_sq'].append(new_exp_avg_sq_splits[j].detach().clone()) + target_new_states['fp32_weights'].append(new_fp32_weights_splits[j].detach().clone()) + + print(new_tp_rank, new_pp_rank, new_zero1_rank, group_id) + check_optimizer_convert(new_meta['metaData'][new_tp_rank][new_pp_rank][new_zero1_rank][group_id], target_new_states, group_id) + + +def model_tp_merge(old_pp_rank, new_states, old_tp_size, new_tp_size, tp_mode, ratio, old_meta_data, new_meta_data, old_map_local_to_global, new_meta, folder): + candidate_states = [defaultdict(list) for _ in range(new_tp_size)] + for old_tp_rank in range(old_tp_size): + ckpt_states = torch.load(os.path.join(folder, f"model_{tp_mode}{old_tp_rank}_pp{old_pp_rank}.pt"), map_location="cpu") + for fqn, tensor in ckpt_states.items(): + assert len(tensor.size()) < 3, "Only support 2D or 1D tensors." + new_tp_rank = old_tp_rank // ratio + candidate_states[new_tp_rank][fqn].append(tensor) + + for new_tp_rank, states in enumerate(candidate_states): + for fqn, tensor_list in states.items(): + global_fqn = old_map_local_to_global[old_pp_rank][fqn] + tp_dim = old_meta_data[global_fqn]['tp_dim'] + assert tp_dim == new_meta_data[global_fqn]['tp_dim'], f"{global_fqn} tp_dim in old and new meta are not equal: old={tp_dim}, new={new_meta_data[fqn]['tp_dim']}" + new_pp_rank = new_meta_data[global_fqn]['pp'] + new_zero1_rank = new_meta_data[global_fqn]['zero1'] + new_fqn = new_meta_data[global_fqn]['fqn'] + group_id = new_meta_data[global_fqn]['group_id'] + + if tp_dim == -1: + assert torch.equal(tensor_list[0], tensor_list[1]), f"{global_fqn} should not be splited by tp, but the tensors in different checkpoints are not equal." + new_states[new_tp_rank][new_pp_rank][new_fqn] = tensor_list[0].detach().clone() + else: + new_states[new_tp_rank][new_pp_rank][new_fqn] = torch.concat(tensor_list, dim=tp_dim).detach().clone() + + splited_shape = new_states[new_tp_rank][new_pp_rank][new_fqn].shape + meta_shape = new_meta['metaData'][new_tp_rank][new_pp_rank][new_zero1_rank][group_id][global_fqn]['shape'] + assert splited_shape == meta_shape, f"{new_fqn}: splited shape {splited_shape} is not euqal to metaData {meta_shape}" + + +def optimizer_tp_merge(new_tp_size, old_tp_size, old_pp_rank, old_zero1_rank, old_meta, new_meta, old_meta_data, new_meta_data, processed_ckpt_states, new_states, ratio): + candidate_exp_avg = [defaultdict(list) for _ in range(new_tp_size)] + candidate_exp_avg_sq = [defaultdict(list) for _ in range(new_tp_size)] + candidate_fp32_weights = [defaultdict(list) for _ in range(new_tp_size)] + for old_tp_rank in range(old_tp_size): + ckpt_states = processed_ckpt_states[old_tp_rank][old_pp_rank][old_zero1_rank] + for group_id in ckpt_states['flat_fp32_weights'].keys(): + old_metaData = old_meta['metaData'][old_tp_rank][old_pp_rank][old_zero1_rank][group_id] + exp_avg_list = ckpt_states['base_optim_states']['state'][group_id]['exp_avg'] + exp_avg_sq_list = ckpt_states['base_optim_states']['state'][group_id]['exp_avg_sq'] + fp32_weights_list = ckpt_states['flat_fp32_weights'][group_id] + new_tp_rank = old_tp_rank // ratio + for i, global_fqn in enumerate(list(old_metaData.keys())): + assert group_id == new_meta_data[global_fqn]['group_id'] + candidate_exp_avg[new_tp_rank][global_fqn].append(exp_avg_list[i]) + candidate_exp_avg_sq[new_tp_rank][global_fqn].append(exp_avg_sq_list[i]) + candidate_fp32_weights[new_tp_rank][global_fqn].append(fp32_weights_list[i]) + + for new_tp_rank in len(candidate_fp32_weights): + for global_fqn in candidate_fp32_weights.keys(): + splited_exp_avg = candidate_exp_avg[new_tp_rank][global_fqn] + splited_exp_avg_sq = candidate_exp_avg_sq[new_tp_rank][global_fqn] + splited_fp32_weights = candidate_fp32_weights[new_tp_rank][global_fqn] + + tp_dim = old_meta_data[global_fqn]['tp_dim'] + new_pp_rank = new_meta_data[global_fqn]['pp'] + new_zero1_rank = new_meta_data[global_fqn]['zero1'] + group_id = new_meta_data[global_fqn]['group_id'] + target_new_states = new_states[new_tp_rank][new_pp_rank][new_zero1_rank][group_id] + + if tp_dim == -1: + assert torch.equal(splited_fp32_weights[0], splited_fp32_weights[1]), f"{global_fqn} should not be splited by tp, but the tensors in different checkpoints are not equal." + target_new_states['global_fqn'].append(global_fqn) + target_new_states['exp_avg'].append(splited_exp_avg[0].detach().clone()) + target_new_states['exp_avg_sq'].append(splited_exp_avg_sq[0].detach().clone()) + target_new_states['fp32_weights'].append(splited_fp32_weights[0].detach().clone()) + else: + target_new_states['global_fqn'].append(global_fqn) + target_new_states['exp_avg'].append(torch.concat(splited_exp_avg, dim=tp_dim).detach().clone()) + target_new_states['exp_avg_sq'].append(torch.concat(splited_exp_avg_sq, dim=tp_dim).detach().clone()) + target_new_states['fp32_weights'].append(torch.concat(splited_fp32_weights, dim=tp_dim).detach().clone()) + + check_optimizer_convert(new_meta['metaData'][new_tp_rank][new_pp_rank][new_zero1_rank][group_id], target_new_states, group_id) + + +def convert_modeling_ckpt(old_pp_size, new_pp_size, old_tp_size, new_tp_size, old_meta_data, new_meta_data, old_map_local_to_global, tp_mode, folder, saved_folder, new_states): + for old_pp_rank in range(old_pp_size): + if old_tp_size != new_tp_size: + if old_tp_size > new_tp_size: + assert old_tp_size % new_tp_size == 0, f"Cannot convert {old_tp_size} TP to {new_tp_size} TP." + ratio = old_tp_size // new_tp_size + model_tp_merge(old_pp_rank, new_states, old_tp_size, new_tp_size, tp_mode, ratio, old_meta_data, new_meta_data, old_map_local_to_global, new_meta, folder) + else: + assert new_tp_size % old_tp_size == 0, f"Cannot convert {old_tp_size} TP to {new_tp_size} TP." + split_maps = get_mapping(old_tp_size, new_tp_size) + ratio = new_tp_size // old_tp_size + model_tp_split(split_maps, old_pp_rank, old_tp_size, new_states, old_meta_data, new_meta_data, ratio, tp_mode, old_map_local_to_global, new_meta, folder) + else: + print(f"New tp and old tp are equal. Directly copy.") + for old_tp_rank in range(old_tp_size): + ckpt_states = torch.load(os.path.join(folder, f"model_{tp_mode}{old_tp_rank}_pp{old_pp_rank}.pt"), map_location='cpu') + for fqn, tensor in ckpt_states.items(): + new_pp_rank = new_meta_data[fqn]['pp'] + new_states[new_tp_rank][new_pp_rank][fqn] = tensor.detach().clone() + for new_tp_rank in range(new_tp_size): + for new_pp_rank in range(new_pp_size): + # print(f"pp={new_pp_rank}, tp={new_tp_rank}: {new_states[new_tp_rank][new_pp_rank].keys()}") + file_name = f"model_{tp_mode}{new_tp_rank}_pp{new_pp_rank}.pt" + states = sorted_state_dict(new_states[new_tp_rank][new_pp_rank]) + # torch.save(os.path.join(saved_folder, file_name), states) + + print(f"Finish Modeling") + + +def convert_optimizer_ckpt(old_meta, new_meta, old_pp_size, new_pp_size, old_tp_size, new_tp_size, old_zero1_size, new_zero1_size, old_meta_data, new_meta_data, saved_folder, new_states, processed_ckpt_states): + for old_pp_rank in range(old_pp_size): + for old_zero1_rank in range(old_zero1_size): + if old_tp_size != new_tp_size: + if old_tp_size > new_tp_size: + assert old_tp_size % new_tp_size == 0, f"Cannot convert {old_tp_size} TP to {new_tp_size} TP." + ratio = old_tp_size // new_tp_size + optimizer_tp_merge(new_tp_size, old_tp_size, old_pp_rank, old_zero1_rank, old_meta, new_meta, old_meta_data, new_meta_data, processed_ckpt_states, new_states, ratio) + else: + assert new_tp_size % old_tp_size == 0, f"Cannot convert {old_tp_size} TP to {new_tp_size} TP." + split_maps = get_mapping(old_tp_size, new_tp_size) + ratio = new_tp_size // old_tp_size + optimizer_tp_split(split_maps, old_tp_size, old_pp_rank, old_zero1_rank, old_meta, new_meta, new_meta_data, processed_ckpt_states, new_states, ratio) + else: + print(f"New tp and old tp are equal. Directly copy.") + for old_tp_rank in range(old_tp_size): + ckpt_states = processed_ckpt_states[old_tp_rank][old_pp_rank][old_zero1_rank] + for group_id in ckpt_states['flat_fp32_weights'].keys(): + old_metaData = old_meta['metaData'][old_tp_rank][old_pp_rank][old_zero1_rank][group_id] + exp_avg_list = ckpt_states['base_optim_states']['state'][group_id]['exp_avg'] + exp_avg_sq_list = ckpt_states['base_optim_states']['state'][group_id]['exp_avg_sq'] + fp32_weights_list = ckpt_states['flat_fp32_weights'][group_id] + for i, global_fqn in enumerate(list(old_metaData.keys())): + new_pp_rank = new_meta_data[global_fqn]['pp'] + new_zero1_rank = new_meta_data[global_fqn]['zero1'] + target_new_states = new_states[new_tp_rank][new_pp_rank][new_zero1_rank][group_id] + target_new_states['global_fqn'].append(global_fqn) + target_new_states['exp_avg'].append(exp_avg_list[i].detach().clone()) + target_new_states['exp_avg_sq'].append(exp_avg_sq_list[i].detach().clone()) + target_new_states['fp32_weights'].append(fp32_weights_list[i].detach().clone()) + + for new_tp_rank in range(new_tp_size): + for new_pp_rank in range(new_pp_size): + for new_zero1_rank in range(new_zero1_size): + file_name = f"optimizer_tp{new_tp_rank}_pp{new_pp_rank}_zo{new_zero1_rank}.pt" + optimizer_state = new_states[new_tp_rank][new_pp_rank][new_zero1_rank] + base_state = processed_ckpt_states[0][0][0] + step = base_state['base_optim_states']['state'][0]['step'] + base_state['base_optim_states']['state'] = {} + base_state['flat_fp32_weights'] = {} + if 'zero_devide_optim_plan' in base_state: + base_state.pop('zero_devide_optim_plan') + + for group_id in optimizer_state.keys(): + if len(optimizer_state[group_id]['global_fqn']) == 0: + print(f"Warning: tp{new_tp_rank}_pp{new_pp_rank}_zo{new_zero1_rank}_groupId{group_id} has no param.") + continue + + assert optimizer_state[group_id]['global_fqn'] == list(new_meta['metaData'][new_tp_rank][new_pp_rank][new_zero1_rank][group_id].keys()) + flat_exp_avg = flatten(optimizer_state[group_id]['exp_avg']) + flat_exp_avg_sq = flatten(optimizer_state[group_id]['exp_avg_sq']) + flat_fp32_weights = flatten(optimizer_state[group_id]['fp32_weights']) + state = {'step': step, 'exp_avg': flat_exp_avg, 'exp_avg_sq': flat_exp_avg_sq} + base_state['base_optim_states']['state'][group_id] = state + base_state['flat_fp32_weights'][group_id] = flat_fp32_weights + + print(f"optimizer tp{new_tp_rank}_pp{new_pp_rank}_zo{new_zero1_rank}: {base_state}") + # torch.save(os.path.join(saved_folder, file_name), base_state) + + print(f"Finish optimizer") + + +if __name__ == "__main__": + folder = "/mnt/petrelfs/lijiaxing/InternEvo/llm_ckpts_3/5" + saved_folder = "/mnt/petrelfs/lijiaxing/InternEvo/llm_ckpts_3_new" + meta_path = "/mnt/petrelfs/lijiaxing/InternEvo/llm_ckpts_3_new/metadata.pt" + + + old_meta = os.path.join(folder, f"metadata.pt") + assert os.path.exists(old_meta), 'old meta file does not exist, plese generate it before converting checkpoint.' + + old_meta = torch.load(old_meta, map_location='cpu') + old_pp_size = old_meta['parallel_setting']['pp_size'] + old_zero1_size = old_meta['parallel_setting']['zero1_size'] + if 'tp_size' in old_meta['parallel_setting']: + tp_mode = 'tp' + elif 'wp_size' in old_meta['parallel_setting']: + tp_mode = 'wp' + else: + assert False, "tp or wp should be in parallel setting." + old_tp_size = old_meta['parallel_setting'][f"{tp_mode}_size"] + + old_meta_data = {} + for pp_rank in range(old_pp_size): + for zero_rank in range(old_zero1_size): + for states in old_meta['metaData'][0][pp_rank][zero_rank].values(): + old_meta_data.update(states) + + old_map_local_to_global = [{} for _ in range(old_pp_size)] + for global_fqn, states in old_meta_data.items(): + old_map_local_to_global[states['pp']][states['fqn']] = global_fqn + + + new_meta = torch.load(meta_path, map_location='cpu') + new_pp_size = new_meta['parallel_setting']['pp_size'] + new_zero1_size = new_meta['parallel_setting']['zero1_size'] + new_tp_size = new_meta['parallel_setting'][f"{tp_mode}_size"] + assert set(new_meta['metaData'][0][0][0].keys()) == set(old_meta['metaData'][0][0][0].keys()), "Error: old meta and new meta have diffent group_id lists." + group_id_list = list(new_meta['metaData'][0][0][0].keys()) + + new_meta_data = {} + for pp_rank in range(new_pp_size): + for zero_rank in range(new_zero1_size): + for states in new_meta['metaData'][0][pp_rank][zero_rank].values(): + new_meta_data.update(states) + + + # new_states = [[{} for _ in range(new_pp_size)] for _ in range(new_tp_size)] + # convert_modeling_ckpt(old_pp_size, new_pp_size, old_tp_size, new_tp_size, old_meta_data, new_meta_data, old_map_local_to_global, tp_mode, folder, saved_folder, new_states) + + processed_ckpt_states = preprocess_optimizer_state(old_tp_size, old_pp_size, old_zero1_size, old_meta, folder) + group_dict = {group_id: {'global_fqn': [], 'exp_avg': [], 'exp_avg_sq': [], 'fp32_weights': []} for group_id in group_id_list} + new_states = [[[group_dict for _ in range(new_zero1_size)] for _ in range(new_pp_size)] for _ in range(new_tp_size)] + convert_optimizer_ckpt(old_meta, new_meta, old_pp_size, new_pp_size, old_tp_size, new_tp_size, old_zero1_size, new_zero1_size, old_meta_data, new_meta_data, saved_folder, new_states, processed_ckpt_states) + + + # srun -p llm_s -x HOST-10-140-60-16 --gres=gpu:8 python tools/convert_ckpt_parallel.py > debug1.log 2>&1 + # srun -p llm_s python tools/convert_ckpt_parallel.py 2>&1|tee debug1.log + + + + \ No newline at end of file From 21704bcf727c23291f5a60c6373b9fd20d8b2f94 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Wed, 19 Feb 2025 21:21:01 +0800 Subject: [PATCH 2/9] pre-commit --- tools/convert_ckpt_parallel.py | 684 ++++++++++++++++++++++----------- 1 file changed, 452 insertions(+), 232 deletions(-) diff --git a/tools/convert_ckpt_parallel.py b/tools/convert_ckpt_parallel.py index d8345ac00..72622eaf6 100644 --- a/tools/convert_ckpt_parallel.py +++ b/tools/convert_ckpt_parallel.py @@ -2,13 +2,10 @@ import os import shutil import sys -from collections import defaultdict +from collections import OrderedDict, defaultdict import torch from torch._utils import _flatten_dense_tensors -from itertools import cycle -from collections import OrderedDict - current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.join(current_dir, "../")) @@ -16,11 +13,11 @@ def parse_args(): args = argparse.ArgumentParser() - args.add_argument("origin_model_path", type=str, default=None) - args.add_argument("target_model_path", type=str, default=None) - args.add_argument("--target_tp_size", type=int, default=0) - args.add_argument("--target_pp_size", type=int, default=0) - args.add_argument("--model_size", type=str, default="7B", choices=["7B", "20B", "70B"]) + args.add_argument("origin_ckpt_path", type=str, default=None) + args.add_argument("target_ckpt_path", type=str, default=None) + args.add_argument("--origin_meta_path", type=str, default=None) + args.add_argument("--target_meta_path", type=str, default=None) + args.add_argument("--copy_file", type=bool, default=True) return args.parse_args() @@ -42,7 +39,7 @@ def map_pp_lists(old_pp, new_pp): result = [] old_ranks = list(range(old_pp)) new_ranks = list(range(new_pp)) - + if old_pp > new_pp: ratio = old_pp // new_pp for i in old_ranks: @@ -50,11 +47,11 @@ def map_pp_lists(old_pp, new_pp): elif old_pp < new_pp: ratio = new_pp // old_pp for i in old_ranks: - result.append(new_ranks[i * ratio:(i + 1) * ratio]) + result.append(new_ranks[i * ratio : (i + 1) * ratio]) else: for i in old_ranks: result.append([new_ranks[i]]) - + assert len(result) == old_pp return result @@ -74,140 +71,200 @@ def flatten(input_): def unflatten_tensor(flat_tensor, states): """ 根据目标形状,将扁平化的张量拆分为多个子张量。 - - :param flat_tensor: 扁平化的张量 - :param shapes: 每个子张量的目标形状(list of tuples) - :return: 切分后的多个子张量列表 """ start = 0 unflat_tensors = [] - + for _, state in states.items(): - shape = state['shape'] + shape = state["shape"] size = torch.prod(torch.tensor(shape)) # 计算每个子张量的大小 - tensor = flat_tensor[start:start + size].reshape(*shape) # 切分并恢复形状 + tensor = flat_tensor[start : start + size].reshape(*shape) # 切分并恢复形状 unflat_tensors.append(tensor) start += size # 更新起始位置 - + return unflat_tensors -def preprocess_optimizer_state(old_tp_size, old_pp_size, old_zero1_size, old_meta, folder): - processed_ckpt_states = [[[{} for _ in range(old_zero1_size)] for _ in range(old_pp_size)] for _ in range(old_tp_size)] +def preprocess_optimizer_state(old_tp_size, old_pp_size, old_zero1_size, old_meta, folder, old_tp_mode): + processed_ckpt_states = [ + [[{} for _ in range(old_zero1_size)] for _ in range(old_pp_size)] for _ in range(old_tp_size) + ] for old_tp_rank in range(old_tp_size): for old_pp_rank in range(old_pp_size): for old_zero1_rank in range(old_zero1_size): - ckpt_states = torch.load(os.path.join(folder, f"optimizer_tp{old_tp_rank}_pp{old_pp_rank}_zo{old_zero1_rank}.pt"), map_location="cpu") - base_optim_states = ckpt_states['base_optim_states']['state'] - flat_fp32_weights = ckpt_states['flat_fp32_weights'] + ckpt_states = torch.load( + os.path.join(folder, f"optimizer_{old_tp_mode}{old_tp_rank}_pp{old_pp_rank}_zo{old_zero1_rank}.pt"), + map_location="cpu", + ) + base_optim_states = ckpt_states["base_optim_states"]["state"] + flat_fp32_weights = ckpt_states["flat_fp32_weights"] processed_state = ckpt_states for group_id in list(base_optim_states.keys()): - exp_avg = base_optim_states[group_id]['exp_avg'] - exp_avg_sq = base_optim_states[group_id]['exp_avg_sq'] + exp_avg = base_optim_states[group_id]["exp_avg"] + exp_avg_sq = base_optim_states[group_id]["exp_avg_sq"] flat_tensor = flat_fp32_weights[group_id] - metaData = old_meta['metaData'][old_tp_rank][old_pp_rank][old_zero1_rank][group_id] - + metaData = old_meta["metaData"][old_tp_rank][old_pp_rank][old_zero1_rank][group_id] + unflat_exp_avg = unflatten_tensor(exp_avg, metaData) unflat_exp_avg_sq = unflatten_tensor(exp_avg_sq, metaData) unflat_tensor = unflatten_tensor(flat_tensor, metaData) - - processed_state['base_optim_states']['state'][group_id]['exp_avg'] = unflat_exp_avg - processed_state['base_optim_states']['state'][group_id]['exp_avg_sq'] = unflat_exp_avg_sq - processed_state['flat_fp32_weights'][group_id] = unflat_tensor - + + processed_state["base_optim_states"]["state"][group_id]["exp_avg"] = unflat_exp_avg + processed_state["base_optim_states"]["state"][group_id]["exp_avg_sq"] = unflat_exp_avg_sq + processed_state["flat_fp32_weights"][group_id] = unflat_tensor + processed_ckpt_states[old_tp_rank][old_pp_rank][old_zero1_rank] = processed_state - + return processed_ckpt_states def check_optimizer_convert(target_meta, target_states, group_id): fqn_list = list(target_meta.keys()) - index = len(target_states['global_fqn']) - 1 + index = len(target_states["global_fqn"]) - 1 meta_fqn = fqn_list[index] - meta_shape = target_meta[meta_fqn]['shape'] - meta_group_id = target_meta[meta_fqn]['group_id'] - states_fqn = target_states['global_fqn'][-1] - states_shape = target_states['fp32_weights'][-1].shape - - print(fqn_list) - print(target_states['global_fqn']) - assert meta_fqn == states_fqn, f"states_fqn {states_fqn} and meta_fqn {meta_fqn} are not the same." - assert meta_group_id == group_id, f"For {states_fqn}: group_id {states_shape} and meta_group_id {meta_shape} are not the same." - assert meta_shape == states_shape, f"For {states_fqn}: states_shape {states_shape} and meta_shape {meta_shape} are not the same." - + meta_shape = target_meta[meta_fqn]["shape"] + meta_group_id = target_meta[meta_fqn]["group_id"] + states_fqn = target_states["global_fqn"][-1] + states_shape = target_states["fp32_weights"][-1].shape -def model_tp_split(split_maps, old_pp_rank, old_tp_size, new_states, old_meta_data, new_meta_data, ratio, tp_mode, old_map_local_to_global, new_meta, folder): + assert meta_fqn == states_fqn, f"states_fqn {states_fqn} and meta_fqn {meta_fqn} are not the same." + assert ( + meta_group_id == group_id + ), f"For {states_fqn}: group_id {states_shape} and meta_group_id {meta_shape} are not the same." + assert ( + meta_shape == states_shape + ), f"For {states_fqn}: states_shape {states_shape} and meta_shape {meta_shape} are not the same." + + +def sort_optimizer_state(target_dict, meta_fqns): + assert len(target_dict.keys()) == len(meta_fqns), f"fqn length error: {len(target_dict.keys())} != {len(meta_fqns)}" + assert set(target_dict.keys()) == set(meta_fqns), f"fqns not equal: {list(target_dict.keys())} != {list(meta_fqns)}" + sorted_exp_avg = [target_dict[key]["exp_avg"] for key in meta_fqns] + sorted_exp_avg_sq = [target_dict[key]["exp_avg_sq"] for key in meta_fqns] + sorted_fp32_weights = [target_dict[key]["fp32_weights"] for key in meta_fqns] + + return sorted_exp_avg, sorted_exp_avg_sq, sorted_fp32_weights + + +def model_tp_split( + split_maps, + old_pp_rank, + old_tp_size, + new_states, + old_meta_data, + new_meta_data, + ratio, + old_tp_mode, + old_map_local_to_global, + new_meta, + folder, +): for old_tp_rank in range(old_tp_size): - ckpt_states = torch.load(os.path.join(folder, f"model_{tp_mode}{old_tp_rank}_pp{old_pp_rank}.pt"), map_location='cpu') + ckpt_states = torch.load( + os.path.join(folder, f"model_{old_tp_mode}{old_tp_rank}_pp{old_pp_rank}.pt"), map_location="cpu" + ) for fqn, tensor in ckpt_states.items(): assert len(tensor.size()) < 3, "Only support 2D or 1D tensors." global_fqn = old_map_local_to_global[old_pp_rank][fqn] - tp_dim = old_meta_data[global_fqn]['tp_dim'] - assert tp_dim == new_meta_data[global_fqn]['tp_dim'], f"{global_fqn} tp_dim in old and new meta are not equal: old={tp_dim}, new={new_meta_data[fqn]['tp_dim']}" - new_pp_rank = new_meta_data[global_fqn]['pp'] - new_zero1_rank = new_meta_data[global_fqn]['zero1'] - new_fqn = new_meta_data[global_fqn]['fqn'] - group_id = new_meta_data[global_fqn]['group_id'] - + tp_dim = old_meta_data[global_fqn]["tp_dim"] + assert tp_dim == new_meta_data[global_fqn]["tp_dim"], ( + f"{global_fqn} tp_dim in old and new meta are not equal: " + f"old={tp_dim}, new={new_meta_data[fqn]['tp_dim']}" + ) + + new_pp_rank = new_meta_data[global_fqn]["pp"] + new_zero1_rank = new_meta_data[global_fqn]["zero1"] + new_fqn = new_meta_data[global_fqn]["fqn"] + group_id = new_meta_data[global_fqn]["group_id"] + if tp_dim == -1: for _, new_tp_rank in enumerate(split_maps[old_tp_rank]): new_states[new_tp_rank][new_pp_rank][new_fqn] = tensor.detach().clone() splited_shape = new_states[new_tp_rank][new_pp_rank][new_fqn].shape - meta_shape = new_meta['metaData'][new_tp_rank][new_pp_rank][new_zero1_rank][group_id][global_fqn]['shape'] - assert splited_shape == meta_shape, f"{new_fqn}: splited shape {splited_shape} is not euqal to metaData {meta_shape}" + meta_shape = new_meta["metaData"][new_tp_rank][new_pp_rank][new_zero1_rank][group_id][global_fqn][ + "shape" + ] + assert ( + splited_shape == meta_shape + ), f"{new_fqn}: splited shape {splited_shape} is not euqal to metaData {meta_shape}" else: split_size = tensor.size()[tp_dim] // ratio new_tp_splits = torch.split(tensor, split_size, dim=tp_dim) for i, new_tp_rank in enumerate(split_maps[old_tp_rank]): new_states[new_tp_rank][new_pp_rank][new_fqn] = new_tp_splits[i].detach().clone() splited_shape = new_states[new_tp_rank][new_pp_rank][new_fqn].shape - meta_shape = new_meta['metaData'][new_tp_rank][new_pp_rank][new_zero1_rank][group_id][global_fqn]['shape'] - assert splited_shape == meta_shape, f"{new_fqn}: splited shape {splited_shape} is not euqal to metaData {meta_shape}" - - -def optimizer_tp_split(split_maps, old_tp_size, old_pp_rank, old_zero1_rank, old_meta, new_meta, new_meta_data, processed_ckpt_states, new_states, ratio): + meta_shape = new_meta["metaData"][new_tp_rank][new_pp_rank][new_zero1_rank][group_id][global_fqn][ + "shape" + ] + assert ( + splited_shape == meta_shape + ), f"{new_fqn}: splited shape {splited_shape} is not euqal to metaData {meta_shape}" + + +def optimizer_tp_split( + split_maps, + old_tp_size, + old_pp_rank, + old_zero1_rank, + old_meta, + new_meta_data, + processed_ckpt_states, + new_states, + ratio, +): for old_tp_rank in range(old_tp_size): ckpt_states = processed_ckpt_states[old_tp_rank][old_pp_rank][old_zero1_rank] - for group_id in ckpt_states['flat_fp32_weights'].keys(): - old_metaData = old_meta['metaData'][old_tp_rank][old_pp_rank][old_zero1_rank][group_id] - exp_avg_list = ckpt_states['base_optim_states']['state'][group_id]['exp_avg'] - exp_avg_sq_list = ckpt_states['base_optim_states']['state'][group_id]['exp_avg_sq'] - fp32_weights_list = ckpt_states['flat_fp32_weights'][group_id] - + for group_id in ckpt_states["flat_fp32_weights"].keys(): + old_metaData = old_meta["metaData"][old_tp_rank][old_pp_rank][old_zero1_rank][group_id] + exp_avg_list = ckpt_states["base_optim_states"]["state"][group_id]["exp_avg"] + exp_avg_sq_list = ckpt_states["base_optim_states"]["state"][group_id]["exp_avg_sq"] + fp32_weights_list = ckpt_states["flat_fp32_weights"][group_id] + for i, global_fqn in enumerate(list(old_metaData.keys())): - tp_dim = old_metaData[global_fqn]['tp_dim'] - new_pp_rank = new_meta_data[global_fqn]['pp'] - new_zero1_rank = new_meta_data[global_fqn]['zero1'] - + tp_dim = old_metaData[global_fqn]["tp_dim"] + new_pp_rank = new_meta_data[global_fqn]["pp"] + new_zero1_rank = new_meta_data[global_fqn]["zero1"] + if tp_dim == -1: for _, new_tp_rank in enumerate(split_maps[old_tp_rank]): target_new_states = new_states[new_tp_rank][new_pp_rank][new_zero1_rank][group_id] - target_new_states['global_fqn'].append(global_fqn) - target_new_states['exp_avg'].append(exp_avg_list[i].detach().clone()) - target_new_states['exp_avg_sq'].append(exp_avg_sq_list[i].detach().clone()) - target_new_states['fp32_weights'].append(fp32_weights_list[i].detach().clone()) - - check_optimizer_convert(new_meta['metaData'][new_tp_rank][new_pp_rank][new_zero1_rank][group_id], target_new_states, group_id) + if global_fqn not in target_new_states: + target_new_states[global_fqn] = {"exp_avg": None, "exp_avg_sq": None, "fp32_weights": None} + target_new_states[global_fqn]["exp_avg"] = exp_avg_list[i].detach().clone() + target_new_states[global_fqn]["exp_avg_sq"] = exp_avg_sq_list[i].detach().clone() + target_new_states[global_fqn]["fp32_weights"] = fp32_weights_list[i].detach().clone() else: - split_size = old_metaData[global_fqn]['shape'][tp_dim] // ratio + split_size = old_metaData[global_fqn]["shape"][tp_dim] // ratio new_exp_avg_splits = torch.split(exp_avg_list[i], split_size, dim=tp_dim) new_exp_avg_sq_splits = torch.split(exp_avg_sq_list[i], split_size, dim=tp_dim) new_fp32_weights_splits = torch.split(fp32_weights_list[i], split_size, dim=tp_dim) for j, new_tp_rank in enumerate(split_maps[old_tp_rank]): target_new_states = new_states[new_tp_rank][new_pp_rank][new_zero1_rank][group_id] - target_new_states['global_fqn'].append(global_fqn) - target_new_states['exp_avg'].append(new_exp_avg_splits[j].detach().clone()) - target_new_states['exp_avg_sq'].append(new_exp_avg_sq_splits[j].detach().clone()) - target_new_states['fp32_weights'].append(new_fp32_weights_splits[j].detach().clone()) - - print(new_tp_rank, new_pp_rank, new_zero1_rank, group_id) - check_optimizer_convert(new_meta['metaData'][new_tp_rank][new_pp_rank][new_zero1_rank][group_id], target_new_states, group_id) - - -def model_tp_merge(old_pp_rank, new_states, old_tp_size, new_tp_size, tp_mode, ratio, old_meta_data, new_meta_data, old_map_local_to_global, new_meta, folder): + if global_fqn not in target_new_states: + target_new_states[global_fqn] = {"exp_avg": None, "exp_avg_sq": None, "fp32_weights": None} + target_new_states[global_fqn]["exp_avg"] = new_exp_avg_splits[j].detach().clone() + target_new_states[global_fqn]["exp_avg_sq"] = new_exp_avg_sq_splits[j].detach().clone() + target_new_states[global_fqn]["fp32_weights"] = new_fp32_weights_splits[j].detach().clone() + + +def model_tp_merge( + old_pp_rank, + new_states, + old_tp_size, + new_tp_size, + old_tp_mode, + ratio, + old_meta_data, + new_meta_data, + old_map_local_to_global, + new_meta, + folder, +): candidate_states = [defaultdict(list) for _ in range(new_tp_size)] for old_tp_rank in range(old_tp_size): - ckpt_states = torch.load(os.path.join(folder, f"model_{tp_mode}{old_tp_rank}_pp{old_pp_rank}.pt"), map_location="cpu") + ckpt_states = torch.load( + os.path.join(folder, f"model_{old_tp_mode}{old_tp_rank}_pp{old_pp_rank}.pt"), map_location="cpu" + ) for fqn, tensor in ckpt_states.items(): assert len(tensor.size()) < 3, "Only support 2D or 1D tensors." new_tp_rank = old_tp_rank // ratio @@ -216,217 +273,380 @@ def model_tp_merge(old_pp_rank, new_states, old_tp_size, new_tp_size, tp_mode, r for new_tp_rank, states in enumerate(candidate_states): for fqn, tensor_list in states.items(): global_fqn = old_map_local_to_global[old_pp_rank][fqn] - tp_dim = old_meta_data[global_fqn]['tp_dim'] - assert tp_dim == new_meta_data[global_fqn]['tp_dim'], f"{global_fqn} tp_dim in old and new meta are not equal: old={tp_dim}, new={new_meta_data[fqn]['tp_dim']}" - new_pp_rank = new_meta_data[global_fqn]['pp'] - new_zero1_rank = new_meta_data[global_fqn]['zero1'] - new_fqn = new_meta_data[global_fqn]['fqn'] - group_id = new_meta_data[global_fqn]['group_id'] - + tp_dim = old_meta_data[global_fqn]["tp_dim"] + assert tp_dim == new_meta_data[global_fqn]["tp_dim"], ( + f"{global_fqn} tp_dim in old and new meta are not equal: " + f"old={tp_dim}, new={new_meta_data[fqn]['tp_dim']}" + ) + + new_pp_rank = new_meta_data[global_fqn]["pp"] + new_zero1_rank = new_meta_data[global_fqn]["zero1"] + new_fqn = new_meta_data[global_fqn]["fqn"] + group_id = new_meta_data[global_fqn]["group_id"] + if tp_dim == -1: - assert torch.equal(tensor_list[0], tensor_list[1]), f"{global_fqn} should not be splited by tp, but the tensors in different checkpoints are not equal." + assert torch.equal( + tensor_list[0], tensor_list[1] + ), f"{global_fqn} should not be splited by tp, but the tensors in different checkpoints are not equal." new_states[new_tp_rank][new_pp_rank][new_fqn] = tensor_list[0].detach().clone() else: new_states[new_tp_rank][new_pp_rank][new_fqn] = torch.concat(tensor_list, dim=tp_dim).detach().clone() - + splited_shape = new_states[new_tp_rank][new_pp_rank][new_fqn].shape - meta_shape = new_meta['metaData'][new_tp_rank][new_pp_rank][new_zero1_rank][group_id][global_fqn]['shape'] - assert splited_shape == meta_shape, f"{new_fqn}: splited shape {splited_shape} is not euqal to metaData {meta_shape}" - - -def optimizer_tp_merge(new_tp_size, old_tp_size, old_pp_rank, old_zero1_rank, old_meta, new_meta, old_meta_data, new_meta_data, processed_ckpt_states, new_states, ratio): + meta_shape = new_meta["metaData"][new_tp_rank][new_pp_rank][new_zero1_rank][group_id][global_fqn]["shape"] + assert ( + splited_shape == meta_shape + ), f"{new_fqn}: splited shape {splited_shape} is not euqal to metaData {meta_shape}" + + +def optimizer_tp_merge( + new_tp_size, + old_tp_size, + old_pp_rank, + old_zero1_rank, + old_meta, + old_meta_data, + new_meta_data, + processed_ckpt_states, + new_states, + ratio, +): candidate_exp_avg = [defaultdict(list) for _ in range(new_tp_size)] candidate_exp_avg_sq = [defaultdict(list) for _ in range(new_tp_size)] candidate_fp32_weights = [defaultdict(list) for _ in range(new_tp_size)] for old_tp_rank in range(old_tp_size): ckpt_states = processed_ckpt_states[old_tp_rank][old_pp_rank][old_zero1_rank] - for group_id in ckpt_states['flat_fp32_weights'].keys(): - old_metaData = old_meta['metaData'][old_tp_rank][old_pp_rank][old_zero1_rank][group_id] - exp_avg_list = ckpt_states['base_optim_states']['state'][group_id]['exp_avg'] - exp_avg_sq_list = ckpt_states['base_optim_states']['state'][group_id]['exp_avg_sq'] - fp32_weights_list = ckpt_states['flat_fp32_weights'][group_id] + for group_id in ckpt_states["flat_fp32_weights"].keys(): + old_metaData = old_meta["metaData"][old_tp_rank][old_pp_rank][old_zero1_rank][group_id] + exp_avg_list = ckpt_states["base_optim_states"]["state"][group_id]["exp_avg"] + exp_avg_sq_list = ckpt_states["base_optim_states"]["state"][group_id]["exp_avg_sq"] + fp32_weights_list = ckpt_states["flat_fp32_weights"][group_id] new_tp_rank = old_tp_rank // ratio for i, global_fqn in enumerate(list(old_metaData.keys())): - assert group_id == new_meta_data[global_fqn]['group_id'] + assert group_id == new_meta_data[global_fqn]["group_id"] candidate_exp_avg[new_tp_rank][global_fqn].append(exp_avg_list[i]) candidate_exp_avg_sq[new_tp_rank][global_fqn].append(exp_avg_sq_list[i]) candidate_fp32_weights[new_tp_rank][global_fqn].append(fp32_weights_list[i]) - - for new_tp_rank in len(candidate_fp32_weights): - for global_fqn in candidate_fp32_weights.keys(): + + for new_tp_rank in range(len(candidate_fp32_weights)): + for global_fqn in candidate_fp32_weights[new_tp_rank].keys(): splited_exp_avg = candidate_exp_avg[new_tp_rank][global_fqn] splited_exp_avg_sq = candidate_exp_avg_sq[new_tp_rank][global_fqn] splited_fp32_weights = candidate_fp32_weights[new_tp_rank][global_fqn] - - tp_dim = old_meta_data[global_fqn]['tp_dim'] - new_pp_rank = new_meta_data[global_fqn]['pp'] - new_zero1_rank = new_meta_data[global_fqn]['zero1'] - group_id = new_meta_data[global_fqn]['group_id'] + + tp_dim = old_meta_data[global_fqn]["tp_dim"] + new_pp_rank = new_meta_data[global_fqn]["pp"] + new_zero1_rank = new_meta_data[global_fqn]["zero1"] + group_id = new_meta_data[global_fqn]["group_id"] target_new_states = new_states[new_tp_rank][new_pp_rank][new_zero1_rank][group_id] - + if global_fqn not in target_new_states: + target_new_states[global_fqn] = {"exp_avg": None, "exp_avg_sq": None, "fp32_weights": None} + if tp_dim == -1: - assert torch.equal(splited_fp32_weights[0], splited_fp32_weights[1]), f"{global_fqn} should not be splited by tp, but the tensors in different checkpoints are not equal." - target_new_states['global_fqn'].append(global_fqn) - target_new_states['exp_avg'].append(splited_exp_avg[0].detach().clone()) - target_new_states['exp_avg_sq'].append(splited_exp_avg_sq[0].detach().clone()) - target_new_states['fp32_weights'].append(splited_fp32_weights[0].detach().clone()) + assert torch.equal( + splited_fp32_weights[0], splited_fp32_weights[1] + ), f"{global_fqn} should not be splited by tp, but the tensors in different checkpoints are not equal." + target_new_states[global_fqn]["exp_avg"] = splited_exp_avg[0].detach().clone() + target_new_states[global_fqn]["exp_avg_sq"] = splited_exp_avg_sq[0].detach().clone() + target_new_states[global_fqn]["fp32_weights"] = splited_fp32_weights[0].detach().clone() else: - target_new_states['global_fqn'].append(global_fqn) - target_new_states['exp_avg'].append(torch.concat(splited_exp_avg, dim=tp_dim).detach().clone()) - target_new_states['exp_avg_sq'].append(torch.concat(splited_exp_avg_sq, dim=tp_dim).detach().clone()) - target_new_states['fp32_weights'].append(torch.concat(splited_fp32_weights, dim=tp_dim).detach().clone()) - - check_optimizer_convert(new_meta['metaData'][new_tp_rank][new_pp_rank][new_zero1_rank][group_id], target_new_states, group_id) - - -def convert_modeling_ckpt(old_pp_size, new_pp_size, old_tp_size, new_tp_size, old_meta_data, new_meta_data, old_map_local_to_global, tp_mode, folder, saved_folder, new_states): + target_new_states[global_fqn]["exp_avg"] = torch.concat(splited_exp_avg, dim=tp_dim).detach().clone() + target_new_states[global_fqn]["exp_avg_sq"] = ( + torch.concat(splited_exp_avg_sq, dim=tp_dim).detach().clone() + ) + target_new_states[global_fqn]["fp32_weights"] = ( + torch.concat(splited_fp32_weights, dim=tp_dim).detach().clone() + ) + + +def convert_modeling_ckpt( + old_pp_size, + new_pp_size, + old_tp_size, + new_tp_size, + old_meta_data, + new_meta_data, + old_map_local_to_global, + old_tp_mode, + new_tp_mode, + folder, + saved_folder, + new_states, +): + print("Begin model convert", flush=True) for old_pp_rank in range(old_pp_size): if old_tp_size != new_tp_size: if old_tp_size > new_tp_size: assert old_tp_size % new_tp_size == 0, f"Cannot convert {old_tp_size} TP to {new_tp_size} TP." ratio = old_tp_size // new_tp_size - model_tp_merge(old_pp_rank, new_states, old_tp_size, new_tp_size, tp_mode, ratio, old_meta_data, new_meta_data, old_map_local_to_global, new_meta, folder) + model_tp_merge( + old_pp_rank, + new_states, + old_tp_size, + new_tp_size, + old_tp_mode, + ratio, + old_meta_data, + new_meta_data, + old_map_local_to_global, + new_meta, + folder, + ) else: assert new_tp_size % old_tp_size == 0, f"Cannot convert {old_tp_size} TP to {new_tp_size} TP." split_maps = get_mapping(old_tp_size, new_tp_size) ratio = new_tp_size // old_tp_size - model_tp_split(split_maps, old_pp_rank, old_tp_size, new_states, old_meta_data, new_meta_data, ratio, tp_mode, old_map_local_to_global, new_meta, folder) + model_tp_split( + split_maps, + old_pp_rank, + old_tp_size, + new_states, + old_meta_data, + new_meta_data, + ratio, + old_tp_mode, + old_map_local_to_global, + new_meta, + folder, + ) else: - print(f"New tp and old tp are equal. Directly copy.") for old_tp_rank in range(old_tp_size): - ckpt_states = torch.load(os.path.join(folder, f"model_{tp_mode}{old_tp_rank}_pp{old_pp_rank}.pt"), map_location='cpu') + ckpt_states = torch.load( + os.path.join(folder, f"model_{old_tp_mode}{old_tp_rank}_pp{old_pp_rank}.pt"), map_location="cpu" + ) for fqn, tensor in ckpt_states.items(): - new_pp_rank = new_meta_data[fqn]['pp'] - new_states[new_tp_rank][new_pp_rank][fqn] = tensor.detach().clone() + global_fqn = old_map_local_to_global[old_pp_rank][fqn] + new_pp_rank = new_meta_data[global_fqn]["pp"] + new_fqn = new_meta_data[global_fqn]["fqn"] + new_states[old_tp_rank][new_pp_rank][new_fqn] = tensor.detach().clone() + for new_tp_rank in range(new_tp_size): for new_pp_rank in range(new_pp_size): # print(f"pp={new_pp_rank}, tp={new_tp_rank}: {new_states[new_tp_rank][new_pp_rank].keys()}") - file_name = f"model_{tp_mode}{new_tp_rank}_pp{new_pp_rank}.pt" + file_name = f"model_{new_tp_mode}{new_tp_rank}_pp{new_pp_rank}.pt" states = sorted_state_dict(new_states[new_tp_rank][new_pp_rank]) - # torch.save(os.path.join(saved_folder, file_name), states) - - print(f"Finish Modeling") - - -def convert_optimizer_ckpt(old_meta, new_meta, old_pp_size, new_pp_size, old_tp_size, new_tp_size, old_zero1_size, new_zero1_size, old_meta_data, new_meta_data, saved_folder, new_states, processed_ckpt_states): + torch.save(states, os.path.join(saved_folder, file_name)) + + print("Finish model convert", flush=True) + + +def convert_optimizer_ckpt( + old_meta, + new_meta, + old_pp_size, + new_pp_size, + old_tp_size, + new_tp_size, + old_zero1_size, + new_zero1_size, + old_meta_data, + new_meta_data, + new_tp_mode, + saved_folder, + new_states, + processed_ckpt_states, +): + print("Begin optimizer convert", flush=True) for old_pp_rank in range(old_pp_size): for old_zero1_rank in range(old_zero1_size): if old_tp_size != new_tp_size: if old_tp_size > new_tp_size: assert old_tp_size % new_tp_size == 0, f"Cannot convert {old_tp_size} TP to {new_tp_size} TP." ratio = old_tp_size // new_tp_size - optimizer_tp_merge(new_tp_size, old_tp_size, old_pp_rank, old_zero1_rank, old_meta, new_meta, old_meta_data, new_meta_data, processed_ckpt_states, new_states, ratio) + optimizer_tp_merge( + new_tp_size, + old_tp_size, + old_pp_rank, + old_zero1_rank, + old_meta, + old_meta_data, + new_meta_data, + processed_ckpt_states, + new_states, + ratio, + ) else: assert new_tp_size % old_tp_size == 0, f"Cannot convert {old_tp_size} TP to {new_tp_size} TP." split_maps = get_mapping(old_tp_size, new_tp_size) ratio = new_tp_size // old_tp_size - optimizer_tp_split(split_maps, old_tp_size, old_pp_rank, old_zero1_rank, old_meta, new_meta, new_meta_data, processed_ckpt_states, new_states, ratio) + optimizer_tp_split( + split_maps, + old_tp_size, + old_pp_rank, + old_zero1_rank, + old_meta, + new_meta_data, + processed_ckpt_states, + new_states, + ratio, + ) else: - print(f"New tp and old tp are equal. Directly copy.") for old_tp_rank in range(old_tp_size): ckpt_states = processed_ckpt_states[old_tp_rank][old_pp_rank][old_zero1_rank] - for group_id in ckpt_states['flat_fp32_weights'].keys(): - old_metaData = old_meta['metaData'][old_tp_rank][old_pp_rank][old_zero1_rank][group_id] - exp_avg_list = ckpt_states['base_optim_states']['state'][group_id]['exp_avg'] - exp_avg_sq_list = ckpt_states['base_optim_states']['state'][group_id]['exp_avg_sq'] - fp32_weights_list = ckpt_states['flat_fp32_weights'][group_id] + for group_id in ckpt_states["flat_fp32_weights"].keys(): + old_metaData = old_meta["metaData"][old_tp_rank][old_pp_rank][old_zero1_rank][group_id] + exp_avg_list = ckpt_states["base_optim_states"]["state"][group_id]["exp_avg"] + exp_avg_sq_list = ckpt_states["base_optim_states"]["state"][group_id]["exp_avg_sq"] + fp32_weights_list = ckpt_states["flat_fp32_weights"][group_id] for i, global_fqn in enumerate(list(old_metaData.keys())): - new_pp_rank = new_meta_data[global_fqn]['pp'] - new_zero1_rank = new_meta_data[global_fqn]['zero1'] - target_new_states = new_states[new_tp_rank][new_pp_rank][new_zero1_rank][group_id] - target_new_states['global_fqn'].append(global_fqn) - target_new_states['exp_avg'].append(exp_avg_list[i].detach().clone()) - target_new_states['exp_avg_sq'].append(exp_avg_sq_list[i].detach().clone()) - target_new_states['fp32_weights'].append(fp32_weights_list[i].detach().clone()) - + new_pp_rank = new_meta_data[global_fqn]["pp"] + new_zero1_rank = new_meta_data[global_fqn]["zero1"] + target_new_states = new_states[old_tp_rank][new_pp_rank][new_zero1_rank][group_id] + if global_fqn not in target_new_states: + target_new_states[global_fqn] = { + "exp_avg": None, + "exp_avg_sq": None, + "fp32_weights": None, + } + target_new_states[global_fqn]["exp_avg"] = exp_avg_list[i].detach().clone() + target_new_states[global_fqn]["exp_avg_sq"] = exp_avg_sq_list[i].detach().clone() + target_new_states[global_fqn]["fp32_weights"] = fp32_weights_list[i].detach().clone() + for new_tp_rank in range(new_tp_size): for new_pp_rank in range(new_pp_size): for new_zero1_rank in range(new_zero1_size): - file_name = f"optimizer_tp{new_tp_rank}_pp{new_pp_rank}_zo{new_zero1_rank}.pt" + file_name = f"optimizer_{new_tp_mode}{new_tp_rank}_pp{new_pp_rank}_zo{new_zero1_rank}.pt" optimizer_state = new_states[new_tp_rank][new_pp_rank][new_zero1_rank] + metaData = new_meta["metaData"][new_tp_rank][new_pp_rank][new_zero1_rank] + assert set(optimizer_state.keys()) == set( + metaData.keys() + ), f"group_id error: state {list((optimizer_state.keys()))} is different from {list(metaData.keys())}" + base_state = processed_ckpt_states[0][0][0] - step = base_state['base_optim_states']['state'][0]['step'] - base_state['base_optim_states']['state'] = {} - base_state['flat_fp32_weights'] = {} - if 'zero_devide_optim_plan' in base_state: - base_state.pop('zero_devide_optim_plan') - + step = base_state["base_optim_states"]["state"][0]["step"] + base_state["base_optim_states"]["state"] = {} + base_state["flat_fp32_weights"] = {} + if "zero_devide_optim_plan" in base_state: + base_state.pop("zero_devide_optim_plan") + for group_id in optimizer_state.keys(): - if len(optimizer_state[group_id]['global_fqn']) == 0: - print(f"Warning: tp{new_tp_rank}_pp{new_pp_rank}_zo{new_zero1_rank}_groupId{group_id} has no param.") - continue - - assert optimizer_state[group_id]['global_fqn'] == list(new_meta['metaData'][new_tp_rank][new_pp_rank][new_zero1_rank][group_id].keys()) - flat_exp_avg = flatten(optimizer_state[group_id]['exp_avg']) - flat_exp_avg_sq = flatten(optimizer_state[group_id]['exp_avg_sq']) - flat_fp32_weights = flatten(optimizer_state[group_id]['fp32_weights']) - state = {'step': step, 'exp_avg': flat_exp_avg, 'exp_avg_sq': flat_exp_avg_sq} - base_state['base_optim_states']['state'][group_id] = state - base_state['flat_fp32_weights'][group_id] = flat_fp32_weights - - print(f"optimizer tp{new_tp_rank}_pp{new_pp_rank}_zo{new_zero1_rank}: {base_state}") - # torch.save(os.path.join(saved_folder, file_name), base_state) - - print(f"Finish optimizer") - - + meta_fqns = metaData[group_id].keys() + sorted_exp_avg, sorted_exp_avg_sq, sorted_fp32_weights = sort_optimizer_state( + optimizer_state[group_id], meta_fqns + ) + flat_exp_avg = flatten(sorted_exp_avg) + flat_exp_avg_sq = flatten(sorted_exp_avg_sq) + flat_fp32_weights = flatten(sorted_fp32_weights) + state = {"step": step, "exp_avg": flat_exp_avg, "exp_avg_sq": flat_exp_avg_sq} + base_state["base_optim_states"]["state"][group_id] = state + base_state["flat_fp32_weights"][group_id] = flat_fp32_weights + + # print(f"optimizer tp{new_tp_rank}_pp{new_pp_rank}_zo{new_zero1_rank}: {base_state}") + torch.save(base_state, os.path.join(saved_folder, file_name)) + + print("Finish optimizer convert", flush=True) + + if __name__ == "__main__": - folder = "/mnt/petrelfs/lijiaxing/InternEvo/llm_ckpts_3/5" - saved_folder = "/mnt/petrelfs/lijiaxing/InternEvo/llm_ckpts_3_new" - meta_path = "/mnt/petrelfs/lijiaxing/InternEvo/llm_ckpts_3_new/metadata.pt" - - - old_meta = os.path.join(folder, f"metadata.pt") - assert os.path.exists(old_meta), 'old meta file does not exist, plese generate it before converting checkpoint.' - - old_meta = torch.load(old_meta, map_location='cpu') - old_pp_size = old_meta['parallel_setting']['pp_size'] - old_zero1_size = old_meta['parallel_setting']['zero1_size'] - if 'tp_size' in old_meta['parallel_setting']: - tp_mode = 'tp' - elif 'wp_size' in old_meta['parallel_setting']: - tp_mode = 'wp' + + args = parse_args() + folder = args.origin_ckpt_path + saved_folder = args.target_ckpt_path + + if args.origin_meta_path is not None: + old_meta_path = args.origin_meta_path + else: + old_meta_path = os.path.join(folder, "metadata.pt") + + if args.target_meta_path is not None: + new_meta_path = args.target_meta_path + else: + new_meta_path = os.path.join(saved_folder, "metadata.pt") + + assert os.path.exists( + old_meta_path + ), "old meta file does not exist, plese generate it before converting checkpoint." + assert os.path.exists( + new_meta_path + ), "new meta file does not exist, plese generate it before converting checkpoint." + + old_meta = torch.load(old_meta_path, map_location="cpu") + old_pp_size = old_meta["parallel_setting"]["pp_size"] + old_zero1_size = old_meta["parallel_setting"]["zero1_size"] + if "tp_size" in old_meta["parallel_setting"]: + old_tp_mode = "tp" + elif "wp_size" in old_meta["parallel_setting"]: + old_tp_mode = "wp" else: assert False, "tp or wp should be in parallel setting." - old_tp_size = old_meta['parallel_setting'][f"{tp_mode}_size"] + old_tp_size = old_meta["parallel_setting"][f"{old_tp_mode}_size"] old_meta_data = {} for pp_rank in range(old_pp_size): for zero_rank in range(old_zero1_size): - for states in old_meta['metaData'][0][pp_rank][zero_rank].values(): + for states in old_meta["metaData"][0][pp_rank][zero_rank].values(): old_meta_data.update(states) old_map_local_to_global = [{} for _ in range(old_pp_size)] for global_fqn, states in old_meta_data.items(): - old_map_local_to_global[states['pp']][states['fqn']] = global_fqn - - - new_meta = torch.load(meta_path, map_location='cpu') - new_pp_size = new_meta['parallel_setting']['pp_size'] - new_zero1_size = new_meta['parallel_setting']['zero1_size'] - new_tp_size = new_meta['parallel_setting'][f"{tp_mode}_size"] - assert set(new_meta['metaData'][0][0][0].keys()) == set(old_meta['metaData'][0][0][0].keys()), "Error: old meta and new meta have diffent group_id lists." - group_id_list = list(new_meta['metaData'][0][0][0].keys()) + old_map_local_to_global[states["pp"]][states["fqn"]] = global_fqn + + new_meta = torch.load(new_meta_path, map_location="cpu") + new_pp_size = new_meta["parallel_setting"]["pp_size"] + new_zero1_size = new_meta["parallel_setting"]["zero1_size"] + if "tp_size" in new_meta["parallel_setting"]: + new_tp_mode = "tp" + elif "wp_size" in new_meta["parallel_setting"]: + new_tp_mode = "wp" + else: + assert False, "tp or wp should be in parallel setting." + # TODO: support converting between tp and wp + assert old_tp_mode == new_tp_mode, "Do not support converting between tp and wp currently." + new_tp_size = new_meta["parallel_setting"][f"{new_tp_mode}_size"] + assert set(new_meta["metaData"][0][0][0].keys()) == set( + old_meta["metaData"][0][0][0].keys() + ), "Error: old meta and new meta have diffent group_id lists." + group_id_list = list(new_meta["metaData"][0][0][0].keys()) new_meta_data = {} for pp_rank in range(new_pp_size): for zero_rank in range(new_zero1_size): - for states in new_meta['metaData'][0][pp_rank][zero_rank].values(): + for states in new_meta["metaData"][0][pp_rank][zero_rank].values(): new_meta_data.update(states) - - - # new_states = [[{} for _ in range(new_pp_size)] for _ in range(new_tp_size)] - # convert_modeling_ckpt(old_pp_size, new_pp_size, old_tp_size, new_tp_size, old_meta_data, new_meta_data, old_map_local_to_global, tp_mode, folder, saved_folder, new_states) - - processed_ckpt_states = preprocess_optimizer_state(old_tp_size, old_pp_size, old_zero1_size, old_meta, folder) - group_dict = {group_id: {'global_fqn': [], 'exp_avg': [], 'exp_avg_sq': [], 'fp32_weights': []} for group_id in group_id_list} - new_states = [[[group_dict for _ in range(new_zero1_size)] for _ in range(new_pp_size)] for _ in range(new_tp_size)] - convert_optimizer_ckpt(old_meta, new_meta, old_pp_size, new_pp_size, old_tp_size, new_tp_size, old_zero1_size, new_zero1_size, old_meta_data, new_meta_data, saved_folder, new_states, processed_ckpt_states) - - - # srun -p llm_s -x HOST-10-140-60-16 --gres=gpu:8 python tools/convert_ckpt_parallel.py > debug1.log 2>&1 - # srun -p llm_s python tools/convert_ckpt_parallel.py 2>&1|tee debug1.log - - - - \ No newline at end of file + + new_states = [[{} for _ in range(new_pp_size)] for _ in range(new_tp_size)] + convert_modeling_ckpt( + old_pp_size=old_pp_size, + new_pp_size=new_pp_size, + old_tp_size=old_tp_size, + new_tp_size=new_tp_size, + old_meta_data=old_meta_data, + new_meta_data=new_meta_data, + old_map_local_to_global=old_map_local_to_global, + old_tp_mode=old_tp_mode, + new_tp_mode=new_tp_mode, + folder=folder, + saved_folder=saved_folder, + new_states=new_states, + ) + + processed_ckpt_states = preprocess_optimizer_state( + old_tp_size, old_pp_size, old_zero1_size, old_meta, folder, old_tp_mode + ) + new_states = [ + [[defaultdict(dict) for _ in range(new_zero1_size)] for _ in range(new_pp_size)] for _ in range(new_tp_size) + ] + convert_optimizer_ckpt( + old_meta=old_meta, + new_meta=new_meta, + old_pp_size=old_pp_size, + new_pp_size=new_pp_size, + old_tp_size=old_tp_size, + new_tp_size=new_tp_size, + old_zero1_size=old_zero1_size, + new_zero1_size=new_zero1_size, + old_meta_data=old_meta_data, + new_meta_data=new_meta_data, + new_tp_mode=new_tp_mode, + saved_folder=saved_folder, + new_states=new_states, + processed_ckpt_states=processed_ckpt_states, + ) + + if args.copy_file: + file_list = ["context.pt", "sampler.pt", "schedulder.pt"] + for file_name in file_list: + src = os.path.join(folder, file_name) + dst = os.path.join(saved_folder, file_name) + shutil.copy(src, dst) + print(f"Finish copy: {file_list}", flush=True) From f4688c5ebc82db3a7ca950012c4c2e37265968d0 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Thu, 20 Feb 2025 21:55:07 +0800 Subject: [PATCH 3/9] test version --- configs/7B_internlm2.py | 4 + internlm/checkpoint/checkpoint_manager.py | 6 + internlm/core/trainer_builder.py | 11 +- internlm/initialize/launch.py | 3 + internlm/model/modules/embedding.py | 12 +- internlm/model/modules/linear.py | 10 ++ internlm/model/modules/mha.py | 24 +-- .../solver/optimizer/hybrid_zero_optim.py | 21 ++- internlm/train/pipeline.py | 138 +++++++++++++++++- tools/convert_ckpt_parallel.py | 48 +++--- 10 files changed, 225 insertions(+), 52 deletions(-) diff --git a/configs/7B_internlm2.py b/configs/7B_internlm2.py index 97758bba4..da8cdb2f5 100644 --- a/configs/7B_internlm2.py +++ b/configs/7B_internlm2.py @@ -38,6 +38,10 @@ async_upload=True, # async ckpt upload. (only work for boto3 ckpt) async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. + # If enable_save_ckpt=True, metadata will be automatically generated. + # If generate_meta_data.enable=True, metadata can be independently generated in generate_meta_data.path during initialization. + # When only need to generate metadata, please set generate_meta_data to do it. + generate_meta_data=dict(enable=False, path='./') ) TRAIN_FOLDER = None diff --git a/internlm/checkpoint/checkpoint_manager.py b/internlm/checkpoint/checkpoint_manager.py index 2f7f5d4ed..93d618de0 100644 --- a/internlm/checkpoint/checkpoint_manager.py +++ b/internlm/checkpoint/checkpoint_manager.py @@ -229,6 +229,7 @@ def __init__( model_config=None, model_config_file=None, feishu_address=None, + meta_data=None, ) -> None: """ CheckpointManager is used to decide when to store ckpt. If it is an asynchronous @@ -247,6 +248,7 @@ def __init__( self.save_ckpt_folder = get_config_value(ckpt_config, "save_ckpt_folder", None) self.oss_snapshot_freq: int = get_config_value(ckpt_config, "oss_snapshot_freq", 50) self.stop_file_path = get_config_value(ckpt_config, "stop_file_path", None) + self.meta_data = meta_data if self.save_ckpt_folder: self.snapshot_ckpt_folder = get_config_value( ckpt_config, "snapshot_ckpt_folder", os.path.join(self.save_ckpt_folder, "snapshot") @@ -629,6 +631,10 @@ def save_checkpoint( save_optimizer_checkpoint(optim=optimizer, state_path=folder) timer("save-optimizer").stop() + if gpc.get_global_rank() == 0: + assert self.meta_data is not None + llm_save(os.path.join(folder, "metadata.pt"), saved_obj=self.meta_data) + if ( hasattr(train_state, "data_state_dict") and gpc.get_local_rank(ParallelMode.TENSOR) == 0 diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index d0ef284d4..784eab11a 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -20,6 +20,7 @@ from internlm.model.metrics import AccPerplex from internlm.monitor.monitor import send_alert_message from internlm.train.pipeline import ( + generate_meta_data, get_scheduler_hooks, initialize_llm_profile, initialize_optimizer, @@ -124,8 +125,13 @@ def __init__( # initialize optimizer optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model, isp_communicator) + # generate ckpt metaData + meta_data = generate_meta_data(optimizer) + # initialize checkpoint manager and try resume training - self.ckpt_manager = self._initialize_checkpoint_manager(model, optimizer, lr_scheduler, train_dl, config_lines) + self.ckpt_manager = self._initialize_checkpoint_manager( + model, optimizer, lr_scheduler, train_dl, config_lines, meta_data + ) self.ckpt_manager.try_resume_training(train_state, self.current_time) # initialize customed llm writer @@ -178,7 +184,7 @@ def _initialize_criterion(self) -> FlashGPTLMLoss: ) def _initialize_checkpoint_manager( - self, model, optimizer, lr_scheduler, train_dl, config_lines + self, model, optimizer, lr_scheduler, train_dl, config_lines, meta_data ) -> CheckpointManager: return CheckpointManager( ckpt_config=gpc.config.ckpt, @@ -189,6 +195,7 @@ def _initialize_checkpoint_manager( model_config=gpc.config.model, model_config_file="".join(config_lines), feishu_address=gpc.config.monitor.alert.feishu_alert_address, + meta_data=meta_data, ) def _initialize_writer(self, train_state, config_lines) -> Writer: diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 1ac8ef31d..bdd383060 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -214,6 +214,9 @@ def args_sanity_check(): if "enable_save_ckpt" not in ckpt: ckpt._add_item("enable_save_ckpt", True) + if "generate_meta_data" not in ckpt: + ckpt._add_item("generate_meta_data", dict(enable=False, path=None)) + # Saving checkpoint args. if ckpt.enable_save_ckpt: assert "checkpoint_every" in ckpt, "If enable save checkpoint, must give checkpoint_every in config.data!" diff --git a/internlm/model/modules/embedding.py b/internlm/model/modules/embedding.py index 93fcd6b23..164686cc0 100644 --- a/internlm/model/modules/embedding.py +++ b/internlm/model/modules/embedding.py @@ -47,14 +47,17 @@ def __init__( self.vocab_parallel = vocab_parallel parallel_size = gpc.weight_parallel_size if is_using_isp() else gpc.tensor_parallel_size + rank = gpc.get_local_rank(ParallelMode.WEIGHT) if is_using_isp() else gpc.get_local_rank(ParallelMode.TENSOR) if vocab_parallel: assert num_embeddings % parallel_size == 0, f"{num_embeddings} is not divisible by {parallel_size}" self.num_embeddings_per_partition = num_embeddings // parallel_size self.embed_dim_per_partition = embedding_dim - self.vocab_start_index = gpc.get_local_rank(ParallelMode.TENSOR) * self.num_embeddings_per_partition + self.vocab_start_index = rank * self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + self.offset = [self.vocab_start_index, 0] + self.tp_dim = 0 else: assert embedding_dim % parallel_size == 0, f"{embedding_dim} is not divisible by {parallel_size}" @@ -62,12 +65,17 @@ def __init__( self.embed_dim_per_partition = embedding_dim // parallel_size self.vocab_start_index = 0 self.vocab_end_index = self.num_embeddings_per_partition + self.offset = [0, self.embed_dim_per_partition * rank] + self.tp_dim = 1 self.weight = nn.Parameter( torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), dtype=dtype) ) - + self.complete_size = [num_embeddings, embedding_dim] setattr(self.weight, "is_embedding_param", True) + setattr(self.weight, "offset", self.offset) + setattr(self.weight, "complete_size", [num_embeddings, embedding_dim]) + setattr(self.weight, "tp_dim", self.tp_dim) def forward(self, input_: Tensor) -> Tensor: if self.vocab_parallel and not is_using_isp(): diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 856e6ba07..ac348e0e2 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -597,6 +597,7 @@ def __init__( world_size = gpc.get_world_size(parallel_mode) rank = gpc.get_local_rank(parallel_mode) + self.offset = None if split_mode != "none": split_features = out_features if split_mode == "column" else in_features @@ -611,11 +612,20 @@ def __init__( if split_mode == "column": super().__init__(in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype) + self.offset = [rank * local_multiple * multiple_of, 0] + self.tp_dim = 0 elif split_mode == "row": super().__init__(local_multiple * multiple_of, out_features, bias=bias, device=device, dtype=dtype) + self.offset = [0, rank * local_multiple * multiple_of] + self.tp_dim = 1 else: super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) + self.complete_size = [out_features, in_features] + setattr(self.weight, "offset", self.offset) + setattr(self.weight, "complete_size", [out_features, in_features]) + setattr(self.weight, "tp_dim", self.tp_dim) + def forward(self, input: torch.Tensor, batch_sizes: torch.Tensor = None) -> torch.Tensor: # pylint: disable=W0622 _class_name = self.__class__.__name__ assert self._communicator is not None, f"{_class_name} should register with a communicator first." diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index 42418a212..5c8a60b3b 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -65,22 +65,6 @@ def _qkv_pre_load_convert(module: "GQA", state_dict, prefix: str, *args, **kwarg ) -def _qkv_save_convert(module: "GQA", state_dict, prefix: str, *args, **kwargs) -> Dict: # pylint: disable=W0613 - wq_name, wk_name, wv_name, fused_name = ( - f"{prefix}wq.weight", - f"{prefix}wk.weight", - f"{prefix}wv.weight", - f"{prefix}wqkv.weight", - ) - - if module.enable_qkv_fusion: - state_dict[wq_name], state_dict[wk_name], state_dict[wv_name] = split_fused_wqkv_weight( - state_dict.pop(fused_name), *args, **kwargs - ) - - return state_dict - - class MHA(nn.Module): """ Multi-head self-attention and cross-attention. @@ -462,15 +446,15 @@ def __init__( if enable_qkv_fusion: assert bias is False, "Fuesd wqkv only support bias is False." self.wqkv = new_linear("wqkv", embed_dim, q_dim + 2 * self.kv_dim, bias, **factory_kwargs) - self._register_load_state_dict_pre_hook( - partial(_qkv_pre_load_convert, q_dim=q_dim, kv_dim=self.kv_dim), with_module=True - ) - self._register_state_dict_hook(partial(_qkv_save_convert, q_dim=q_dim, kv_dim=self.kv_dim)) else: self.wq = new_linear("wq", embed_dim, q_dim, bias, **factory_kwargs) self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs) self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs) + self._register_load_state_dict_pre_hook( + partial(_qkv_pre_load_convert, q_dim=q_dim, kv_dim=self.kv_dim), with_module=True + ) + self.inner_attn = SelfAttention( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout, layer_idx=layer_idx ) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 49f3fbcf0..478a8af25 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -149,7 +149,7 @@ def __init__( assert self._param_bcast_sync_handler is not None self._isp_communicator = isp_communicator - + self.meta_for_zero = None # iterate over the param group in the optimizer # partition these param groups for data parallel training # and add buffers to parameter store for future access @@ -165,6 +165,9 @@ def __init__( zero_mode = param_group["optimizer_mode"] self._zero_local_rank.append(gpc.get_local_rank(zero_mode)) self._zero_world_size.append(gpc.get_world_size(zero_mode)) + + if self.meta_for_zero is None: + self.meta_for_zero = [{} for _ in range(gpc.get_world_size(zero_mode))] # TODO _broadcast_parallel_mode is not only used in broadcast, maybe can change its name self._broadcast_parallel_mode.append(zero_mode) @@ -278,6 +281,22 @@ def _partition_param_list(self, group_id, param_group): else: rank_to_go = numel_per_rank.index(min(numel_per_rank)) params_per_rank[rank_to_go].append(param) + + if group_id not in self.meta_for_zero[rank_to_go]: + self.meta_for_zero[rank_to_go][group_id] = {} + + from internlm.train.pipeline import map_fqn_local_to_global + + global_fqn = map_fqn_local_to_global[param.fqn] if param.fqn in map_fqn_local_to_global else param.fqn + self.meta_for_zero[rank_to_go][group_id][global_fqn] = { + "tp_dim": getattr(param, "tp_dim", -1), + "pp": gpc.get_local_rank(ParallelMode.PIPELINE), + "zero1": rank_to_go, + "fqn": param.fqn, + "shape": param.shape, + "group_id": group_id, + } + self.params_per_rank_id_dict[-1][rank_to_go].append(global_id) numel_per_rank[rank_to_go] += param.numel() diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 5907a4e30..046c6f178 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -2,10 +2,12 @@ # -*- encoding: utf-8 -*- import math +import os import time from typing import Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union import torch +import torch.distributed as dist from torch import nn from torch.utils.data import DataLoader @@ -116,17 +118,32 @@ logger = get_logger(__file__) internlm_accelerator = get_accelerator() +# For universal checkpoint +# record offset and complete_size of param in each layer +map_layer_attr = {} +map_fqn_local_to_global = {} +map_fqn_global_to_local = {} + def set_param_unique_tracking_name(model): for chunk_id, chunk in enumerate(unwrap_naive_amp(model)): # Important: only works for llama-class models childrens = chunk.named_children() - for _, children in childrens: + for children_name, children in childrens: if isinstance(children, nn.ModuleList): for idx, block in enumerate(children): for name, child in block.named_modules(): + if name == "": + continue + + full_name = f"{chunk_id}.{idx}.{name}" + name_parts = f"{full_name}.weight".split(".", 2) + # global_id for pipeline parallel case + global_id = model.first_layer + idx + local_fqn = f"{children_name}." + ".".join(name_parts[1:]) + global_fqn = f"{children_name}.{global_id}." + ".".join(name_parts[2:]) + if isinstance(child, (ParallelLinearWithCommExt)): - full_name = f"{chunk_id}.{idx}.{name}" setattr( child.weight, "tracking_name", @@ -138,19 +155,132 @@ def set_param_unique_tracking_name(model): "tracking_name", f"{full_name}.bias", ) + + setattr( + child.weight, + "fqn", + f"{local_fqn}", + ) + if child.bias is not None: + setattr( + child.bias, + "fqn", + f"{local_fqn}", + ) + + assert hasattr(child, "offset"), f"{child}" + map_fqn_local_to_global[local_fqn] = global_fqn + map_fqn_global_to_local[global_fqn] = local_fqn + + assert global_fqn not in map_layer_attr, f"{map_layer_attr} exists" + map_layer_attr[global_fqn] = { + "offset": getattr(child, "offset", [0] * len(child.weight.size())), + "complete_size": getattr(child, "complete_size", list(child.weight.size())), + } + + elif isinstance(child, (RMSNorm)): + map_fqn_local_to_global[local_fqn] = global_fqn + map_fqn_global_to_local[global_fqn] = local_fqn + setattr( + child.weight, + "fqn", + f"{local_fqn}", + ) + map_layer_attr[global_fqn] = { + "offset": getattr(child, "offset", [0] * len(child.weight.size())), + "complete_size": getattr(child, "complete_size", list(child.weight.size())), + } + else: + full_name = f"{chunk_id}.{children_name}" + local_fqn = f"{children_name}.weight" + assert getattr(children, "bias", None) is None if isinstance(children, Embedding1D): setattr( children.weight, "tracking_name", - f"{chunk_id}_embedding.weight", + f"{chunk_id}_embeddings.weight", ) + assert local_fqn not in map_layer_attr, f"{map_layer_attr} exists" else: setattr( children.weight, "tracking_name", - f"{chunk_id}_head.weight", + f"{full_name}.weight", ) + assert local_fqn not in map_layer_attr, f"{map_layer_attr} exists" + + setattr( + children.weight, + "fqn", + f"{local_fqn}", + ) + if getattr(children, "bias", None) is not None: + if children.bias is not None: + setattr( + children.bias, + "fqn", + f"{local_fqn}", + ) + + map_layer_attr[local_fqn] = { + "offset": getattr(children, "offset", [0] * len(children.weight.size())), + "complete_size": getattr(children, "complete_size", list(children.weight.size())), + } + + +def generate_meta_data(optimizer): + if not (gpc.config.ckpt.enable_save_ckpt or gpc.config.ckpt.generate_meta_data.enable): + return + + if gpc.get_world_size(ParallelMode.PIPELINE) > 1: + assert optimizer.meta_for_zero is not None + dst = gpc.get_ranks_in_group(ParallelMode.PIPELINE)[0] + if gpc.get_global_rank() == dst: + output = [None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))] + else: + output = None + + dist.gather_object(optimizer.meta_for_zero, output, dst=dst, group=gpc.get_group(ParallelMode.PIPELINE)) + pp_gather_output = output + + else: + pp_gather_output = [optimizer.meta_for_zero] + + tp_parallel = ParallelMode.WEIGHT if is_using_isp() else ParallelMode.TENSOR + if gpc.get_world_size(tp_parallel) > 1: + dst = gpc.get_ranks_in_group(tp_parallel)[0] + if gpc.get_global_rank() == dst: + output = [None for _ in range(gpc.get_world_size(tp_parallel))] + else: + output = None + + dist.gather_object(pp_gather_output, output, dst=dst, group=gpc.get_group(tp_parallel)) + final_output = output + else: + final_output = [pp_gather_output] + + if gpc.get_global_rank() == 0: + assert len(final_output) == gpc.get_world_size(tp_parallel) + assert len(final_output[0]) == gpc.get_world_size(ParallelMode.PIPELINE) + assert len(final_output[0][0]) == gpc.get_world_size(ParallelMode.ZERO1) + tp_mode = "wp_size" if is_using_isp() else "tp_size" + final_meta = { + "parallel_setting": { + tp_mode: gpc.get_world_size(tp_parallel), + "pp_size": gpc.get_world_size(ParallelMode.PIPELINE), + "zero1_size": gpc.get_world_size(ParallelMode.ZERO1), + }, + "metaData": final_output, + } + + if gpc.config.ckpt.generate_meta_data.enable: + save_path = os.path.join(gpc.config.ckpt.generate_meta_data.path, "metadata.pt") + torch.save(final_meta, save_path) + logger.info(f"Successfully generate metadata.pt in {gpc.config.ckpt.generate_meta_data.path}") + + return final_meta + return None def set_fp32_attr_for_model(model: Union[nn.Module, nn.ModuleList]): diff --git a/tools/convert_ckpt_parallel.py b/tools/convert_ckpt_parallel.py index 72622eaf6..5fedff8d4 100644 --- a/tools/convert_ckpt_parallel.py +++ b/tools/convert_ckpt_parallel.py @@ -17,7 +17,8 @@ def parse_args(): args.add_argument("target_ckpt_path", type=str, default=None) args.add_argument("--origin_meta_path", type=str, default=None) args.add_argument("--target_meta_path", type=str, default=None) - args.add_argument("--copy_file", type=bool, default=True) + args.add_argument("--copy_file", type=bool, default=True, help="enable/disable copy other file.") + args.add_argument("--convert_optimizer", type=bool, default=True, help="enable/disable optimizer converting.") return args.parse_args() @@ -620,28 +621,29 @@ def convert_optimizer_ckpt( new_states=new_states, ) - processed_ckpt_states = preprocess_optimizer_state( - old_tp_size, old_pp_size, old_zero1_size, old_meta, folder, old_tp_mode - ) - new_states = [ - [[defaultdict(dict) for _ in range(new_zero1_size)] for _ in range(new_pp_size)] for _ in range(new_tp_size) - ] - convert_optimizer_ckpt( - old_meta=old_meta, - new_meta=new_meta, - old_pp_size=old_pp_size, - new_pp_size=new_pp_size, - old_tp_size=old_tp_size, - new_tp_size=new_tp_size, - old_zero1_size=old_zero1_size, - new_zero1_size=new_zero1_size, - old_meta_data=old_meta_data, - new_meta_data=new_meta_data, - new_tp_mode=new_tp_mode, - saved_folder=saved_folder, - new_states=new_states, - processed_ckpt_states=processed_ckpt_states, - ) + if args.convert_optimizer: + processed_ckpt_states = preprocess_optimizer_state( + old_tp_size, old_pp_size, old_zero1_size, old_meta, folder, old_tp_mode + ) + new_states = [ + [[defaultdict(dict) for _ in range(new_zero1_size)] for _ in range(new_pp_size)] for _ in range(new_tp_size) + ] + convert_optimizer_ckpt( + old_meta=old_meta, + new_meta=new_meta, + old_pp_size=old_pp_size, + new_pp_size=new_pp_size, + old_tp_size=old_tp_size, + new_tp_size=new_tp_size, + old_zero1_size=old_zero1_size, + new_zero1_size=new_zero1_size, + old_meta_data=old_meta_data, + new_meta_data=new_meta_data, + new_tp_mode=new_tp_mode, + saved_folder=saved_folder, + new_states=new_states, + processed_ckpt_states=processed_ckpt_states, + ) if args.copy_file: file_list = ["context.pt", "sampler.pt", "schedulder.pt"] From 771dd1ca0c5b7743e9d3e4eef4da641a1ffc81e2 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Fri, 21 Feb 2025 10:41:41 +0800 Subject: [PATCH 4/9] config --- configs/7B_internlm2.py | 4 ---- configs/demo.py | 4 ++++ 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/configs/7B_internlm2.py b/configs/7B_internlm2.py index da8cdb2f5..97758bba4 100644 --- a/configs/7B_internlm2.py +++ b/configs/7B_internlm2.py @@ -38,10 +38,6 @@ async_upload=True, # async ckpt upload. (only work for boto3 ckpt) async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. - # If enable_save_ckpt=True, metadata will be automatically generated. - # If generate_meta_data.enable=True, metadata can be independently generated in generate_meta_data.path during initialization. - # When only need to generate metadata, please set generate_meta_data to do it. - generate_meta_data=dict(enable=False, path='./') ) TRAIN_FOLDER = None diff --git a/configs/demo.py b/configs/demo.py index e66f007f4..40c33f2cf 100644 --- a/configs/demo.py +++ b/configs/demo.py @@ -49,6 +49,10 @@ async_upload=True, # async ckpt upload. (only work for boto3 ckpt) async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. + # If enable_save_ckpt=True, metadata will be automatically generated. + # If generate_meta_data.enable=True, metadata can be independently generated in generate_meta_data.path during initialization. + # When only need to generate metadata, please set generate_meta_data to do it. + generate_meta_data=dict(enable=False, path='./') ) TRAIN_FOLDER = None # "/path/to/dataset" From 76a3b2b5a5cb1e422b225d50bbd6d5228e958827 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Fri, 21 Feb 2025 11:18:53 +0800 Subject: [PATCH 5/9] fix ci --- internlm/initialize/launch.py | 5 +++ .../solver/optimizer/hybrid_zero_optim.py | 33 +++++++++---------- internlm/train/pipeline.py | 2 +- tests/test_training/test_loss.py | 1 + tools/convert_ckpt_parallel.py | 10 +++--- 5 files changed, 27 insertions(+), 24 deletions(-) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index bdd383060..549e37116 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -217,6 +217,11 @@ def args_sanity_check(): if "generate_meta_data" not in ckpt: ckpt._add_item("generate_meta_data", dict(enable=False, path=None)) + if ckpt.enable_save_ckpt or ckpt.generate_meta_data.enable: + ckpt.need_metadata = True + else: + ckpt.need_metadata = False + # Saving checkpoint args. if ckpt.enable_save_ckpt: assert "checkpoint_every" in ckpt, "If enable save checkpoint, must give checkpoint_every in config.data!" diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 478a8af25..d25eaafba 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -166,7 +166,7 @@ def __init__( self._zero_local_rank.append(gpc.get_local_rank(zero_mode)) self._zero_world_size.append(gpc.get_world_size(zero_mode)) - if self.meta_for_zero is None: + if gpc.config.ckpt.need_metadata and self.meta_for_zero is None: self.meta_for_zero = [{} for _ in range(gpc.get_world_size(zero_mode))] # TODO _broadcast_parallel_mode is not only used in broadcast, maybe can change its name self._broadcast_parallel_mode.append(zero_mode) @@ -281,25 +281,24 @@ def _partition_param_list(self, group_id, param_group): else: rank_to_go = numel_per_rank.index(min(numel_per_rank)) params_per_rank[rank_to_go].append(param) - - if group_id not in self.meta_for_zero[rank_to_go]: - self.meta_for_zero[rank_to_go][group_id] = {} - - from internlm.train.pipeline import map_fqn_local_to_global - - global_fqn = map_fqn_local_to_global[param.fqn] if param.fqn in map_fqn_local_to_global else param.fqn - self.meta_for_zero[rank_to_go][group_id][global_fqn] = { - "tp_dim": getattr(param, "tp_dim", -1), - "pp": gpc.get_local_rank(ParallelMode.PIPELINE), - "zero1": rank_to_go, - "fqn": param.fqn, - "shape": param.shape, - "group_id": group_id, - } - self.params_per_rank_id_dict[-1][rank_to_go].append(global_id) numel_per_rank[rank_to_go] += param.numel() + if gpc.config.ckpt.need_metadata: + if group_id not in self.meta_for_zero[rank_to_go]: + self.meta_for_zero[rank_to_go][group_id] = {} + + from internlm.train.pipeline import map_fqn_local_to_global + global_fqn = map_fqn_local_to_global[param.fqn] if param.fqn in map_fqn_local_to_global else param.fqn + self.meta_for_zero[rank_to_go][group_id][global_fqn] = { + "tp_dim": getattr(param, "tp_dim", -1), + "pp": gpc.get_local_rank(ParallelMode.PIPELINE), + "zero1": rank_to_go, + "fqn": param.fqn, + "shape": param.shape, + "group_id": group_id, + } + # check whether any rank is not assigned to parameters. for rank, params in enumerate(params_per_rank): if len(params) == 0: diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 046c6f178..378eb9766 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -230,7 +230,7 @@ def set_param_unique_tracking_name(model): def generate_meta_data(optimizer): - if not (gpc.config.ckpt.enable_save_ckpt or gpc.config.ckpt.generate_meta_data.enable): + if not gpc.config.ckpt.need_metadata: return if gpc.get_world_size(ParallelMode.PIPELINE) > 1: diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index 4094c5822..f9936ed94 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -125,6 +125,7 @@ def train( initialize_distributed_env(config=config, launcher=launcher) assert hasattr(gpc, "config") and gpc.config is not None + gpc.config.ckpt.need_metadata = False # check parallel config assert ( gpc.get_world_size(ParallelMode.DATA) == dp_size diff --git a/tools/convert_ckpt_parallel.py b/tools/convert_ckpt_parallel.py index 5fedff8d4..117d4673b 100644 --- a/tools/convert_ckpt_parallel.py +++ b/tools/convert_ckpt_parallel.py @@ -70,23 +70,21 @@ def flatten(input_): def unflatten_tensor(flat_tensor, states): - """ - 根据目标形状,将扁平化的张量拆分为多个子张量。 - """ start = 0 unflat_tensors = [] for _, state in states.items(): shape = state["shape"] - size = torch.prod(torch.tensor(shape)) # 计算每个子张量的大小 - tensor = flat_tensor[start : start + size].reshape(*shape) # 切分并恢复形状 + size = torch.prod(torch.tensor(shape)) + tensor = flat_tensor[start : start + size].reshape(*shape) unflat_tensors.append(tensor) - start += size # 更新起始位置 + start += size return unflat_tensors def preprocess_optimizer_state(old_tp_size, old_pp_size, old_zero1_size, old_meta, folder, old_tp_mode): + # preprocess optimizer_state to unflatten format processed_ckpt_states = [ [[{} for _ in range(old_zero1_size)] for _ in range(old_pp_size)] for _ in range(old_tp_size) ] From 88cf02dd87842f8ad2055529d9917d4a7be7802e Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Fri, 21 Feb 2025 13:24:00 +0800 Subject: [PATCH 6/9] fix_ci --- .../solver/optimizer/hybrid_zero_optim.py | 1 + tests/test_core/test_pipeline.py | 1 + tests/test_training/train_CI.py | 22 ++++++++++++++----- tests/test_utils/common_fixture.py | 1 + 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index d25eaafba..02d57df37 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -289,6 +289,7 @@ def _partition_param_list(self, group_id, param_group): self.meta_for_zero[rank_to_go][group_id] = {} from internlm.train.pipeline import map_fqn_local_to_global + global_fqn = map_fqn_local_to_global[param.fqn] if param.fqn in map_fqn_local_to_global else param.fqn self.meta_for_zero[rank_to_go][group_id][global_fqn] = { "tp_dim": getattr(param, "tp_dim", -1), diff --git a/tests/test_core/test_pipeline.py b/tests/test_core/test_pipeline.py index 180fe4b71..5d6a6efd2 100644 --- a/tests/test_core/test_pipeline.py +++ b/tests/test_core/test_pipeline.py @@ -82,6 +82,7 @@ eta_min=1e-5, last_epoch=-1, ), + ckpt=dict(need_metadata=False), ) ) diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py index b33cf4c38..325538a04 100644 --- a/tests/test_training/train_CI.py +++ b/tests/test_training/train_CI.py @@ -20,7 +20,7 @@ from internlm.checkpoint import CheckpointManager # noqa: E402 from internlm.core.context import ParallelMode # noqa: E402 from internlm.core.context import global_context as gpc # noqa: E402 -from internlm.core.trainer import TrainState, Trainer # noqa: E402 +from internlm.core.trainer import Trainer, TrainState # noqa: E402 from internlm.data import ( # noqa: E402 build_train_loader_with_data_type, build_valid_loader_with_data_type, @@ -59,6 +59,18 @@ logger = get_logger(__file__) +def fuse_wqkv(key, state_dict) -> None: # pylint: disable=W0613 + prefix = key.rstrip("wqkv.weight") + wq_name, wk_name, wv_name = ( + f"{prefix}wq.weight", + f"{prefix}wk.weight", + f"{prefix}wv.weight", + ) + + wq, wk, wv = state_dict.pop(wq_name), state_dict.pop(wk_name), state_dict.pop(wv_name) + state_dict[key] = torch.cat([wq, wk, wv], dim=0) + + def check_model_weights(model, ckpt_path, total_equal=False): model1_dict = torch.load(ckpt_path, map_location="cuda") model2_dict = model.state_dict() @@ -66,11 +78,11 @@ def check_model_weights(model, ckpt_path, total_equal=False): copy_of_ordered_dict = model2_dict.copy() for key in copy_of_ordered_dict.keys(): - if "wqkv" in key: - model2_dict[key.replace("wqkv", "Wqkv")] = model2_dict.pop(key) - key = key.replace("wqkv", "Wqkv") if key not in model1_dict: - assert False, f"Error: The key {key} for current model dose not exist in standard ckpt!" + if "wqkv" in key: + fuse_wqkv(key, model1_dict) + else: + assert False, f"Error: The key {key} for current model dose not exist in standard ckpt!" for key in model1_dict.keys(): if key in model2_dict: diff --git a/tests/test_utils/common_fixture.py b/tests/test_utils/common_fixture.py index f4b34ddee..c336524e1 100644 --- a/tests/test_utils/common_fixture.py +++ b/tests/test_utils/common_fixture.py @@ -85,6 +85,7 @@ tensorboard_folder="", alert_address=None, monitor=dict(alert=dict(enable_feishu_alert=False, feishu_alert_address=None, light_monitor_address=None)), + ckpt=dict(need_metadata=False), ) ) From 24c6901b40ba63f7c3e09745c78216ef40ef8c74 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Fri, 21 Feb 2025 14:41:35 +0800 Subject: [PATCH 7/9] fix ci --- ci_scripts/train/load_ckpt.sh | 2 +- internlm/checkpoint/checkpoint_manager.py | 2 +- tests/test_training/train_CI.py | 8 ++++++-- tools/convert_ckpt_parallel.py | 19 ++++++++++++++++++- 4 files changed, 26 insertions(+), 5 deletions(-) diff --git a/ci_scripts/train/load_ckpt.sh b/ci_scripts/train/load_ckpt.sh index 287adbd89..3b447bcf1 100644 --- a/ci_scripts/train/load_ckpt.sh +++ b/ci_scripts/train/load_ckpt.sh @@ -8,7 +8,7 @@ source ./ci_scripts/common/variables.sh readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts" readonly CKPTS40_PATH="$GITHUB_WORKSPACE/llm_ckpts/40" readonly CKPTS40_OUTPUT="${CKPTS40_PATH}/*.pt" -expected_num=22 +expected_num=23 exit_code=0 source ./ci_scripts/common/basic_func.sh diff --git a/internlm/checkpoint/checkpoint_manager.py b/internlm/checkpoint/checkpoint_manager.py index 93d618de0..4c99f63ae 100644 --- a/internlm/checkpoint/checkpoint_manager.py +++ b/internlm/checkpoint/checkpoint_manager.py @@ -631,7 +631,7 @@ def save_checkpoint( save_optimizer_checkpoint(optim=optimizer, state_path=folder) timer("save-optimizer").stop() - if gpc.get_global_rank() == 0: + if gpc.get_global_rank() == 0 and gpc.config.ckpt.need_metadata: assert self.meta_data is not None llm_save(os.path.join(folder, "metadata.pt"), saved_obj=self.meta_data) diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py index 325538a04..d1728bfdc 100644 --- a/tests/test_training/train_CI.py +++ b/tests/test_training/train_CI.py @@ -60,7 +60,7 @@ def fuse_wqkv(key, state_dict) -> None: # pylint: disable=W0613 - prefix = key.rstrip("wqkv.weight") + prefix = key.rstrip("Wqkv.weight") wq_name, wk_name, wv_name = ( f"{prefix}wq.weight", f"{prefix}wk.weight", @@ -78,8 +78,12 @@ def check_model_weights(model, ckpt_path, total_equal=False): copy_of_ordered_dict = model2_dict.copy() for key in copy_of_ordered_dict.keys(): + if "wqkv" in key: + model2_dict[key.replace("wqkv", "Wqkv")] = model2_dict.pop(key) + key = key.replace("wqkv", "Wqkv") + if key not in model1_dict: - if "wqkv" in key: + if "Wqkv" in key: fuse_wqkv(key, model1_dict) else: assert False, f"Error: The key {key} for current model dose not exist in standard ckpt!" diff --git a/tools/convert_ckpt_parallel.py b/tools/convert_ckpt_parallel.py index 117d4673b..78a30cd3a 100644 --- a/tools/convert_ckpt_parallel.py +++ b/tools/convert_ckpt_parallel.py @@ -1,3 +1,16 @@ +""" +Usage: + python tools/convert_ckpt_parallel.py \ + \ + (optional) [--origin_meta_path ] [--target_meta_path ] \ + (optional) [--copy_file ] [--convert_optimizer ] + + When meta_path is not specified, it will automatically search and load meta in the ckpt path. + Default to convert optimizer state and copy files. +Example: + srun -p llm_s python tools/convert_ckpt_parallel.py \ + /llm_ckpt/100 /target_ckpt/converted +""" import argparse import os import shutil @@ -530,7 +543,6 @@ def convert_optimizer_ckpt( base_state["base_optim_states"]["state"][group_id] = state base_state["flat_fp32_weights"][group_id] = flat_fp32_weights - # print(f"optimizer tp{new_tp_rank}_pp{new_pp_rank}_zo{new_zero1_rank}: {base_state}") torch.save(base_state, os.path.join(saved_folder, file_name)) print("Finish optimizer convert", flush=True) @@ -559,6 +571,7 @@ def convert_optimizer_ckpt( new_meta_path ), "new meta file does not exist, plese generate it before converting checkpoint." + # read and process metaData for original ckpt old_meta = torch.load(old_meta_path, map_location="cpu") old_pp_size = old_meta["parallel_setting"]["pp_size"] old_zero1_size = old_meta["parallel_setting"]["zero1_size"] @@ -570,16 +583,19 @@ def convert_optimizer_ckpt( assert False, "tp or wp should be in parallel setting." old_tp_size = old_meta["parallel_setting"][f"{old_tp_mode}_size"] + # To facilitate key query, summarize meta_data. old_meta_data = {} for pp_rank in range(old_pp_size): for zero_rank in range(old_zero1_size): for states in old_meta["metaData"][0][pp_rank][zero_rank].values(): old_meta_data.update(states) + # map local fqn to global fqn old_map_local_to_global = [{} for _ in range(old_pp_size)] for global_fqn, states in old_meta_data.items(): old_map_local_to_global[states["pp"]][states["fqn"]] = global_fqn + # read and process metaData for target ckpt new_meta = torch.load(new_meta_path, map_location="cpu") new_pp_size = new_meta["parallel_setting"]["pp_size"] new_zero1_size = new_meta["parallel_setting"]["zero1_size"] @@ -597,6 +613,7 @@ def convert_optimizer_ckpt( ), "Error: old meta and new meta have diffent group_id lists." group_id_list = list(new_meta["metaData"][0][0][0].keys()) + # To facilitate key query, summarize meta_data. new_meta_data = {} for pp_rank in range(new_pp_size): for zero_rank in range(new_zero1_size): From fdedd37c532a9e87ce029374c7e1e0077c7ecb9b Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Fri, 21 Feb 2025 14:57:55 +0800 Subject: [PATCH 8/9] fix ci --- tests/test_training/train_CI.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py index d1728bfdc..c6d87fda0 100644 --- a/tests/test_training/train_CI.py +++ b/tests/test_training/train_CI.py @@ -60,11 +60,11 @@ def fuse_wqkv(key, state_dict) -> None: # pylint: disable=W0613 - prefix = key.rstrip("Wqkv.weight") + prefix = key.rstrip(".Wqkv.weight") wq_name, wk_name, wv_name = ( - f"{prefix}wq.weight", - f"{prefix}wk.weight", - f"{prefix}wv.weight", + f"{prefix}.wq.weight", + f"{prefix}.wk.weight", + f"{prefix}.wv.weight", ) wq, wk, wv = state_dict.pop(wq_name), state_dict.pop(wk_name), state_dict.pop(wv_name) From c3af20dbf4363c66d8069c0d20fcc526be28ac77 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Fri, 21 Feb 2025 15:19:19 +0800 Subject: [PATCH 9/9] fix ci --- ci_scripts/train/slurm_train.sh | 2 +- ci_scripts/train/torchrun.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ci_scripts/train/slurm_train.sh b/ci_scripts/train/slurm_train.sh index b3117a165..ca5e840b9 100644 --- a/ci_scripts/train/slurm_train.sh +++ b/ci_scripts/train/slurm_train.sh @@ -8,7 +8,7 @@ source ./ci_scripts/common/variables.sh readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts" readonly CKPTS20_PATH="$GITHUB_WORKSPACE/llm_ckpts/20" readonly CKPTS20_OUTPUT="${CKPTS20_PATH}/*.pt" -expected_num=22 +expected_num=23 exit_code=0 source ./ci_scripts/common/basic_func.sh diff --git a/ci_scripts/train/torchrun.sh b/ci_scripts/train/torchrun.sh index 31681d02c..27c815725 100644 --- a/ci_scripts/train/torchrun.sh +++ b/ci_scripts/train/torchrun.sh @@ -8,7 +8,7 @@ source ./ci_scripts/common/variables.sh readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts" readonly CKPTS20_PATH="$GITHUB_WORKSPACE/llm_ckpts/20" readonly CKPTS_OUTPUT="${CKPTS20_PATH}/*.pt" -expected_num=22 +expected_num=23 exit_code=0 source ./ci_scripts/common/basic_func.sh