diff --git a/test/prototype/moe_training/test_fsdp.py b/test/prototype/moe_training/test_fsdp.py index 4994a76854..302256dbec 100644 --- a/test/prototype/moe_training/test_fsdp.py +++ b/test/prototype/moe_training/test_fsdp.py @@ -1,3 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +###################################################################### +# +# To run these unit tests, use the following command: +# +# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_fsdp.py +# +####################################################################### + import copy import os diff --git a/test/prototype/moe_training/test_tp.py b/test/prototype/moe_training/test_tp.py new file mode 100644 index 0000000000..3984ae2d40 --- /dev/null +++ b/test/prototype/moe_training/test_tp.py @@ -0,0 +1,219 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +###################################################################### +# +# To run these unit tests, use the following command: +# +# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_tp.py +# +####################################################################### + +import copy +import os + +import pytest +import torch +from torch import distributed as dist +from torch import nn +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.nn import functional as F + +# this feature requires CUDA and SM89+ +if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9): + pytest.skip( + "CUDA not available or compute capability < 8.9", allow_module_level=True + ) + +from torchao.float8.float8_utils import compute_error +from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig +from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor +from torchao.quantization.quant_api import quantize_ + +# this test requires torchtitan +try: + from torchtitan.experiments.llama4.infra.parallelize import apply_moe_tp + from torchtitan.experiments.llama4.model.args import TransformerModelArgs + from torchtitan.experiments.llama4.model.moe import MoE +except ImportError: + import warnings + + warnings.warn("torchtitan not installed, skipping MoE tests.") + pytest.skip(allow_module_level=True) + + +@pytest.mark.parametrize( + "target_fqns", + [ + ["experts"], + ["experts,shared_expert"], + ], +) +def test_moe_float8_training_tp_sp(target_fqns: list[str]): + assert torch.cuda.is_available() + + # setup distributed for fsdp + mesh = setup_distributed() + + # define model args + model_args = TransformerModelArgs( + moe_enabled=True, + num_experts=8, + dim=256, + vocab_size=1024, + ) + init_std = 0.02 + device = torch.device("cuda") + + # reference bf16 MoE + ref_model = MoE(model_args).to(torch.bfloat16).cuda() + torch.manual_seed(1) + ref_model.init_weights(init_std, device) + + # target MoE for testing conversion + model = copy.deepcopy(ref_model) + + # assert starting params are identical for both models + for param1, param2 in zip(model.parameters(), ref_model.parameters()): + assert torch.equal(param1, param2) + + # convert MoE to float8 training + def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: + for target_fqn in target_fqns: + if target_fqn in cur_fqn: + return True + return False + + # quantize test model + config = MoETrainingConfig() + quantize_(model, config=config, filter_fn=moe_module_filter_fn) + + # validate that only the experts were converted + _validate_model_conversion( + model, + target_fqns=target_fqns, + ) + + # apply TP + apply_moe_tp(model, mesh) + apply_moe_tp(ref_model, mesh) + + # inputs + batch, seq, dim = 8, 2048, 256 + ref_x = torch.randn( + batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device + ) + x = ref_x.detach().clone().requires_grad_(True) + + # forward pass + ref_out = ref_model(ref_x) + out = model(x) + + # validate output + out_sqnr = compute_error(out, ref_out) + assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}." + + # compute loss + labels = torch.ones_like(ref_out) + ref_loss = F.mse_loss(ref_out, labels) + out_loss = F.mse_loss(out, labels) + + # backward pass + ref_loss.backward() + out_loss.backward() + + # validate input gradient + input_grad_sqnr = compute_error(x.grad, ref_x.grad) + assert input_grad_sqnr.item() >= 30.0, ( + f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}." + ) + + # validate param gradients + for param1, param2 in zip(model.parameters(), ref_model.parameters()): + param_grad_sqnr = compute_error(param1.grad, param2.grad) + assert param_grad_sqnr.item() >= 25.0, ( + f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}." + ) + + dist.destroy_process_group() + + +def _validate_model_conversion( + root_module: nn.Module, + target_fqns: list[str], +): + def _recursive_validate( + module: nn.Module, + cur_fqn: str, + ): + is_allowed_module = any([target_fqn in cur_fqn for target_fqn in target_fqns]) + + # check current module params + for param_name, param in module.named_parameters(recurse=False): + is_converted_type = isinstance(param, ScaledGroupedMMTensor) + if is_converted_type: + assert is_allowed_module, ( + f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." + ) + if not is_allowed_module: + assert not is_converted_type, ( + f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." + ) + + # recursively check child modules + for child_name, child_module in module.named_children(): + child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name + _recursive_validate(child_module, child_fqn) + + _recursive_validate(root_module, "") + + +def setup_distributed(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + device_mesh = init_device_mesh("cuda", (world_size,)) + # seed must be the same in all processes + torch.manual_seed(1) + torch.cuda.set_device(rank) + return device_mesh + + +def apply_moe_tp( + model: nn.Module, + tp_mesh: DeviceMesh, +): + # base on llama4 MoE TP implementation here: https://github.com/pytorch/torchtitan/blob/d9cc6b4df341eec27768b5ab9cead87ef595dbc2/torchtitan/experiments/llama4/infra/parallelize.py#L147C1-L180C10 + from torch.distributed.tensor import Partial, Replicate, Shard + from torch.distributed.tensor.parallel import ( + PrepareModuleInputOutput, + parallelize_module, + ) + from torchtitan.experiments.llama4.infra.expert_parallel import ( + NoParallel, + TensorParallel, + ) + + moe_layer_plan = { + # input / output sharding on the seqlen dim + # all-gather for input, reduce-scatter for output + "moe": PrepareModuleInputOutput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + use_local_input=True, + output_layouts=(Partial(),), + desired_output_layouts=(Shard(1),), + ), + # replicate computation for the router + "moe.router.gate": NoParallel(), + # input Replicate, output Partial + "moe.experts": TensorParallel(output_layout=Partial()), + "moe.shared_expert": TensorParallel(output_layout=Partial()), + } + parallelize_module( + module=model, + device_mesh=tp_mesh, + parallelize_plan=moe_layer_plan, + ) diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index 29adffd831..602f2cf626 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import logging from typing import Optional import torch @@ -18,11 +19,13 @@ _is_column_major, ) +logger: logging.Logger = logging.getLogger(__name__) + def _scaled_grouped_mm( A: torch.Tensor, B_t: torch.Tensor, - offs: torch.Tensor, + offs: Optional[torch.Tensor] = None, out_dtype: Optional[torch.dtype] = torch.bfloat16, ) -> torch.Tensor: """ @@ -36,8 +39,8 @@ def _scaled_grouped_mm( and in column-major memory layout. offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor. out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported. - use_triton_for_per_group_scales (bool): Whether to use custom triton kernels to compute per-group scales. Default is True. """ + logger.info("Using scaled_grouped_mm") return _Float8GroupedMM.apply( A, B_t, @@ -54,12 +57,11 @@ def forward( ctx, A: torch.Tensor, B_t: torch.Tensor, - offs: torch.Tensor, + offs: Optional[torch.Tensor] = None, out_dtype: Optional[torch.dtype] = torch.bfloat16, - use_triton_for_per_group_scales: bool = True, ) -> torch.Tensor: - # torchao _scaled_grouped_mm only supports A=2D, B=3D. - assert A.ndim == 2, "A must be 2D" + # torchao _scaled_grouped_mm only supports A=2D|3D and B=3D. + assert A.ndim == 2 or A.ndim == 3, "A must be 2D or 3D" assert B_t.ndim == 3, "B must be 3D" assert A.size(-1) % 16 == 0, ( @@ -76,7 +78,6 @@ def forward( assert B_t.dtype == torch.float32 or B_t.dtype == torch.bfloat16, ( "B must be float32 or bfloat16" ) - assert offs.dtype == torch.int32, "offs must be int32" # Assert A and B dims are compatible for a scaled grouped GEMM. assert A.size(-1) == B_t.size(-2), ( @@ -152,9 +153,11 @@ def forward( return torch._scaled_grouped_mm( A_fp8_row_major, B_t_fp8_col_major, - A_scales.squeeze().reciprocal(), - B_t_scales.squeeze().reciprocal(), - offs, + # Squeeze A scales to: (B, S, 1) => (B, M), or (B*S, 1) => (B*S) + A_scales.squeeze(-1).reciprocal(), + # Squeeze B scales to: (B, 1, N) => (B, N) + B_t_scales.squeeze(1).reciprocal(), + offs=offs, out_dtype=out_dtype, use_fast_accum=True, ) @@ -194,9 +197,9 @@ def backward(ctx, grad_output: torch.Tensor): grad_A = torch._scaled_grouped_mm( grad_output_fp8_row_major, B_fp8_col_major, - grad_output_scales.squeeze().reciprocal(), - B_scales.squeeze().reciprocal(), - offs, + grad_output_scales.squeeze(-1).reciprocal(), + B_scales.squeeze(1).reciprocal(), + offs=offs, out_dtype=out_dtype, use_fast_accum=True, ) @@ -241,7 +244,7 @@ def backward(ctx, grad_output: torch.Tensor): A_fp8_col_major, grad_output_t_scales.reciprocal(), A_scales.reciprocal(), - offs, + offs=offs, out_dtype=out_dtype, use_fast_accum=True, ) diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index b41527a4ae..1b0046a5c8 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -9,13 +9,16 @@ import torch import torch.utils._pytree as pytree +from torch import nn from torch._prims_common import suggest_memory_format +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.autograd.grad_mode import _unsafe_preserve_version_counter from torchao.prototype.moe_training import _scaled_grouped_mm logger: logging.Logger = logging.getLogger(__name__) - _ops_to_preserve_subclass = { torch.ops.aten.empty_like.default, torch.ops.aten.new_zeros.default, @@ -37,7 +40,7 @@ class ScaledGroupedMMTensor(torch.Tensor): differentiable _scaled_grouped_mm autograd function. """ - grouped_mm_func_name = "_grouped_mm" + grouped_mm_func_names = {"_grouped_mm", "_grouped_mm.default"} offs_arg_name = "offs" @staticmethod @@ -46,6 +49,7 @@ def __new__( tensor: torch.Tensor, dtype: torch.dtype, ): + logger.debug(f"__new__: Creating ScaledGroupedMMTensor with dtype={dtype}") return torch.Tensor._make_wrapper_subclass( cls, tensor.size(), @@ -64,14 +68,16 @@ def __init__( tensor: torch.Tensor, dtype: torch.dtype, ): - self._data = tensor + self._data = tensor.to(dtype) self._dtype = dtype + logger.debug(f"__init__: ScaledGroupedMMTensor with self._data.dtype={self._data.dtype} and dtype={dtype}") @classmethod def __torch_function__(cls, func, types, args, kwargs={}): - logger.info(f"{func.__name__}, args: {args}, kwargs: {kwargs}") + logger.debug(f"func: {func.__name__}, args={args}, kwargs={kwargs}") + # override the grouped mm op to use the differentiable _scaled_grouped_mm - if func.__name__ == cls.grouped_mm_func_name: + if func.__name__ in cls.grouped_mm_func_names: # Use torchao scaled grouped mm with dynamic quant for # "2d x 3d with offsets" case (used for routed experts). # Otherwise, fall back to regular grouped mm. @@ -80,10 +86,9 @@ def __torch_function__(cls, func, types, args, kwargs={}): # used for shared experts. This is basically the grouped_mm # kernel handling a bmm. A, B = args[0], args[1] - A_is_2d = A.dim() == 2 + A_is_2d_or_3d = A.dim() in (2, 3) B_is_3d = B.dim() == 3 - has_offs = kwargs.get(cls.offs_arg_name) is not None - if A_is_2d and B_is_3d and has_offs: + if A_is_2d_or_3d and B_is_3d: return _scaled_grouped_mm( *args, **kwargs, @@ -96,21 +101,13 @@ def __torch_function__(cls, func, types, args, kwargs={}): @classmethod def __torch_dispatch__(cls, func, types, args, kwargs={}): + logger.debug(f"dispatch: {func.__name__}, args={args}, kwargs={kwargs}") # detach is special case if func == torch.ops.aten.detach.default: return ScaledGroupedMMTensor(args[0]._data, args[0]._dtype) # unwrap args and kwargs - dtype: Optional[torch.dtype] = None - - def unwrap(t): - nonlocal dtype - if dtype is None: - dtype = t._dtype - else: - assert t._dtype == dtype - return t._data - + unwrap = lambda x: x._data args, kwargs = pytree.tree_map_only( ScaledGroupedMMTensor, unwrap, (args, kwargs or {}) ) @@ -125,16 +122,17 @@ def unwrap(t): # wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass return pytree.tree_map_only( torch.Tensor, - lambda x: ScaledGroupedMMTensor(x, dtype), + lambda x: ScaledGroupedMMTensor(x, x.dtype), out, ) def __repr__(self): - return f"ScaledGroupedMMTensor(data={self._data}, dtype={self._dtype})" + return f"ScaledGroupedMMTensor(data.dtype={self._data.dtype}, self.dtype={self._dtype})" def __tensor_flatten__(self): return ["_data"], {"_dtype": self._dtype} + @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): return ScaledGroupedMMTensor( @@ -142,9 +140,18 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): flatten_spec["_dtype"], ) - def fsdp_pre_all_gather(self, mesh): + # fsdp hooks based on https://github.com/pytorch/pytorch/blob/20e40492b046b9287726d3ec656117e4dc38f0e2/test/distributed/_composable/fsdp/test_fully_shard_extensions.py#L81 + def fsdp_pre_all_gather( + self, + mesh: DeviceMesh, + outer_size: torch.Size, + outer_stride: tuple[int, ...], + module: nn.Module, + mp_policy: MixedPrecisionPolicy, + ): all_gather_inputs = (self._data,) all_gather_metadata = () + logger.debug(f"fsdp_pre_all_gather: self._data.dtype={self._data.dtype}, param_dtype: {mp_policy.param_dtype}") return all_gather_inputs, all_gather_metadata def fsdp_post_all_gather( @@ -156,6 +163,13 @@ def fsdp_post_all_gather( out: Optional[torch.Tensor] = None, ): (data,) = all_gather_outputs + logger.debug(f"fsdp_post_all_gather: data.dtype={data.dtype}, param_dtype: {param_dtype}") + + if out is not None: + # with torch.no_grad(): + # out.copy_(data) + return + output = ScaledGroupedMMTensor(data, param_dtype) inner_tensors = (data,) return output, inner_tensors