|
| 1 | +#!/usr/bin/env python |
| 2 | +# -*- encoding: utf-8 -*- |
| 3 | + |
| 4 | +import gc |
| 5 | +import json |
| 6 | +import logging |
| 7 | +import os |
| 8 | +import shutil |
| 9 | +import socket |
| 10 | +import traceback |
| 11 | +from pathlib import Path |
| 12 | + |
| 13 | +import numpy as np |
| 14 | +import torch |
| 15 | +from tqdm import tqdm |
| 16 | + |
| 17 | +from internlm.accelerator import get_accelerator |
| 18 | +from internlm.apis.inference import SequenceGenerator |
| 19 | +from internlm.core.context import global_context as gpc |
| 20 | +from internlm.data import build_generation_loader_with_data_type |
| 21 | +from internlm.initialize import initialize_distributed_env |
| 22 | +from internlm.monitor import initialize_monitor_manager |
| 23 | +from internlm.monitor.monitor import monitor_manager as mm |
| 24 | +from internlm.train import initialize_model, initialize_parallel_communicator |
| 25 | +from internlm.utils.common import ( |
| 26 | + enable_pytorch_expandable_segments, |
| 27 | + launch_time, |
| 28 | + parse_args, |
| 29 | +) |
| 30 | +from internlm.utils.gputest import empty_cache_and_diag |
| 31 | +from internlm.utils.logger import get_logger |
| 32 | +from internlm.utils.megatron_timers import megatron_timer as timer |
| 33 | +from internlm.utils.parallel import get_parallel_log_file_name |
| 34 | +from internlm.utils.storage_manager import init_storage_manager |
| 35 | +from tools.load_internlm2_model import get_model_device, merge_pp_within_tp |
| 36 | + |
| 37 | +# global llm logger |
| 38 | +logger = logging.getLogger(__file__) |
| 39 | +internlm_accelerator = get_accelerator() |
| 40 | + |
| 41 | + |
| 42 | +def get_latest_subdirectory(folder_path): |
| 43 | + if ":" in folder_path: |
| 44 | + prefix, folder_path = folder_path.split(":", 1) |
| 45 | + prefix += ":" |
| 46 | + else: |
| 47 | + prefix = "" |
| 48 | + subdirectories = [name for name in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, name))] |
| 49 | + subdirectories_sorted = sorted( |
| 50 | + subdirectories, key=lambda x: os.path.getctime(os.path.join(folder_path, x)), reverse=True |
| 51 | + ) |
| 52 | + if subdirectories_sorted: |
| 53 | + return prefix + os.path.join(folder_path, subdirectories_sorted[0]) |
| 54 | + else: |
| 55 | + return None |
| 56 | + |
| 57 | + |
| 58 | +def main(): |
| 59 | + enable_pytorch_expandable_segments() |
| 60 | + |
| 61 | + generation_config = gpc.config["generation"] |
| 62 | + |
| 63 | + generation_config = type( |
| 64 | + "", |
| 65 | + (object,), |
| 66 | + { |
| 67 | + "output_folder": Path(generation_config["output_folder"]), |
| 68 | + "ckpt_folder": generation_config["ckpt_folder"] |
| 69 | + if "ckpt_folder" in generation_config |
| 70 | + else get_latest_subdirectory(gpc.config.ckpt.save_ckpt_folder), |
| 71 | + "data_folder": generation_config["data_folder"] if "data_folder" in generation_config else None, |
| 72 | + "batch_size": generation_config.get("batch_size", None), |
| 73 | + "eos_id": generation_config.get("eos_id", 2), |
| 74 | + "bos_id": generation_config.get("bos_id", 1), |
| 75 | + "pad_id": generation_config.get("bos_id", 1), |
| 76 | + "additional_eos_token_list": generation_config.get("additional_eos_token_list", None), |
| 77 | + "max_length": generation_config.get("max_length", 100), |
| 78 | + "do_sample": generation_config.get("do_sample", True), |
| 79 | + "temperature": generation_config.get("temperature", 1.0), |
| 80 | + "num_beams": generation_config.get("num_beams", 1), |
| 81 | + "top_k": generation_config.get("top_k", 50), |
| 82 | + "top_p": generation_config.get("top_p", 1.0), |
| 83 | + "repetition_penalty": generation_config.get("repetition_penalty", 1), |
| 84 | + "length_penalty": generation_config.get("length_penalty", 1.0), |
| 85 | + }, |
| 86 | + ) |
| 87 | + |
| 88 | + if not os.path.exists(generation_config.output_folder.absolute()): |
| 89 | + generation_config.output_folder.mkdir(exist_ok=True, parents=True) |
| 90 | + |
| 91 | + # get and broadcast current time |
| 92 | + current_time = launch_time() |
| 93 | + objs = [current_time] |
| 94 | + torch.distributed.broadcast_object_list(objs, src=0) |
| 95 | + current_time = objs[0].replace(":", ".") |
| 96 | + global logger |
| 97 | + logger = get_logger( |
| 98 | + __file__, launch_time=current_time, job_name=gpc.config.JOB_NAME, file_name=get_parallel_log_file_name() |
| 99 | + ) |
| 100 | + |
| 101 | + try: |
| 102 | + init_storage_manager(False, None, None) |
| 103 | + except AssertionError: |
| 104 | + pass |
| 105 | + except Exception as e: |
| 106 | + raise e |
| 107 | + |
| 108 | + # initialize model |
| 109 | + model = initialize_model() |
| 110 | + _ = initialize_parallel_communicator(model) |
| 111 | + model = model.model |
| 112 | + |
| 113 | + state_dict = merge_pp_within_tp(generation_config.ckpt_folder, del_model_prefix=True) |
| 114 | + missing_k, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
| 115 | + if len(missing_k) != 0: |
| 116 | + logger.warning(f"Warning: missing keys {missing_k}") |
| 117 | + if len(unexpected_keys) != 0: |
| 118 | + logger.warning(f"Warning: unexpected keys {unexpected_keys}") |
| 119 | + |
| 120 | + param_dtype = gpc.config.model.dtype |
| 121 | + if isinstance(param_dtype, str): |
| 122 | + try: |
| 123 | + param_dtype = eval(param_dtype) # pylint: disable=W0123 |
| 124 | + finally: |
| 125 | + pass |
| 126 | + if param_dtype == "torch.tf32": |
| 127 | + param_dtype = torch.float32 |
| 128 | + torch.backends.cudnn.allow_tf32 = True |
| 129 | + torch.backends.cuda.matmul.allow_tf32 = True |
| 130 | + |
| 131 | + model.to(param_dtype) |
| 132 | + model.eval() |
| 133 | + torch.distributed.barrier() |
| 134 | + |
| 135 | + data_cfg = gpc.config.data |
| 136 | + if generation_config.data_folder: |
| 137 | + data_cfg.valid_folder = generation_config.data_folder |
| 138 | + gene_dls = build_generation_loader_with_data_type(data_cfg, generation_config) |
| 139 | + |
| 140 | + sequenece_generator = SequenceGenerator( |
| 141 | + decoder=model, |
| 142 | + eos_token_id=generation_config.eos_id, |
| 143 | + pad_token_id=generation_config.bos_id, |
| 144 | + bos_token_id=generation_config.pad_id, |
| 145 | + additional_eos_token_list=generation_config.additional_eos_token_list, |
| 146 | + ) |
| 147 | + |
| 148 | + ds_count = 0 |
| 149 | + gc.disable() |
| 150 | + with torch.inference_mode(): |
| 151 | + for ds_name, gene_dl in gene_dls.items(): |
| 152 | + if len(gene_dl) == 0: |
| 153 | + logger.info(f"Validation dataset: {ds_name} is empty") |
| 154 | + continue |
| 155 | + timer(f"dataset {ds_count}").start() |
| 156 | + |
| 157 | + # pylint: disable=forgotten-debug-statement |
| 158 | + all_output_str = [] |
| 159 | + # pylint: disable=unused-variable |
| 160 | + for val_idx, (labels, input_ids) in tqdm( |
| 161 | + enumerate(gene_dl), |
| 162 | + desc="generate.", |
| 163 | + total=len(gene_dl), |
| 164 | + position=1, |
| 165 | + leave=False, |
| 166 | + ): |
| 167 | + empty_cache_and_diag(val_idx, interval=gpc.config.data.empty_cache_and_diag_interval) |
| 168 | + input_ids = torch.LongTensor(input_ids) |
| 169 | + if input_ids.size(1) >= generation_config.max_length: |
| 170 | + logger.warning( |
| 171 | + f"Not generating for the {val_idx}'th batch, because the sequence " |
| 172 | + f"length of the batch is {input_ids.size(1)} over the max generation" |
| 173 | + f"length {generation_config.max_length}" |
| 174 | + ) |
| 175 | + output_ids = input_ids[:, : generation_config.max_length, ...] |
| 176 | + else: |
| 177 | + input_ids = input_ids.clamp(min=0, max=gpc.config.model.vocab_size).to(get_model_device(model)) |
| 178 | + output_ids = sequenece_generator.generate( |
| 179 | + tokens=input_ids, |
| 180 | + max_length=generation_config.max_length, |
| 181 | + do_sample=generation_config.do_sample, |
| 182 | + temperature=generation_config.temperature, |
| 183 | + num_beams=generation_config.num_beams, |
| 184 | + top_k=generation_config.top_k, |
| 185 | + top_p=generation_config.top_p, |
| 186 | + repetition_penalty=generation_config.repetition_penalty, |
| 187 | + length_penalty=generation_config.length_penalty, |
| 188 | + ) |
| 189 | + for output in output_ids: |
| 190 | + not_pad_indices = torch.nonzero(output != generation_config.pad_id) |
| 191 | + if not_pad_indices.nelement() != 0: |
| 192 | + sequence = output[not_pad_indices[0] :] |
| 193 | + else: |
| 194 | + sequence = output |
| 195 | + sequence = sequence.tolist() |
| 196 | + line = str.encode(json.dumps({"tokens": sequence})) |
| 197 | + all_output_str.append( |
| 198 | + ( |
| 199 | + line, |
| 200 | + len(line), |
| 201 | + ) |
| 202 | + ) |
| 203 | + |
| 204 | + bin_meta, last_position = [], 0 |
| 205 | + with open(generation_config.output_folder.joinpath(f"{ds_name}.bin"), "wb") as file: |
| 206 | + for line, token_num in all_output_str: |
| 207 | + file.write(line) |
| 208 | + bin_meta.append((last_position, token_num)) |
| 209 | + last_position += len(line) |
| 210 | + |
| 211 | + with open(generation_config.output_folder.joinpath(f"{ds_name}.bin.meta"), "wb") as file: |
| 212 | + np.save(file, bin_meta) |
| 213 | + |
| 214 | + timer(f"dataset {ds_count}").stop() |
| 215 | + ds_count += 1 |
| 216 | + |
| 217 | + |
| 218 | +if __name__ == "__main__": |
| 219 | + args = parse_args() |
| 220 | + hostname = socket.gethostname() |
| 221 | + |
| 222 | + # initialize distributed environment |
| 223 | + initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) |
| 224 | + assert hasattr(gpc, "config") and gpc.config is not None |
| 225 | + assert "generation" in gpc.config, f"Please set `generation` config in `{args.config}` file" |
| 226 | + assert ( |
| 227 | + "output_folder" in gpc.config["generation"] |
| 228 | + ), "Must set `output_folder` for the save folder of generation data" |
| 229 | + |
| 230 | + # initialize monitor manager context |
| 231 | + with initialize_monitor_manager( |
| 232 | + job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address |
| 233 | + ): |
| 234 | + try: |
| 235 | + main() |
| 236 | + except Exception: |
| 237 | + logger.error( |
| 238 | + f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}", |
| 239 | + ) |
| 240 | + mm.monitor_exception( |
| 241 | + alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc() |
| 242 | + ) |
| 243 | + |
| 244 | + # internlm_accelerator.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") |
| 245 | + finally: |
| 246 | + # local rank0 delete all files in shm_path, when use shm |
| 247 | + devices_per_node = internlm_accelerator.device_count() |
| 248 | + local_rank = gpc.get_global_rank() % devices_per_node |
| 249 | + if gpc.config.data.use_shm and local_rank == 0: |
| 250 | + if os.path.exists(gpc.config.data.shm_path): |
| 251 | + shutil.rmtree(gpc.config.data.shm_path) |
0 commit comments