From 5c23c6b112ccc3d24de51c4f37421602a6d0959d Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 20 Jun 2025 07:10:13 -0700 Subject: [PATCH 01/12] Update [ghstack-poisoned] --- test/float8/test_dtensor.py | 2 ++ test/float8/test_fsdp2_tp.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index 9db046b749..a9ccb35b79 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -67,6 +67,8 @@ def setup_distributed(): device_mesh = init_device_mesh("cuda", (world_size,)) # seed must be the same in all processes torch.manual_seed(1) + local_rank = torch.distributed.get_rank() + torch.cuda.set_device(local_rank) return device_mesh diff --git a/test/float8/test_fsdp2_tp.py b/test/float8/test_fsdp2_tp.py index fa3d30410b..f04b791273 100644 --- a/test/float8/test_fsdp2_tp.py +++ b/test/float8/test_fsdp2_tp.py @@ -46,6 +46,8 @@ def setup_distributed(): ) # seed must be the same in all processes torch.manual_seed(1) + local_rank = torch.distributed.get_rank() + torch.cuda.set_device(local_rank) return device_mesh From ad2ce6213a6045b99448e4b1e3cd57d87f43cde0 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 20 Jun 2025 07:30:21 -0700 Subject: [PATCH 02/12] Update [ghstack-poisoned] --- benchmarks/float8/bench_matmul.py | 2 +- benchmarks/float8/float8_roofline.py | 2 +- test/float8/test_base.py | 2 +- test/float8/test_compile.py | 2 +- test/float8/test_dtensor.py | 2 +- test/float8/test_fsdp2/test_fsdp2.py | 5 ++++- test/float8/test_fsdp2_tp.py | 2 +- test/float8/test_numerics_integration.py | 2 +- torchao/testing/{float8 => training}/__init__.py | 0 torchao/testing/{float8 => training}/dtensor_utils.py | 0 torchao/testing/{float8 => training}/fsdp2_utils.py | 0 torchao/testing/{float8 => training}/roofline_utils.py | 0 torchao/testing/{float8 => training}/test_utils.py | 0 13 files changed, 11 insertions(+), 8 deletions(-) rename torchao/testing/{float8 => training}/__init__.py (100%) rename torchao/testing/{float8 => training}/dtensor_utils.py (100%) rename torchao/testing/{float8 => training}/fsdp2_utils.py (100%) rename torchao/testing/{float8 => training}/roofline_utils.py (100%) rename torchao/testing/{float8 => training}/test_utils.py (100%) diff --git a/benchmarks/float8/bench_matmul.py b/benchmarks/float8/bench_matmul.py index e3f19d8f49..cf844fa51b 100644 --- a/benchmarks/float8/bench_matmul.py +++ b/benchmarks/float8/bench_matmul.py @@ -16,7 +16,7 @@ get_name_to_shapes_iter, ) -from torchao.testing.float8.roofline_utils import get_specs +from torchao.testing.training.roofline_utils import get_specs def benchmark_fn_in_sec(f, *args, **kwargs): diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index f9374f835e..5a8419cde8 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -63,7 +63,7 @@ ) from torchao.prototype.mx_formats import MXLinearConfig from torchao.quantization import quantize_ -from torchao.testing.float8.roofline_utils import ( +from torchao.testing.training.roofline_utils import ( get_float8_mem_sympy, get_gemm_time_sympy, ) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 8e3efeab60..15099dc2c1 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -55,7 +55,7 @@ fp8_tensor_statistics, tensor_to_scale, ) -from torchao.testing.float8.test_utils import get_test_float8_linear_config +from torchao.testing.training.test_utils import get_test_float8_linear_config from torchao.utils import is_MI300, is_ROCM random.seed(0) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index ac5d1f8d96..aaf9d3d3f5 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -37,7 +37,7 @@ hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig -from torchao.testing.float8.test_utils import get_test_float8_linear_config +from torchao.testing.training.test_utils import get_test_float8_linear_config def _test_compile_base( diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index a9ccb35b79..e7220bff9f 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -57,7 +57,7 @@ ) from torchao.float8.float8_utils import tensor_to_scale from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor -from torchao.testing.float8.dtensor_utils import ToyModel +from torchao.testing.training.dtensor_utils import ToyModel torch.set_float32_matmul_precision("high") diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index 6f0cfecf41..b4c7f9fd15 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -43,7 +43,10 @@ from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic from torchao.float8.float8_tensor import GemmInputRole from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor -from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp +from torchao.testing.training.fsdp2_utils import ( + check_parity_bf16_mp, + check_parity_no_mp, +) if not is_sm_at_least_89(): pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) diff --git a/test/float8/test_fsdp2_tp.py b/test/float8/test_fsdp2_tp.py index f04b791273..93c7735149 100644 --- a/test/float8/test_fsdp2_tp.py +++ b/test/float8/test_fsdp2_tp.py @@ -32,7 +32,7 @@ Float8ColwiseParallel, Float8RowwiseParallel, ) -from torchao.testing.float8.dtensor_utils import ToyModel +from torchao.testing.training.dtensor_utils import ToyModel def setup_distributed(): diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index f25c876189..db02444109 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -33,7 +33,7 @@ convert_to_float8_training, ) from torchao.float8.float8_utils import IS_ROCM, compute_error -from torchao.testing.float8.test_utils import get_test_float8_linear_config +from torchao.testing.training.test_utils import get_test_float8_linear_config torch.manual_seed(0) diff --git a/torchao/testing/float8/__init__.py b/torchao/testing/training/__init__.py similarity index 100% rename from torchao/testing/float8/__init__.py rename to torchao/testing/training/__init__.py diff --git a/torchao/testing/float8/dtensor_utils.py b/torchao/testing/training/dtensor_utils.py similarity index 100% rename from torchao/testing/float8/dtensor_utils.py rename to torchao/testing/training/dtensor_utils.py diff --git a/torchao/testing/float8/fsdp2_utils.py b/torchao/testing/training/fsdp2_utils.py similarity index 100% rename from torchao/testing/float8/fsdp2_utils.py rename to torchao/testing/training/fsdp2_utils.py diff --git a/torchao/testing/float8/roofline_utils.py b/torchao/testing/training/roofline_utils.py similarity index 100% rename from torchao/testing/float8/roofline_utils.py rename to torchao/testing/training/roofline_utils.py diff --git a/torchao/testing/float8/test_utils.py b/torchao/testing/training/test_utils.py similarity index 100% rename from torchao/testing/float8/test_utils.py rename to torchao/testing/training/test_utils.py From 5eb2066824cadea7028939993c0d1ab35ae09c7a Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 20 Jun 2025 07:52:19 -0700 Subject: [PATCH 03/12] Update [ghstack-poisoned] --- test/float8/test_dtensor.py | 160 ++++------------------ torchao/testing/training/dtensor_utils.py | 138 +++++++++++++++++++ 2 files changed, 161 insertions(+), 137 deletions(-) diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index e7220bff9f..5509eb1cc2 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -10,7 +10,6 @@ TODO(future): make this run in CI """ -import copy import os import pytest @@ -23,12 +22,6 @@ from torch.distributed._tensor import DTensor, Replicate, Shard, distribute_tensor from torch.distributed.device_mesh import DeviceMesh, init_device_mesh -from torch.distributed.tensor.parallel import ( - ColwiseParallel, - PrepareModuleInput, - RowwiseParallel, - parallelize_module, -) from torch.testing._internal.distributed._tensor.common_dtensor import ( ModelArgs, Transformer, @@ -50,14 +43,11 @@ LinearMMConfig, hp_tensor_and_scale_to_float8, ) -from torchao.float8.float8_tensor_parallel import ( - Float8ColwiseParallel, - Float8RowwiseParallel, - PrepareFloat8ModuleInput, -) from torchao.float8.float8_utils import tensor_to_scale from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor -from torchao.testing.training.dtensor_utils import ToyModel +from torchao.testing.training.dtensor_utils import ( + _test_lowp_mlp_tensor_parallelism_base, +) torch.set_float32_matmul_precision("high") @@ -193,140 +183,36 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): loss.backward() -def _test_fp8_mlp_tensor_parallelism_base( - mesh: DeviceMesh, size=16, compile: bool = False, rowwise: bool = False -): - device = mesh.device_type - - if rowwise: - config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE) - # hack around config being frozen - # TODO(future PR): we should make this nicer at the config level - object.__setattr__(config, "emulate", True) - else: - config = Float8LinearConfig(emulate=True) - - toy_model = ToyModel().to(device) - toy_model_fp8 = convert_to_float8_training(toy_model, config=config) - - tp_model = copy.deepcopy(toy_model) - tp_model = convert_to_float8_training(tp_model, config=config) - sp_model = copy.deepcopy(toy_model) - sp_model = convert_to_float8_training(sp_model, config=config) - - # For tensorwise scaling, enable float8 all_gather. - # For rowwise scaling, keep high precision all_gather. Motivation for - # not doing float8 all-gather for rowwise: tensors need to be scaled both ways, - # so for float8 all-gather we'd need to send two float8 copies per tensor, - # which is similar # bytes over the wire than just doing bfloat16 all-gather. - if rowwise: - colwise_parallel_cls = ColwiseParallel - rowwise_parallel_cls = RowwiseParallel - prepare_input_cls = PrepareModuleInput - else: - colwise_parallel_cls = Float8ColwiseParallel - rowwise_parallel_cls = Float8RowwiseParallel - prepare_input_cls = PrepareFloat8ModuleInput - - # vanilla TP - tp_model = parallelize_module( - tp_model, - mesh, - { - "ffn.w1": colwise_parallel_cls(), - "ffn.w2": colwise_parallel_cls(), - "ffn.out_proj": rowwise_parallel_cls(), - }, +def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): + tensorwise_config = Float8LinearConfig(emulate=True) + _test_lowp_mlp_tensor_parallelism_base( + mesh, tensorwise_config, size, compile=False, allgather_in_lowp=True ) - # "sequence parallel" mlp computation - sp_model = parallelize_module( - sp_model, - mesh, - { - "ffn": prepare_input_cls( - input_layouts=Shard(1), desired_input_layouts=Replicate() - ), - "ffn.w1": colwise_parallel_cls(), - "ffn.w2": colwise_parallel_cls(), - "ffn.out_proj": rowwise_parallel_cls( - output_layouts=Shard(1), use_local_output=False - ), - }, + rowwise_config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE) + # hack around config being frozen + # TODO(future PR): we should make this nicer at the config level + object.__setattr__(rowwise_config, "emulate", True) + _test_lowp_mlp_tensor_parallelism_base( + mesh, rowwise_config, size, compile=False, allgather_in_lowp=False ) - # prepare_input_cls with specific submodule fqn - sp_model2 = copy.deepcopy(toy_model) - sp_model2 = convert_to_float8_training(sp_model2, config=config) - if rowwise: - prepare_input = prepare_input_cls( - input_layouts=Shard(1), - desired_input_layouts=Replicate(), - ) - else: - prepare_input = prepare_input_cls( - input_layouts=Shard(1), - desired_input_layouts=Replicate(), - fwd_config_submodule_fqn="w2", - ) - - sp_model2 = parallelize_module( - sp_model2, - mesh, - { - "ffn": prepare_input, - "ffn.w1": colwise_parallel_cls(), - "ffn.w2": colwise_parallel_cls(), - "ffn.out_proj": rowwise_parallel_cls( - output_layouts=Shard(1), use_local_output=False - ), - }, - ) - - if compile: - tp_model = torch.compile(tp_model) - sp_model = torch.compile(sp_model) - sp_model2 = torch.compile(sp_model2) - - x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) - x_fp32_tp_input = x_fp32.clone() - x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)]) - - tp_out = tp_model(x_fp32_tp_input) - tp_out.sum().backward() - sp_out = sp_model(x_fp32_sp_input) - sp_out.sum().backward() - global_out = toy_model_fp8(x_fp32) - global_out.sum().backward() - torch.testing.assert_close(tp_out, global_out) - torch.testing.assert_close(sp_out.full_tensor(), global_out) - torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad) - torch.testing.assert_close( - tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad +def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): + tensorwise_config = Float8LinearConfig(emulate=True) + _test_lowp_mlp_tensor_parallelism_base( + mesh, tensorwise_config, size, compile=True, allgather_in_lowp=True ) - sp_out2 = sp_model2(x_fp32_sp_input) - sp_out2.sum().backward() - torch.testing.assert_close(sp_out2.full_tensor(), global_out) - torch.testing.assert_close( - tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad - ) - torch.testing.assert_close( - tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad + rowwise_config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE) + # hack around config being frozen + # TODO(future PR): we should make this nicer at the config level + object.__setattr__(rowwise_config, "emulate", True) + _test_lowp_mlp_tensor_parallelism_base( + mesh, rowwise_config, size, compile=True, allgather_in_lowp=False ) -def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): - _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False, rowwise=False) - _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False, rowwise=True) - - -def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): - _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True, rowwise=False) - _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True, rowwise=True) - - def _test_distribute_fsdp_tensor_subclass(tp_mesh: DeviceMesh): torch.manual_seed(42) model = Transformer(ModelArgs(dropout_p=0.0, weight_tying=False)).cuda() diff --git a/torchao/testing/training/dtensor_utils.py b/torchao/testing/training/dtensor_utils.py index 84e4095263..7ac0360363 100644 --- a/torchao/testing/training/dtensor_utils.py +++ b/torchao/testing/training/dtensor_utils.py @@ -3,9 +3,27 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import copy +import torch import torch.nn as nn import torch.nn.functional as F +from torch.distributed._tensor import Replicate, Shard, distribute_tensor +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, + parallelize_module, +) + +from torchao.float8 import Float8LinearConfig +from torchao.float8.float8_linear_utils import convert_to_float8_training +from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput, +) class FeedForward(nn.Module): @@ -28,3 +46,123 @@ def __init__(self): def forward(self, x): return self.ffn(x) + + +def _test_lowp_mlp_tensor_parallelism_base( + mesh: DeviceMesh, + config: Float8LinearConfig, + size=16, + compile: bool = False, + allgather_in_lowp: bool = False, +): + device = mesh.device_type + + toy_model = ToyModel().to(device) + toy_model_fp8 = convert_to_float8_training(toy_model, config=config) + + tp_model = copy.deepcopy(toy_model) + tp_model = convert_to_float8_training(tp_model, config=config) + sp_model = copy.deepcopy(toy_model) + sp_model = convert_to_float8_training(sp_model, config=config) + + # For tensorwise scaling, enable float8 all_gather. + # For rowwise scaling, keep high precision all_gather. Motivation for + # not doing float8 all-gather for rowwise: tensors need to be scaled both ways, + # so for float8 all-gather we'd need to send two float8 copies per tensor, + # which is similar # bytes over the wire than just doing bfloat16 all-gather. + if not allgather_in_lowp: + colwise_parallel_cls = ColwiseParallel + rowwise_parallel_cls = RowwiseParallel + prepare_input_cls = PrepareModuleInput + else: + colwise_parallel_cls = Float8ColwiseParallel + rowwise_parallel_cls = Float8RowwiseParallel + prepare_input_cls = PrepareFloat8ModuleInput + + # vanilla TP + tp_model = parallelize_module( + tp_model, + mesh, + { + "ffn.w1": colwise_parallel_cls(), + "ffn.w2": colwise_parallel_cls(), + "ffn.out_proj": rowwise_parallel_cls(), + }, + ) + + # "sequence parallel" mlp computation + sp_model = parallelize_module( + sp_model, + mesh, + { + "ffn": prepare_input_cls( + input_layouts=Shard(1), desired_input_layouts=Replicate() + ), + "ffn.w1": colwise_parallel_cls(), + "ffn.w2": colwise_parallel_cls(), + "ffn.out_proj": rowwise_parallel_cls( + output_layouts=Shard(1), use_local_output=False + ), + }, + ) + + # prepare_input_cls with specific submodule fqn + sp_model2 = copy.deepcopy(toy_model) + sp_model2 = convert_to_float8_training(sp_model2, config=config) + + if not allgather_in_lowp: + prepare_input = prepare_input_cls( + input_layouts=Shard(1), + desired_input_layouts=Replicate(), + ) + else: + prepare_input = prepare_input_cls( + input_layouts=Shard(1), + desired_input_layouts=Replicate(), + fwd_config_submodule_fqn="w2", + ) + + sp_model2 = parallelize_module( + sp_model2, + mesh, + { + "ffn": prepare_input, + "ffn.w1": colwise_parallel_cls(), + "ffn.w2": colwise_parallel_cls(), + "ffn.out_proj": rowwise_parallel_cls( + output_layouts=Shard(1), use_local_output=False + ), + }, + ) + + if compile: + tp_model = torch.compile(tp_model) + sp_model = torch.compile(sp_model) + sp_model2 = torch.compile(sp_model2) + + x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) + x_fp32_tp_input = x_fp32.clone() + x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)]) + + tp_out = tp_model(x_fp32_tp_input) + tp_out.sum().backward() + sp_out = sp_model(x_fp32_sp_input) + sp_out.sum().backward() + global_out = toy_model_fp8(x_fp32) + global_out.sum().backward() + torch.testing.assert_close(tp_out, global_out) + torch.testing.assert_close(sp_out.full_tensor(), global_out) + torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad) + torch.testing.assert_close( + tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad + ) + + sp_out2 = sp_model2(x_fp32_sp_input) + sp_out2.sum().backward() + torch.testing.assert_close(sp_out2.full_tensor(), global_out) + torch.testing.assert_close( + tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad + ) + torch.testing.assert_close( + tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad + ) From 6e3df575dc4e2283495133b91a03c63446d692f0 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 20 Jun 2025 19:35:44 -0700 Subject: [PATCH 04/12] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_dtensor.py | 98 +++++++++++++++++++++++ test/prototype/mx_formats/test_dtensor.sh | 18 +++++ torchao/prototype/mx_formats/kernels.py | 3 - torchao/prototype/mx_formats/mx_tensor.py | 45 ++++++++--- torchao/testing/training/dtensor_utils.py | 23 ++++-- 5 files changed, 169 insertions(+), 18 deletions(-) create mode 100644 test/prototype/mx_formats/test_dtensor.py create mode 100755 test/prototype/mx_formats/test_dtensor.sh diff --git a/test/prototype/mx_formats/test_dtensor.py b/test/prototype/mx_formats/test_dtensor.py new file mode 100644 index 0000000000..bfc930c579 --- /dev/null +++ b/test/prototype/mx_formats/test_dtensor.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +""" +Test numerics of manually defined float16 TP vs mxfp8 TP of toy models + +Note: for now, this does not run in CI. +TODO(future): make this run in CI +""" + +import os + +import pytest +import torch + +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 + +if not TORCH_VERSION_AT_LEAST_2_7: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + +from torch.distributed._tensor import DTensor, Shard, distribute_tensor +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from tqdm import tqdm + +from torchao.prototype.mx_formats import MXLinearConfig +from torchao.prototype.mx_formats.mx_tensor import MXTensor +from torchao.testing.training.dtensor_utils import ( + _test_lowp_mlp_tensor_parallelism_base, +) + +torch.set_float32_matmul_precision("high") + + +def setup_distributed(): + world_size = int(os.environ.get("WORLD_SIZE", -1)) + device_mesh = init_device_mesh("cuda", (world_size,)) + # seed must be the same in all processes + torch.manual_seed(1) + local_rank = torch.distributed.get_rank() + torch.cuda.set_device(local_rank) + return device_mesh + + +def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=4): + device = mesh.device_type + + x_fp32 = torch.rand(size, size, device=device) + x_fp8 = MXTensor.to_mx(x_fp32, torch.float8_e4m3fn, block_size=size // 2) + + dist_x_fp32 = distribute_tensor(x_fp32, mesh, [Shard(0)]) + dist_x_fp8 = MXTensor.to_mx(dist_x_fp32, torch.float8_e4m3fn, block_size=size // 2) + assert isinstance(dist_x_fp8, DTensor) + + # Verify that the result of to_mx with DTensor matches the slice of the + # result of to_mx without DTensor. This will fail on numeric op mismatches. + local_rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + assert size % world_size == 0, "unsupported" + x_fp8_fp32 = x_fp8.to_dtype(torch.float32) + rows_per_slice = size // world_size + slice_start = local_rank * rows_per_slice + slice_end = (local_rank + 1) * rows_per_slice + x_fp8_fp32_slice = x_fp8_fp32[slice_start:slice_end] + torch.testing.assert_close( + x_fp8_fp32_slice, dist_x_fp8.to_local().to_dtype(torch.float32), atol=0, rtol=0 + ) + + +def _test_mxfp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): + config = MXLinearConfig.from_recipe_name("mxfp8_emulated") + # TODO(future PR): assert that the K dim must be divisible by block size, + # today this is silently incorrect if block_size is greater than K + config.block_size = 16 + _test_lowp_mlp_tensor_parallelism_base( + mesh, config, size, compile=False, allgather_in_lowp=False + ) + + # TODO(future PR): compile + + +if __name__ == "__main__": + device_mesh = setup_distributed() + tests = [ + _test_dtensor_cast_to_mxfp8, + # TODO(next PR): enable this (current PR got too large, so splitting) + # _test_mxfp8_mlp_tensor_parallelism_eager, + ] + + for test in tqdm(tests, desc="Running tests"): + try: + test(device_mesh) + except Exception as e: + print(f"Test {test.__name__} failed with error: {e}") + raise e + + torch.distributed.destroy_process_group() diff --git a/test/prototype/mx_formats/test_dtensor.sh b/test/prototype/mx_formats/test_dtensor.sh new file mode 100755 index 0000000000..3fc26f6bca --- /dev/null +++ b/test/prototype/mx_formats/test_dtensor.sh @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +#!/bin/bash + +# terminate script on first error +set -e + +if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then + echo "Skipping test_dtensor.sh because no CUDA devices are available." + exit +fi + +# integration tests for TP/SP +NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/prototype/mx_formats/test_dtensor.py +# NCCL_DEBUG=WARN torchrun --nproc_per_node 1 test/prototype/mx_formats/test_dtensor.py diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index eacf0ac5df..f96e73a55a 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1102,15 +1102,12 @@ def _triton_calculate_scale(x, axis): bf16_mbits = 7 bf16_exp_bias = 127 fp32_mbits = 23 - # We use a small epsilon to avoid division by zero - epsilon = 1e-10 # Find the maximum absolute value for each row max_abs = tl.max(x, axis=axis) # Calculate the e8m0 scale by extracting the exponent (floor) # TODO(future PR): support other exponent extraction types (ceil, RNE) - max_abs = max_abs + epsilon max_abs = max_abs.to(tl.bfloat16) max_abs_int16 = max_abs.to(tl.int16, bitcast=True) extracted_pow2 = ((max_abs_int16 >> bf16_mbits) & 0b11111111) - bf16_exp_bias diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 784d3eda6d..1897a8b949 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -21,6 +21,7 @@ from typing import Callable, Dict, Union import torch +from torch.distributed._tensor import DTensor from torchao.prototype.mx_formats.config import MXGemmKernelChoice from torchao.prototype.mx_formats.constants import ( @@ -166,6 +167,8 @@ def to_mx( # calculate the scale in e8m0 format orig_shape = data_hp.shape + # TODO(future PR): fix this line for TP, currently this reshape does not work + # for rank 3 tensor where dim1 is sharded data_hp = data_hp.reshape(-1, block_size) # find max value of the data @@ -174,10 +177,6 @@ def to_mx( # section 6.3. max_abs = torch.amax(torch.abs(data_hp), 1) - # Add an epsilon to prevent the log2 function call for returning -inf - # where the values are zero. - eps = F32_MIN_NORMAL * (max_abs == 0).type(max_abs.dtype) - # Set X to be the largest power-of-two less than or equal to # max_abs(v), divided by the largest power of two representable # in the element data type, and get the mbits at the same time @@ -233,8 +232,14 @@ def to_mx( ) # Calculate the scale for different modes - max_abs_int32 = (max_abs + eps).view(hp_int_dtype) - extracted_pow2 = ((max_abs_int32 >> hp_mbits) & 0b11111111) - hp_exp_bias + max_abs_int32 = max_abs.view(hp_int_dtype) + # the `>>` seems to be silently incorrect (result is the same as the first + # operand) if the input is a DTensor. If we use `torch.bitwise_right_shift` + # instead, it works. Same for `<<`. + # TODO(before land): file an issue in pytorch/pytorch about this + extracted_pow2 = ( + (torch.bitwise_right_shift(max_abs_int32, hp_mbits)) & 0b11111111 + ) - hp_exp_bias if scaling_mode in (ScaleCalculationMode.FLOOR, ScaleCalculationMode.EVEN): scale_e8m0_unbiased = extracted_pow2 - target_max_pow2 @@ -266,9 +271,9 @@ def to_mx( ) # For now, calculate the scale in floating point. - scale_fp32 = (scale_e8m0_biased.to(torch.int32) << MBITS_F32).view( - torch.float32 - ) + scale_fp32 = ( + torch.bitwise_left_shift(scale_e8m0_biased.to(torch.int32), MBITS_F32) + ).view(torch.float32) # Today, 2**-127 returns 0 in compile+inductor+triton because it is in the # float32 denormal range. For now, manually adjust the fp scale. This is @@ -597,6 +602,28 @@ def to_mx( scale_e8m0_biased, data_lp = to_mx( data_hp, elem_dtype, block_size, scaling_mode, pack_fp6 ) + if isinstance(scale_e8m0_biased, DTensor): + assert isinstance(data_lp, DTensor), "unsupported" + local_scale_e8m0_biased = scale_e8m0_biased.to_local() + local_data_lp = data_lp.to_local() + inner_mx_tensor = MXTensor( + local_scale_e8m0_biased, + local_data_lp, + elem_dtype, + block_size, + data_hp.dtype, + use_fp4_custom_triton_dequant_kernel, + gemm_kernel_choice, + pack_fp6, + ) + return DTensor.from_local( + inner_mx_tensor, + data_lp.device_mesh, + data_lp.placements, + run_check=False, + shape=data_lp.size(), + stride=data_lp.stride(), + ) return MXTensor( scale_e8m0_biased, data_lp, diff --git a/torchao/testing/training/dtensor_utils.py b/torchao/testing/training/dtensor_utils.py index 7ac0360363..815ee20969 100644 --- a/torchao/testing/training/dtensor_utils.py +++ b/torchao/testing/training/dtensor_utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import copy +from typing import Union import torch import torch.nn as nn @@ -24,6 +25,8 @@ Float8RowwiseParallel, PrepareFloat8ModuleInput, ) +from torchao.prototype.mx_formats.config import MXLinearConfig +from torchao.quantization import quantize_ class FeedForward(nn.Module): @@ -36,7 +39,9 @@ def __init__(self): self.out_proj = nn.Linear(32, 16, bias=False) def forward(self, x): - return self.out_proj(F.silu(self.w1(x)) * self.w2(x)) + x = F.silu(self.w1(x)) * self.w2(x) + x = self.out_proj(x) + return x class ToyModel(nn.Module): @@ -50,20 +55,26 @@ def forward(self, x): def _test_lowp_mlp_tensor_parallelism_base( mesh: DeviceMesh, - config: Float8LinearConfig, + config: Union[Float8LinearConfig, MXLinearConfig], size=16, compile: bool = False, allgather_in_lowp: bool = False, ): device = mesh.device_type + # TODO(future): remove this once float8 training works with `quantize_` API + convert_model_func = convert_to_float8_training + if isinstance(config, MXLinearConfig): + convert_model_func = quantize_ + toy_model = ToyModel().to(device) - toy_model_fp8 = convert_to_float8_training(toy_model, config=config) + toy_model_fp8 = copy.deepcopy(toy_model) + convert_model_func(toy_model_fp8, config=config) tp_model = copy.deepcopy(toy_model) - tp_model = convert_to_float8_training(tp_model, config=config) + convert_model_func(tp_model, config=config) sp_model = copy.deepcopy(toy_model) - sp_model = convert_to_float8_training(sp_model, config=config) + convert_model_func(sp_model, config=config) # For tensorwise scaling, enable float8 all_gather. # For rowwise scaling, keep high precision all_gather. Motivation for @@ -108,7 +119,7 @@ def _test_lowp_mlp_tensor_parallelism_base( # prepare_input_cls with specific submodule fqn sp_model2 = copy.deepcopy(toy_model) - sp_model2 = convert_to_float8_training(sp_model2, config=config) + convert_model_func(sp_model2, config=config) if not allgather_in_lowp: prepare_input = prepare_input_cls( From 75e6fe72d45dcf833cbb5dbeb8a62fcb3efd886c Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 20 Jun 2025 19:45:05 -0700 Subject: [PATCH 05/12] Update [ghstack-poisoned] --- collect_env.py | 697 ++++++++++++++++++++++ torchao/prototype/mx_formats/mx_tensor.py | 8 +- 2 files changed, 701 insertions(+), 4 deletions(-) create mode 100644 collect_env.py diff --git a/collect_env.py b/collect_env.py new file mode 100644 index 0000000000..4270522e6c --- /dev/null +++ b/collect_env.py @@ -0,0 +1,697 @@ +# mypy: allow-untyped-defs + +# Unlike the rest of the PyTorch this file must be python2 compliant. +# This script outputs relevant system environment info +# Run it with `python collect_env.py` or `python -m torch.utils.collect_env` +import datetime +import json +import locale +import re +import subprocess +import sys +import os +from collections import namedtuple + + +try: + import torch + TORCH_AVAILABLE = True +except (ImportError, NameError, AttributeError, OSError): + TORCH_AVAILABLE = False + +# System Environment Information +SystemEnv = namedtuple('SystemEnv', [ + 'torch_version', + 'is_debug_build', + 'cuda_compiled_version', + 'gcc_version', + 'clang_version', + 'cmake_version', + 'os', + 'libc_version', + 'python_version', + 'python_platform', + 'is_cuda_available', + 'cuda_runtime_version', + 'cuda_module_loading', + 'nvidia_driver_version', + 'nvidia_gpu_models', + 'cudnn_version', + 'pip_version', # 'pip' or 'pip3' + 'pip_packages', + 'conda_packages', + 'hip_compiled_version', + 'hip_runtime_version', + 'miopen_runtime_version', + 'caching_allocator_config', + 'is_xnnpack_available', + 'cpu_info', +]) + +COMMON_PATTERNS = [ + "torch", + "numpy", + "triton", + "optree", +] + +NVIDIA_PATTERNS = [ + "cuda-cudart", + "cuda-cupti", + "cuda-libraries", + "cuda-opencl", + "cuda-nvrtc", + "cuda-runtime", + "cublas", + "cudnn", + "cufft", + "curand", + "cusolver", + "cusparse", + "nccl", + "nvjitlink", + "nvtx", +] + +CONDA_PATTERNS = [ + "cudatoolkit", + "soumith", + "mkl", + "magma", +] + +PIP_PATTERNS = [ + "mypy", + "flake8", + "onnx", +] + + +def run(command): + """Return (return-code, stdout, stderr).""" + shell = True if type(command) is str else False + p = subprocess.Popen(command, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, shell=shell) + raw_output, raw_err = p.communicate() + rc = p.returncode + if get_platform() == 'win32': + enc = 'oem' + else: + enc = locale.getpreferredencoding() + output = raw_output.decode(enc) + err = raw_err.decode(enc) + return rc, output.strip(), err.strip() + + +def run_and_read_all(run_lambda, command): + """Run command using run_lambda; reads and returns entire output if rc is 0.""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + return out + + +def run_and_parse_first_match(run_lambda, command, regex): + """Run command using run_lambda, returns the first regex match if it exists.""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + match = re.search(regex, out) + if match is None: + return None + return match.group(1) + +def run_and_return_first_line(run_lambda, command): + """Run command using run_lambda and returns first line if output is not empty.""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + return out.split('\n')[0] + + +def get_conda_packages(run_lambda, patterns=None): + if patterns is None: + patterns = CONDA_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS + conda = os.environ.get('CONDA_EXE', 'conda') + out = run_and_read_all(run_lambda, "{} list".format(conda)) + if out is None: + return out + + return "\n".join( + line + for line in out.splitlines() + if not line.startswith("#") + and any(name in line for name in patterns) + ) + +def get_gcc_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)') + +def get_clang_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'clang --version', r'clang version (.*)') + + +def get_cmake_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'cmake --version', r'cmake (.*)') + + +def get_nvidia_driver_version(run_lambda): + if get_platform() == 'darwin': + cmd = 'kextstat | grep -i cuda' + return run_and_parse_first_match(run_lambda, cmd, + r'com[.]nvidia[.]CUDA [(](.*?)[)]') + smi = get_nvidia_smi() + return run_and_parse_first_match(run_lambda, smi, r'Driver Version: (.*?) ') + + +def get_gpu_info(run_lambda): + if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr(torch.version, 'hip') and torch.version.hip is not None): + if TORCH_AVAILABLE and torch.cuda.is_available(): + if torch.version.hip is not None: + prop = torch.cuda.get_device_properties(0) + if hasattr(prop, "gcnArchName"): + gcnArch = " ({})".format(prop.gcnArchName) + else: + gcnArch = "NoGCNArchNameOnOldPyTorch" + else: + gcnArch = "" + return torch.cuda.get_device_name(None) + gcnArch + return None + smi = get_nvidia_smi() + uuid_regex = re.compile(r' \(UUID: .+?\)') + rc, out, _ = run_lambda(smi + ' -L') + if rc != 0: + return None + # Anonymize GPUs by removing their UUID + return re.sub(uuid_regex, '', out) + + +def get_running_cuda_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'nvcc --version', r'release .+ V(.*)') + + +def get_cudnn_version(run_lambda): + """Return a list of libcudnn.so; it's hard to tell which one is being used.""" + if get_platform() == 'win32': + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') + cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%") + where_cmd = os.path.join(system_root, 'System32', 'where') + cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) + elif get_platform() == 'darwin': + # CUDA libraries and drivers can be found in /usr/local/cuda/. See + # https://docs.nvidia.com/cuda/archive/9.0/cuda-installation-guide-mac-os-x/index.html#installation + # https://docs.nvidia.com/deeplearning/cudnn/installation/latest/ + # Use CUDNN_LIBRARY when cudnn library is installed elsewhere. + cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*' + else: + cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev' + rc, out, _ = run_lambda(cudnn_cmd) + # find will return 1 if there are permission errors or if not found + if len(out) == 0 or (rc != 1 and rc != 0): + l = os.environ.get('CUDNN_LIBRARY') + if l is not None and os.path.isfile(l): + return os.path.realpath(l) + return None + files_set = set() + for fn in out.split('\n'): + fn = os.path.realpath(fn) # eliminate symbolic links + if os.path.isfile(fn): + files_set.add(fn) + if not files_set: + return None + # Alphabetize the result because the order is non-deterministic otherwise + files = sorted(files_set) + if len(files) == 1: + return files[0] + result = '\n'.join(files) + return 'Probably one of the following:\n{}'.format(result) + + +def get_nvidia_smi(): + # Note: nvidia-smi is currently available only on Windows and Linux + smi = 'nvidia-smi' + if get_platform() == 'win32': + system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') + program_files_root = os.environ.get('PROGRAMFILES', 'C:\\Program Files') + legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', 'NVSMI', smi) + new_path = os.path.join(system_root, 'System32', smi) + smis = [new_path, legacy_path] + for candidate_smi in smis: + if os.path.exists(candidate_smi): + smi = '"{}"'.format(candidate_smi) + break + return smi + + +# example outputs of CPU infos +# * linux +# Architecture: x86_64 +# CPU op-mode(s): 32-bit, 64-bit +# Address sizes: 46 bits physical, 48 bits virtual +# Byte Order: Little Endian +# CPU(s): 128 +# On-line CPU(s) list: 0-127 +# Vendor ID: GenuineIntel +# Model name: Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz +# CPU family: 6 +# Model: 106 +# Thread(s) per core: 2 +# Core(s) per socket: 32 +# Socket(s): 2 +# Stepping: 6 +# BogoMIPS: 5799.78 +# Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr +# sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl +# xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 +# pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand +# hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced +# fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap +# avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 +# xsaves wbnoinvd ida arat avx512vbmi pku ospke avx512_vbmi2 gfni vaes vpclmulqdq +# avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear flush_l1d arch_capabilities +# Virtualization features: +# Hypervisor vendor: KVM +# Virtualization type: full +# Caches (sum of all): +# L1d: 3 MiB (64 instances) +# L1i: 2 MiB (64 instances) +# L2: 80 MiB (64 instances) +# L3: 108 MiB (2 instances) +# NUMA: +# NUMA node(s): 2 +# NUMA node0 CPU(s): 0-31,64-95 +# NUMA node1 CPU(s): 32-63,96-127 +# Vulnerabilities: +# Itlb multihit: Not affected +# L1tf: Not affected +# Mds: Not affected +# Meltdown: Not affected +# Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown +# Retbleed: Not affected +# Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp +# Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization +# Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence +# Srbds: Not affected +# Tsx async abort: Not affected +# * win32 +# Architecture=9 +# CurrentClockSpeed=2900 +# DeviceID=CPU0 +# Family=179 +# L2CacheSize=40960 +# L2CacheSpeed= +# Manufacturer=GenuineIntel +# MaxClockSpeed=2900 +# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz +# ProcessorType=3 +# Revision=27142 +# +# Architecture=9 +# CurrentClockSpeed=2900 +# DeviceID=CPU1 +# Family=179 +# L2CacheSize=40960 +# L2CacheSpeed= +# Manufacturer=GenuineIntel +# MaxClockSpeed=2900 +# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz +# ProcessorType=3 +# Revision=27142 + +def get_cpu_info(run_lambda): + rc, out, err = 0, '', '' + if get_platform() == 'linux': + rc, out, err = run_lambda('lscpu') + elif get_platform() == 'win32': + rc, out, err = run_lambda( + 'powershell.exe "gwmi -Class Win32_Processor | Select-Object -Property Name,Manufacturer,Family,\ + Architecture,ProcessorType,DeviceID,CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision\ + | ConvertTo-Json"' + ) + if rc == 0: + lst = [] + try: + obj = json.loads(out) + if type(obj) is list: + for o in obj: + lst.append("----------------------") + lst.extend([f"{k}: {v}" for (k, v) in o.items()]) + else: + lst.extend([f"{k}: {v}" for (k, v) in obj.items()]) + except ValueError as e: + lst.append(out) + lst.append(str(e)) + out = "\n".join(lst) + elif get_platform() == 'darwin': + rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string") + cpu_info = 'None' + if rc == 0: + cpu_info = out + else: + cpu_info = err + return cpu_info + + +def get_platform(): + if sys.platform.startswith('linux'): + return 'linux' + elif sys.platform.startswith('win32'): + return 'win32' + elif sys.platform.startswith('cygwin'): + return 'cygwin' + elif sys.platform.startswith('darwin'): + return 'darwin' + else: + return sys.platform + + +def get_mac_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', r'(.*)') + + +def get_windows_version(run_lambda): + ret = run_and_read_all( + run_lambda, + 'powershell.exe "gwmi -Class Win32_OperatingSystem | Select-Object -Property Caption,\ + OSArchitecture,Version | ConvertTo-Json"', + ) + try: + obj = json.loads(ret) + ret = f'{obj["Caption"]} ({obj["Version"]} {obj["OSArchitecture"]})' + except ValueError as e: + ret += f"\n{str(e)}" + return ret + + +def get_lsb_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'lsb_release -a', r'Description:\t(.*)') + + +def check_release_file(run_lambda): + return run_and_parse_first_match(run_lambda, 'cat /etc/*-release', + r'PRETTY_NAME="(.*)"') + + +def get_os(run_lambda): + from platform import machine + platform = get_platform() + + if platform == 'win32' or platform == 'cygwin': + return get_windows_version(run_lambda) + + if platform == 'darwin': + version = get_mac_version(run_lambda) + if version is None: + return None + return 'macOS {} ({})'.format(version, machine()) + + if platform == 'linux': + # Ubuntu/Debian based + desc = get_lsb_version(run_lambda) + if desc is not None: + return '{} ({})'.format(desc, machine()) + + # Try reading /etc/*-release + desc = check_release_file(run_lambda) + if desc is not None: + return '{} ({})'.format(desc, machine()) + + return '{} ({})'.format(platform, machine()) + + # Unknown platform + return platform + + +def get_python_platform(): + import platform + return platform.platform() + + +def get_libc_version(): + import platform + if get_platform() != 'linux': + return 'N/A' + return '-'.join(platform.libc_ver()) + + +def get_pip_packages(run_lambda, patterns=None): + """Return `pip list` output. Note: will also find conda-installed pytorch and numpy packages.""" + if patterns is None: + patterns = PIP_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS + + pip_version = 'pip3' if sys.version_info.major == 3 else 'pip' + + os.environ['PIP_DISABLE_PIP_VERSION_CHECK'] = '1' + # People generally have pip as `pip` or `pip3` + # But here it is invoked as `python -mpip` + out = run_and_read_all(run_lambda, [sys.executable, '-mpip', 'list', '--format=freeze']) + if out is None: + return pip_version, out + + filtered_out = '\n'.join( + line + for line in out.splitlines() + if any(name in line for name in patterns) + ) + + return pip_version, filtered_out + + +def get_cachingallocator_config(): + ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '') + if not ca_config: + ca_config = os.environ.get('PYTORCH_HIP_ALLOC_CONF', '') + return ca_config + + +def get_cuda_module_loading_config(): + if TORCH_AVAILABLE and torch.cuda.is_available(): + torch.cuda.init() + config = os.environ.get('CUDA_MODULE_LOADING', '') + return config + else: + return "N/A" + + +def is_xnnpack_available(): + if TORCH_AVAILABLE: + import torch.backends.xnnpack + return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined] + else: + return "N/A" + +def get_env_info(): + """ + Collects environment information to aid in debugging. + + The returned environment information contains details on torch version, is debug build + or not, cuda compiled version, gcc version, clang version, cmake version, operating + system, libc version, python version, python platform, CUDA availability, CUDA + runtime version, CUDA module loading config, GPU model and configuration, Nvidia + driver version, cuDNN version, pip version and versions of relevant pip and + conda packages, HIP runtime version, MIOpen runtime version, + Caching allocator config, XNNPACK availability and CPU information. + + Returns: + SystemEnv (namedtuple): A tuple containining various environment details + and system information. + """ + run_lambda = run + pip_version, pip_list_output = get_pip_packages(run_lambda) + + if TORCH_AVAILABLE: + version_str = torch.__version__ + debug_mode_str = str(torch.version.debug) + cuda_available_str = str(torch.cuda.is_available()) + cuda_version_str = torch.version.cuda + if not hasattr(torch.version, 'hip') or torch.version.hip is None: # cuda version + hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' + else: # HIP version + def get_version_or_na(cfg, prefix): + _lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s] + return _lst[0] if _lst else 'N/A' + + cfg = torch._C._show_config().split('\n') + hip_runtime_version = get_version_or_na(cfg, 'HIP Runtime') + miopen_runtime_version = get_version_or_na(cfg, 'MIOpen') + cuda_version_str = 'N/A' + hip_compiled_version = torch.version.hip + else: + version_str = debug_mode_str = cuda_available_str = cuda_version_str = 'N/A' + hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' + + sys_version = sys.version.replace("\n", " ") + + conda_packages = get_conda_packages(run_lambda) + + return SystemEnv( + torch_version=version_str, + is_debug_build=debug_mode_str, + python_version='{} ({}-bit runtime)'.format(sys_version, sys.maxsize.bit_length() + 1), + python_platform=get_python_platform(), + is_cuda_available=cuda_available_str, + cuda_compiled_version=cuda_version_str, + cuda_runtime_version=get_running_cuda_version(run_lambda), + cuda_module_loading=get_cuda_module_loading_config(), + nvidia_gpu_models=get_gpu_info(run_lambda), + nvidia_driver_version=get_nvidia_driver_version(run_lambda), + cudnn_version=get_cudnn_version(run_lambda), + hip_compiled_version=hip_compiled_version, + hip_runtime_version=hip_runtime_version, + miopen_runtime_version=miopen_runtime_version, + pip_version=pip_version, + pip_packages=pip_list_output, + conda_packages=conda_packages, + os=get_os(run_lambda), + libc_version=get_libc_version(), + gcc_version=get_gcc_version(run_lambda), + clang_version=get_clang_version(run_lambda), + cmake_version=get_cmake_version(run_lambda), + caching_allocator_config=get_cachingallocator_config(), + is_xnnpack_available=is_xnnpack_available(), + cpu_info=get_cpu_info(run_lambda), + ) + +env_info_fmt = """ +PyTorch version: {torch_version} +Is debug build: {is_debug_build} +CUDA used to build PyTorch: {cuda_compiled_version} +ROCM used to build PyTorch: {hip_compiled_version} + +OS: {os} +GCC version: {gcc_version} +Clang version: {clang_version} +CMake version: {cmake_version} +Libc version: {libc_version} + +Python version: {python_version} +Python platform: {python_platform} +Is CUDA available: {is_cuda_available} +CUDA runtime version: {cuda_runtime_version} +CUDA_MODULE_LOADING set to: {cuda_module_loading} +GPU models and configuration: {nvidia_gpu_models} +Nvidia driver version: {nvidia_driver_version} +cuDNN version: {cudnn_version} +HIP runtime version: {hip_runtime_version} +MIOpen runtime version: {miopen_runtime_version} +Is XNNPACK available: {is_xnnpack_available} + +CPU: +{cpu_info} + +Versions of relevant libraries: +{pip_packages} +{conda_packages} +""".strip() + + +def pretty_str(envinfo): + def replace_nones(dct, replacement='Could not collect'): + for key in dct.keys(): + if dct[key] is not None: + continue + dct[key] = replacement + return dct + + def replace_bools(dct, true='Yes', false='No'): + for key in dct.keys(): + if dct[key] is True: + dct[key] = true + elif dct[key] is False: + dct[key] = false + return dct + + def prepend(text, tag='[prepend]'): + lines = text.split('\n') + updated_lines = [tag + line for line in lines] + return '\n'.join(updated_lines) + + def replace_if_empty(text, replacement='No relevant packages'): + if text is not None and len(text) == 0: + return replacement + return text + + def maybe_start_on_next_line(string): + # If `string` is multiline, prepend a \n to it. + if string is not None and len(string.split('\n')) > 1: + return '\n{}\n'.format(string) + return string + + mutable_dict = envinfo._asdict() + + # If nvidia_gpu_models is multiline, start on the next line + mutable_dict['nvidia_gpu_models'] = \ + maybe_start_on_next_line(envinfo.nvidia_gpu_models) + + # If the machine doesn't have CUDA, report some fields as 'No CUDA' + dynamic_cuda_fields = [ + 'cuda_runtime_version', + 'nvidia_gpu_models', + 'nvidia_driver_version', + ] + all_cuda_fields = dynamic_cuda_fields + ['cudnn_version'] + all_dynamic_cuda_fields_missing = all( + mutable_dict[field] is None for field in dynamic_cuda_fields) + if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing: + for field in all_cuda_fields: + mutable_dict[field] = 'No CUDA' + if envinfo.cuda_compiled_version is None: + mutable_dict['cuda_compiled_version'] = 'None' + + # Replace True with Yes, False with No + mutable_dict = replace_bools(mutable_dict) + + # Replace all None objects with 'Could not collect' + mutable_dict = replace_nones(mutable_dict) + + # If either of these are '', replace with 'No relevant packages' + mutable_dict['pip_packages'] = replace_if_empty(mutable_dict['pip_packages']) + mutable_dict['conda_packages'] = replace_if_empty(mutable_dict['conda_packages']) + + # Tag conda and pip packages with a prefix + # If they were previously None, they'll show up as ie '[conda] Could not collect' + if mutable_dict['pip_packages']: + mutable_dict['pip_packages'] = prepend(mutable_dict['pip_packages'], + '[{}] '.format(envinfo.pip_version)) + if mutable_dict['conda_packages']: + mutable_dict['conda_packages'] = prepend(mutable_dict['conda_packages'], + '[conda] ') + mutable_dict['cpu_info'] = envinfo.cpu_info + return env_info_fmt.format(**mutable_dict) + + +def get_pretty_env_info(): + """ + Returns a pretty string of environment information. + + This function retrieves environment information by calling the `get_env_info` function + and then formats the information into a human-readable string. The retrieved environment + information is listed in the document of `get_env_info`. + This function is used in `python collect_env.py` that should be executed when reporting a bug. + + Returns: + str: A pretty string of the environment information. + """ + return pretty_str(get_env_info()) + + +def main(): + print("Collecting environment information...") + output = get_pretty_env_info() + print(output) + + if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr(torch.utils, '_crash_handler'): + minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR + if sys.platform == "linux" and os.path.exists(minidump_dir): + dumps = [os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)] + latest = max(dumps, key=os.path.getctime) + ctime = os.path.getctime(latest) + creation_time = datetime.datetime.fromtimestamp(ctime).strftime('%Y-%m-%d %H:%M:%S') + msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \ + "if this is related to your bug please include it when you file a report ***" + print(msg, file=sys.stderr) + + + +if __name__ == '__main__': + main() diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 1897a8b949..ef9ae42fcd 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -233,10 +233,8 @@ def to_mx( # Calculate the scale for different modes max_abs_int32 = max_abs.view(hp_int_dtype) - # the `>>` seems to be silently incorrect (result is the same as the first - # operand) if the input is a DTensor. If we use `torch.bitwise_right_shift` - # instead, it works. Same for `<<`. - # TODO(before land): file an issue in pytorch/pytorch about this + # For now, use `torch.bitwise_right_shift` instead of `>>` to support DTensor + # See https://github.com/pytorch/pytorch/issues/156533. extracted_pow2 = ( (torch.bitwise_right_shift(max_abs_int32, hp_mbits)) & 0b11111111 ) - hp_exp_bias @@ -271,6 +269,8 @@ def to_mx( ) # For now, calculate the scale in floating point. + # For now, use `torch.bitwise_left_shift` instead of `<<` to support DTensor + # See https://github.com/pytorch/pytorch/issues/156533. scale_fp32 = ( torch.bitwise_left_shift(scale_e8m0_biased.to(torch.int32), MBITS_F32) ).view(torch.float32) From 8bf42da5b43e905b7b4e020b1047d58efe9d1539 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 20 Jun 2025 19:45:44 -0700 Subject: [PATCH 06/12] Update [ghstack-poisoned] --- collect_env.py | 697 ------------------------------------------------- 1 file changed, 697 deletions(-) delete mode 100644 collect_env.py diff --git a/collect_env.py b/collect_env.py deleted file mode 100644 index 4270522e6c..0000000000 --- a/collect_env.py +++ /dev/null @@ -1,697 +0,0 @@ -# mypy: allow-untyped-defs - -# Unlike the rest of the PyTorch this file must be python2 compliant. -# This script outputs relevant system environment info -# Run it with `python collect_env.py` or `python -m torch.utils.collect_env` -import datetime -import json -import locale -import re -import subprocess -import sys -import os -from collections import namedtuple - - -try: - import torch - TORCH_AVAILABLE = True -except (ImportError, NameError, AttributeError, OSError): - TORCH_AVAILABLE = False - -# System Environment Information -SystemEnv = namedtuple('SystemEnv', [ - 'torch_version', - 'is_debug_build', - 'cuda_compiled_version', - 'gcc_version', - 'clang_version', - 'cmake_version', - 'os', - 'libc_version', - 'python_version', - 'python_platform', - 'is_cuda_available', - 'cuda_runtime_version', - 'cuda_module_loading', - 'nvidia_driver_version', - 'nvidia_gpu_models', - 'cudnn_version', - 'pip_version', # 'pip' or 'pip3' - 'pip_packages', - 'conda_packages', - 'hip_compiled_version', - 'hip_runtime_version', - 'miopen_runtime_version', - 'caching_allocator_config', - 'is_xnnpack_available', - 'cpu_info', -]) - -COMMON_PATTERNS = [ - "torch", - "numpy", - "triton", - "optree", -] - -NVIDIA_PATTERNS = [ - "cuda-cudart", - "cuda-cupti", - "cuda-libraries", - "cuda-opencl", - "cuda-nvrtc", - "cuda-runtime", - "cublas", - "cudnn", - "cufft", - "curand", - "cusolver", - "cusparse", - "nccl", - "nvjitlink", - "nvtx", -] - -CONDA_PATTERNS = [ - "cudatoolkit", - "soumith", - "mkl", - "magma", -] - -PIP_PATTERNS = [ - "mypy", - "flake8", - "onnx", -] - - -def run(command): - """Return (return-code, stdout, stderr).""" - shell = True if type(command) is str else False - p = subprocess.Popen(command, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, shell=shell) - raw_output, raw_err = p.communicate() - rc = p.returncode - if get_platform() == 'win32': - enc = 'oem' - else: - enc = locale.getpreferredencoding() - output = raw_output.decode(enc) - err = raw_err.decode(enc) - return rc, output.strip(), err.strip() - - -def run_and_read_all(run_lambda, command): - """Run command using run_lambda; reads and returns entire output if rc is 0.""" - rc, out, _ = run_lambda(command) - if rc != 0: - return None - return out - - -def run_and_parse_first_match(run_lambda, command, regex): - """Run command using run_lambda, returns the first regex match if it exists.""" - rc, out, _ = run_lambda(command) - if rc != 0: - return None - match = re.search(regex, out) - if match is None: - return None - return match.group(1) - -def run_and_return_first_line(run_lambda, command): - """Run command using run_lambda and returns first line if output is not empty.""" - rc, out, _ = run_lambda(command) - if rc != 0: - return None - return out.split('\n')[0] - - -def get_conda_packages(run_lambda, patterns=None): - if patterns is None: - patterns = CONDA_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS - conda = os.environ.get('CONDA_EXE', 'conda') - out = run_and_read_all(run_lambda, "{} list".format(conda)) - if out is None: - return out - - return "\n".join( - line - for line in out.splitlines() - if not line.startswith("#") - and any(name in line for name in patterns) - ) - -def get_gcc_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)') - -def get_clang_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'clang --version', r'clang version (.*)') - - -def get_cmake_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'cmake --version', r'cmake (.*)') - - -def get_nvidia_driver_version(run_lambda): - if get_platform() == 'darwin': - cmd = 'kextstat | grep -i cuda' - return run_and_parse_first_match(run_lambda, cmd, - r'com[.]nvidia[.]CUDA [(](.*?)[)]') - smi = get_nvidia_smi() - return run_and_parse_first_match(run_lambda, smi, r'Driver Version: (.*?) ') - - -def get_gpu_info(run_lambda): - if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr(torch.version, 'hip') and torch.version.hip is not None): - if TORCH_AVAILABLE and torch.cuda.is_available(): - if torch.version.hip is not None: - prop = torch.cuda.get_device_properties(0) - if hasattr(prop, "gcnArchName"): - gcnArch = " ({})".format(prop.gcnArchName) - else: - gcnArch = "NoGCNArchNameOnOldPyTorch" - else: - gcnArch = "" - return torch.cuda.get_device_name(None) + gcnArch - return None - smi = get_nvidia_smi() - uuid_regex = re.compile(r' \(UUID: .+?\)') - rc, out, _ = run_lambda(smi + ' -L') - if rc != 0: - return None - # Anonymize GPUs by removing their UUID - return re.sub(uuid_regex, '', out) - - -def get_running_cuda_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'nvcc --version', r'release .+ V(.*)') - - -def get_cudnn_version(run_lambda): - """Return a list of libcudnn.so; it's hard to tell which one is being used.""" - if get_platform() == 'win32': - system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') - cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%") - where_cmd = os.path.join(system_root, 'System32', 'where') - cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) - elif get_platform() == 'darwin': - # CUDA libraries and drivers can be found in /usr/local/cuda/. See - # https://docs.nvidia.com/cuda/archive/9.0/cuda-installation-guide-mac-os-x/index.html#installation - # https://docs.nvidia.com/deeplearning/cudnn/installation/latest/ - # Use CUDNN_LIBRARY when cudnn library is installed elsewhere. - cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*' - else: - cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev' - rc, out, _ = run_lambda(cudnn_cmd) - # find will return 1 if there are permission errors or if not found - if len(out) == 0 or (rc != 1 and rc != 0): - l = os.environ.get('CUDNN_LIBRARY') - if l is not None and os.path.isfile(l): - return os.path.realpath(l) - return None - files_set = set() - for fn in out.split('\n'): - fn = os.path.realpath(fn) # eliminate symbolic links - if os.path.isfile(fn): - files_set.add(fn) - if not files_set: - return None - # Alphabetize the result because the order is non-deterministic otherwise - files = sorted(files_set) - if len(files) == 1: - return files[0] - result = '\n'.join(files) - return 'Probably one of the following:\n{}'.format(result) - - -def get_nvidia_smi(): - # Note: nvidia-smi is currently available only on Windows and Linux - smi = 'nvidia-smi' - if get_platform() == 'win32': - system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') - program_files_root = os.environ.get('PROGRAMFILES', 'C:\\Program Files') - legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', 'NVSMI', smi) - new_path = os.path.join(system_root, 'System32', smi) - smis = [new_path, legacy_path] - for candidate_smi in smis: - if os.path.exists(candidate_smi): - smi = '"{}"'.format(candidate_smi) - break - return smi - - -# example outputs of CPU infos -# * linux -# Architecture: x86_64 -# CPU op-mode(s): 32-bit, 64-bit -# Address sizes: 46 bits physical, 48 bits virtual -# Byte Order: Little Endian -# CPU(s): 128 -# On-line CPU(s) list: 0-127 -# Vendor ID: GenuineIntel -# Model name: Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz -# CPU family: 6 -# Model: 106 -# Thread(s) per core: 2 -# Core(s) per socket: 32 -# Socket(s): 2 -# Stepping: 6 -# BogoMIPS: 5799.78 -# Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr -# sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl -# xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 -# pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand -# hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced -# fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap -# avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 -# xsaves wbnoinvd ida arat avx512vbmi pku ospke avx512_vbmi2 gfni vaes vpclmulqdq -# avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear flush_l1d arch_capabilities -# Virtualization features: -# Hypervisor vendor: KVM -# Virtualization type: full -# Caches (sum of all): -# L1d: 3 MiB (64 instances) -# L1i: 2 MiB (64 instances) -# L2: 80 MiB (64 instances) -# L3: 108 MiB (2 instances) -# NUMA: -# NUMA node(s): 2 -# NUMA node0 CPU(s): 0-31,64-95 -# NUMA node1 CPU(s): 32-63,96-127 -# Vulnerabilities: -# Itlb multihit: Not affected -# L1tf: Not affected -# Mds: Not affected -# Meltdown: Not affected -# Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown -# Retbleed: Not affected -# Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp -# Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization -# Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence -# Srbds: Not affected -# Tsx async abort: Not affected -# * win32 -# Architecture=9 -# CurrentClockSpeed=2900 -# DeviceID=CPU0 -# Family=179 -# L2CacheSize=40960 -# L2CacheSpeed= -# Manufacturer=GenuineIntel -# MaxClockSpeed=2900 -# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz -# ProcessorType=3 -# Revision=27142 -# -# Architecture=9 -# CurrentClockSpeed=2900 -# DeviceID=CPU1 -# Family=179 -# L2CacheSize=40960 -# L2CacheSpeed= -# Manufacturer=GenuineIntel -# MaxClockSpeed=2900 -# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz -# ProcessorType=3 -# Revision=27142 - -def get_cpu_info(run_lambda): - rc, out, err = 0, '', '' - if get_platform() == 'linux': - rc, out, err = run_lambda('lscpu') - elif get_platform() == 'win32': - rc, out, err = run_lambda( - 'powershell.exe "gwmi -Class Win32_Processor | Select-Object -Property Name,Manufacturer,Family,\ - Architecture,ProcessorType,DeviceID,CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision\ - | ConvertTo-Json"' - ) - if rc == 0: - lst = [] - try: - obj = json.loads(out) - if type(obj) is list: - for o in obj: - lst.append("----------------------") - lst.extend([f"{k}: {v}" for (k, v) in o.items()]) - else: - lst.extend([f"{k}: {v}" for (k, v) in obj.items()]) - except ValueError as e: - lst.append(out) - lst.append(str(e)) - out = "\n".join(lst) - elif get_platform() == 'darwin': - rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string") - cpu_info = 'None' - if rc == 0: - cpu_info = out - else: - cpu_info = err - return cpu_info - - -def get_platform(): - if sys.platform.startswith('linux'): - return 'linux' - elif sys.platform.startswith('win32'): - return 'win32' - elif sys.platform.startswith('cygwin'): - return 'cygwin' - elif sys.platform.startswith('darwin'): - return 'darwin' - else: - return sys.platform - - -def get_mac_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', r'(.*)') - - -def get_windows_version(run_lambda): - ret = run_and_read_all( - run_lambda, - 'powershell.exe "gwmi -Class Win32_OperatingSystem | Select-Object -Property Caption,\ - OSArchitecture,Version | ConvertTo-Json"', - ) - try: - obj = json.loads(ret) - ret = f'{obj["Caption"]} ({obj["Version"]} {obj["OSArchitecture"]})' - except ValueError as e: - ret += f"\n{str(e)}" - return ret - - -def get_lsb_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'lsb_release -a', r'Description:\t(.*)') - - -def check_release_file(run_lambda): - return run_and_parse_first_match(run_lambda, 'cat /etc/*-release', - r'PRETTY_NAME="(.*)"') - - -def get_os(run_lambda): - from platform import machine - platform = get_platform() - - if platform == 'win32' or platform == 'cygwin': - return get_windows_version(run_lambda) - - if platform == 'darwin': - version = get_mac_version(run_lambda) - if version is None: - return None - return 'macOS {} ({})'.format(version, machine()) - - if platform == 'linux': - # Ubuntu/Debian based - desc = get_lsb_version(run_lambda) - if desc is not None: - return '{} ({})'.format(desc, machine()) - - # Try reading /etc/*-release - desc = check_release_file(run_lambda) - if desc is not None: - return '{} ({})'.format(desc, machine()) - - return '{} ({})'.format(platform, machine()) - - # Unknown platform - return platform - - -def get_python_platform(): - import platform - return platform.platform() - - -def get_libc_version(): - import platform - if get_platform() != 'linux': - return 'N/A' - return '-'.join(platform.libc_ver()) - - -def get_pip_packages(run_lambda, patterns=None): - """Return `pip list` output. Note: will also find conda-installed pytorch and numpy packages.""" - if patterns is None: - patterns = PIP_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS - - pip_version = 'pip3' if sys.version_info.major == 3 else 'pip' - - os.environ['PIP_DISABLE_PIP_VERSION_CHECK'] = '1' - # People generally have pip as `pip` or `pip3` - # But here it is invoked as `python -mpip` - out = run_and_read_all(run_lambda, [sys.executable, '-mpip', 'list', '--format=freeze']) - if out is None: - return pip_version, out - - filtered_out = '\n'.join( - line - for line in out.splitlines() - if any(name in line for name in patterns) - ) - - return pip_version, filtered_out - - -def get_cachingallocator_config(): - ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '') - if not ca_config: - ca_config = os.environ.get('PYTORCH_HIP_ALLOC_CONF', '') - return ca_config - - -def get_cuda_module_loading_config(): - if TORCH_AVAILABLE and torch.cuda.is_available(): - torch.cuda.init() - config = os.environ.get('CUDA_MODULE_LOADING', '') - return config - else: - return "N/A" - - -def is_xnnpack_available(): - if TORCH_AVAILABLE: - import torch.backends.xnnpack - return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined] - else: - return "N/A" - -def get_env_info(): - """ - Collects environment information to aid in debugging. - - The returned environment information contains details on torch version, is debug build - or not, cuda compiled version, gcc version, clang version, cmake version, operating - system, libc version, python version, python platform, CUDA availability, CUDA - runtime version, CUDA module loading config, GPU model and configuration, Nvidia - driver version, cuDNN version, pip version and versions of relevant pip and - conda packages, HIP runtime version, MIOpen runtime version, - Caching allocator config, XNNPACK availability and CPU information. - - Returns: - SystemEnv (namedtuple): A tuple containining various environment details - and system information. - """ - run_lambda = run - pip_version, pip_list_output = get_pip_packages(run_lambda) - - if TORCH_AVAILABLE: - version_str = torch.__version__ - debug_mode_str = str(torch.version.debug) - cuda_available_str = str(torch.cuda.is_available()) - cuda_version_str = torch.version.cuda - if not hasattr(torch.version, 'hip') or torch.version.hip is None: # cuda version - hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' - else: # HIP version - def get_version_or_na(cfg, prefix): - _lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s] - return _lst[0] if _lst else 'N/A' - - cfg = torch._C._show_config().split('\n') - hip_runtime_version = get_version_or_na(cfg, 'HIP Runtime') - miopen_runtime_version = get_version_or_na(cfg, 'MIOpen') - cuda_version_str = 'N/A' - hip_compiled_version = torch.version.hip - else: - version_str = debug_mode_str = cuda_available_str = cuda_version_str = 'N/A' - hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' - - sys_version = sys.version.replace("\n", " ") - - conda_packages = get_conda_packages(run_lambda) - - return SystemEnv( - torch_version=version_str, - is_debug_build=debug_mode_str, - python_version='{} ({}-bit runtime)'.format(sys_version, sys.maxsize.bit_length() + 1), - python_platform=get_python_platform(), - is_cuda_available=cuda_available_str, - cuda_compiled_version=cuda_version_str, - cuda_runtime_version=get_running_cuda_version(run_lambda), - cuda_module_loading=get_cuda_module_loading_config(), - nvidia_gpu_models=get_gpu_info(run_lambda), - nvidia_driver_version=get_nvidia_driver_version(run_lambda), - cudnn_version=get_cudnn_version(run_lambda), - hip_compiled_version=hip_compiled_version, - hip_runtime_version=hip_runtime_version, - miopen_runtime_version=miopen_runtime_version, - pip_version=pip_version, - pip_packages=pip_list_output, - conda_packages=conda_packages, - os=get_os(run_lambda), - libc_version=get_libc_version(), - gcc_version=get_gcc_version(run_lambda), - clang_version=get_clang_version(run_lambda), - cmake_version=get_cmake_version(run_lambda), - caching_allocator_config=get_cachingallocator_config(), - is_xnnpack_available=is_xnnpack_available(), - cpu_info=get_cpu_info(run_lambda), - ) - -env_info_fmt = """ -PyTorch version: {torch_version} -Is debug build: {is_debug_build} -CUDA used to build PyTorch: {cuda_compiled_version} -ROCM used to build PyTorch: {hip_compiled_version} - -OS: {os} -GCC version: {gcc_version} -Clang version: {clang_version} -CMake version: {cmake_version} -Libc version: {libc_version} - -Python version: {python_version} -Python platform: {python_platform} -Is CUDA available: {is_cuda_available} -CUDA runtime version: {cuda_runtime_version} -CUDA_MODULE_LOADING set to: {cuda_module_loading} -GPU models and configuration: {nvidia_gpu_models} -Nvidia driver version: {nvidia_driver_version} -cuDNN version: {cudnn_version} -HIP runtime version: {hip_runtime_version} -MIOpen runtime version: {miopen_runtime_version} -Is XNNPACK available: {is_xnnpack_available} - -CPU: -{cpu_info} - -Versions of relevant libraries: -{pip_packages} -{conda_packages} -""".strip() - - -def pretty_str(envinfo): - def replace_nones(dct, replacement='Could not collect'): - for key in dct.keys(): - if dct[key] is not None: - continue - dct[key] = replacement - return dct - - def replace_bools(dct, true='Yes', false='No'): - for key in dct.keys(): - if dct[key] is True: - dct[key] = true - elif dct[key] is False: - dct[key] = false - return dct - - def prepend(text, tag='[prepend]'): - lines = text.split('\n') - updated_lines = [tag + line for line in lines] - return '\n'.join(updated_lines) - - def replace_if_empty(text, replacement='No relevant packages'): - if text is not None and len(text) == 0: - return replacement - return text - - def maybe_start_on_next_line(string): - # If `string` is multiline, prepend a \n to it. - if string is not None and len(string.split('\n')) > 1: - return '\n{}\n'.format(string) - return string - - mutable_dict = envinfo._asdict() - - # If nvidia_gpu_models is multiline, start on the next line - mutable_dict['nvidia_gpu_models'] = \ - maybe_start_on_next_line(envinfo.nvidia_gpu_models) - - # If the machine doesn't have CUDA, report some fields as 'No CUDA' - dynamic_cuda_fields = [ - 'cuda_runtime_version', - 'nvidia_gpu_models', - 'nvidia_driver_version', - ] - all_cuda_fields = dynamic_cuda_fields + ['cudnn_version'] - all_dynamic_cuda_fields_missing = all( - mutable_dict[field] is None for field in dynamic_cuda_fields) - if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing: - for field in all_cuda_fields: - mutable_dict[field] = 'No CUDA' - if envinfo.cuda_compiled_version is None: - mutable_dict['cuda_compiled_version'] = 'None' - - # Replace True with Yes, False with No - mutable_dict = replace_bools(mutable_dict) - - # Replace all None objects with 'Could not collect' - mutable_dict = replace_nones(mutable_dict) - - # If either of these are '', replace with 'No relevant packages' - mutable_dict['pip_packages'] = replace_if_empty(mutable_dict['pip_packages']) - mutable_dict['conda_packages'] = replace_if_empty(mutable_dict['conda_packages']) - - # Tag conda and pip packages with a prefix - # If they were previously None, they'll show up as ie '[conda] Could not collect' - if mutable_dict['pip_packages']: - mutable_dict['pip_packages'] = prepend(mutable_dict['pip_packages'], - '[{}] '.format(envinfo.pip_version)) - if mutable_dict['conda_packages']: - mutable_dict['conda_packages'] = prepend(mutable_dict['conda_packages'], - '[conda] ') - mutable_dict['cpu_info'] = envinfo.cpu_info - return env_info_fmt.format(**mutable_dict) - - -def get_pretty_env_info(): - """ - Returns a pretty string of environment information. - - This function retrieves environment information by calling the `get_env_info` function - and then formats the information into a human-readable string. The retrieved environment - information is listed in the document of `get_env_info`. - This function is used in `python collect_env.py` that should be executed when reporting a bug. - - Returns: - str: A pretty string of the environment information. - """ - return pretty_str(get_env_info()) - - -def main(): - print("Collecting environment information...") - output = get_pretty_env_info() - print(output) - - if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr(torch.utils, '_crash_handler'): - minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR - if sys.platform == "linux" and os.path.exists(minidump_dir): - dumps = [os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)] - latest = max(dumps, key=os.path.getctime) - ctime = os.path.getctime(latest) - creation_time = datetime.datetime.fromtimestamp(ctime).strftime('%Y-%m-%d %H:%M:%S') - msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \ - "if this is related to your bug please include it when you file a report ***" - print(msg, file=sys.stderr) - - - -if __name__ == '__main__': - main() From c0080cd49de452fc7d610e8200e88c2703725f91 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 20 Jun 2025 19:46:22 -0700 Subject: [PATCH 07/12] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_dtensor.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/test/prototype/mx_formats/test_dtensor.sh b/test/prototype/mx_formats/test_dtensor.sh index 3fc26f6bca..03531f6059 100755 --- a/test/prototype/mx_formats/test_dtensor.sh +++ b/test/prototype/mx_formats/test_dtensor.sh @@ -15,4 +15,3 @@ fi # integration tests for TP/SP NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/prototype/mx_formats/test_dtensor.py -# NCCL_DEBUG=WARN torchrun --nproc_per_node 1 test/prototype/mx_formats/test_dtensor.py From c6fc48ba2a36d0879c46a03c99a37172f0efb7a2 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 24 Jun 2025 07:08:22 -0700 Subject: [PATCH 08/12] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_dtensor.py | 12 ++--- test/prototype/mx_formats/test_mx_linear.py | 4 +- test/prototype/mx_formats/test_mx_tensor.py | 2 +- torchao/prototype/mx_formats/kernels.py | 9 ++-- torchao/prototype/mx_formats/mx_tensor.py | 52 +++++++++++---------- torchao/testing/training/dtensor_utils.py | 11 +++-- 6 files changed, 48 insertions(+), 42 deletions(-) diff --git a/test/prototype/mx_formats/test_dtensor.py b/test/prototype/mx_formats/test_dtensor.py index bfc930c579..4aefb3874e 100644 --- a/test/prototype/mx_formats/test_dtensor.py +++ b/test/prototype/mx_formats/test_dtensor.py @@ -68,24 +68,22 @@ def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=4): ) -def _test_mxfp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): +def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16): config = MXLinearConfig.from_recipe_name("mxfp8_emulated") - # TODO(future PR): assert that the K dim must be divisible by block size, - # today this is silently incorrect if block_size is greater than K config.block_size = 16 _test_lowp_mlp_tensor_parallelism_base( mesh, config, size, compile=False, allgather_in_lowp=False ) - - # TODO(future PR): compile + _test_lowp_mlp_tensor_parallelism_base( + mesh, config, size, compile=True, allgather_in_lowp=False + ) if __name__ == "__main__": device_mesh = setup_distributed() tests = [ _test_dtensor_cast_to_mxfp8, - # TODO(next PR): enable this (current PR got too large, so splitting) - # _test_mxfp8_mlp_tensor_parallelism_eager, + _test_mxfp8_mlp_tensor_parallelism, ] for test in tqdm(tests, desc="Running tests"): diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index bfb6742d14..b48b21bbf9 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -190,8 +190,8 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn): # TODO(future): enable compile support @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_activation_checkpointing(): - input_shape = (2, 4) - grad_shape = (2, 8) + input_shape = (16, 4) + grad_shape = (16, 8) elem_dtype = torch.float8_e4m3fn m = nn.Sequential( diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 6dfd33f9c7..f0124dd47b 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -72,7 +72,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_hello_world(elem_dtype): - data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16) + data = torch.randn(8, 8, device="cuda", dtype=torch.bfloat16) block_size = 4 _test_mx(data, elem_dtype, block_size) diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index f96e73a55a..af059c5970 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1056,7 +1056,7 @@ def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: # effective mx block size since we're packing 2 fp4 into 1 uint8 packed_mx_block_size = 3 * mx_block_size // 4 - packed_shape = [uint8_data.shape[0], packed_mx_block_size] + packed_shape = [*uint8_data.shape[:-1], packed_mx_block_size] n_mx_blocks = uint8_data.numel() // mx_block_size grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) @@ -1337,7 +1337,10 @@ def triton_to_mxfp8_dim1( # Create scale tensors col_scale = torch.empty( - (n_cols * n_rows // inner_block_size, 1), dtype=torch.uint8, device=x.device + # (n_cols * n_rows // inner_block_size, 1), dtype=torch.uint8, device=x.device + (n_cols, n_rows // inner_block_size, 1), + dtype=torch.uint8, + device=x.device, ) # Calculate grid dimensions based on tile size @@ -1374,7 +1377,7 @@ def triton_to_mxfp8_dim1_reference( scale_e8m0_dim1, x_hp_d1_normalized = to_mx( x_hp_d1, torch.float8_e4m3fn, block_size ) - scale_e8m0_dim1 = scale_e8m0_dim1.unsqueeze(1).view(torch.float8_e8m0fnu) + scale_e8m0_dim1 = scale_e8m0_dim1.view(torch.float8_e8m0fnu) return ( x_hp_d1_normalized.t(), scale_e8m0_dim1, diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index ef9ae42fcd..1a7fca43de 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -25,7 +25,6 @@ from torchao.prototype.mx_formats.config import MXGemmKernelChoice from torchao.prototype.mx_formats.constants import ( - BF16_EXP_BIAS, BLOCK_SIZE_DEFAULT, DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, @@ -62,7 +61,6 @@ # TODO(later): read from somewhere else? SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23 -EBITS_BF16, MBITS_BF16 = 8, 7 EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1 EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3 EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2 @@ -137,9 +135,7 @@ def _to_mx_rceil( ) # scale and saturated cast the data elements to max of target dtype - data_lp = torch.clamp( - data_hp * descale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos - ) + data_lp = torch.clamp(data_hp * descale_fp, min=-1 * max_pos, max=max_pos) return exponent, data_lp @@ -160,22 +156,33 @@ def to_mx( torch.float, ), f"{data_hp.dtype} is not supported yet" # TODO(future PR): consider supporting padding - assert data_hp.numel() % block_size == 0, "unsupported" + assert data_hp.shape[-1] % block_size == 0, ( + f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}" + ) assert data_hp.is_contiguous(), "unsupported" assert elem_dtype in SUPPORTED_ELEM_DTYPES, "unsupported" - # calculate the scale in e8m0 format - orig_shape = data_hp.shape - # TODO(future PR): fix this line for TP, currently this reshape does not work - # for rank 3 tensor where dim1 is sharded - data_hp = data_hp.reshape(-1, block_size) + data_hp = data_hp.reshape( + *orig_shape[:-1], orig_shape[-1] // block_size, block_size + ) # find max value of the data # Note: this only implements the `minimally supported` version of # https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf # section 6.3. - max_abs = torch.amax(torch.abs(data_hp), 1) + max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1) + + # We cast to float32 here because + # in the `max_abs_int32 = max_abs.view(hp_int_dtype)` line below, + # if tensor parallel is enabled then the resulting shape is 2x larger + # than it should be under some conditions, likely because of a bug in + # the `view` op with DTensor and target dtype int16. I reproduce in + # torchtitan but not in a unit test, so not enough info to file a good + # issue in pytorch/pytorch. For now, work around. In the future we should + # debug and fix this properly. + data_hp = data_hp.to(torch.float32) + max_abs = max_abs.to(torch.float32) # Set X to be the largest power-of-two less than or equal to # max_abs(v), divided by the largest power of two representable @@ -206,17 +213,11 @@ def to_mx( if scaling_mode == ScaleCalculationMode.RCEIL: scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos) else: - if data_hp.dtype is torch.float32: - hp_int_dtype = torch.int32 - hp_mbits = MBITS_F32 - hp_ebits = EBITS_F32 - hp_exp_bias = F32_EXP_BIAS - else: - assert data_hp.dtype is torch.bfloat16 - hp_int_dtype = torch.int16 - hp_mbits = MBITS_BF16 - hp_ebits = EBITS_BF16 - hp_exp_bias = BF16_EXP_BIAS + assert data_hp.dtype is torch.float32 + hp_int_dtype = torch.int32 + hp_mbits = MBITS_F32 + hp_ebits = EBITS_F32 + hp_exp_bias = F32_EXP_BIAS # rounding before calculating the largest power of 2 # X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)) @@ -285,7 +286,7 @@ def to_mx( scale_fp32 = torch.clamp(scale_fp32, min=F32_MIN_NORMAL) # scale and saturated cast the data elements to max of target dtype - data_lp = data_hp / scale_fp32.unsqueeze(1) + data_lp = data_hp / scale_fp32 if ( elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2) @@ -511,7 +512,8 @@ def __new__( assert scale_e8m0_bits.dtype == torch.float8_e8m0fnu, ( f"scale_e8m0_bits.dtype must be `torch.float8_e8m0fnu`, got {scale_e8m0_bits.dtype}" ) - assert len(scale_e8m0_bits.shape) == 1, "unsupported" + # TODO new assertion + # assert len(scale_e8m0_bits.shape) == 1, "unsupported" assert data_bits.dtype in ( torch.float8_e4m3fn, torch.float8_e5m2, diff --git a/torchao/testing/training/dtensor_utils.py b/torchao/testing/training/dtensor_utils.py index 815ee20969..7ebf67d53c 100644 --- a/torchao/testing/training/dtensor_utils.py +++ b/torchao/testing/training/dtensor_utils.py @@ -152,15 +152,18 @@ def _test_lowp_mlp_tensor_parallelism_base( sp_model2 = torch.compile(sp_model2) x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) + go_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) x_fp32_tp_input = x_fp32.clone() + go_fp32_tp = go_fp32.clone() x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)]) + go_fp32_sp = distribute_tensor(go_fp32.clone(), mesh, [Shard(0)]) tp_out = tp_model(x_fp32_tp_input) - tp_out.sum().backward() + tp_out.backward(go_fp32_tp) sp_out = sp_model(x_fp32_sp_input) - sp_out.sum().backward() + sp_out.backward(go_fp32_sp) global_out = toy_model_fp8(x_fp32) - global_out.sum().backward() + global_out.backward(go_fp32) torch.testing.assert_close(tp_out, global_out) torch.testing.assert_close(sp_out.full_tensor(), global_out) torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad) @@ -169,7 +172,7 @@ def _test_lowp_mlp_tensor_parallelism_base( ) sp_out2 = sp_model2(x_fp32_sp_input) - sp_out2.sum().backward() + sp_out2.backward(go_fp32_sp) torch.testing.assert_close(sp_out2.full_tensor(), global_out) torch.testing.assert_close( tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad From 4cc1531ff444cf60543832433f67973f3c62da29 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 24 Jun 2025 07:12:53 -0700 Subject: [PATCH 09/12] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/kernels.py | 1 - torchao/prototype/mx_formats/mx_tensor.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index af059c5970..72cbba1802 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1337,7 +1337,6 @@ def triton_to_mxfp8_dim1( # Create scale tensors col_scale = torch.empty( - # (n_cols * n_rows // inner_block_size, 1), dtype=torch.uint8, device=x.device (n_cols, n_rows // inner_block_size, 1), dtype=torch.uint8, device=x.device, diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 1a7fca43de..e98878af77 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -512,8 +512,6 @@ def __new__( assert scale_e8m0_bits.dtype == torch.float8_e8m0fnu, ( f"scale_e8m0_bits.dtype must be `torch.float8_e8m0fnu`, got {scale_e8m0_bits.dtype}" ) - # TODO new assertion - # assert len(scale_e8m0_bits.shape) == 1, "unsupported" assert data_bits.dtype in ( torch.float8_e4m3fn, torch.float8_e5m2, From aabeb61b9e41abba6e62940ea9c840a88c5376c5 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 24 Jun 2025 07:31:42 -0700 Subject: [PATCH 10/12] Update [ghstack-poisoned] --- .../mx_formats/{test_dtensor.py => test_mx_dtensor.py} | 0 .../mx_formats/{test_dtensor.sh => test_mx_dtensor.sh} | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename test/prototype/mx_formats/{test_dtensor.py => test_mx_dtensor.py} (100%) rename test/prototype/mx_formats/{test_dtensor.sh => test_mx_dtensor.sh} (95%) diff --git a/test/prototype/mx_formats/test_dtensor.py b/test/prototype/mx_formats/test_mx_dtensor.py similarity index 100% rename from test/prototype/mx_formats/test_dtensor.py rename to test/prototype/mx_formats/test_mx_dtensor.py diff --git a/test/prototype/mx_formats/test_dtensor.sh b/test/prototype/mx_formats/test_mx_dtensor.sh similarity index 95% rename from test/prototype/mx_formats/test_dtensor.sh rename to test/prototype/mx_formats/test_mx_dtensor.sh index 03531f6059..abf9424e3c 100755 --- a/test/prototype/mx_formats/test_dtensor.sh +++ b/test/prototype/mx_formats/test_mx_dtensor.sh @@ -14,4 +14,4 @@ if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; fi # integration tests for TP/SP -NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/prototype/mx_formats/test_dtensor.py +NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/prototype/mx_formats/test_mx_dtensor.py From b2518291d49aca413782b764d191968ffd460a4c Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 24 Jun 2025 12:18:03 -0700 Subject: [PATCH 11/12] Update [ghstack-poisoned] --- test/float8/test_dtensor.py | 4 ++-- test/float8/test_fsdp2_tp.py | 6 +++--- test/prototype/mx_formats/test_mx_dtensor.py | 13 +++++++------ torchao/prototype/mx_formats/kernels.py | 7 ++++++- torchao/testing/training/dtensor_utils.py | 16 ++++++++-------- 5 files changed, 26 insertions(+), 20 deletions(-) diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index 5509eb1cc2..2255d25a6b 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -183,7 +183,7 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): loss.backward() -def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): +def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=32): tensorwise_config = Float8LinearConfig(emulate=True) _test_lowp_mlp_tensor_parallelism_base( mesh, tensorwise_config, size, compile=False, allgather_in_lowp=True @@ -198,7 +198,7 @@ def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): ) -def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): +def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=32): tensorwise_config = Float8LinearConfig(emulate=True) _test_lowp_mlp_tensor_parallelism_base( mesh, tensorwise_config, size, compile=True, allgather_in_lowp=True diff --git a/test/float8/test_fsdp2_tp.py b/test/float8/test_fsdp2_tp.py index 93c7735149..f8449b2474 100644 --- a/test/float8/test_fsdp2_tp.py +++ b/test/float8/test_fsdp2_tp.py @@ -61,7 +61,7 @@ def _test_fp8_mlp_tensor_parallelism_base( enable_fsdp_float8_all_gather=True, ) - toy_model = ToyModel().to(device) + toy_model = ToyModel(size).to(device) tp_model = copy.deepcopy(toy_model) tp_model = convert_to_float8_training(tp_model, config=config) @@ -94,11 +94,11 @@ def _test_fp8_mlp_tensor_parallelism_base( # TODO(future PR): test numerics, and add more cases -def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): +def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=32): _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False) -def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): +def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=32): _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True) diff --git a/test/prototype/mx_formats/test_mx_dtensor.py b/test/prototype/mx_formats/test_mx_dtensor.py index 4aefb3874e..5669639477 100644 --- a/test/prototype/mx_formats/test_mx_dtensor.py +++ b/test/prototype/mx_formats/test_mx_dtensor.py @@ -68,21 +68,22 @@ def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=4): ) -def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=16): +def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=128): config = MXLinearConfig.from_recipe_name("mxfp8_emulated") - config.block_size = 16 + config.block_size = 32 + config.use_fp8_dim1_cast_triton_kernel = True _test_lowp_mlp_tensor_parallelism_base( mesh, config, size, compile=False, allgather_in_lowp=False ) - _test_lowp_mlp_tensor_parallelism_base( - mesh, config, size, compile=True, allgather_in_lowp=False - ) + # _test_lowp_mlp_tensor_parallelism_base( + # mesh, config, size, compile=True, allgather_in_lowp=False + # ) if __name__ == "__main__": device_mesh = setup_distributed() tests = [ - _test_dtensor_cast_to_mxfp8, + # _test_dtensor_cast_to_mxfp8, _test_mxfp8_mlp_tensor_parallelism, ] diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 72cbba1802..f6957e3db9 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1315,7 +1315,8 @@ def triton_to_mxfp8_dim1( * `col_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim1 """ assert x.is_contiguous(), "`x` must be contiguous" - assert x.dtype == torch.bfloat16 + # TODO(before land): maybe gate by FakeTensor below? + # assert x.dtype == torch.bfloat16 assert inner_block_size <= 32 # Get tensor shape @@ -1362,6 +1363,10 @@ def triton_to_mxfp8_dim1( output_col_major.t(), col_scale.view(torch.float8_e8m0fnu), ) + + print('ASDFASDFASDF') + from torchao import triton_to_mxfp8_dim1 + print(triton_to_mxfp8_dim1) def triton_to_mxfp8_dim1_reference( x_hp: torch.Tensor, block_size diff --git a/torchao/testing/training/dtensor_utils.py b/torchao/testing/training/dtensor_utils.py index 7ebf67d53c..67325b22e2 100644 --- a/torchao/testing/training/dtensor_utils.py +++ b/torchao/testing/training/dtensor_utils.py @@ -32,11 +32,11 @@ class FeedForward(nn.Module): """MLP based model""" - def __init__(self): + def __init__(self, size): super(FeedForward, self).__init__() - self.w1 = nn.Linear(16, 32, bias=False) - self.w2 = nn.Linear(16, 32, bias=False) - self.out_proj = nn.Linear(32, 16, bias=False) + self.w1 = nn.Linear(size, size * 2, bias=False) + self.w2 = nn.Linear(size, size * 2, bias=False) + self.out_proj = nn.Linear(size * 2, size, bias=False) def forward(self, x): x = F.silu(self.w1(x)) * self.w2(x) @@ -45,9 +45,9 @@ def forward(self, x): class ToyModel(nn.Module): - def __init__(self): + def __init__(self, size): super(ToyModel, self).__init__() - self.ffn = FeedForward() + self.ffn = FeedForward(size) def forward(self, x): return self.ffn(x) @@ -56,7 +56,7 @@ def forward(self, x): def _test_lowp_mlp_tensor_parallelism_base( mesh: DeviceMesh, config: Union[Float8LinearConfig, MXLinearConfig], - size=16, + size=32, compile: bool = False, allgather_in_lowp: bool = False, ): @@ -67,7 +67,7 @@ def _test_lowp_mlp_tensor_parallelism_base( if isinstance(config, MXLinearConfig): convert_model_func = quantize_ - toy_model = ToyModel().to(device) + toy_model = ToyModel(size).to(device) toy_model_fp8 = copy.deepcopy(toy_model) convert_model_func(toy_model_fp8, config=config) From dc0803c309bc3a903c8df8ca75e268ab3221f194 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 25 Jun 2025 07:38:40 -0700 Subject: [PATCH 12/12] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_dtensor.py | 16 +++- torchao/prototype/mx_formats/kernels.py | 19 ++++- torchao/prototype/mx_formats/mx_linear.py | 80 ++++++++++++-------- torchao/testing/training/dtensor_utils.py | 4 +- 4 files changed, 81 insertions(+), 38 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_dtensor.py b/test/prototype/mx_formats/test_mx_dtensor.py index 5669639477..4f5cce1a2a 100644 --- a/test/prototype/mx_formats/test_mx_dtensor.py +++ b/test/prototype/mx_formats/test_mx_dtensor.py @@ -69,12 +69,25 @@ def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=4): def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=128): + config = MXLinearConfig.from_recipe_name("mxfp8_emulated") + config.block_size = 32 + _test_lowp_mlp_tensor_parallelism_base( + mesh, config, size, compile=False, allgather_in_lowp=False + ) + _test_lowp_mlp_tensor_parallelism_base( + mesh, config, size, compile=True, allgather_in_lowp=False + ) + + +def _test_mxfp8_mlp_tensor_parallelism_dim1_triton(mesh: DeviceMesh, size=128): config = MXLinearConfig.from_recipe_name("mxfp8_emulated") config.block_size = 32 config.use_fp8_dim1_cast_triton_kernel = True _test_lowp_mlp_tensor_parallelism_base( mesh, config, size, compile=False, allgather_in_lowp=False ) + # TODO(future PR): enable compile here, currently seeing + # https://www.internalfb.com/phabricator/paste/view/P1851219639 # _test_lowp_mlp_tensor_parallelism_base( # mesh, config, size, compile=True, allgather_in_lowp=False # ) @@ -83,8 +96,9 @@ def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=128): if __name__ == "__main__": device_mesh = setup_distributed() tests = [ - # _test_dtensor_cast_to_mxfp8, + _test_dtensor_cast_to_mxfp8, _test_mxfp8_mlp_tensor_parallelism, + _test_mxfp8_mlp_tensor_parallelism_dim1_triton, ] for test in tqdm(tests, desc="Running tests"): diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index f6957e3db9..0265c27a50 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1363,10 +1363,21 @@ def triton_to_mxfp8_dim1( output_col_major.t(), col_scale.view(torch.float8_e8m0fnu), ) - - print('ASDFASDFASDF') - from torchao import triton_to_mxfp8_dim1 - print(triton_to_mxfp8_dim1) + + # print(torch.ops.torchao.triton_to_mxfp8_dim1.default) + + from torch.distributed.tensor import Replicate, Shard + from torch.distributed.tensor.experimental import register_sharding + + @register_sharding(torch.ops.torchao.triton_to_mxfp8_dim1.default) + def custom_triton_to_mxfp8_dim1_sharding(x, inner_block_size=32): + replicate = ([Replicate(), Replicate()], [Replicate(), None]) + # Note that the data is returned transposed, which is why + # we flip the sharding dim below + shard_dim0 = ([Shard(1), Shard(1)], [Shard(0), None]) + shard_dim1 = ([Shard(0), Shard(0)], [Shard(1), None]) + acceptable_shardings = [replicate, shard_dim0, shard_dim1] + return acceptable_shardings def triton_to_mxfp8_dim1_reference( x_hp: torch.Tensor, block_size diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index 4db029480f..4d2744fd7e 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -12,6 +12,7 @@ import torch import torch.nn.functional as F +from torch.distributed._tensor import DTensor from torchao.prototype.mx_formats.config import ( MXGemmKernelChoice, @@ -25,6 +26,46 @@ ) +def _triton_to_mxfp8_dim1_wrapper( + a, block_size, elem_dtype, hp_dtype, gemm_kernel_choice +): + a_data, a_scale = triton_to_mxfp8_dim1(a, block_size) + if isinstance(a_data, DTensor): + assert isinstance(a_scale, DTensor) + a_data_local = a_data.to_local() + a_scale_local = a_scale.to_local() + inner = MXTensor( + a_scale_local, + a_data_local.t(), + elem_dtype, + block_size, + hp_dtype, + False, + gemm_kernel_choice, + False, + ) + mx_tensor = DTensor.from_local( + inner, + a_data.device_mesh, + a_data.placements, + run_check=False, + shape=a_data.t().size(), + stride=a_data.t().stride(), + ) + else: + mx_tensor = MXTensor( + a_scale, + a_data.t(), + elem_dtype, + block_size, + hp_dtype, + False, + gemm_kernel_choice, + False, + ) + return mx_tensor + + @torch._dynamo.allow_in_graph class mx_mm(torch.autograd.Function): # There are three gemms in a forward + backward of a Linear layer: @@ -95,20 +136,9 @@ def backward(ctx, grad_output_hp: torch.Tensor): ) if use_fp8_dim1_cast_triton_kernel: - weight_mx_dim1_data, weight_mx_dim1_scale = triton_to_mxfp8_dim1( - weight_hp, block_size + weight_mx_dim1 = _triton_to_mxfp8_dim1_wrapper( + weight_hp, block_size, w_elem_dtype, weight_hp.dtype, gemm_kernel_choice ) - weight_mx_dim1 = MXTensor( - weight_mx_dim1_scale.reshape(-1), - weight_mx_dim1_data.t(), - w_elem_dtype, - block_size, - weight_hp.dtype, - False, - gemm_kernel_choice, - False, - ) - else: weight_hp_t_c = weight_hp.t().contiguous() weight_mx_dim1 = MXTensor.to_mx( @@ -124,18 +154,12 @@ def backward(ctx, grad_output_hp: torch.Tensor): # input_t @ grad_output = grad_weight if use_fp8_dim1_cast_triton_kernel: - grad_output_mx_dim1_data, grad_output_mx_dim1_scale = triton_to_mxfp8_dim1( - grad_output_hp_r, block_size - ) - grad_output_mx_dim1 = MXTensor( - grad_output_mx_dim1_scale.reshape(-1), - grad_output_mx_dim1_data.t(), - grad_elem_dtype, + grad_output_mx_dim1 = _triton_to_mxfp8_dim1_wrapper( + grad_output_hp_r, block_size, + grad_elem_dtype, grad_output_hp_r.dtype, - False, gemm_kernel_choice, - False, ) else: grad_output_mx_dim1 = MXTensor.to_mx( @@ -146,18 +170,12 @@ def backward(ctx, grad_output_hp: torch.Tensor): ) if use_fp8_dim1_cast_triton_kernel: - input_t_mx_dim0_tmp_data, input_t_mx_dim0_tmp_scale = triton_to_mxfp8_dim1( - input_hp_r, block_size - ) - input_t_mx_dim0_tmp = MXTensor( - input_t_mx_dim0_tmp_scale.reshape(-1), - input_t_mx_dim0_tmp_data.t(), - in_elem_dtype, + input_t_mx_dim0_tmp = _triton_to_mxfp8_dim1_wrapper( + input_hp_r, block_size, + in_elem_dtype, input_hp_r.dtype, - False, gemm_kernel_choice, - False, ) input_t_mx_dim0 = input_t_mx_dim0_tmp.t() else: diff --git a/torchao/testing/training/dtensor_utils.py b/torchao/testing/training/dtensor_utils.py index 67325b22e2..1e5490e8e1 100644 --- a/torchao/testing/training/dtensor_utils.py +++ b/torchao/testing/training/dtensor_utils.py @@ -151,8 +151,8 @@ def _test_lowp_mlp_tensor_parallelism_base( sp_model = torch.compile(sp_model) sp_model2 = torch.compile(sp_model2) - x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) - go_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False) + x_fp32 = torch.rand(1, size * 2, size, device=device, requires_grad=False) + go_fp32 = torch.rand(1, size * 2, size, device=device, requires_grad=False) x_fp32_tp_input = x_fp32.clone() go_fp32_tp = go_fp32.clone() x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])