diff --git a/test/prototype/moe_training/test_fsdp.py b/test/prototype/moe_training/test_fsdp.py index 4994a76854..074fd3e4a0 100644 --- a/test/prototype/moe_training/test_fsdp.py +++ b/test/prototype/moe_training/test_fsdp.py @@ -16,9 +16,10 @@ from torchao.float8.float8_utils import compute_error from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig -from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor from torchao.quantization.quant_api import quantize_ +from .testing_utils import _validate_model_conversion + # this test requires torchtitan try: from torchtitan.experiments.llama4.model.args import TransformerModelArgs @@ -119,36 +120,6 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: dist.destroy_process_group() -def _validate_model_conversion( - root_module: nn.Module, - target_fqns: list[str], -): - def _recursive_validate( - module: nn.Module, - cur_fqn: str, - ): - is_allowed_module = cur_fqn in target_fqns - - # check current module params - for param_name, param in module.named_parameters(recurse=False): - is_converted_type = isinstance(param, ScaledGroupedMMTensor) - if is_converted_type: - assert is_allowed_module, ( - f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." - ) - if not is_allowed_module: - assert not is_converted_type, ( - f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." - ) - - # recursively check child modules - for child_name, child_module in module.named_children(): - child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name - _recursive_validate(child_module, child_fqn) - - _recursive_validate(root_module, "") - - def setup_distributed(): rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) diff --git a/test/prototype/moe_training/test_fsdp.sh b/test/prototype/moe_training/test_fsdp.sh index 353ad3fad2..5f858061f4 100755 --- a/test/prototype/moe_training/test_fsdp.sh +++ b/test/prototype/moe_training/test_fsdp.sh @@ -1 +1 @@ -torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_fsdp.py +torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_fsdp.py -s diff --git a/test/prototype/moe_training/test_tp.py b/test/prototype/moe_training/test_tp.py new file mode 100644 index 0000000000..1088f01654 --- /dev/null +++ b/test/prototype/moe_training/test_tp.py @@ -0,0 +1,245 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +###################################################################### +# +# To run these unit tests, use the following command: +# +# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_tp.py +# +####################################################################### + +import copy +import os + +import pytest +import torch +from torch import distributed as dist +from torch import nn +from torch.distributed._tensor import DTensor +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.tensor import Partial, Replicate, Shard +from torch.nn import functional as F + +try: + from torch.distributed.tensor.parallel import ( + PrepareModuleInputOutput, + parallelize_module, + ) +except ImportError: + import warnings + + warnings.warn( + "torch version is too old, these tests require nightly build. Skipping MoE training tests." + ) + pytest.skip(allow_module_level=True) + + +# this feature requires CUDA and SM89+ +if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9): + pytest.skip( + "CUDA not available or compute capability < 8.9", allow_module_level=True + ) + +from torchao.float8.float8_utils import compute_error +from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig +from torchao.quantization.quant_api import quantize_ + +from .testing_utils import _validate_model_conversion + +# this test requires torchtitan +try: + from torchtitan.experiments.llama4.infra.expert_parallel import ( + ExpertParallel, + ExpertTensorParallel, + NoParallel, + TensorParallel, + ) + from torchtitan.experiments.llama4.model.args import TransformerModelArgs + from torchtitan.experiments.llama4.model.moe import MoE +except ImportError: + import warnings + + warnings.warn("torchtitan not installed, skipping MoE tests.") + pytest.skip(allow_module_level=True) + + +@pytest.mark.parametrize( + "target_fqns", + [ + ["experts"], + # TODO: investigate hang when shared_expert is converted + # ["experts,shared_expert"], + ], +) +def test_moe_float8_training_tp(target_fqns: list[str]): + assert torch.cuda.is_available() + + # setup distributed for tp + mesh = setup_distributed() + + # define model args + model_args = TransformerModelArgs( + moe_enabled=True, + num_experts=8, + dim=256, + vocab_size=1024, + ) + init_std = 0.02 + device = torch.device("cuda") + + # reference bf16 MoE + ref_model = MoE(model_args).to(torch.bfloat16).cuda() + torch.manual_seed(1) + ref_model.init_weights(init_std, device) + + # target MoE for testing conversion + model = copy.deepcopy(ref_model) + + # assert starting params are identical for both models + for param1, param2 in zip(model.parameters(), ref_model.parameters()): + assert torch.equal(param1, param2) + + # convert MoE to float8 training + def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: + for target_fqn in target_fqns: + if target_fqn in cur_fqn: + return True + return False + + # quantize test model + config = MoETrainingConfig() + quantize_(model, config=config, filter_fn=moe_module_filter_fn) + + # validate that only the experts were converted + _validate_model_conversion( + model, + target_fqns=target_fqns, + ) + + # apply TP + apply_moe_ep_tp(model, tp_mesh=mesh, ep_mesh=None, ep_tp_mesh=None) + apply_moe_ep_tp(ref_model, tp_mesh=mesh, ep_mesh=None, ep_tp_mesh=None) + + # Rough validation that parallelization was applied properly. + assert isinstance(model.experts.w1.data, DTensor), ( + "test model experts.w1 is not a DTensor" + ) + assert isinstance(model.experts.w2.data, DTensor), ( + "test model experts.w2 is not a DTensor" + ) + assert isinstance(model.experts.w3.data, DTensor), ( + "test model experts.w3 is not a DTensor" + ) + assert isinstance(ref_model.experts.w1.data, DTensor), ( + "ref model experts.w1 is not a DTensor" + ) + assert isinstance(ref_model.experts.w2.data, DTensor), ( + "ref model experts.w2 is not a DTensor" + ) + assert isinstance(ref_model.experts.w3.data, DTensor), ( + "ref model experts.w3 is not a DTensor" + ) + + # inputs + batch, seq, dim = 8, 2048, 256 + ref_x = torch.randn( + batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device + ) + x = ref_x.detach().clone().requires_grad_(True) + + # forward pass + ref_out = ref_model(ref_x) + out = model(x) + + # validate output + out_sqnr = compute_error(out, ref_out) + assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}." + + # compute loss + labels = torch.ones_like(ref_out) + ref_loss = F.mse_loss(ref_out, labels) + out_loss = F.mse_loss(out, labels) + + # backward pass + ref_loss.backward() + out_loss.backward() + + # validate input gradient + input_grad_sqnr = compute_error(x.grad, ref_x.grad) + assert input_grad_sqnr.item() >= 28.0, ( + f"SQNR must be >= 28.0, got {input_grad_sqnr.item()}." + ) + + # validate param gradients + for param1, param2 in zip(model.parameters(), ref_model.parameters()): + param_grad_sqnr = compute_error(param1.grad, param2.grad) + assert param_grad_sqnr.item() >= 25.0, ( + f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}." + ) + + dist.destroy_process_group() + + +def setup_distributed(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + device_mesh = init_device_mesh("cuda", (world_size,)) + # seed must be the same in all processes + torch.manual_seed(1) + torch.cuda.set_device(rank) + return device_mesh + + +def apply_moe_ep_tp( + model: nn.Module, + tp_mesh: DeviceMesh | None, + ep_mesh: DeviceMesh | None, + ep_tp_mesh: DeviceMesh | None, +): + # Modified version of moe parallelization from https://github.com/pytorch/torchtitan/pull/1324/ + # that supports single MoE layer independent of a transformer. + if tp_mesh is not None: + moe_layer_plan = { + # input / output sharding on the seqlen dim + # all-gather for input, reduce-scatter for output + "moe": PrepareModuleInputOutput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + use_local_input=True, + output_layouts=(Partial(),), + desired_output_layouts=(Shard(1),), + ), + # replicate computation for the router + "moe.router.gate": NoParallel(), + # input Replicate, output Partial + "moe.shared_expert": TensorParallel(), + } + parallelize_module( + module=model, + device_mesh=tp_mesh, + parallelize_plan=moe_layer_plan, + ) + + # if ep_mesh is not None: + experts_mesh, experts_plan = None, None + if ep_mesh is None: + experts_mesh = tp_mesh + # input Replicate, output Partial + experts_plan = TensorParallel() + elif tp_mesh is None: + experts_mesh = ep_mesh + # input / output sharding on the batch / tokens dim + experts_plan = ExpertParallel() + else: + experts_mesh = ep_tp_mesh + experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh) + + parallelize_module( + module=model.experts, + device_mesh=experts_mesh, + parallelize_plan=experts_plan, + ) diff --git a/test/prototype/moe_training/test_tp.sh b/test/prototype/moe_training/test_tp.sh old mode 100644 new mode 100755 index 16905c0538..2ab7636113 --- a/test/prototype/moe_training/test_tp.sh +++ b/test/prototype/moe_training/test_tp.sh @@ -1 +1 @@ -torchrun --nproc_per_node=2 -m pytest test/prototype/moe_training/test_tp.py +torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_tp.py -s diff --git a/test/prototype/moe_training/test_training.py b/test/prototype/moe_training/test_training.py index 71320af83e..7087d1d571 100644 --- a/test/prototype/moe_training/test_training.py +++ b/test/prototype/moe_training/test_training.py @@ -13,9 +13,10 @@ from torchao.float8.float8_utils import compute_error from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig -from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor from torchao.quantization.quant_api import quantize_ +from .testing_utils import _validate_model_conversion + # this test requires torchtitan try: from torchtitan.experiments.llama4.model.args import TransformerModelArgs @@ -108,33 +109,3 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: assert param_grad_sqnr.item() >= 25.0, ( f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}." ) - - -def _validate_model_conversion( - root_module: nn.Module, - target_fqns: list[str], -): - def _recursive_validate( - module: nn.Module, - cur_fqn: str, - ): - is_allowed_module = cur_fqn in target_fqns - - # check current module params - for param_name, param in module.named_parameters(recurse=False): - is_converted_type = isinstance(param, ScaledGroupedMMTensor) - if is_converted_type: - assert is_allowed_module, ( - f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." - ) - if not is_allowed_module: - assert not is_converted_type, ( - f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." - ) - - # recursively check child modules - for child_name, child_module in module.named_children(): - child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name - _recursive_validate(child_module, child_fqn) - - _recursive_validate(root_module, "") diff --git a/test/prototype/moe_training/testing_utils.py b/test/prototype/moe_training/testing_utils.py new file mode 100644 index 0000000000..cf13b81ae3 --- /dev/null +++ b/test/prototype/moe_training/testing_utils.py @@ -0,0 +1,33 @@ +from torch import nn + +from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor + + +def _validate_model_conversion( + root_module: nn.Module, + target_fqns: list[str], +): + def _recursive_validate( + module: nn.Module, + cur_fqn: str, + ): + is_allowed_module = any([target_fqn in cur_fqn for target_fqn in target_fqns]) + + # check current module params + for param_name, param in module.named_parameters(recurse=False): + is_converted_type = isinstance(param, ScaledGroupedMMTensor) + if is_converted_type: + assert is_allowed_module, ( + f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." + ) + if not is_allowed_module: + assert not is_converted_type, ( + f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." + ) + + # recursively check child modules + for child_name, child_module in module.named_children(): + child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name + _recursive_validate(child_module, child_fqn) + + _recursive_validate(root_module, "") diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index 5a08074d5d..d9ccdcba03 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -25,7 +25,7 @@ def _scaled_grouped_mm( A: torch.Tensor, B_t: torch.Tensor, - offs: torch.Tensor, + offs: Optional[torch.Tensor] = None, out_dtype: Optional[torch.dtype] = torch.bfloat16, ) -> torch.Tensor: """ @@ -35,12 +35,13 @@ def _scaled_grouped_mm( Args: A (bf16/float32 torch.Tensor): The first high-precision input tensor, which must be a 2D tensor of shape (M * num_groups, K) and in row-major memory layout. - B_t (bf16/float32 torch.Tensor): The second high-precision input tensor which must be 3D, which must be shape (B, K, N) + B_t (bf16/float32 torch.Tensor): The second high-precision input tensor which must be 3D, which must be shape (E, K, N) and in column-major memory layout. offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor. out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported. """ - # logger.info("Using scaled_grouped_mm") + # TODO: Remove once prototype is more mature. This is currently very useful for development and debugging. + logger.info("Using scaled_grouped_mm") return _Float8GroupedMM.apply( A, B_t, @@ -57,12 +58,11 @@ def forward( ctx, A: torch.Tensor, B_t: torch.Tensor, - offs: torch.Tensor, + offs: Optional[torch.Tensor] = None, out_dtype: Optional[torch.dtype] = torch.bfloat16, - use_triton_for_per_group_scales: bool = True, ) -> torch.Tensor: - # torchao _scaled_grouped_mm only supports A=2D, B=3D. - assert A.ndim == 2, "A must be 2D" + # torchao _scaled_grouped_mm only supports A=2D|3D and B=3D. + assert A.ndim == 2 or A.ndim == 3, "A must be 2D or 3D" assert B_t.ndim == 3, "B must be 3D" assert A.size(-1) % 16 == 0, ( @@ -79,7 +79,9 @@ def forward( assert B_t.dtype == torch.float32 or B_t.dtype == torch.bfloat16, ( "B must be float32 or bfloat16" ) - assert offs.dtype == torch.int32, "offs must be int32" + assert offs is None or offs.dtype == torch.int32, ( + "offs must be int32 tensor or None" + ) # Assert A and B dims are compatible for a scaled grouped GEMM. assert A.size(-1) == B_t.size(-2), ( @@ -96,8 +98,8 @@ def forward( B_t = B_t.transpose(-2, -1).contiguous().transpose(-2, -1) # Convert high precision input tensor to float8, row-major for left operand of grouped GEMM. - # A shape: (M, K) - # A_scales shape: (M,1) + # A shape: (M, K) or (B, M, K) + # A_scales shape: (M,1) or (B, M, 1) A_scales = tensor_to_scale( A, torch.float8_e4m3fn, @@ -109,9 +111,9 @@ def forward( A_fp8_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn) # Convert B to float8, column-major for right operand of grouped GEMM. - # B shape: (B, K, N) + # B shape: (E, K, N) # B scales must be computed rowwise keeping the outer/final dim, so: - # B_scales shape: (B, 1, N) + # B_scales shape: (E, 1, N) B_t_scales = tensor_to_scale( B_t, torch.float8_e4m3fn, @@ -127,9 +129,9 @@ def forward( # In the backward this is needed for grad_A: grad_output @ B. B = B_t.contiguous().transpose(-2, -1) - # - B shape: (B, K, N) + # - B shape: (E, K, N) # - B scales must be computed rowwise keeping the outer/final dim, so: - # - B_scale shape: (B, 1, N) + # - B_scale shape: (E, 1, N) B_scales = tensor_to_scale( B, torch.float8_e4m3fn, @@ -152,11 +154,17 @@ def forward( assert _is_column_major(B_t_fp8_col_major), ( "B must be column-major for output = A @ B" ) + + # Squeeze empty dims out of scales, to comply with grouped mm API. + # A_scales shape: (M,1) or (B, M, 1) + # B_t_scales shape: (E, 1, N) + A_scales = A_scales.squeeze(-1) + B_t_scales = B_t_scales.squeeze(1) return torch._scaled_grouped_mm( A_fp8_row_major, B_t_fp8_col_major, - A_scales.squeeze().reciprocal(), - B_t_scales.squeeze().reciprocal(), + A_scales.reciprocal(), # Reciprocals are needed for rescaling the output. + B_t_scales.reciprocal(), offs, out_dtype=out_dtype, use_fast_accum=True, @@ -185,7 +193,6 @@ def backward(ctx, grad_output: torch.Tensor): ) # Compute grad_A. - # # grad_A = grad_output @ B # grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K) assert not _is_column_major(grad_output_fp8_row_major), ( @@ -194,6 +201,12 @@ def backward(ctx, grad_output: torch.Tensor): assert _is_column_major(B_fp8_col_major), ( "B must be column-major for grad_A = grad_output @ B" ) + + # Squeeze empty dims out of scales, to comply with grouped mm API. + # grad_output_scales shape: (M,1) or (B, M, 1) + # B_scales shape: (E, 1, N) + grad_output_scales = grad_output_scales.squeeze(-1) + B_scales = B_scales.squeeze(1) grad_A = torch._scaled_grouped_mm( grad_output_fp8_row_major, B_fp8_col_major, @@ -239,6 +252,10 @@ def backward(ctx, grad_output: torch.Tensor): assert _is_column_major(A_fp8_col_major), ( "A must be column-major for grad_B = grad_output_t @ A" ) + + # Per-token group scales computed via triton kernels above do not have + # the empty dim like the scales computed via tensor_to_scale, so we need + # don't need to squeeze here. grad_B = torch._scaled_grouped_mm( grad_output_t_fp8_row_major, A_fp8_col_major, diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index ddcc84f515..d6fce479d4 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -11,6 +11,7 @@ import torch.utils._pytree as pytree from torch import nn from torch._prims_common import suggest_memory_format +from torch.distributed._tensor import DTensor from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import MixedPrecisionPolicy @@ -154,21 +155,33 @@ def fsdp_post_all_gather( ): (data,) = all_gather_outputs - # For training step 1+, out=unsharded param, so we need to copy data to `out` - # if `self._data`` and `out` do not share the same storage. - # Otherwise, if they do share the same storage, we can just return directly. + # For training step 1+, out=unshared param. if out is not None: - assert isinstance(out, ScaledGroupedMMTensor), f"{type(out)}" + if isinstance(out, ScaledGroupedMMTensor): + out_data = out._data + elif isinstance(out, DTensor) and isinstance( + out._local_tensor, ScaledGroupedMMTensor + ): + out_data = out._local_tensor._data + else: + raise RuntimeError( + f"expect out to be ScaledGroupedMMTensor or DTensor with local_tensor=ScaledGroupedMM, but got {type(out)}" + ) + + # If `data` (all gather outputs) is already in the mixed precision policy param_dtype, + # verify it has underlying storage as `out` (pre-allocated unsharded param), + # and then we can just return directly. if data.dtype == param_dtype: assert ( data.untyped_storage().data_ptr() - == out._data.untyped_storage().data_ptr() + == out_data.untyped_storage().data_ptr() ) else: - assert out._data.dtype == param_dtype, ( - f"{out._data.dtype} {param_dtype}" - ) - out._data.copy_(data) + # Otherwise, verify that `out` (pre-allocated unsharded param) has the + # mixed precision policy param_dtype, then copy `data` to `out`. + assert out_data.dtype == param_dtype, f"{out_data.dtype} {param_dtype}" + out_data.copy_(data) + return # For training step 0, out=None, so we need to return a new ScaledGroupedMMTensor.