Skip to content

Commit 02f061c

Browse files
set dtype
1 parent ac14d92 commit 02f061c

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
78
from typing import Optional
89

910
import torch
@@ -19,6 +20,9 @@
1920
)
2021

2122

23+
logger: logging.Logger = logging.getLogger(__name__)
24+
25+
2226
def _scaled_grouped_mm(
2327
A: torch.Tensor,
2428
B_t: torch.Tensor,
@@ -36,8 +40,8 @@ def _scaled_grouped_mm(
3640
and in column-major memory layout.
3741
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
3842
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
39-
use_triton_for_per_group_scales (bool): Whether to use custom triton kernels to compute per-group scales. Default is True.
4043
"""
44+
logger.info("Using scaled_grouped_mm")
4145
return _Float8GroupedMM.apply(
4246
A,
4347
B_t,

torchao/prototype/moe_training/tensor.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99

1010
import torch
1111
import torch.utils._pytree as pytree
12+
from torch import nn
1213
from torch._prims_common import suggest_memory_format
14+
from torch.distributed.device_mesh import DeviceMesh
15+
from torch.distributed.fsdp import MixedPrecisionPolicy
16+
from torch.autograd.grad_mode import _unsafe_preserve_version_counter
1317

1418
from torchao.prototype.moe_training import _scaled_grouped_mm
1519

@@ -69,7 +73,6 @@ def __init__(
6973

7074
@classmethod
7175
def __torch_function__(cls, func, types, args, kwargs={}):
72-
logger.info(f"{func.__name__}, args: {args}, kwargs: {kwargs}")
7376
# override the grouped mm op to use the differentiable _scaled_grouped_mm
7477
if func.__name__ == cls.grouped_mm_func_name:
7578
# Use torchao scaled grouped mm with dynamic quant for
@@ -142,9 +145,18 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
142145
flatten_spec["_dtype"],
143146
)
144147

145-
def fsdp_pre_all_gather(self, mesh):
146-
all_gather_inputs = (self._data,)
148+
# fsdp hooks based on https://github.com/pytorch/pytorch/blob/20e40492b046b9287726d3ec656117e4dc38f0e2/test/distributed/_composable/fsdp/test_fully_shard_extensions.py#L81
149+
def fsdp_pre_all_gather(
150+
self,
151+
mesh: DeviceMesh,
152+
outer_size: torch.Size,
153+
outer_stride: tuple[int, ...],
154+
module: nn.Module,
155+
mp_policy: MixedPrecisionPolicy,
156+
):
157+
all_gather_inputs = (self._data.to(mp_policy.param_dtype),)
147158
all_gather_metadata = ()
159+
logger.debug(f"fsdp_pre_all_gather: self._data.dtype={self._data.dtype}, param_dtype: {mp_policy.param_dtype}")
148160
return all_gather_inputs, all_gather_metadata
149161

150162
def fsdp_post_all_gather(
@@ -156,6 +168,15 @@ def fsdp_post_all_gather(
156168
out: Optional[torch.Tensor] = None,
157169
):
158170
(data,) = all_gather_outputs
159-
output = ScaledGroupedMMTensor(data, param_dtype)
160-
inner_tensors = (data,)
171+
logger.debug(f"fsdp_post_all_gather: data.dtype={data.dtype}, param_dtype: {param_dtype}")
172+
173+
if out is not None:
174+
#with _unsafe_preserve_version_counter(out):
175+
with torch.no_grad():
176+
out.copy_(data)
177+
return
178+
179+
upcast_data = data.to(param_dtype)
180+
output = ScaledGroupedMMTensor(upcast_data, param_dtype)
181+
inner_tensors = (upcast_data,)
161182
return output, inner_tensors

0 commit comments

Comments
 (0)