|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +###################################################################### |
| 7 | +# |
| 8 | +# To run these unit tests, use the following command: |
| 9 | +# |
| 10 | +# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_tp.py |
| 11 | +# |
| 12 | +####################################################################### |
| 13 | + |
| 14 | +import copy |
| 15 | +import os |
| 16 | + |
| 17 | +import pytest |
| 18 | +import torch |
| 19 | +from torch import distributed as dist |
| 20 | +from torch import nn |
| 21 | +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh |
| 22 | +from torch.nn import functional as F |
| 23 | + |
| 24 | +# this feature requires CUDA and SM89+ |
| 25 | +if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9): |
| 26 | + pytest.skip( |
| 27 | + "CUDA not available or compute capability < 8.9", allow_module_level=True |
| 28 | + ) |
| 29 | + |
| 30 | +from torchao.float8.float8_utils import compute_error |
| 31 | +from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig |
| 32 | +from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor |
| 33 | +from torchao.quantization.quant_api import quantize_ |
| 34 | + |
| 35 | +# this test requires torchtitan |
| 36 | +try: |
| 37 | + from torchtitan.experiments.llama4.infra.parallelize import apply_moe_tp |
| 38 | + from torchtitan.experiments.llama4.model.args import TransformerModelArgs |
| 39 | + from torchtitan.experiments.llama4.model.moe import MoE |
| 40 | +except ImportError: |
| 41 | + import warnings |
| 42 | + |
| 43 | + warnings.warn("torchtitan not installed, skipping MoE tests.") |
| 44 | + pytest.skip(allow_module_level=True) |
| 45 | + |
| 46 | + |
| 47 | +@pytest.mark.parametrize( |
| 48 | + "target_fqns", |
| 49 | + [ |
| 50 | + ["experts"], |
| 51 | + ["experts,shared_expert"], |
| 52 | + ], |
| 53 | +) |
| 54 | +def test_moe_float8_training_tp_sp(target_fqns: list[str]): |
| 55 | + assert torch.cuda.is_available() |
| 56 | + |
| 57 | + # setup distributed for fsdp |
| 58 | + mesh = setup_distributed() |
| 59 | + |
| 60 | + # define model args |
| 61 | + model_args = TransformerModelArgs( |
| 62 | + moe_enabled=True, |
| 63 | + num_experts=8, |
| 64 | + dim=256, |
| 65 | + vocab_size=1024, |
| 66 | + ) |
| 67 | + init_std = 0.02 |
| 68 | + device = torch.device("cuda") |
| 69 | + |
| 70 | + # reference bf16 MoE |
| 71 | + ref_model = MoE(model_args).to(torch.bfloat16).cuda() |
| 72 | + torch.manual_seed(1) |
| 73 | + ref_model.init_weights(init_std, device) |
| 74 | + |
| 75 | + # target MoE for testing conversion |
| 76 | + model = copy.deepcopy(ref_model) |
| 77 | + |
| 78 | + # assert starting params are identical for both models |
| 79 | + for param1, param2 in zip(model.parameters(), ref_model.parameters()): |
| 80 | + assert torch.equal(param1, param2) |
| 81 | + |
| 82 | + # convert MoE to float8 training |
| 83 | + def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: |
| 84 | + for target_fqn in target_fqns: |
| 85 | + if target_fqn in cur_fqn: |
| 86 | + return True |
| 87 | + return False |
| 88 | + |
| 89 | + # quantize test model |
| 90 | + config = MoETrainingConfig() |
| 91 | + quantize_(model, config=config, filter_fn=moe_module_filter_fn) |
| 92 | + |
| 93 | + # validate that only the experts were converted |
| 94 | + _validate_model_conversion( |
| 95 | + model, |
| 96 | + target_fqns=target_fqns, |
| 97 | + ) |
| 98 | + |
| 99 | + # apply TP |
| 100 | + apply_moe_tp(model, mesh) |
| 101 | + apply_moe_tp(ref_model, mesh) |
| 102 | + |
| 103 | + # inputs |
| 104 | + batch, seq, dim = 8, 2048, 256 |
| 105 | + ref_x = torch.randn( |
| 106 | + batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device |
| 107 | + ) |
| 108 | + x = ref_x.detach().clone().requires_grad_(True) |
| 109 | + |
| 110 | + # forward pass |
| 111 | + ref_out = ref_model(ref_x) |
| 112 | + out = model(x) |
| 113 | + |
| 114 | + # validate output |
| 115 | + out_sqnr = compute_error(out, ref_out) |
| 116 | + assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}." |
| 117 | + |
| 118 | + # compute loss |
| 119 | + labels = torch.ones_like(ref_out) |
| 120 | + ref_loss = F.mse_loss(ref_out, labels) |
| 121 | + out_loss = F.mse_loss(out, labels) |
| 122 | + |
| 123 | + # backward pass |
| 124 | + ref_loss.backward() |
| 125 | + out_loss.backward() |
| 126 | + |
| 127 | + # validate input gradient |
| 128 | + input_grad_sqnr = compute_error(x.grad, ref_x.grad) |
| 129 | + assert input_grad_sqnr.item() >= 30.0, ( |
| 130 | + f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}." |
| 131 | + ) |
| 132 | + |
| 133 | + # validate param gradients |
| 134 | + for param1, param2 in zip(model.parameters(), ref_model.parameters()): |
| 135 | + param_grad_sqnr = compute_error(param1.grad, param2.grad) |
| 136 | + assert param_grad_sqnr.item() >= 25.0, ( |
| 137 | + f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}." |
| 138 | + ) |
| 139 | + |
| 140 | + dist.destroy_process_group() |
| 141 | + |
| 142 | + |
| 143 | +def _validate_model_conversion( |
| 144 | + root_module: nn.Module, |
| 145 | + target_fqns: list[str], |
| 146 | +): |
| 147 | + def _recursive_validate( |
| 148 | + module: nn.Module, |
| 149 | + cur_fqn: str, |
| 150 | + ): |
| 151 | + is_allowed_module = any([target_fqn in cur_fqn for target_fqn in target_fqns]) |
| 152 | + |
| 153 | + # check current module params |
| 154 | + for param_name, param in module.named_parameters(recurse=False): |
| 155 | + is_converted_type = isinstance(param, ScaledGroupedMMTensor) |
| 156 | + if is_converted_type: |
| 157 | + assert is_allowed_module, ( |
| 158 | + f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." |
| 159 | + ) |
| 160 | + if not is_allowed_module: |
| 161 | + assert not is_converted_type, ( |
| 162 | + f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." |
| 163 | + ) |
| 164 | + |
| 165 | + # recursively check child modules |
| 166 | + for child_name, child_module in module.named_children(): |
| 167 | + child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name |
| 168 | + _recursive_validate(child_module, child_fqn) |
| 169 | + |
| 170 | + _recursive_validate(root_module, "") |
| 171 | + |
| 172 | + |
| 173 | +def setup_distributed(): |
| 174 | + rank = int(os.environ["RANK"]) |
| 175 | + world_size = int(os.environ["WORLD_SIZE"]) |
| 176 | + dist.init_process_group("nccl", rank=rank, world_size=world_size) |
| 177 | + device_mesh = init_device_mesh("cuda", (world_size,)) |
| 178 | + # seed must be the same in all processes |
| 179 | + torch.manual_seed(1) |
| 180 | + torch.cuda.set_device(rank) |
| 181 | + return device_mesh |
| 182 | + |
| 183 | + |
| 184 | +def apply_moe_tp( |
| 185 | + model: nn.Module, |
| 186 | + tp_mesh: DeviceMesh, |
| 187 | +): |
| 188 | + from torch.distributed.tensor import Partial, Replicate, Shard |
| 189 | + from torch.distributed.tensor.parallel import ( |
| 190 | + PrepareModuleInputOutput, |
| 191 | + parallelize_module, |
| 192 | + ) |
| 193 | + from torchtitan.experiments.llama4.infra.expert_parallel import ( |
| 194 | + NoParallel, |
| 195 | + TensorParallel, |
| 196 | + ) |
| 197 | + |
| 198 | + moe_layer_plan = { |
| 199 | + # input / output sharding on the seqlen dim |
| 200 | + # all-gather for input, reduce-scatter for output |
| 201 | + "moe": PrepareModuleInputOutput( |
| 202 | + input_layouts=(Shard(1),), |
| 203 | + desired_input_layouts=(Replicate(),), |
| 204 | + use_local_input=True, |
| 205 | + output_layouts=(Partial(),), |
| 206 | + desired_output_layouts=(Shard(1),), |
| 207 | + ), |
| 208 | + # replicate computation for the router |
| 209 | + "moe.router.gate": NoParallel(), |
| 210 | + # input Replicate, output Partial |
| 211 | + "moe.experts": TensorParallel(output_layout=Partial()), |
| 212 | + "moe.shared_expert": TensorParallel(output_layout=Partial()), |
| 213 | + } |
| 214 | + parallelize_module( |
| 215 | + module=model, |
| 216 | + device_mesh=tp_mesh, |
| 217 | + parallelize_plan=moe_layer_plan, |
| 218 | + ) |
0 commit comments