Skip to content

[moe training] Cast to mixed precision policy param dtype in fsdp_pre_all_gather hook #2455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 2, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Optional

import torch
Expand All @@ -18,6 +19,8 @@
_is_column_major,
)

logger: logging.Logger = logging.getLogger(__name__)


def _scaled_grouped_mm(
A: torch.Tensor,
Expand All @@ -36,8 +39,8 @@ def _scaled_grouped_mm(
and in column-major memory layout.
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
use_triton_for_per_group_scales (bool): Whether to use custom triton kernels to compute per-group scales. Default is True.
"""
logger.debug("Using scaled_grouped_mm")
return _Float8GroupedMM.apply(
A,
B_t,
Expand Down
45 changes: 29 additions & 16 deletions torchao/prototype/moe_training/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@

import torch
import torch.utils._pytree as pytree
from torch import nn
from torch._prims_common import suggest_memory_format
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import MixedPrecisionPolicy

from torchao.prototype.moe_training import _scaled_grouped_mm

logger: logging.Logger = logging.getLogger(__name__)


_ops_to_preserve_subclass = {
torch.ops.aten.empty_like.default,
torch.ops.aten.new_zeros.default,
Expand Down Expand Up @@ -64,12 +66,14 @@ def __init__(
tensor: torch.Tensor,
dtype: torch.dtype,
):
self._data = tensor
self._data = tensor.to(dtype)
self._dtype = dtype

@classmethod
def __torch_function__(cls, func, types, args, kwargs={}):
logger.info(f"{func.__name__}, args: {args}, kwargs: {kwargs}")
logger.debug(
f"ScaledGroupedMMTensor func: {func.__name__}, args: {args}, kwargs: {kwargs}"
)
# override the grouped mm op to use the differentiable _scaled_grouped_mm
if func.__name__ == cls.grouped_mm_func_name:
# Use torchao scaled grouped mm with dynamic quant for
Expand Down Expand Up @@ -100,17 +104,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}):
if func == torch.ops.aten.detach.default:
return ScaledGroupedMMTensor(args[0]._data, args[0]._dtype)

# unwrap args and kwargs
dtype: Optional[torch.dtype] = None

def unwrap(t):
nonlocal dtype
if dtype is None:
dtype = t._dtype
else:
assert t._dtype == dtype
return t._data

# unwrap args/kwargs
unwrap = lambda x: x._data if isinstance(x, ScaledGroupedMMTensor) else x
args, kwargs = pytree.tree_map_only(
ScaledGroupedMMTensor, unwrap, (args, kwargs or {})
)
Expand All @@ -125,7 +120,7 @@ def unwrap(t):
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
return pytree.tree_map_only(
torch.Tensor,
lambda x: ScaledGroupedMMTensor(x, dtype),
lambda x: ScaledGroupedMMTensor(x, x.dtype),
out,
)

Expand All @@ -142,9 +137,20 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
flatten_spec["_dtype"],
)

def fsdp_pre_all_gather(self, mesh):
# fsdp hooks based on https://github.com/pytorch/pytorch/blob/20e40492b046b9287726d3ec656117e4dc38f0e2/test/distributed/_composable/fsdp/test_fully_shard_extensions.py#L81
def fsdp_pre_all_gather(
self,
mesh: DeviceMesh,
outer_size: torch.Size,
outer_stride: tuple[int, ...],
module: nn.Module,
mp_policy: MixedPrecisionPolicy,
):
all_gather_inputs = (self._data,)
all_gather_metadata = ()
logger.debug(
f"ScaledGroupedMMTensor fsdp_pre_all_gather: self._data.dtype={self._data.dtype}, param_dtype: {mp_policy.param_dtype}"
)
return all_gather_inputs, all_gather_metadata

def fsdp_post_all_gather(
Expand All @@ -156,6 +162,13 @@ def fsdp_post_all_gather(
out: Optional[torch.Tensor] = None,
):
(data,) = all_gather_outputs
logger.debug(
f"ScaledGroupedMMTensor fsdp_post_all_gather: data.dtype={data.dtype}, param_dtype: {param_dtype}"
)

if out is not None:
return

output = ScaledGroupedMMTensor(data, param_dtype)
inner_tensors = (data,)
return output, inner_tensors
Loading