diff --git a/README.md b/README.md index a33da44d9..35a9d8187 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,36 @@ Llama 3 8B model locally on 8 GPUs CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh ``` +### Fine-tuning from an existing checkpoint + +You first need to download the Llama checkpoint. Here are the commands: + +```bash +export HF_TOKEN=... # get your HF token from https://huggingface.co/settings/tokens +# Download the tokenizer and model weights +rm -rf tmp +uv run huggingface-cli download meta-llama/Llama-3.1-8B original/tokenizer.model --local-dir tmp +uv run huggingface-cli download meta-llama/Llama-3.1-8B original/consolidated.00.pth --local-dir tmp +uv run huggingface-cli download meta-llama/Llama-3.1-8B original/params.json --local-dir tmp +# Convert the model weights to the DCP format and move it and the tokenizer to the assets folder +mkdir -p assets/tokenizer && cp tmp/original/tokenizer.model assets/tokenizer/Meta-Llama-3.1-8B-tokenizer.model +uv run python -m scripts.convert_llama_to_dcp tmp/original/ assets/models/dcp/llama3.1-8B +``` + +Then you can fine-tune from the checkpoint: + +```bash +CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" uv run ./run_train.sh \ + --model.tokenizer_path assets/tokenizer/Meta-Llama-3.1-8B-tokenizer.model \ + --training.max_seq_len 131072 \ + --checkpoint.initial_load_path "assets/models/dcp/llama3.1-8B" \ + --profiling.no_enable_profiling \ + --activation_checkpoint.mode full \ + --training.global_batch_size 64 \ + --lr_scheduler.warmup_steps 40 \ + --optimizer.lr 1e-5 +``` + ### Multi-Node Training For training on ParallelCluster/Slurm type configurations, you can use the `multinode_trainer.slurm` file to submit your sbatch job. diff --git a/scripts/convert_llama_to_dcp.py b/scripts/convert_llama_to_dcp.py index fa415efad..eea5dc6f7 100644 --- a/scripts/convert_llama_to_dcp.py +++ b/scripts/convert_llama_to_dcp.py @@ -10,7 +10,7 @@ import torch import torch.distributed.checkpoint as DCP -from torchtitan.models.llama.model import precompute_freqs_cis +from torchtitan.models.llama3.model import precompute_freqs_cis from torchtitan.tools.logging import init_logger, logger diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index af62501ce..adcc280b2 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -252,9 +252,6 @@ def load_state_dict(state_dict): self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM ) or self.ft_manager - if not self.enable_checkpoint and self.ft_manager is None: - return - self.states = states self.states.update( { @@ -265,6 +262,14 @@ def load_state_dict(state_dict): } ) self.ft_states = {DATALOADER: dataloader} + self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder) + self.exclude_from_loading = ckpt_config.exclude_from_loading + self.initial_load_path = ckpt_config.initial_load_path + self.initial_load_model_weights_only = ( + ckpt_config.initial_load_model_weights_only + ) + if not self.enable_checkpoint and self.ft_manager is None: + return self.staging = False self.sending_to_checkpoint_mp = False @@ -272,11 +277,6 @@ def load_state_dict(state_dict): self.cpu_offload_state_dict = None self.staging_stream = torch.cuda.Stream() if self.enable_staging else None - self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder) - self.initial_load_path = ckpt_config.initial_load_path - self.initial_load_model_weights_only = ( - ckpt_config.initial_load_model_weights_only - ) self.interval = ckpt_config.interval async_mode = ckpt_config.async_mode.lower() if async_mode == AsyncMode.ASYNC or self.ft_manager: @@ -299,7 +299,6 @@ def load_state_dict(state_dict): self.last_save_model_weights_only = ckpt_config.last_save_model_weights_only self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype] - self.exclude_from_loading = ckpt_config.exclude_from_loading self.mp = None self.async_future = None @@ -418,21 +417,11 @@ def load(self, step: int = -1) -> bool: if self.ft_manager: self._ft_load() - if not self.enable_checkpoint: - return False - model_only = False - if not os.path.exists(self.folder): - if self.initial_load_path: - checkpoint_id = self.initial_load_path - if not os.path.isdir(checkpoint_id): - raise ValueError( - "initial_load_full_checkpoint is specified but the path is not valid." - ) - model_only = self.initial_load_model_weights_only - else: + # 1. Try load from checkpoint folder first if it exists + if os.path.exists(self.folder): + if not self.enable_checkpoint: return False - else: if self.initial_load_path: logger.info( "`initial_load_path` is provided but the checkpoint folder exists. " @@ -446,11 +435,33 @@ def load(self, step: int = -1) -> bool: if not os.path.isdir(checkpoint_id): return False + return self.load_from_path(checkpoint_id, model_only) + + # Then try `initial_load_path` + if self.initial_load_path: + checkpoint_id = self.initial_load_path + if not os.path.isdir(checkpoint_id): + raise ValueError( + "initial_load_full_checkpoint is specified but the path is not valid." + ) + model_only = self.initial_load_model_weights_only + return self.load_from_path(checkpoint_id, model_only) + + return False + + def load_from_path(self, path: str, model_only: bool = False) -> bool: + """Load the checkpoint from the given path. - logger.info(f"Loading the checkpoint from {checkpoint_id}.") + This function will load the checkpoint from the given path. If ``model_only`` is + True, it will load the model only. + """ + if not os.path.isdir(path): + raise ValueError(f"The path {path} is not a valid checkpoint folder.") + + logger.info(f"Loading the checkpoint from {path}.") begin = time.monotonic() states = self._states_to_load(model_only) - dcp.load(states, checkpoint_id=checkpoint_id) + dcp.load(states, checkpoint_id=path) GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." @@ -561,20 +572,13 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]: Dict[str, Any]: The states to load for the given step. """ # For the first step, we will only load the model weights. - if model_only: - sd = self.states[MODEL].state_dict() - for k in excluded_parameters_for_model_only: - sd.pop(k, None) - return sd - - for exclude_key in self.exclude_from_loading: - if exclude_key not in self.states: - raise ValueError(f"{exclude_key} not found in state_dict.") - + states = {MODEL: self.states[MODEL]} if model_only else self.states states_to_load = { - k: v for k, v in self.states.items() if k not in self.exclude_from_loading + k: v for k, v in states.items() if k not in self.exclude_from_loading } - + for exclude_key in self.exclude_from_loading: + if exclude_key not in states: + raise ValueError(f"{exclude_key} not found in state_dict.") if self.ft_manager: states_to_load.pop(DATALOADER) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index f12b21ba5..bcd1abeb9 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -207,6 +207,12 @@ class Training: seq_len: int = 2048 """Sequence length""" + max_seq_len: int | None = None + """ + Maximum sequence length (defaults to the `seq_len`, but could be + used for loading pre-trained model. E.g., 131072 for llama3.1-8B) + """ + max_norm: float | int = 1.0 """Max norm for gradient clipping""" diff --git a/torchtitan/models/llama3/model.py b/torchtitan/models/llama3/model.py index 20026a690..43096dbd2 100644 --- a/torchtitan/models/llama3/model.py +++ b/torchtitan/models/llama3/model.py @@ -42,7 +42,7 @@ class TransformerModelArgs(BaseModelArgs): def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: self.vocab_size = tokenizer.n_words - self.max_seq_len = job_config.training.seq_len + self.max_seq_len = job_config.training.max_seq_len or self.max_seq_len self.eos_id = tokenizer.eos_id if job_config.activation_checkpoint.mode == "selective" and self.use_flex_attn: