diff --git a/torchtitan/experiments/forge/README.md b/torchtitan/experiments/forge/README.md new file mode 100644 index 000000000..9f58e8c1b --- /dev/null +++ b/torchtitan/experiments/forge/README.md @@ -0,0 +1,14 @@ +## `ForgeEngine` + +The `forge` folder contains a lightweight training engine that serves as a streamlined subset of the `Trainer` class from [torchtitan/train.py](/torchtitan/train.py). This engine provides only the essential constructor method, making it highly flexible for various downstream applications. + +The [`ForgeEngine`](engine.py) takes a [`ForgeJobConfig`](job_config.py) to +- Initialize an SPMD distributed training environment +- Construct and scale models via n-D parallelisms and meta-device initialization +- Provide necessary training components and utilities + +**Primary Use Case**: The engine is designed for building trainers in post-training workflows where multiple specialized components (trainer, generator, replay buffer, parameter server, etc.) work together. + +Additionally, the folder provides a train spec registration method [`register_train_spec`](train_spec.py) that allows users to extend beyond the core set of models and training components available in torchtitan, enabling greater flexibility and customization for specific training requirements. + +The [example_train.py](./example_train.py) demonstrates how to use `ForgeEngine` for pretraining, achieving the same functionality as [torchtitan/train.py](/torchtitan/train.py). diff --git a/torchtitan/experiments/forge/__init__.py b/torchtitan/experiments/forge/__init__.py new file mode 100644 index 000000000..1654959ce --- /dev/null +++ b/torchtitan/experiments/forge/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .engine import ForgeEngine +from .job_config import ForgeJobConfig +from .train_spec import ForgeTrainSpec, register_train_spec + +__all__ = ["ForgeEngine", "ForgeJobConfig", "ForgeTrainSpec", "register_train_spec"] diff --git a/torchtitan/experiments/forge/engine.py b/torchtitan/experiments/forge/engine.py new file mode 100644 index 000000000..ee0d4299a --- /dev/null +++ b/torchtitan/experiments/forge/engine.py @@ -0,0 +1,276 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os +from typing import Generator + +import torch +from torch.distributed.elastic.multiprocessing.errors import record + +import torchtitan.protocols.train_spec as train_spec_module +from torchtitan.components.checkpoint import CheckpointManager +from torchtitan.components.ft import FTManager, init_ft_manager +from torchtitan.components.loss import rescale_accumulated_loss +from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.protocols.model_converter import build_model_converters +from torchtitan.protocols.train_spec import BaseModelArgs +from torchtitan.tools import utils + +from .job_config import ForgeJobConfig +from .train_spec import ForgeTrainSpec, get_train_spec + + +class ForgeEngine(torch.distributed.checkpoint.stateful.Stateful): + # core configs + job_config: ForgeJobConfig + parallel_dims: ParallelDims + train_spec: ForgeTrainSpec + + # swappable training components in ForgeTrainSpec + model_parts: list[torch.nn.Module] + loss_fn: train_spec_module.LossFunction + optimizers: train_spec_module.OptimizersContainer + lr_schedulers: train_spec_module.LRSchedulersContainer + + # non-swappable training components + checkpointer: CheckpointManager + ft_manager: FTManager + + # runtime utilities + device: torch.device + gc_handler: utils.GarbageCollection + gradient_accumulation_steps: int + train_context: Generator[None, None, None] + pp_has_first_stage: bool + pp_has_last_stage: bool + + # Fields in ForgeEngine which are not in original Trainer + # for dataloading + tokenizer: train_spec_module.BaseTokenizer | None + dp_degree: int + dp_rank: int + # for logging + model_args: BaseModelArgs + num_flops_per_token: float + model_param_count: int + global_batch_size: int + + # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html + @record + def __init__(self, job_config: ForgeJobConfig): + torch._C._log_api_usage_once("torchtitan.train") + + self.job_config = job_config + + if job_config.experimental.custom_import: + importlib.import_module(job_config.experimental.custom_import) + + device_module, device_type = utils.device_module, utils.device_type + self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") + # Device has to be set before creating TorchFT manager. + device_module.set_device(self.device) + + # init distributed and build meshes + dist_utils.init_distributed(job_config) + world_size = int(os.environ["WORLD_SIZE"]) + parallelism_config = job_config.parallelism + self.parallel_dims = parallel_dims = ParallelDims( + dp_shard=parallelism_config.data_parallel_shard_degree, + dp_replicate=parallelism_config.data_parallel_replicate_degree, + cp=parallelism_config.context_parallel_degree, + tp=parallelism_config.tensor_parallel_degree, + pp=parallelism_config.pipeline_parallel_degree, + ep=parallelism_config.expert_parallel_degree, + world_size=world_size, + ) + + world_mesh = parallel_dims.world_mesh + if parallel_dims.dp_enabled: + dp_mesh = world_mesh["dp"] + dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() + else: + dp_degree, dp_rank = 1, 0 + + self.ft_manager = init_ft_manager(job_config) + # If TorchFT is enabled, the dp_rank and dp_degree, which are used for + # dataloader must be changed. + if self.ft_manager.enabled: + dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank) + + self.dp_degree, self.dp_rank = dp_degree, dp_rank + + # take control of garbage collection to avoid stragglers + self.gc_handler = utils.GarbageCollection( + gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug + ) + + # Set random seed, and maybe enable deterministic mode + # (mainly for debugging, expect perf loss). + dist_utils.set_determinism( + world_mesh, + self.device, + job_config.training.seed, + job_config.training.deterministic, + ) + self.train_spec = get_train_spec(job_config.model.name) + + # build tokenizer + self.tokenizer = tokenizer = ( + self.train_spec.build_tokenizer_fn(job_config) + if self.train_spec.build_tokenizer_fn is not None + else None + ) + + # build model (using meta init) + self.model_args = model_args = self.train_spec.model_args[ + job_config.model.flavor + ] + # set the model args from training job configs + model_args.update_from_config(job_config, tokenizer) + + with torch.device("meta"): + model = self.train_spec.model_cls(model_args) + + # Build the collection of model converters. No-op if `model.converters` empty + model_converters = build_model_converters(job_config, parallel_dims) + model_converters.convert(model) + + # calculate model size and flops per token + ( + self.model_param_count, + self.num_flops_per_token, + ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) + + # move sharded model to CPU/GPU and initialize weights via DTensor + if job_config.checkpoint.create_seed_checkpoint: + init_device = "cpu" + buffer_device = None + elif job_config.training.enable_cpu_offload: + init_device = "cpu" + buffer_device = device_type + else: + init_device = device_type + buffer_device = None + + self.loss_fn = self.train_spec.build_loss_fn(job_config) + + # verify batch sizes + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + global_batch_size = job_config.training.local_batch_size * dp_degree + assert global_batch_size > 0 + assert ( + global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0 + ), ( + f"global batch size must be multiple of local batch size times " + f"data-parallel degree ({global_batch_size} " + f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)" + ) + self.global_batch_size = global_batch_size + + # calculate gradient accumulation steps + self.gradient_accumulation_steps = global_batch_size // ( + job_config.training.local_batch_size * dp_degree + ) + assert self.gradient_accumulation_steps > 0 + self.loss_fn = rescale_accumulated_loss( + self.loss_fn, self.gradient_accumulation_steps + ) + + # apply parallelisms and initialization + if parallel_dims.pp_enabled: + if not self.train_spec.pipelining_fn: + raise RuntimeError( + f"Pipeline Parallel is enabled but {self.train_spec.name} " + f"does not support pipelining" + ) + + # apply both PT-D Pipeline Parallel and SPMD-style PT-D techniques + ( + self.pp_schedule, + self.model_parts, + self.pp_has_first_stage, + self.pp_has_last_stage, + ) = self.train_spec.pipelining_fn( + model, + parallel_dims, + job_config, + self.device, + model_args, + self.train_spec.parallelize_fn, + self.loss_fn, + ) + # when PP is enabled, `model` obj is no longer used after this point, + # model_parts is used instead + del model + + for m in self.model_parts: + m.to_empty(device=init_device) + with torch.no_grad(): + m.init_weights(buffer_device=buffer_device) + m.train() + else: + # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel + model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) + + model.to_empty(device=init_device) + with torch.no_grad(): + model.init_weights(buffer_device=buffer_device) + model.train() + + self.model_parts = [model] + + if ( + self.ft_manager.enabled + and job_config.fault_tolerance.semi_sync_method is None + ): + self.ft_manager.set_all_reduce_hook(self.model_parts) + + # build optimizer after applying parallelisms to the model + self.optimizers = self.train_spec.build_optimizers_fn( + self.model_parts, job_config, parallel_dims, self.ft_manager + ) + self.lr_schedulers = self.train_spec.build_lr_schedulers_fn( + self.optimizers, job_config + ) + # Post optimizer step model converters hook. + # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 + # where it issues a single all-reduce for all parameters at once for better performance + self.optimizers.register_step_post_hook( + lambda *args, **kwargs: model_converters.post_optimizer_hook( + self.model_parts + ) + ) + + self.checkpointer = CheckpointManager( + dataloader=None, + model_parts=self.model_parts, + optimizers=self.optimizers, + lr_schedulers=self.lr_schedulers, + states={"train_state": self}, + job_config=job_config, + ft_manager=self.ft_manager, + ) + + loss_parallel_enabled = ( + parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel + ) + self.train_context = dist_utils.get_train_context( + loss_parallel_enabled, + parallelism_config.enable_compiled_autograd, + ) + self.maybe_enable_amp = dist_utils.maybe_enable_amp( + parallel_dims, + job_config.training.mixed_precision_param, + device_type, + ) + + def close(self) -> None: + if self.checkpointer: + self.checkpointer.close() diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py new file mode 100644 index 000000000..c19c57c95 --- /dev/null +++ b/torchtitan/experiments/forge/example_train.py @@ -0,0 +1,351 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import time +from datetime import timedelta +from typing import Any, Iterable, Optional + +import torch +from torch.distributed.elastic.multiprocessing.errors import record + +import torchtitan.protocols.train_spec as train_spec_module +from torchtitan.components.dataloader import DataloaderStopIteration +from torchtitan.components.ft import maybe_semi_sync_training +from torchtitan.components.metrics import build_metrics_processor +from torchtitan.components.validate import build_validator +from torchtitan.config_manager import ConfigManager, JobConfig +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.distributed import utils as dist_utils +from torchtitan.tools import utils +from torchtitan.tools.logging import init_logger, logger +from torchtitan.tools.profiling import ( + maybe_enable_memory_snapshot, + maybe_enable_profiling, +) + +from .engine import ForgeEngine + + +class Trainer(ForgeEngine): + dataloader: train_spec_module.BaseDataLoader + validator: train_spec_module.BaseValidator + metrics_processor: train_spec_module.MetricsProcessor + + # additional training states + step: int + + # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html + @record + def __init__(self, job_config: JobConfig): + logger.info(f"Starting job: {job_config.job.description}") + + if job_config.job.print_args: + logger.info(f"Running with args: {job_config.to_dict()}") + + # NOTE: Here we are passing in JobConfig as a superset of ForgeJobConfig + super().__init__(job_config) + + # build dataloader + self.dataloader = build_hf_dataloader( + dp_world_size=self.dp_degree, + dp_rank=self.dp_rank, + tokenizer=self.tokenizer, + job_config=job_config, + ) + + model_args = self.model_args + logger.info( + f"Built {self.train_spec.name} {job_config.model.flavor} with {model_args}" + ) + + # metrics logging + self.metrics_processor = build_metrics_processor( + job_config, self.parallel_dims, model_args + ) + color = self.metrics_processor.color + + self.metrics_processor.num_flops_per_token = self.num_flops_per_token + + logger.info( + f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} " + f"{color.red}size: {self.model_param_count:,} total parameters{color.reset}" + ) + + # initialize device memory monitor and get peak flops for MFU calculation + device_memory_monitor = self.metrics_processor.device_memory_monitor + gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) + logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") + device_mem_stats = device_memory_monitor.get_peak_stats() + logger.info( + f"{utils.device_type.upper()} memory usage for model: " + f"{device_mem_stats.max_reserved_gib:.2f}GiB" + f"({device_mem_stats.max_reserved_pct:.2f}%)" + ) + + self.metrics_processor.optimizers = self.optimizers + + # Initialize trainer states that will be saved in checkpoint. + # These attributes must be initialized before checkpoint loading. + self.step = 0 + + # Build validator if validation is configured + if job_config.validation.enabled: + self.validator = build_validator( + job_config=job_config, + dp_world_size=self.dp_degree, + dp_rank=self.dp_rank, + tokenizer=self.tokenizer, + parallel_dims=self.parallel_dims, + loss_fn=self.train_spec.build_loss_fn(job_config), + validation_context=self.train_context, + maybe_enable_amp=self.maybe_enable_amp, + ) + + logger.info( + "Trainer is initialized with " + f"local batch size {job_config.training.local_batch_size}, " + f"global batch size {self.global_batch_size}, " + f"gradient accumulation steps {self.gradient_accumulation_steps}, " + f"sequence length {job_config.training.seq_len}, " + f"total steps {job_config.training.steps} " + f"(warmup {job_config.lr_scheduler.warmup_steps})." + ) + + def batch_generator( + self, data_iterable: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] + ) -> Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]: + """Returns an iterator that processes batches from the data iterator.""" + device_type = utils.device_type + data_iterator = iter(data_iterable) + + while True: + try: + batch = next(data_iterator) + except StopIteration as ex: + # If data runs out during gradient accumulation, that + # entire step will not be executed. + raise DataloaderStopIteration() from ex + data_load_start = time.perf_counter() + input_dict, labels = batch + self.metrics_processor.ntokens_since_last_log += labels.numel() + self.metrics_processor.data_loading_times.append( + time.perf_counter() - data_load_start + ) + + # Move tensors to the appropriate device + for k, v in input_dict.items(): + if isinstance(v, torch.Tensor): + input_dict[k] = v.to(device_type) + labels = labels.to(device_type) + + yield input_dict, labels + + def forward_backward_step( + self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor + ) -> torch.Tensor: + model_parts = self.model_parts + parallel_dims = self.parallel_dims + + # apply context parallelism if cp is enabled + # ensure CP handles the separate freqs_cis buffer for each pp stage + inputs = input_dict["input"] + optional_context_parallel_ctx = ( + dist_utils.create_context_parallel_ctx( + cp_mesh=parallel_dims.world_mesh["cp"], + cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], + cp_seq_dims=[1, 1] + [0 for _ in model_parts], + cp_no_restore_buffers={inputs, labels}, + cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, + ) + if parallel_dims.cp_enabled + else None + ) + + if parallel_dims.pp_enabled: + # Pipeline Parallel forward / backward inside step() call + with self.train_context(optional_context_parallel_ctx): + targets, losses = ( + (labels, []) if self.pp_has_last_stage else (None, None) + ) + if self.pp_has_first_stage: + self.pp_schedule.step( + inputs, target=targets, losses=losses, input_batch=inputs + ) + else: + self.pp_schedule.step( + target=targets, losses=losses, input_batch=inputs + ) + + # accumulate losses across pipeline microbatches + # TODO: PP+FSDP unexpectedly puts the loss back to the CPU + loss = ( + torch.mean(torch.stack(losses)).to(self.device) + if self.pp_has_last_stage + else torch.tensor([-1.0], device=self.device) + ) + else: + # Non-PP forward / backward + with self.train_context(optional_context_parallel_ctx): + assert len(model_parts) == 1 + with self.maybe_enable_amp: + pred = model_parts[0](inputs) + loss = self.loss_fn(pred, labels) + # need to free to before bwd to avoid peaking memory + del pred + loss.backward() + + return loss + + def train_step( + self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] + ): + self.optimizers.zero_grad() + + # Keep these variables local to shorten the code as these are + # the major variables that are used in the training loop. + parallel_dims = self.parallel_dims + + accumulated_losses = [] + # If data runs out during gradient accumulation, that + # entire step will not be executed. + for microbatch in range(self.gradient_accumulation_steps): + input_dict, labels = next(data_iterator) + loss = self.forward_backward_step(input_dict, labels) + accumulated_losses.append(loss.detach()) + + grad_norm = dist_utils.clip_grad_norm_( + [p for m in self.model_parts for p in m.parameters()], + self.job_config.training.max_norm, + foreach=True, + pp_mesh=( + parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None + ), + parallel_dims=parallel_dims, + ) + self.checkpointer.maybe_wait_for_staging() + self.optimizers.step() + self.lr_schedulers.step() + + # Reduce the data collected over gradient accumulation steps. + loss = torch.sum(torch.stack(accumulated_losses)) + + # log metrics + if not self.metrics_processor.should_log(self.step): + return + + if parallel_dims.dp_cp_enabled or self.ft_manager.enabled: + loss = loss.detach() + # Skip ft manager communication when using semi sync training + use_ft_pg = ( + self.ft_manager.enabled + and self.job_config.fault_tolerance.semi_sync_method is None + ) + ft_pg = self.ft_manager.replicate_pg if use_ft_pg else None + global_avg_loss, global_max_loss = ( + dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), + dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), + ) + else: + global_avg_loss = global_max_loss = loss.detach().item() + + self.metrics_processor.log( + self.step, + global_avg_loss, + global_max_loss, + grad_norm.item(), + ) + + @record + def train(self): + job_config = self.job_config + + self.checkpointer.load(step=job_config.checkpoint.load_step) + logger.info(f"Training starts at step {self.step + 1}.") + + with ( + maybe_enable_profiling(job_config, global_step=self.step) as torch_profiler, + maybe_enable_memory_snapshot( + job_config, global_step=self.step + ) as memory_profiler, + maybe_semi_sync_training( + job_config, + ft_manager=self.ft_manager, + model_parts=self.model_parts, + optimizer=self.optimizers, + ), + ): + data_iterator = self.batch_generator(self.dataloader) + while self.step < job_config.training.steps: + self.step += 1 + self.gc_handler.run(self.step) + try: + self.train_step(data_iterator) + except DataloaderStopIteration: + logger.warning("Ran out of data; last step was canceled.") + break + + # Run validation if validator is available + if ( + self.job_config.validation.enabled + and self.validator.should_validate(self.step) + ): + self.validator.validate(self.model_parts) + + self.checkpointer.save( + self.step, last_step=(self.step == job_config.training.steps) + ) + + # signal the profiler that the next profiling step has started + if torch_profiler: + torch_profiler.step() + if memory_profiler: + memory_profiler.step() + + # reduce timeout after first train step for faster signal + # (assuming lazy init and compilation are finished) + if self.step == 1: + dist_utils.set_pg_timeouts( + timeout=timedelta( + seconds=job_config.comm.train_timeout_seconds + ), + world_mesh=self.parallel_dims.world_mesh, + ) + + if torch.distributed.get_rank() == 0: + logger.info("Sleeping 2 seconds for other ranks to complete") + time.sleep(2) + + logger.info("Training completed") + + def state_dict(self) -> dict[str, Any]: + return {"step": self.step} + + def load_state_dict(self, state_dict: dict[str, Any]): + self.step = state_dict["step"] + + def close(self) -> None: + if self.metrics_processor: + self.metrics_processor.close() + super().close() + + +if __name__ == "__main__": + init_logger() + config_manager = ConfigManager() + config = config_manager.parse_args() + trainer: Optional[Trainer] = None + + try: + trainer = Trainer(config) + trainer.train() + except Exception: + if trainer: + trainer.close() + raise + else: + trainer.close() + torch.distributed.destroy_process_group() + logger.info("Process group destroyed.") diff --git a/torchtitan/experiments/forge/job_config.py b/torchtitan/experiments/forge/job_config.py new file mode 100644 index 000000000..cbe4dcd2a --- /dev/null +++ b/torchtitan/experiments/forge/job_config.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import asdict, dataclass, field +from typing import Any + +from torchtitan.config_manager import ( + ActivationCheckpoint, + Checkpoint, + Comm, + Experimental, + FaultTolerance, + Float8, + LRScheduler, + Model, + MX, + Optimizer, + Parallelism, + Training, +) + + +@dataclass +class ForgeJobConfig: + model: Model = field(default_factory=Model) + optimizer: Optimizer = field(default_factory=Optimizer) + lr_scheduler: LRScheduler = field(default_factory=LRScheduler) + training: Training = field(default_factory=Training) + parallelism: Parallelism = field(default_factory=Parallelism) + checkpoint: Checkpoint = field(default_factory=Checkpoint) + activation_checkpoint: ActivationCheckpoint = field( + default_factory=ActivationCheckpoint + ) + float8: Float8 = field(default_factory=Float8) + mx: MX = field(default_factory=MX) + comm: Comm = field(default_factory=Comm) + fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance) + experimental: Experimental = field(default_factory=Experimental) + + def to_dict(self) -> dict[str, Any]: + return asdict(self) diff --git a/torchtitan/experiments/forge/train_spec.py b/torchtitan/experiments/forge/train_spec.py new file mode 100644 index 000000000..f52affa66 --- /dev/null +++ b/torchtitan/experiments/forge/train_spec.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + +# Import torchtitan.models to ensure all train specs are registered +import torchtitan.models # noqa: F401 + +from torchtitan.protocols.train_spec import ( + _train_specs, + BaseModelArgs, + LossFunctionBuilder, + LRSchedulersBuilder, + ModelProtocol, + OptimizersBuilder, + ParallelizeFunction, + PipeliningFunction, + TokenizerBuilder, + TrainSpec, +) + + +@dataclass +class ForgeTrainSpec: + name: str + model_cls: type[ModelProtocol] + model_args: dict[str, BaseModelArgs] + parallelize_fn: ParallelizeFunction + pipelining_fn: PipeliningFunction | None + build_optimizers_fn: OptimizersBuilder + build_lr_schedulers_fn: LRSchedulersBuilder + build_tokenizer_fn: TokenizerBuilder | None + build_loss_fn: LossFunctionBuilder + + +# Copy and transform train specs from torchtitan.protocols.train_spec._train_specs +# This happens during import after all models have been registered +_forge_train_specs = {} + + +def register_train_spec(train_spec: ForgeTrainSpec) -> None: + global _forge_train_specs + if train_spec.name in _forge_train_specs: + raise ValueError(f"Model {train_spec.name} is already registered.") + + _forge_train_specs[train_spec.name] = train_spec + + +def get_train_spec(name: str) -> ForgeTrainSpec: + global _forge_train_specs + if name not in _forge_train_specs: + raise ValueError(f"Model {name} is not registered.") + return _forge_train_specs[name] + + +def _transform_train_spec(original_spec: TrainSpec): + """Transform the original train spec to ForgeTrainSpec format.""" + # Create a new TrainSpec with only the fields we need in forge + return ForgeTrainSpec( + name=original_spec.name, + model_cls=original_spec.model_cls, + model_args=original_spec.model_args, + parallelize_fn=original_spec.parallelize_fn, + pipelining_fn=original_spec.pipelining_fn, + build_optimizers_fn=original_spec.build_optimizers_fn, + build_lr_schedulers_fn=original_spec.build_lr_schedulers_fn, + build_tokenizer_fn=original_spec.build_tokenizer_fn, + build_loss_fn=original_spec.build_loss_fn, + ) + + +# Populate _forge_train_specs with transformed specs +for name, spec in _train_specs.items(): + register_train_spec(_transform_train_spec(spec))