Skip to content

Commit 5360aad

Browse files
debug
1 parent 41a7890 commit 5360aad

File tree

2 files changed

+6
-10
lines changed

2 files changed

+6
-10
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _scaled_grouped_mm(
4040
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
4141
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
4242
"""
43-
logger.debug("Using scaled_grouped_mm")
43+
logger.info("Using scaled_grouped_mm")
4444
return _Float8GroupedMM.apply(
4545
A,
4646
B_t,

torchao/prototype/moe_training/tensor.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __new__(
4848
tensor: torch.Tensor,
4949
dtype: torch.dtype,
5050
):
51+
logger.info(f"ScaledGroupedMMTensor __new__: tensor.dtype={tensor.dtype}, dtype: {dtype}, shape: {tensor.shape}")
5152
return torch.Tensor._make_wrapper_subclass(
5253
cls,
5354
tensor.size(),
@@ -66,14 +67,13 @@ def __init__(
6667
tensor: torch.Tensor,
6768
dtype: torch.dtype,
6869
):
70+
logger.info(f"ScaledGroupedMMTensor __init__: tensor.dtype={tensor.dtype}, dtype: {dtype}, shape: {tensor.shape}")
6971
self._data = tensor.to(dtype)
7072
self._dtype = dtype
7173

7274
@classmethod
7375
def __torch_function__(cls, func, types, args, kwargs={}):
74-
logger.debug(
75-
f"ScaledGroupedMMTensor func: {func.__name__}, args: {args}, kwargs: {kwargs}"
76-
)
76+
logger.info(f"ScaledGroupedMMTensor func: {func.__name__}, args: {args}, kwargs: {kwargs}")
7777
# override the grouped mm op to use the differentiable _scaled_grouped_mm
7878
if func.__name__ == cls.grouped_mm_func_name:
7979
# Use torchao scaled grouped mm with dynamic quant for
@@ -148,9 +148,7 @@ def fsdp_pre_all_gather(
148148
):
149149
all_gather_inputs = (self._data,)
150150
all_gather_metadata = ()
151-
logger.debug(
152-
f"ScaledGroupedMMTensor fsdp_pre_all_gather: self._data.dtype={self._data.dtype}, param_dtype: {mp_policy.param_dtype}"
153-
)
151+
#logger.info(f"ScaledGroupedMMTensor fsdp_pre_all_gather: self._data.dtype={self._data.dtype}, self._data.shape={self._data.shape}, param_dtype: {mp_policy.param_dtype}")
154152
return all_gather_inputs, all_gather_metadata
155153

156154
def fsdp_post_all_gather(
@@ -162,9 +160,7 @@ def fsdp_post_all_gather(
162160
out: Optional[torch.Tensor] = None,
163161
):
164162
(data,) = all_gather_outputs
165-
logger.debug(
166-
f"ScaledGroupedMMTensor fsdp_post_all_gather: data.dtype={data.dtype}, param_dtype: {param_dtype}"
167-
)
163+
#logger.info(f"ScaledGroupedMMTensor fsdp_post_all_gather: data.dtype={data.dtype}, param_dtype: {param_dtype}")
168164

169165
if out is not None:
170166
return

0 commit comments

Comments
 (0)