Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci_scripts/train/load_ckpt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ source ./ci_scripts/common/variables.sh
readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts"
readonly CKPTS40_PATH="$GITHUB_WORKSPACE/llm_ckpts/40"
readonly CKPTS40_OUTPUT="${CKPTS40_PATH}/*.pt"
expected_num=22
expected_num=23
exit_code=0

source ./ci_scripts/common/basic_func.sh
Expand Down
2 changes: 1 addition & 1 deletion ci_scripts/train/slurm_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ source ./ci_scripts/common/variables.sh
readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts"
readonly CKPTS20_PATH="$GITHUB_WORKSPACE/llm_ckpts/20"
readonly CKPTS20_OUTPUT="${CKPTS20_PATH}/*.pt"
expected_num=22
expected_num=23
exit_code=0

source ./ci_scripts/common/basic_func.sh
Expand Down
2 changes: 1 addition & 1 deletion ci_scripts/train/torchrun.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ source ./ci_scripts/common/variables.sh
readonly CKPTS_PATH="$GITHUB_WORKSPACE/llm_ckpts"
readonly CKPTS20_PATH="$GITHUB_WORKSPACE/llm_ckpts/20"
readonly CKPTS_OUTPUT="${CKPTS20_PATH}/*.pt"
expected_num=22
expected_num=23
exit_code=0

source ./ci_scripts/common/basic_func.sh
Expand Down
4 changes: 4 additions & 0 deletions configs/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
# If enable_save_ckpt=True, metadata will be automatically generated.
# If generate_meta_data.enable=True, metadata can be independently generated in generate_meta_data.path during initialization.
# When only need to generate metadata, please set generate_meta_data to do it.
generate_meta_data=dict(enable=False, path='./')
)

