Skip to content

FastPersist rebase #467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: staging-fast-persist
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
parser = _add_memoryopt_args(parser)
parser = _add_activation_checkpoint_args(parser)
parser = _add_distillation_args(parser)
parser = _add_fast_checkpointing_args(parser)
parser = _add_inference_args(parser)
parser = _add_transformer_engine_args(parser)
parser = _add_retro_args(parser)
Expand Down Expand Up @@ -182,6 +183,16 @@ def validate_args(args, defaults={}):
else:
args.virtual_pipeline_model_parallel_size = None

# MoE
if len(args.num_experts) == 1 and args.num_experts[0] == -1:
# Hack to set number of experts to world size
args.num_experts = [int(args.world_size)]

if args.moe_expert_parallel_size is None:
assert len(args.num_experts) == 1, \
f'Unspecified --moe-expert-parallel-size only supported for single value --num-experts'
args.moe_expert_parallel_size = min(int(args.num_experts[0]), int(args.world_size))

# Parameters dtype.
args.params_dtype = torch.float
if args.fp16:
Expand Down Expand Up @@ -947,8 +958,8 @@ def _add_training_args(parser):
help='DeepSpeed inference engine being used')
group.add_argument('--cpu-optimizer', action='store_true',
help='Run optimizer on CPU')
group.add_argument('--cpu_torch_adam', action='store_true',
help='Use Torch Adam as optimizer on CPU.')
group.add_argument('--torch_adam', action='store_true',
help='Use Torch Adam as optimizer.')
group.add_argument('--ds_fused_adam', action='store_true',
help='Use DeepSpeed FusedAdam as optimizer.')
group.add_argument('--no-pipeline-parallel', action='store_true',
Expand Down Expand Up @@ -1564,6 +1575,17 @@ def _add_distillation_args(parser):

return parser

def _add_fast_checkpointing_args(parser):
group = parser.add_argument_group('Fast Checkpointing configuration')
group.add_argument('--checkpoint-io-buffer-size', type=int, default=None,
help="Fast checkpointing I/O buffer size")
group.add_argument('--checkpoint-data-parallel', type=str, default=None,
help='Fast checkpointing data parallelism mode.')
group.add_argument('--aio-intra-op-parallelism', type=int, default=None,
help='AIO intra op parallelsm.')
group.add_argument('--checkpoint-writer-decoupled', action='store_true',
help='Decoupled checkpointing.')
return parser

def _add_profiler_args(parser):
group = parser.add_argument_group(title='profiling configuration')
Expand Down
5 changes: 5 additions & 0 deletions megatron/core/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,11 @@ def get_data_parallel_src_rank():
return _DATA_PARALLEL_GLOBAL_RANKS[0]


def get_data_parallel_group_ranks():
"""Return all the ranks in the data parallel group."""
assert _DATA_PARALLEL_GLOBAL_RANKS is not None, "Data parallel group is not initialized"
return _DATA_PARALLEL_GLOBAL_RANKS

def get_pipeline_model_parallel_first_rank():
"""Return the global rank of the first process in the pipeline for the
current tensor parallel group"""
Expand Down
12 changes: 6 additions & 6 deletions megatron/data/indexed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def write_longs(f, a):
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float64,
7: np.float32,
8: np.uint16,
6: float,
7: np.double,
8: np.uint16
}


Expand Down Expand Up @@ -271,8 +271,8 @@ class IndexedDatasetBuilder(object):
np.int16: 2,
np.int32: 4,
np.int64: 8,
np.float32: 4,
np.float64: 8,
float: 4,
np.double: 8
}

def __init__(self, out_file, dtype=np.int32):
Expand Down Expand Up @@ -496,7 +496,7 @@ def __getstate__(self):
def __setstate__(self, state):
self._do_init(state, skip_warmup=True)

def _do_init(self, path, skip_warmup):
def _do_init(self, path, skip_warmup=True):
self._path = path
self._index = self.Index(index_file_path(self._path), skip_warmup)

Expand Down
12 changes: 8 additions & 4 deletions megatron/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

from torch.optim import AdamW as TorchAdamW
from torch.optim import Adam as TorchAdam

