diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index 5509eb1cc2..2255d25a6b 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -183,7 +183,7 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): loss.backward() -def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): +def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=32): tensorwise_config = Float8LinearConfig(emulate=True) _test_lowp_mlp_tensor_parallelism_base( mesh, tensorwise_config, size, compile=False, allgather_in_lowp=True @@ -198,7 +198,7 @@ def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): ) -def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): +def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=32): tensorwise_config = Float8LinearConfig(emulate=True) _test_lowp_mlp_tensor_parallelism_base( mesh, tensorwise_config, size, compile=True, allgather_in_lowp=True diff --git a/test/float8/test_fsdp2_tp.py b/test/float8/test_fsdp2_tp.py index 93c7735149..8a735c5865 100644 --- a/test/float8/test_fsdp2_tp.py +++ b/test/float8/test_fsdp2_tp.py @@ -34,6 +34,8 @@ ) from torchao.testing.training.dtensor_utils import ToyModel +torch.set_float32_matmul_precision("high") + def setup_distributed(): world_size = int(os.environ.get("WORLD_SIZE", -1)) @@ -61,7 +63,7 @@ def _test_fp8_mlp_tensor_parallelism_base( enable_fsdp_float8_all_gather=True, ) - toy_model = ToyModel().to(device) + toy_model = ToyModel(size).to(device) tp_model = copy.deepcopy(toy_model) tp_model = convert_to_float8_training(tp_model, config=config) @@ -94,11 +96,11 @@ def _test_fp8_mlp_tensor_parallelism_base( # TODO(future PR): test numerics, and add more cases -def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): +def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=32): _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False) -def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): +def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=32): _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True) diff --git a/test/prototype/mx_formats/test_mx_dtensor.py b/test/prototype/mx_formats/test_mx_dtensor.py index 4aefb3874e..4f5cce1a2a 100644 --- a/test/prototype/mx_formats/test_mx_dtensor.py +++ b/test/prototype/mx_formats/test_mx_dtensor.py @@ -68,9 +68,9 @@ def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=4): ) -def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16): +def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=128): config = MXLinearConfig.from_recipe_name("mxfp8_emulated") - config.block_size = 16 + config.block_size = 32 _test_lowp_mlp_tensor_parallelism_base( mesh, config, size, compile=False, allgather_in_lowp=False ) @@ -79,11 +79,26 @@ def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16): ) +def _test_mxfp8_mlp_tensor_parallelism_dim1_triton(mesh: DeviceMesh, size=128): + config = MXLinearConfig.from_recipe_name("mxfp8_emulated") + config.block_size = 32 + config.use_fp8_dim1_cast_triton_kernel = True + _test_lowp_mlp_tensor_parallelism_base( + mesh, config, size, compile=False, allgather_in_lowp=False + ) + # TODO(future PR): enable compile here, currently seeing + # https://www.internalfb.com/phabricator/paste/view/P1851219639 + # _test_lowp_mlp_tensor_parallelism_base( + # mesh, config, size, compile=True, allgather_in_lowp=False + # ) + + if __name__ == "__main__": device_mesh = setup_distributed() tests = [ _test_dtensor_cast_to_mxfp8, _test_mxfp8_mlp_tensor_parallelism, + _test_mxfp8_mlp_tensor_parallelism_dim1_triton, ] for test in tqdm(tests, desc="Running tests"): diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index a051974e28..e1e37ea7fa 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -8,6 +8,8 @@ import numpy as np import torch +from torch.distributed.tensor import Replicate, Shard +from torch.distributed.tensor.experimental import register_sharding from torch.utils._triton import has_triton from torchao.prototype.custom_fp_utils import ( @@ -1315,7 +1317,6 @@ def triton_to_mxfp8_dim1( * `col_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim1 """ assert x.is_contiguous(), "`x` must be contiguous" - assert x.dtype == torch.bfloat16 assert inner_block_size <= 32 # Get tensor shape @@ -1363,6 +1364,16 @@ def triton_to_mxfp8_dim1( col_scale.view(torch.float8_e8m0fnu), ) + @register_sharding(torch.ops.torchao.triton_to_mxfp8_dim1.default) + def custom_triton_to_mxfp8_dim1_sharding(x, inner_block_size=32): + replicate = ([Replicate(), Replicate()], [Replicate(), None]) + # Note that the data is returned transposed, which is why + # we flip the sharding dim below + shard_dim0 = ([Shard(1), Shard(1)], [Shard(0), None]) + shard_dim1 = ([Shard(0), Shard(0)], [Shard(1), None]) + acceptable_shardings = [replicate, shard_dim0, shard_dim1] + return acceptable_shardings + def triton_to_mxfp8_dim1_reference( x_hp: torch.Tensor, block_size ) -> Tuple[torch.Tensor, torch.Tensor]: diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index 4db029480f..4d2744fd7e 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -12,6 +12,7 @@ import torch import torch.nn.functional as F +from torch.distributed._tensor import DTensor from torchao.prototype.mx_formats.config import ( MXGemmKernelChoice, @@ -25,6 +26,46 @@ ) +def _triton_to_mxfp8_dim1_wrapper( + a, block_size, elem_dtype, hp_dtype, gemm_kernel_choice +): + a_data, a_scale = triton_to_mxfp8_dim1(a, block_size) + if isinstance(a_data, DTensor): + assert isinstance(a_scale, DTensor) + a_data_local = a_data.to_local() + a_scale_local = a_scale.to_local() + inner = MXTensor( + a_scale_local, + a_data_local.t(), + elem_dtype, + block_size, + hp_dtype, + False, + gemm_kernel_choice, + False, + ) + mx_tensor = DTensor.from_local( + inner, + a_data.device_mesh, + a_data.placements, + run_check=False, + shape=a_data.t().size(), + stride=a_data.t().stride(), + ) + else: + mx_tensor = MXTensor( + a_scale, + a_data.t(), + elem_dtype, + block_size, + hp_dtype, + False, + gemm_kernel_choice, + False, + ) + return mx_tensor + + @torch._dynamo.allow_in_graph class mx_mm(torch.autograd.Function): # There are three gemms in a forward + backward of a Linear layer: @@ -95,20 +136,9 @@ def backward(ctx, grad_output_hp: torch.Tensor): ) if use_fp8_dim1_cast_triton_kernel: - weight_mx_dim1_data, weight_mx_dim1_scale = triton_to_mxfp8_dim1( - weight_hp, block_size + weight_mx_dim1 = _triton_to_mxfp8_dim1_wrapper( + weight_hp, block_size, w_elem_dtype, weight_hp.dtype, gemm_kernel_choice ) - weight_mx_dim1 = MXTensor( - weight_mx_dim1_scale.reshape(-1), - weight_mx_dim1_data.t(), - w_elem_dtype, - block_size, - weight_hp.dtype, - False, - gemm_kernel_choice, - False, - ) - else: weight_hp_t_c = weight_hp.t().contiguous() weight_mx_dim1 = MXTensor.to_mx( @@ -124,18 +154,12 @@ def backward(ctx, grad_output_hp: torch.Tensor): # input_t @ grad_output = grad_weight if use_fp8_dim1_cast_triton_kernel: - grad_output_mx_dim1_data, grad_output_mx_dim1_scale = triton_to_mxfp8_dim1( - grad_output_hp_r, block_size - ) - grad_output_mx_dim1 = MXTensor( - grad_output_mx_dim1_scale.reshape(-1), - grad_output_mx_dim1_data.t(), - grad_elem_dtype, + grad_output_mx_dim1 = _triton_to_mxfp8_dim1_wrapper( + grad_output_hp_r, block_size, + grad_elem_dtype, grad_output_hp_r.dtype, - False, gemm_kernel_choice, - False, ) else: grad_output_mx_dim1 = MXTensor.to_mx( @@ -146,18 +170,12 @@ def backward(ctx, grad_output_hp: torch.Tensor): ) if use_fp8_dim1_cast_triton_kernel: - input_t_mx_dim0_tmp_data, input_t_mx_dim0_tmp_scale = triton_to_mxfp8_dim1( - input_hp_r, block_size - ) - input_t_mx_dim0_tmp = MXTensor( - input_t_mx_dim0_tmp_scale.reshape(-1), - input_t_mx_dim0_tmp_data.t(), - in_elem_dtype, + input_t_mx_dim0_tmp = _triton_to_mxfp8_dim1_wrapper( + input_hp_r, block_size, + in_elem_dtype, input_hp_r.dtype, - False, gemm_kernel_choice, - False, ) input_t_mx_dim0 = input_t_mx_dim0_tmp.t() else: diff --git a/torchao/testing/training/dtensor_utils.py b/torchao/testing/training/dtensor_utils.py index 7ebf67d53c..acbfbb6a3e 100644 --- a/torchao/testing/training/dtensor_utils.py +++ b/torchao/testing/training/dtensor_utils.py @@ -32,11 +32,11 @@ class FeedForward(nn.Module): """MLP based model""" - def __init__(self): + def __init__(self, size): super(FeedForward, self).__init__() - self.w1 = nn.Linear(16, 32, bias=False) - self.w2 = nn.Linear(16, 32, bias=False) - self.out_proj = nn.Linear(32, 16, bias=False) + self.w1 = nn.Linear(size, size * 2, bias=False) + self.w2 = nn.Linear(size, size * 2, bias=False) + self.out_proj = nn.Linear(size * 2, size, bias=False) def forward(self, x): x = F.silu(self.w1(x)) * self.w2(x) @@ -45,9 +45,9 @@ def forward(self, x): class ToyModel(nn.Module): - def __init__(self): + def __init__(self, size): super(ToyModel, self).__init__() - self.ffn = FeedForward() + self.ffn = FeedForward(size) def forward(self, x): return self.ffn(x) @@ -56,7 +56,7 @@ def forward(self, x): def _test_lowp_mlp_tensor_parallelism_base( mesh: DeviceMesh, config: Union[Float8LinearConfig, MXLinearConfig], - size=16, + size=32, compile: bool = False, allgather_in_lowp: bool = False, ): @@ -67,7 +67,7 @@ def _test_lowp_mlp_tensor_parallelism_base( if isinstance(config, MXLinearConfig): convert_model_func = quantize_ - toy_model = ToyModel().to(device) + toy_model = ToyModel(size).to(device) toy_model_fp8 = copy.deepcopy(toy_model) convert_model_func(toy_model_fp8, config=config) @@ -151,8 +151,8 @@ def _test_lowp_mlp_tensor_parallelism_base( sp_model = torch.compile(sp_model) sp_model2 = torch.compile(sp_model2) - x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) - go_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) + x_fp32 = torch.rand(2, size * 2, size, device=device, requires_grad=False) + go_fp32 = torch.rand(2, size * 2, size, device=device, requires_grad=False) x_fp32_tp_input = x_fp32.clone() go_fp32_tp = go_fp32.clone() x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])