TRAIN_FOLDER = None # "/path/to/dataset"
Expand Down
6 changes: 6 additions & 0 deletions internlm/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def __init__(
model_config=None,
model_config_file=None,
feishu_address=None,
meta_data=None,
) -> None:
"""
CheckpointManager is used to decide when to store ckpt. If it is an asynchronous
Expand All @@ -247,6 +248,7 @@ def __init__(
self.save_ckpt_folder = get_config_value(ckpt_config, "save_ckpt_folder", None)
self.oss_snapshot_freq: int = get_config_value(ckpt_config, "oss_snapshot_freq", 50)
self.stop_file_path = get_config_value(ckpt_config, "stop_file_path", None)
self.meta_data = meta_data
if self.save_ckpt_folder:
self.snapshot_ckpt_folder = get_config_value(
ckpt_config, "snapshot_ckpt_folder", os.path.join(self.save_ckpt_folder, "snapshot")
Expand Down Expand Up @@ -629,6 +631,10 @@ def save_checkpoint(
save_optimizer_checkpoint(optim=optimizer, state_path=folder)
timer("save-optimizer").stop()

if gpc.get_global_rank() == 0 and gpc.config.ckpt.need_metadata:
assert self.meta_data is not None
llm_save(os.path.join(folder, "metadata.pt"), saved_obj=self.meta_data)

if (
hasattr(train_state, "data_state_dict")
and gpc.get_local_rank(ParallelMode.TENSOR) == 0
Expand Down
11 changes: 9 additions & 2 deletions internlm/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from internlm.model.metrics import AccPerplex
from internlm.monitor.monitor import send_alert_message
from internlm.train.pipeline import (
generate_meta_data,
get_scheduler_hooks,
initialize_llm_profile,
initialize_optimizer,
Expand Down Expand Up @@ -124,8 +125,13 @@ def __init__(
# initialize optimizer
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model, isp_communicator)

# generate ckpt metaData
meta_data = generate_meta_data(optimizer)

# initialize checkpoint manager and try resume training
self.ckpt_manager = self._initialize_checkpoint_manager(model, optimizer, lr_scheduler, train_dl, config_lines)
self.ckpt_manager = self._initialize_checkpoint_manager(
model, optimizer, lr_scheduler, train_dl, config_lines, meta_data
)
self.ckpt_manager.try_resume_training(train_state, self.current_time)

# initialize customed llm writer
Expand Down Expand Up @@ -178,7 +184,7 @@ def _initialize_criterion(self) -> FlashGPTLMLoss:
)

def _initialize_checkpoint_manager(
self, model, optimizer, lr_scheduler, train_dl, config_lines
self, model, optimizer, lr_scheduler, train_dl, config_lines, meta_data
) -> CheckpointManager:
return CheckpointManager(
ckpt_config=gpc.config.ckpt,
Expand All @@ -189,6 +195,7 @@ def _initialize_checkpoint_manager(
model_config=gpc.config.model,
model_config_file="".join(config_lines),
feishu_address=gpc.config.monitor.alert.feishu_alert_address,
meta_data=meta_data,
)

def _initialize_writer(self, train_state, config_lines) -> Writer:
Expand Down
8 changes: 8 additions & 0 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,14 @@ def args_sanity_check():
if "enable_save_ckpt" not in ckpt:
ckpt._add_item("enable_save_ckpt", True)

if "generate_meta_data" not in ckpt:
ckpt._add_item("generate_meta_data", dict(enable=False, path=None))

if ckpt.enable_save_ckpt or ckpt.generate_meta_data.enable:
ckpt.need_metadata = True
else:
ckpt.need_metadata = False

# Saving checkpoint args.
if ckpt.enable_save_ckpt:
assert "checkpoint_every" in ckpt, "If enable save checkpoint, must give checkpoint_every in config.data!"
Expand Down
12 changes: 10 additions & 2 deletions internlm/model/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,27 +47,35 @@ def __init__(
self.vocab_parallel = vocab_parallel

parallel_size = gpc.weight_parallel_size if is_using_isp() else gpc.tensor_parallel_size
rank = gpc.get_local_rank(ParallelMode.WEIGHT) if is_using_isp() else gpc.get_local_rank(ParallelMode.TENSOR)

if vocab_parallel:
assert num_embeddings % parallel_size == 0, f"{num_embeddings} is not divisible by {parallel_size}"

self.num_embeddings_per_partition = num_embeddings // parallel_size
self.embed_dim_per_partition = embedding_dim
self.vocab_start_index = gpc.get_local_rank(ParallelMode.TENSOR) * self.num_embeddings_per_partition
self.vocab_start_index = rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
self.offset = [self.vocab_start_index, 0]
self.tp_dim = 0
else:
assert embedding_dim % parallel_size == 0, f"{embedding_dim} is not divisible by {parallel_size}"

self.num_embeddings_per_partition = num_embeddings
self.embed_dim_per_partition = embedding_dim // parallel_size
self.vocab_start_index = 0
self.vocab_end_index = self.num_embeddings_per_partition
self.offset = [0, self.embed_dim_per_partition * rank]
self.tp_dim = 1

self.weight = nn.Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), dtype=dtype)
)

self.complete_size = [num_embeddings, embedding_dim]
setattr(self.weight, "is_embedding_param", True)
setattr(self.weight, "offset", self.offset)
setattr(self.weight, "complete_size", [num_embeddings, embedding_dim])
setattr(self.weight, "tp_dim", self.tp_dim)

def forward(self, input_: Tensor) -> Tensor:
if self.vocab_parallel and not is_using_isp():
Expand Down
10 changes: 10 additions & 0 deletions internlm/model/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ def __init__(

world_size = gpc.get_world_size(parallel_mode)
rank = gpc.get_local_rank(parallel_mode)
self.offset = None

if split_mode != "none":
split_features = out_features if split_mode == "column" else in_features
Expand All @@ -611,11 +612,20 @@ def __init__(

if split_mode == "column":
super().__init__(in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype)
self.offset = [rank * local_multiple * multiple_of, 0]
self.tp_dim = 0
elif split_mode == "row":
super().__init__(local_multiple * multiple_of, out_features, bias=bias, device=device, dtype=dtype)
self.offset = [0, rank * local_multiple * multiple_of]
self.tp_dim = 1
else:
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)

self.complete_size = [out_features, in_features]
setattr(self.weight, "offset", self.offset)
setattr(self.weight, "complete_size", [out_features, in_features])
setattr(self.weight, "tp_dim", self.tp_dim)

def forward(self, input: torch.Tensor, batch_sizes: torch.Tensor = None) -> torch.Tensor: # pylint: disable=W0622
_class_name = self.__class__.__name__
assert self._communicator is not None, f"{_class_name} should register with a communicator first."
Expand Down
24 changes: 4 additions & 20 deletions internlm/model/modules/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,6 @@ def _qkv_pre_load_convert(module: "GQA", state_dict, prefix: str, *args, **kwarg
)


def _qkv_save_convert(module: "GQA", state_dict, prefix: str, *args, **kwargs) -> Dict: # pylint: disable=W0613
wq_name, wk_name, wv_name, fused_name = (
f"{prefix}wq.weight",
f"{prefix}wk.weight",
f"{prefix}wv.weight",
f"{prefix}wqkv.weight",
)

if module.enable_qkv_fusion:
state_dict[wq_name], state_dict[wk_name], state_dict[wv_name] = split_fused_wqkv_weight(
state_dict.pop(fused_name), *args, **kwargs
)

return state_dict


class MHA(nn.Module):
"""
Multi-head self-attention and cross-attention.
Expand Down Expand Up @@ -462,15 +446,15 @@ def __init__(
if enable_qkv_fusion:
assert bias is False, "Fuesd wqkv only support bias is False."
self.wqkv = new_linear("wqkv", embed_dim, q_dim + 2 * self.kv_dim, bias, **factory_kwargs)
self._register_load_state_dict_pre_hook(
partial(_qkv_pre_load_convert, q_dim=q_dim, kv_dim=self.kv_dim), with_module=True
)
self._register_state_dict_hook(partial(_qkv_save_convert, q_dim=q_dim, kv_dim=self.kv_dim))
else:
self.wq = new_linear("wq", embed_dim, q_dim, bias, **factory_kwargs)
self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs)
self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs)

