Skip to content

Commit fb09282

Browse files
feat(tools): update InternEvo style ckpt inference tool. (#260)
1 parent 2ac2d08 commit fb09282

File tree

15 files changed

+504
-44
lines changed

15 files changed

+504
-44
lines changed

configs/7B_internlm2.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
JOB_NAME = "7b_internlm2_train"
2-
model_type="INTERNLM2_PUBLIC"
2+
model_type = "INTERNLM2_PUBLIC"
33
DO_ALERT = False
44

55
VOCAB_SIZE = 92544
@@ -205,3 +205,18 @@
205205
# metric_dtype can be "fp32" or other string
206206
# only when set to "fp32" will use fp32 to calc in metrics
207207
# metric_dtype = "fp32"
208+
209+
generation = dict(
210+
ckpt_folder="/path/to/saved/ckpt",
211+
output_folder="/path/to/save/generation",
212+
batch_size=1,
213+
eos_id=[2, 0],
214+
bos_id=1,
215+
max_length=100,
216+
do_sample=True,
217+
temperature=1.0,
218+
top_k=50,
219+
top_p=1.0,
220+
repetition_penalty=1,
221+
length_penalty=1.0,
222+
)

configs/_base_/models/internlm2_1B.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
mlp_ratio=MLP_RATIO,
2626
multiple_of=MULTIPLE_OF,
2727
norm_type="rmsnorm",
28-
adapt_hf=True,
28+
qk_interleaved=False,
2929
apply_post_layer_norm=False,
3030
no_bias=True,
3131
layer_norm_epsilon=1e-5,

configs/_base_/models/internlm2_20B.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
num_kv_attention_heads=NUM_KV_ATTENTION_HEAD,
2424
mlp_ratio=MLP_RATIO,
2525
norm_type="rmsnorm",
26-
adapt_hf=True,
26+
qk_interleaved=False,
2727
apply_post_layer_norm=False,
2828
no_bias=True,
2929
layer_norm_epsilon=1e-5,

configs/_base_/models/internlm2_7B.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
num_kv_attention_heads=NUM_KV_ATTENTION_HEAD,
2424
mlp_ratio=MLP_RATIO,
2525
norm_type="rmsnorm",
26-
adapt_hf=False,
26+
qk_interleaved=True,
2727
apply_post_layer_norm=False,
2828
no_bias=True,
2929
layer_norm_epsilon=1e-5,

doc/usage.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,31 @@ $ torchrun --nnodes=1 --nproc_per_node=8 train.py --config ./configs/7B_sft.py -
459459
2023-07-07 12:29:16,994 INFO train.py:323 in record_current_batch_training_metrics -- tflops=189.3109313713174,step=5,loss=9.822169303894043,tgs (tokens/gpu/second)=4262.67,lr=1.4000000000000001e-06,loss_scale=65536.0,grad_norm=47.10386835560855,micro_num=4,num_consumed_tokens=786432,inf_nan_skip_batches=0,num_samples_in_batch=17,largest_length=2048,largest_batch=6,smallest_batch=3,adam_beta2=0.95,fwd_bwd_time=3.69
460460
```
461461

462+
### 加载训练的checkpoint并生成
463+
464+
若在 slurm 上启动分布式运行环境,多节点 16 卡的运行命令如下所示:
465+
```bash
466+
$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python generate.py --config ./configs/7B_sft.py
467+
```
468+
469+
在配置文件中添加`generation`配置
470+
```
471+
generation = dict(
472+
ckpt_folder="/path/to/saved/ckpt",
473+
output_folder="/path/to/save/generation",
474+
batch_size=1,
475+
eos_id=[2, 0],
476+
bos_id=1,
477+
max_length=100,
478+
do_sample=True,
479+
temperature=1.0,
480+
top_k=50,
481+
top_p=1.0,
482+
repetition_penalty=1,
483+
length_penalty=1.0,
484+
)
485+
```
486+
462487
### 长文本生成
463488

464489
在推理阶段,我们可以使用 Dynamic NTK RoPE 来代替原始的 RoPE,从而使得模型能够适应长文本的输入输出,达到 16K 的外推效果。

generate.py

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
#!/usr/bin/env python
2+
# -*- encoding: utf-8 -*-
3+
4+
import gc
5+
import json
6+
import logging
7+
import os
8+
import shutil
9+
import socket
10+
import traceback
11+
from pathlib import Path
12+
13+
import numpy as np
14+
import torch
15+
from tqdm import tqdm
16+
17+
from internlm.accelerator import get_accelerator
18+
from internlm.apis.inference import SequenceGenerator
19+
from internlm.core.context import global_context as gpc
20+
from internlm.data import build_generation_loader_with_data_type
21+
from internlm.initialize import initialize_distributed_env
22+
from internlm.monitor import initialize_monitor_manager
23+
from internlm.monitor.monitor import monitor_manager as mm
24+
from internlm.train import initialize_model, initialize_parallel_communicator
25+
from internlm.utils.common import (
26+
enable_pytorch_expandable_segments,
27+
launch_time,
28+
parse_args,
29+
)
30+
from internlm.utils.gputest import empty_cache_and_diag
31+
from internlm.utils.logger import get_logger
32+
from internlm.utils.megatron_timers import megatron_timer as timer
33+
from internlm.utils.parallel import get_parallel_log_file_name
34+
from internlm.utils.storage_manager import init_storage_manager
35+
from tools.load_internlm2_model import get_model_device, merge_pp_within_tp
36+
37+
# global llm logger
38+
logger = logging.getLogger(__file__)
39+
internlm_accelerator = get_accelerator()
40+
41+
42+
def get_latest_subdirectory(folder_path):
43+
if ":" in folder_path:
44+
prefix, folder_path = folder_path.split(":", 1)
45+
prefix += ":"
46+
else:
47+
prefix = ""
48+
subdirectories = [name for name in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, name))]
49+
subdirectories_sorted = sorted(
50+
subdirectories, key=lambda x: os.path.getctime(os.path.join(folder_path, x)), reverse=True
51+
)
52+
if subdirectories_sorted:
53+
return prefix + os.path.join(folder_path, subdirectories_sorted[0])
54+
else:
55+
return None
56+
57+
58+
def main():
59+
enable_pytorch_expandable_segments()
60+
61+
generation_config = gpc.config["generation"]
62+
63+
generation_config = type(
64+
"",
65+
(object,),
66+
{
67+
"output_folder": Path(generation_config["output_folder"]),
68+
"ckpt_folder": generation_config["ckpt_folder"]
69+
if "ckpt_folder" in generation_config
70+
else get_latest_subdirectory(gpc.config.ckpt.save_ckpt_folder),
71+
"data_folder": generation_config["data_folder"] if "data_folder" in generation_config else None,
72+
"batch_size": generation_config.get("batch_size", None),
73+
"eos_id": generation_config.get("eos_id", 2),
74+
"bos_id": generation_config.get("bos_id", 1),
75+
"pad_id": generation_config.get("bos_id", 1),
76+
"additional_eos_token_list": generation_config.get("additional_eos_token_list", None),
77+
"max_length": generation_config.get("max_length", 100),
78+
"do_sample": generation_config.get("do_sample", True),
79+
"temperature": generation_config.get("temperature", 1.0),
80+
"num_beams": generation_config.get("num_beams", 1),
81+
"top_k": generation_config.get("top_k", 50),
82+
"top_p": generation_config.get("top_p", 1.0),
83+
"repetition_penalty": generation_config.get("repetition_penalty", 1),
84+
"length_penalty": generation_config.get("length_penalty", 1.0),
85+
},
86+
)
87+
88+
if not os.path.exists(generation_config.output_folder.absolute()):
89+
generation_config.output_folder.mkdir(exist_ok=True, parents=True)
90+
91+
# get and broadcast current time
92+
current_time = launch_time()
93+
objs = [current_time]
94+
torch.distributed.broadcast_object_list(objs, src=0)
95+
current_time = objs[0].replace(":", ".")
96+
global logger
97+
logger = get_logger(
98+
__file__, launch_time=current_time, job_name=gpc.config.JOB_NAME, file_name=get_parallel_log_file_name()
99+
)
100+
101+
try:
102+
init_storage_manager(False, None, None)
103+
except AssertionError:
104+
pass
105+
except Exception as e:
106+
raise e
107+
108+
# initialize model
109+
model = initialize_model()
110+
_ = initialize_parallel_communicator(model)
111+
model = model.model
112+
113+
state_dict = merge_pp_within_tp(generation_config.ckpt_folder, del_model_prefix=True)
114+
missing_k, unexpected_keys = model.load_state_dict(state_dict, strict=False)
115+
if len(missing_k) != 0:
116+
logger.warning(f"Warning: missing keys {missing_k}")
117+
if len(unexpected_keys) != 0:
118+
logger.warning(f"Warning: unexpected keys {unexpected_keys}")
119+
120+
param_dtype = gpc.config.model.dtype
121+
if isinstance(param_dtype, str):
122+
try:
123+
param_dtype = eval(param_dtype) # pylint: disable=W0123
124+
finally:
125+
pass
126+
if param_dtype == "torch.tf32":
127+
param_dtype = torch.float32
128+
torch.backends.cudnn.allow_tf32 = True
129+
torch.backends.cuda.matmul.allow_tf32 = True
130+
131+
model.to(param_dtype)
132+
model.eval()
133+
torch.distributed.barrier()
134+
135+
data_cfg = gpc.config.data
136+
if generation_config.data_folder:
137+
data_cfg.valid_folder = generation_config.data_folder
138+
gene_dls = build_generation_loader_with_data_type(data_cfg, generation_config)
139+
140+
sequenece_generator = SequenceGenerator(
141+
decoder=model,
142+
eos_token_id=generation_config.eos_id,
143+
pad_token_id=generation_config.bos_id,
144+
bos_token_id=generation_config.pad_id,
145+
additional_eos_token_list=generation_config.additional_eos_token_list,
146+
)
147+
148+
ds_count = 0
149+
gc.disable()
150+
with torch.inference_mode():
151+
for ds_name, gene_dl in gene_dls.items():
152+
if len(gene_dl) == 0:
153+
logger.info(f"Validation dataset: {ds_name} is empty")
154+
continue
155+
timer(f"dataset {ds_count}").start()
156+
157+
# pylint: disable=forgotten-debug-statement
158+
all_output_str = []
159+
# pylint: disable=unused-variable
160+
for val_idx, (labels, input_ids) in tqdm(
161+
enumerate(gene_dl),
162+
desc="generate.",
163+
total=len(gene_dl),
164+
position=1,
165+
leave=False,
166+
):
167+
empty_cache_and_diag(val_idx, interval=gpc.config.data.empty_cache_and_diag_interval)
168+
input_ids = torch.LongTensor(input_ids)
169+
if input_ids.size(1) >= generation_config.max_length:
170+
logger.warning(
171+
f"Not generating for the {val_idx}'th batch, because the sequence "
172+
f"length of the batch is {input_ids.size(1)} over the max generation"
173+
f"length {generation_config.max_length}"
174+
)
175+
output_ids = input_ids[:, : generation_config.max_length, ...]
176+
else:
177+
input_ids = input_ids.clamp(min=0, max=gpc.config.model.vocab_size).to(get_model_device(model))
178+
output_ids = sequenece_generator.generate(
179+
tokens=input_ids,
180+
max_length=generation_config.max_length,
181+
do_sample=generation_config.do_sample,
182+
temperature=generation_config.temperature,
183+
num_beams=generation_config.num_beams,
184+
top_k=generation_config.top_k,
185+
top_p=generation_config.top_p,
186+
repetition_penalty=generation_config.repetition_penalty,
187+
length_penalty=generation_config.length_penalty,
188+
)
189+
for output in output_ids:
190+
not_pad_indices = torch.nonzero(output != generation_config.pad_id)
191+
if not_pad_indices.nelement() != 0:
192+
sequence = output[not_pad_indices[0] :]
193+
else:
194+
sequence = output
195+
sequence = sequence.tolist()
196+
line = str.encode(json.dumps({"tokens": sequence}))
197+
all_output_str.append(
198+
(
199+
line,
200+
len(line),
201+
)
202+
)
203+
204+
bin_meta, last_position = [], 0
205+
with open(generation_config.output_folder.joinpath(f"{ds_name}.bin"), "wb") as file:
206+
for line, token_num in all_output_str:
207+
file.write(line)
208+
bin_meta.append((last_position, token_num))
209+
last_position += len(line)
210+
211+
with open(generation_config.output_folder.joinpath(f"{ds_name}.bin.meta"), "wb") as file:
212+
np.save(file, bin_meta)
213+
214+
timer(f"dataset {ds_count}").stop()
215+
ds_count += 1
216+
217+
218+
if __name__ == "__main__":
219+
args = parse_args()
220+
hostname = socket.gethostname()
221+
222+
# initialize distributed environment
223+
initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed)
224+
assert hasattr(gpc, "config") and gpc.config is not None
225+
assert "generation" in gpc.config, f"Please set `generation` config in `{args.config}` file"
226+
assert (
227+
"output_folder" in gpc.config["generation"]
228+
), "Must set `output_folder` for the save folder of generation data"
229+
230+
# initialize monitor manager context
231+
with initialize_monitor_manager(
232+
job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address
233+
):
234+
try:
235+
main()
236+
except Exception:
237+
logger.error(
238+
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}",
239+
)
240+
mm.monitor_exception(
241+
alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc()
242+
)
243+
244+
# internlm_accelerator.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle")
245+
finally:
246+
# local rank0 delete all files in shm_path, when use shm
247+
devices_per_node = internlm_accelerator.device_count()
248+
local_rank = gpc.get_global_rank() % devices_per_node
249+
if gpc.config.data.use_shm and local_rank == 0:
250+
if os.path.exists(gpc.config.data.shm_path):
251+
shutil.rmtree(gpc.config.data.shm_path)

internlm/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from .build_dataloader import (
2+
build_generation_loader_with_data_type,
23
build_train_loader_with_data_type,
34
build_valid_loader_with_data_type,
45
)
56

67
__all__ = [
78
"build_train_loader_with_data_type",
89
"build_valid_loader_with_data_type",
10+
"build_generation_loader_with_data_type",
911
]

0 commit comments

Comments
 (0)