from deepspeed.accelerator import get_accelerator
if get_accelerator().device_name() == 'cuda':
from apex.optimizers import FusedAdam as Adam
Expand Down Expand Up @@ -82,7 +85,7 @@ def get_megatron_optimizer(model,

if args.cpu_optimizer:
assert args.optimizer == 'adam', 'CPU offloading is for Adam'
if args.cpu_torch_adam:
if args.torch_adam:
cpu_adam_optimizer = torch.optim.AdamW
else:
from deepspeed.ops.adam import DeepSpeedCPUAdam
Expand All @@ -95,10 +98,11 @@ def get_megatron_optimizer(model,
else:
if args.optimizer == 'adam':
if args.ds_fused_adam:
global Adam
from deepspeed.ops.adam import FusedAdam
Adam = FusedAdam
optimizer = Adam(param_groups,
adam_optimizer = FusedAdam
else:
adam_optimizer = TorchAdamW if args.torch_adam else Adam
optimizer = adam_optimizer(param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
Expand Down
36 changes: 26 additions & 10 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from enum import Enum

from megatron import get_args
from megatron import get_args, is_rank_0
from megatron import get_signal_handler
from megatron import get_timers
from megatron import get_tensorboard_writer
Expand All @@ -42,13 +42,12 @@
from megatron.initialize import set_jit_fusion_options
from megatron.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import check_adlr_autoresume_termination, CHECKPOINT_SIZE, get_checkpoint_folder_size
from megatron.utils import unwrap_model, found_kill_switch
from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.utils import report_memory, throughput_calculator, checkpoint_throughput_calculator, update_rotary_pos_emb
from megatron.model.vision.knn_monitor import compute_feature_bank
from megatron.utils import report_memory, throughput_calculator, checkpoint_throughput_calculator, update_rotary_pos_emb, set_ds_auto_config_values, moe_parameters_in_billions
from megatron.arguments import core_transformer_config_from_args
from megatron.profiler import setup_profiler, trigger, on_step_begin, on_step_end

Expand Down Expand Up @@ -83,7 +82,7 @@ def _create_ds_config_dict():

# Clear config path
args.deepspeed_config = None

ds_config_dict = set_ds_auto_config_values(ds_config_dict)
return ds_config_dict


Expand Down Expand Up @@ -263,9 +262,13 @@ def pretrain(train_valid_test_dataset_provider,
test_data_iterator, model,
iteration, process_non_loss_data_func, config,
verbose=True, write_to_tensorboard=not args.skip_train, test=True)
if args.deepspeed:
model[0].destroy()

return model



def update_train_iters(args):

# For iteration-based training, we don't need to do anything
Expand Down Expand Up @@ -577,7 +580,7 @@ def setup_model_and_optimizer(model_provider_func,

if args.deepspeed:
print_rank_0("DeepSpeed is enabled.")
pp = mpu.get_pipeline_model_parallel_world_size()

if args.data_efficiency_curriculum_learning and build_train_valid_test_datasets_provider is not None:
train_ds = None
# Only need to build dataset on tp rank 0 since Megatron has the
Expand Down Expand Up @@ -1182,6 +1185,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
log_string += ' samples per second: {:.3f} |'.format(samples_per_sec)
log_string += ' tokens per gpu per second (tgs): {:.3f} |'.format(tokens_per_gpu_per_second)
log_string += ' TFLOPs: {:.2f} |'.format(tflops)
log_string += ' params(B): {:.2f} |'.format(approx_parameters_in_billions)
log_string += ' moe params(B): {:.2f} |'.format(moe_parameters_in_billions())
total_loss_dict[advanced_iters_key] = 0
total_loss_dict[skipped_iters_key] = 0
total_loss_dict[nan_iters_key] = 0
Expand All @@ -1195,17 +1200,18 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
return report_memory_flag


CHECKPOINT_GB = None
def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler):
global CHECKPOINT_GB
timers = get_timers()
# Extra barrier is added to make sure
# all ranks report the max time.
# all ranks report the max time.
timers('save-checkpoint', log_level=0).start(barrier=True)
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
timers('save-checkpoint').stop(barrier=True)
checkpoint_throughput_calculator(model, timers('save-checkpoint').elapsed(reset=False))
checkpoint_throughput_calculator(model, timers('save-checkpoint').elapsed(reset=False, barrier=True))
timers.log(['save-checkpoint'])


def train(forward_step_func, model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator,
process_non_loss_data_func):
Expand Down Expand Up @@ -1308,6 +1314,15 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)

# Checkpointing
saved_checkpoint = False
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler)
saved_checkpoint = True

report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'],
iteration, loss_scale,
Expand Down Expand Up @@ -1355,7 +1370,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
done_cuda, op=torch.distributed.ReduceOp.MAX)
done = done_cuda.item()
if done:
if not saved_checkpoint:
if args.save and not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler)
print_datetime('exiting program after {} minutes'.format(train_time))
Expand Down Expand Up @@ -1394,6 +1409,7 @@ def evaluate(forward_step_func,
args = get_args()

if args.vision_pretraining and args.vision_pretraining_type == "dino":
from megatron.model.vision.knn_monitor import compute_feature_bank
compute_feature_bank(model)

# Turn on evaluation mode which disables dropout.
Expand Down
64 changes: 64 additions & 0 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from megatron.core import mpu
from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
from megatron.model.module import param_is_not_shared

CHECKPOINT_SIZE = None
from megatron.model.rotary_pos_embedding import RotaryEmbedding


Expand Down Expand Up @@ -259,6 +261,14 @@ def is_rank_0():
else:
return True

def moe_parameters_in_billions():
args = get_args()
self_attn = (4 * args.num_layers * (args.hidden_size**2)) + (2 * args.hidden_size)
non_expert_fc = (4 * args.num_layers * (args.hidden_size**2)) + (2 * args.hidden_size)
expert_fc = int(args.num_experts[0]) * non_expert_fc

return (self_attn + non_expert_fc + expert_fc) / 1e9

def get_parameters_in_billions(model):
gpus_per_model = torch.distributed.get_world_size(group=mpu.get_model_parallel_group())

Expand Down Expand Up @@ -311,6 +321,60 @@ def throughput_calculator(model, args, iteration_time, total_iterations):
tflops = flops_per_iteration / (elapsed_time_per_iter * args.world_size * (10**12))
return samples_per_second, tflops, approx_parameters_in_billions


def _get_folder_size(folder):
size = 0
for path, _, files in os.walk(folder):
size += sum([os.path.getsize(os.path.join(path, f)) for f in files])
return size

def get_checkpoint_folder_size(iteration):
args = get_args()
if args.local_rank == 0:
folder = os.path.join(get_args().save, f'global_step{iteration}')
size_tensor = torch.tensor(_get_folder_size(folder)).cuda()
else:
size_tensor = torch.tensor(0).cuda()

torch.distributed.reduce(tensor=size_tensor, dst=0)
return int(size_tensor)

def _replace_auto_config_values(old_config, replace_dict):
new_config = {}
for key, value in old_config.items():
if type(value) == dict:
new_config[key] = _replace_auto_config_values(value, replace_dict)
elif value == "auto" and replace_dict.get(key, None) is not None:
new_config[key] = replace_dict[key]
else:
new_config[key] = old_config[key]

return new_config


def set_ds_auto_config_values(ds_config_dict):
from deepspeed.runtime.constants import TRAIN_MICRO_BATCH_SIZE_PER_GPU, TRAIN_BATCH_SIZE
from deepspeed.runtime.model_checkpointing.constants import (
CHECKPOINT_IO_BUFFER_SIZE,
CHECKPOINT_DATA_PARALLEL,
CHECKPOINT_WRITER_DECOUPLED
)
from deepspeed.runtime.swap_tensor.constants import AIO_INTRA_OP_PARALLELISM

args = get_args()

replace_dict = {
TRAIN_BATCH_SIZE: args.global_batch_size,
TRAIN_MICRO_BATCH_SIZE_PER_GPU: args.micro_batch_size,
CHECKPOINT_IO_BUFFER_SIZE: args.checkpoint_io_buffer_size,
CHECKPOINT_DATA_PARALLEL: args.checkpoint_data_parallel,
CHECKPOINT_WRITER_DECOUPLED: args.checkpoint_writer_decoupled,
AIO_INTRA_OP_PARALLELISM: args.aio_intra_op_parallelism
}

ds_config = _replace_auto_config_values(ds_config_dict, replace_dict)
return ds_config

def checkpoint_throughput_calculator(model, latency_second):
approx_parameters_in_billions = get_parameters_in_billions(model)
checkpoint_multiplier = 14 # fp16 weights (2), fp32 weights (4), fp32 momentum (4), fp32 variance (4)
Expand Down
2 changes: 1 addition & 1 deletion pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import deepspeed
from deepspeed.runtime.utils import see_memory_usage
from deepspeed.accelerator.real_accelerator import get_accelerator
from deepspeed.accelerator import get_accelerator
from deepspeed.sequence.fpdt_layer import FPDT_InputConstruct
import os
import subprocess
Expand Down