self._register_load_state_dict_pre_hook(
partial(_qkv_pre_load_convert, q_dim=q_dim, kv_dim=self.kv_dim), with_module=True
)

self.inner_attn = SelfAttention(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout, layer_idx=layer_idx
)
Expand Down
21 changes: 20 additions & 1 deletion internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(
assert self._param_bcast_sync_handler is not None

self._isp_communicator = isp_communicator

self.meta_for_zero = None
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
# and add buffers to parameter store for future access
Expand All @@ -165,6 +165,9 @@ def __init__(
zero_mode = param_group["optimizer_mode"]
self._zero_local_rank.append(gpc.get_local_rank(zero_mode))
self._zero_world_size.append(gpc.get_world_size(zero_mode))

if gpc.config.ckpt.need_metadata and self.meta_for_zero is None:
self.meta_for_zero = [{} for _ in range(gpc.get_world_size(zero_mode))]
# TODO _broadcast_parallel_mode is not only used in broadcast, maybe can change its name
self._broadcast_parallel_mode.append(zero_mode)

Expand Down Expand Up @@ -281,6 +284,22 @@ def _partition_param_list(self, group_id, param_group):
self.params_per_rank_id_dict[-1][rank_to_go].append(global_id)
numel_per_rank[rank_to_go] += param.numel()

if gpc.config.ckpt.need_metadata:
if group_id not in self.meta_for_zero[rank_to_go]:
self.meta_for_zero[rank_to_go][group_id] = {}

from internlm.train.pipeline import map_fqn_local_to_global

global_fqn = map_fqn_local_to_global[param.fqn] if param.fqn in map_fqn_local_to_global else param.fqn
self.meta_for_zero[rank_to_go][group_id][global_fqn] = {
"tp_dim": getattr(param, "tp_dim", -1),
"pp": gpc.get_local_rank(ParallelMode.PIPELINE),
"zero1": rank_to_go,
"fqn": param.fqn,
"shape": param.shape,
"group_id": group_id,
}

# check whether any rank is not assigned to parameters.
for rank, params in enumerate(params_per_rank):
if len(params) == 0:
Expand Down
Loading
Loading