From 02f061cf186a59a18223543316cd75642a9dcfc0 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 27 Jun 2025 12:28:25 -0700 Subject: [PATCH 1/5] set dtype --- .../moe_training/scaled_grouped_mm.py | 6 +++- torchao/prototype/moe_training/tensor.py | 31 ++++++++++++++++--- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index 29adffd831..56eb7b0d64 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import logging from typing import Optional import torch @@ -19,6 +20,9 @@ ) +logger: logging.Logger = logging.getLogger(__name__) + + def _scaled_grouped_mm( A: torch.Tensor, B_t: torch.Tensor, @@ -36,8 +40,8 @@ def _scaled_grouped_mm( and in column-major memory layout. offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor. out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported. - use_triton_for_per_group_scales (bool): Whether to use custom triton kernels to compute per-group scales. Default is True. """ + logger.info("Using scaled_grouped_mm") return _Float8GroupedMM.apply( A, B_t, diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index b41527a4ae..437c7c7fd6 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -9,7 +9,11 @@ import torch import torch.utils._pytree as pytree +from torch import nn from torch._prims_common import suggest_memory_format +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.autograd.grad_mode import _unsafe_preserve_version_counter from torchao.prototype.moe_training import _scaled_grouped_mm @@ -69,7 +73,6 @@ def __init__( @classmethod def __torch_function__(cls, func, types, args, kwargs={}): - logger.info(f"{func.__name__}, args: {args}, kwargs: {kwargs}") # override the grouped mm op to use the differentiable _scaled_grouped_mm if func.__name__ == cls.grouped_mm_func_name: # Use torchao scaled grouped mm with dynamic quant for @@ -142,9 +145,18 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): flatten_spec["_dtype"], ) - def fsdp_pre_all_gather(self, mesh): - all_gather_inputs = (self._data,) + # fsdp hooks based on https://github.com/pytorch/pytorch/blob/20e40492b046b9287726d3ec656117e4dc38f0e2/test/distributed/_composable/fsdp/test_fully_shard_extensions.py#L81 + def fsdp_pre_all_gather( + self, + mesh: DeviceMesh, + outer_size: torch.Size, + outer_stride: tuple[int, ...], + module: nn.Module, + mp_policy: MixedPrecisionPolicy, + ): + all_gather_inputs = (self._data.to(mp_policy.param_dtype),) all_gather_metadata = () + logger.debug(f"fsdp_pre_all_gather: self._data.dtype={self._data.dtype}, param_dtype: {mp_policy.param_dtype}") return all_gather_inputs, all_gather_metadata def fsdp_post_all_gather( @@ -156,6 +168,15 @@ def fsdp_post_all_gather( out: Optional[torch.Tensor] = None, ): (data,) = all_gather_outputs - output = ScaledGroupedMMTensor(data, param_dtype) - inner_tensors = (data,) + logger.debug(f"fsdp_post_all_gather: data.dtype={data.dtype}, param_dtype: {param_dtype}") + + if out is not None: + #with _unsafe_preserve_version_counter(out): + with torch.no_grad(): + out.copy_(data) + return + + upcast_data = data.to(param_dtype) + output = ScaledGroupedMMTensor(upcast_data, param_dtype) + inner_tensors = (upcast_data,) return output, inner_tensors From 41a7890391ed2c180d9fd8994261f9fe90daf9c8 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 30 Jun 2025 19:21:57 -0700 Subject: [PATCH 2/5] fix dtype bug and add logging --- .../moe_training/scaled_grouped_mm.py | 3 +- torchao/prototype/moe_training/tensor.py | 40 ++++++++----------- 2 files changed, 17 insertions(+), 26 deletions(-) diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index 56eb7b0d64..fcc18cff94 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -19,7 +19,6 @@ _is_column_major, ) - logger: logging.Logger = logging.getLogger(__name__) @@ -41,7 +40,7 @@ def _scaled_grouped_mm( offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor. out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported. """ - logger.info("Using scaled_grouped_mm") + logger.debug("Using scaled_grouped_mm") return _Float8GroupedMM.apply( A, B_t, diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index 437c7c7fd6..b0d6fbf8ca 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -13,13 +13,11 @@ from torch._prims_common import suggest_memory_format from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import MixedPrecisionPolicy -from torch.autograd.grad_mode import _unsafe_preserve_version_counter from torchao.prototype.moe_training import _scaled_grouped_mm logger: logging.Logger = logging.getLogger(__name__) - _ops_to_preserve_subclass = { torch.ops.aten.empty_like.default, torch.ops.aten.new_zeros.default, @@ -68,11 +66,14 @@ def __init__( tensor: torch.Tensor, dtype: torch.dtype, ): - self._data = tensor + self._data = tensor.to(dtype) self._dtype = dtype @classmethod def __torch_function__(cls, func, types, args, kwargs={}): + logger.debug( + f"ScaledGroupedMMTensor func: {func.__name__}, args: {args}, kwargs: {kwargs}" + ) # override the grouped mm op to use the differentiable _scaled_grouped_mm if func.__name__ == cls.grouped_mm_func_name: # Use torchao scaled grouped mm with dynamic quant for @@ -103,17 +104,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}): if func == torch.ops.aten.detach.default: return ScaledGroupedMMTensor(args[0]._data, args[0]._dtype) - # unwrap args and kwargs - dtype: Optional[torch.dtype] = None - - def unwrap(t): - nonlocal dtype - if dtype is None: - dtype = t._dtype - else: - assert t._dtype == dtype - return t._data - + # unwrap args/kwargs + unwrap = lambda x: x._data if isinstance(x, ScaledGroupedMMTensor) else x args, kwargs = pytree.tree_map_only( ScaledGroupedMMTensor, unwrap, (args, kwargs or {}) ) @@ -128,7 +120,7 @@ def unwrap(t): # wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass return pytree.tree_map_only( torch.Tensor, - lambda x: ScaledGroupedMMTensor(x, dtype), + lambda x: ScaledGroupedMMTensor(x, x.dtype), out, ) @@ -154,9 +146,11 @@ def fsdp_pre_all_gather( module: nn.Module, mp_policy: MixedPrecisionPolicy, ): - all_gather_inputs = (self._data.to(mp_policy.param_dtype),) + all_gather_inputs = (self._data,) all_gather_metadata = () - logger.debug(f"fsdp_pre_all_gather: self._data.dtype={self._data.dtype}, param_dtype: {mp_policy.param_dtype}") + logger.debug( + f"ScaledGroupedMMTensor fsdp_pre_all_gather: self._data.dtype={self._data.dtype}, param_dtype: {mp_policy.param_dtype}" + ) return all_gather_inputs, all_gather_metadata def fsdp_post_all_gather( @@ -168,15 +162,13 @@ def fsdp_post_all_gather( out: Optional[torch.Tensor] = None, ): (data,) = all_gather_outputs - logger.debug(f"fsdp_post_all_gather: data.dtype={data.dtype}, param_dtype: {param_dtype}") + logger.debug( + f"ScaledGroupedMMTensor fsdp_post_all_gather: data.dtype={data.dtype}, param_dtype: {param_dtype}" + ) if out is not None: - #with _unsafe_preserve_version_counter(out): - with torch.no_grad(): - out.copy_(data) return - upcast_data = data.to(param_dtype) - output = ScaledGroupedMMTensor(upcast_data, param_dtype) - inner_tensors = (upcast_data,) + output = ScaledGroupedMMTensor(data, param_dtype) + inner_tensors = (data,) return output, inner_tensors From 5360aad605387b05b96bf21d44e8936a68a94033 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 1 Jul 2025 13:08:20 -0700 Subject: [PATCH 3/5] debug --- .../prototype/moe_training/scaled_grouped_mm.py | 2 +- torchao/prototype/moe_training/tensor.py | 14 +++++--------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index fcc18cff94..7672451dc6 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -40,7 +40,7 @@ def _scaled_grouped_mm( offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor. out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported. """ - logger.debug("Using scaled_grouped_mm") + logger.info("Using scaled_grouped_mm") return _Float8GroupedMM.apply( A, B_t, diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index b0d6fbf8ca..d0521f89a0 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -48,6 +48,7 @@ def __new__( tensor: torch.Tensor, dtype: torch.dtype, ): + logger.info(f"ScaledGroupedMMTensor __new__: tensor.dtype={tensor.dtype}, dtype: {dtype}, shape: {tensor.shape}") return torch.Tensor._make_wrapper_subclass( cls, tensor.size(), @@ -66,14 +67,13 @@ def __init__( tensor: torch.Tensor, dtype: torch.dtype, ): + logger.info(f"ScaledGroupedMMTensor __init__: tensor.dtype={tensor.dtype}, dtype: {dtype}, shape: {tensor.shape}") self._data = tensor.to(dtype) self._dtype = dtype @classmethod def __torch_function__(cls, func, types, args, kwargs={}): - logger.debug( - f"ScaledGroupedMMTensor func: {func.__name__}, args: {args}, kwargs: {kwargs}" - ) + logger.info(f"ScaledGroupedMMTensor func: {func.__name__}, args: {args}, kwargs: {kwargs}") # override the grouped mm op to use the differentiable _scaled_grouped_mm if func.__name__ == cls.grouped_mm_func_name: # Use torchao scaled grouped mm with dynamic quant for @@ -148,9 +148,7 @@ def fsdp_pre_all_gather( ): all_gather_inputs = (self._data,) all_gather_metadata = () - logger.debug( - f"ScaledGroupedMMTensor fsdp_pre_all_gather: self._data.dtype={self._data.dtype}, param_dtype: {mp_policy.param_dtype}" - ) + #logger.info(f"ScaledGroupedMMTensor fsdp_pre_all_gather: self._data.dtype={self._data.dtype}, self._data.shape={self._data.shape}, param_dtype: {mp_policy.param_dtype}") return all_gather_inputs, all_gather_metadata def fsdp_post_all_gather( @@ -162,9 +160,7 @@ def fsdp_post_all_gather( out: Optional[torch.Tensor] = None, ): (data,) = all_gather_outputs - logger.debug( - f"ScaledGroupedMMTensor fsdp_post_all_gather: data.dtype={data.dtype}, param_dtype: {param_dtype}" - ) + #logger.info(f"ScaledGroupedMMTensor fsdp_post_all_gather: data.dtype={data.dtype}, param_dtype: {param_dtype}") if out is not None: return From 7fdba521a886c1a6078c32e3fc658654bdfd9bcc Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 1 Jul 2025 13:46:55 -0700 Subject: [PATCH 4/5] don't have dtype param --- .../moe_training/conversion_utils.py | 4 +-- torchao/prototype/moe_training/tensor.py | 27 +++++++------------ 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/torchao/prototype/moe_training/conversion_utils.py b/torchao/prototype/moe_training/conversion_utils.py index 72056e68b3..2da8186f2d 100644 --- a/torchao/prototype/moe_training/conversion_utils.py +++ b/torchao/prototype/moe_training/conversion_utils.py @@ -84,7 +84,7 @@ def _swap_params( f"Does not support a root nn.Parameter with children: {module}" ) if not isinstance(module.data, ScaledGroupedMMTensor): - new_data = ScaledGroupedMMTensor(module.data, module.data.dtype) + new_data = ScaledGroupedMMTensor(module.data) return nn.Parameter(new_data, requires_grad=module.requires_grad) return module @@ -110,7 +110,7 @@ def post_order_traversal( for param_name, param in module.named_parameters(recurse=False): if not isinstance(param.data, ScaledGroupedMMTensor): new_param = nn.Parameter( - ScaledGroupedMMTensor(param.data, param.data.dtype), + ScaledGroupedMMTensor(param.data), requires_grad=param.requires_grad, ) setattr(module, param_name, new_param) diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index d0521f89a0..0c5b7ace7b 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -46,16 +46,15 @@ class ScaledGroupedMMTensor(torch.Tensor): def __new__( cls, tensor: torch.Tensor, - dtype: torch.dtype, ): - logger.info(f"ScaledGroupedMMTensor __new__: tensor.dtype={tensor.dtype}, dtype: {dtype}, shape: {tensor.shape}") + # logger.info(f"ScaledGroupedMMTensor __new__: tensor.dtype={tensor.dtype}, dtype: {dtype}, shape: {tensor.shape}") return torch.Tensor._make_wrapper_subclass( cls, tensor.size(), strides=tensor.stride(), storage_offset=tensor.storage_offset(), memory_format=suggest_memory_format(tensor), - dtype=dtype, + dtype=tensor.dtype, layout=tensor.layout, device=tensor.device, pin_memory=tensor.is_pinned(), @@ -65,15 +64,11 @@ def __new__( def __init__( self, tensor: torch.Tensor, - dtype: torch.dtype, ): - logger.info(f"ScaledGroupedMMTensor __init__: tensor.dtype={tensor.dtype}, dtype: {dtype}, shape: {tensor.shape}") - self._data = tensor.to(dtype) - self._dtype = dtype + self._data = tensor @classmethod def __torch_function__(cls, func, types, args, kwargs={}): - logger.info(f"ScaledGroupedMMTensor func: {func.__name__}, args: {args}, kwargs: {kwargs}") # override the grouped mm op to use the differentiable _scaled_grouped_mm if func.__name__ == cls.grouped_mm_func_name: # Use torchao scaled grouped mm with dynamic quant for @@ -102,7 +97,7 @@ def __torch_function__(cls, func, types, args, kwargs={}): def __torch_dispatch__(cls, func, types, args, kwargs={}): # detach is special case if func == torch.ops.aten.detach.default: - return ScaledGroupedMMTensor(args[0]._data, args[0]._dtype) + return ScaledGroupedMMTensor(args[0]._data) # unwrap args/kwargs unwrap = lambda x: x._data if isinstance(x, ScaledGroupedMMTensor) else x @@ -120,21 +115,20 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}): # wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass return pytree.tree_map_only( torch.Tensor, - lambda x: ScaledGroupedMMTensor(x, x.dtype), + lambda x: ScaledGroupedMMTensor(x), out, ) def __repr__(self): - return f"ScaledGroupedMMTensor(data={self._data}, dtype={self._dtype})" + return f"ScaledGroupedMMTensor(data={self._data})" def __tensor_flatten__(self): - return ["_data"], {"_dtype": self._dtype} + return ["_data"] @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): return ScaledGroupedMMTensor( inner_tensors["_data"], - flatten_spec["_dtype"], ) # fsdp hooks based on https://github.com/pytorch/pytorch/blob/20e40492b046b9287726d3ec656117e4dc38f0e2/test/distributed/_composable/fsdp/test_fully_shard_extensions.py#L81 @@ -146,9 +140,9 @@ def fsdp_pre_all_gather( module: nn.Module, mp_policy: MixedPrecisionPolicy, ): - all_gather_inputs = (self._data,) + # cast to mixed precision dtype prior to all-gather + all_gather_inputs = (self._data.to(mp_policy.param_dtype),) all_gather_metadata = () - #logger.info(f"ScaledGroupedMMTensor fsdp_pre_all_gather: self._data.dtype={self._data.dtype}, self._data.shape={self._data.shape}, param_dtype: {mp_policy.param_dtype}") return all_gather_inputs, all_gather_metadata def fsdp_post_all_gather( @@ -160,11 +154,10 @@ def fsdp_post_all_gather( out: Optional[torch.Tensor] = None, ): (data,) = all_gather_outputs - #logger.info(f"ScaledGroupedMMTensor fsdp_post_all_gather: data.dtype={data.dtype}, param_dtype: {param_dtype}") if out is not None: return - output = ScaledGroupedMMTensor(data, param_dtype) + output = ScaledGroupedMMTensor(data) inner_tensors = (data,) return output, inner_tensors From bb9626ebb97b6542e452580bdef74f688a2691f7 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 1 Jul 2025 14:24:14 -0700 Subject: [PATCH 5/5] handle out != None --- .../prototype/moe_training/scaled_grouped_mm.py | 2 +- torchao/prototype/moe_training/tensor.py | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index 7672451dc6..5a08074d5d 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -40,7 +40,7 @@ def _scaled_grouped_mm( offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor. out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported. """ - logger.info("Using scaled_grouped_mm") + # logger.info("Using scaled_grouped_mm") return _Float8GroupedMM.apply( A, B_t, diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index 0c5b7ace7b..ddcc84f515 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -47,7 +47,6 @@ def __new__( cls, tensor: torch.Tensor, ): - # logger.info(f"ScaledGroupedMMTensor __new__: tensor.dtype={tensor.dtype}, dtype: {dtype}, shape: {tensor.shape}") return torch.Tensor._make_wrapper_subclass( cls, tensor.size(), @@ -155,9 +154,24 @@ def fsdp_post_all_gather( ): (data,) = all_gather_outputs + # For training step 1+, out=unsharded param, so we need to copy data to `out` + # if `self._data`` and `out` do not share the same storage. + # Otherwise, if they do share the same storage, we can just return directly. if out is not None: + assert isinstance(out, ScaledGroupedMMTensor), f"{type(out)}" + if data.dtype == param_dtype: + assert ( + data.untyped_storage().data_ptr() + == out._data.untyped_storage().data_ptr() + ) + else: + assert out._data.dtype == param_dtype, ( + f"{out._data.dtype} {param_dtype}" + ) + out._data.copy_(data) return + # For training step 0, out=None, so we need to return a new ScaledGroupedMMTensor. output = ScaledGroupedMMTensor(data) inner_tensors = (data,) return output, inner_tensors