Skip to content

add float8 support #1378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: autoparallel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,28 @@ class Experimental:
needs to ensure that the path can be imported.
"""

reorder_for_compute_comm_overlap: bool = False
"""
Whether to enable inductor comm reordering passes
"""

reorder_for_compute_comm_overlap_passes: list[str] = field(
default_factory=lambda: [
"sink_waits",
"reorder_communication_preserving_peak_memory",
]
)
"""
Sequence of reordering passes (names of functions inside _inductor.comms) to call,
if reorder_for_compute_comm_overlap is enabled.
"""

reorder_prefetch_limit: int | None = None
"""
How many ops to allow moving any individual collective, if 'reorder_communication_preserving_peak_memory'
pass is enabled. default of None means unlimited
"""


@dataclass
class JobConfig:
Expand Down
1 change: 1 addition & 0 deletions torchtitan/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torchtitan.experiments.auto_parallel # noqa: F401
import torchtitan.experiments.llama4 # noqa: F401
import torchtitan.experiments.simple_fsdp # noqa: F401
7 changes: 7 additions & 0 deletions torchtitan/experiments/auto_parallel/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
## Auto Parallel

requires installing git@github.com:pytorch-labs/autoparallel.git

`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4`

(or llama3-8b.toml)
31 changes: 31 additions & 0 deletions torchtitan/experiments/auto_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer
from torchtitan.models.llama3 import llama3_configs, pipeline_llama, Transformer
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
from .parallelize_llama import parallelize_llama

register_train_spec(
TrainSpec(
name="llama3_auto_parallel",
cls=Transformer,
config=llama3_configs,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
build_tokenizer_fn=build_tiktoken_tokenizer,
build_loss_fn=build_cross_entropy_loss,
)
)
85 changes: 85 additions & 0 deletions torchtitan/experiments/auto_parallel/parallelize_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import time

import torch

from autoparallel.api import AutoParallel

from torch.distributed import DeviceMesh
from torch.distributed.tensor.placement_types import Replicate, Shard

from torchtitan.config_manager import JobConfig
from torchtitan.distributed import ParallelDims

from torchtitan.tools.logging import logger


def parallelize_llama(
model,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
):
"""
Apply tensor parallelism, activation checkpointing, torch.compile, and data
parallelism to the model.

NOTE: The passed-in model preferably should be on meta device. Otherwise,
the model must fit on GPU or CPU memory.
"""

def input_fn():
global_batch_size = job_config.training.global_batch_size
if global_batch_size < 0:
# This global batch size results in 1 gradient accumulation
# step.
dp_degree = world_mesh["dp"].size()
global_batch_size = job_config.training.local_batch_size * dp_degree
return torch.rand(
(global_batch_size, job_config.training.seq_len), device="cuda"
)

# TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP
assert (
len(world_mesh.shape) == 2
), "Only support 2D mesh (DP, TP) for now- OK if one has size=1"
assert parallel_dims.dp_shard_enabled is True, "DDP not supported yet"
assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet"
assert parallel_dims.cp_enabled is False, "CP not supported yet"
assert parallel_dims.pp_enabled is False, "PP not supported yet"

# TODO: there are multiple float8 recipes, this just hardcodes one
enable_float8_linear = "float8" in job_config.model.converters
if enable_float8_linear:
import copy
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.config import Float8LinearConfig
model = convert_to_float8_training(copy.deepcopy(model), config=Float8LinearConfig())

# bail out
# model = model_fn()
# return model

autop = AutoParallel(model, input_fn, world_mesh)
autop.add_parameter_memory_constraint(low=None, high=None)

x_sharding = (Shard(0), Replicate())

autop.add_input_constraints([x_sharding])
autop.add_output_constraints([x_sharding])
t0 = time.time()
sharding_placement = autop.optimize_placement()
t1 = time.time()
logger.info(f"AutoParallel took {t1 - t0} seconds")
parallel_mod = autop.apply_placement(sharding_placement)

if job_config.training.compile:
torch._inductor.config.reorder_for_peak_memory = False
parallel_mod = torch.compile(parallel_mod, fullgraph=True)

return parallel_mod
27 changes: 21 additions & 6 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import torch
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.tensor import DTensor

import torchtitan.components.ft as ft
import torchtitan.protocols.train_spec as train_spec_module
Expand Down Expand Up @@ -113,6 +114,21 @@ def __init__(self, job_config: JobConfig):
gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug
)

# TODO(whc)
# I do this becuase otherwise sometimes inductor will skip re-running passes like comms reordering
torch._inductor.config.force_disable_caches = True

# allow configuring inductor comms optimizations from torchtitan commandline
torch._inductor.config.reorder_for_compute_comm_overlap = (
job_config.experimental.reorder_for_compute_comm_overlap
)
torch._inductor.config.reorder_for_compute_comm_overlap_passes = (
job_config.experimental.reorder_for_compute_comm_overlap_passes
)
torch._inductor.config.reorder_prefetch_limit = (
job_config.experimental.reorder_prefetch_limit
)

# Set random seed, and maybe enable deterministic mode
# (mainly for debugging, expect perf loss).
dist_utils.set_determinism(
Expand All @@ -138,20 +154,19 @@ def __init__(self, job_config: JobConfig):
)

# build model (using meta init)
model_cls = self.train_spec.cls
model_args = self.train_spec.config[job_config.model.flavor]
model_cls = self.train_spec.cls
# set the model args from training job configs
model_args.update_from_config(job_config, tokenizer)

logger.info(
f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}"
)

with torch.device("meta"):
model = model_cls.from_model_args(model_args)

# Build the collection of model converters. No-op if `model.converters` empty
model_converters = build_model_converters(job_config, parallel_dims)
model_converters.convert(model)
# Build the collection of model converters. No-op if `model.converters` empty
model_converters = build_model_converters(job_config, parallel_dims)
model_converters.convert(model)

# metrics logging
build_metrics_processor_fn = (
Expand Down
Loading