@@ -48,6 +48,7 @@ def __new__(
48
48
tensor : torch .Tensor ,
49
49
dtype : torch .dtype ,
50
50
):
51
+ logger .info (f"ScaledGroupedMMTensor __new__: tensor.dtype={ tensor .dtype } , dtype: { dtype } , shape: { tensor .shape } " )
51
52
return torch .Tensor ._make_wrapper_subclass (
52
53
cls ,
53
54
tensor .size (),
@@ -66,14 +67,13 @@ def __init__(
66
67
tensor : torch .Tensor ,
67
68
dtype : torch .dtype ,
68
69
):
70
+ logger .info (f"ScaledGroupedMMTensor __init__: tensor.dtype={ tensor .dtype } , dtype: { dtype } , shape: { tensor .shape } " )
69
71
self ._data = tensor .to (dtype )
70
72
self ._dtype = dtype
71
73
72
74
@classmethod
73
75
def __torch_function__ (cls , func , types , args , kwargs = {}):
74
- logger .debug (
75
- f"ScaledGroupedMMTensor func: { func .__name__ } , args: { args } , kwargs: { kwargs } "
76
- )
76
+ logger .info (f"ScaledGroupedMMTensor func: { func .__name__ } , args: { args } , kwargs: { kwargs } " )
77
77
# override the grouped mm op to use the differentiable _scaled_grouped_mm
78
78
if func .__name__ == cls .grouped_mm_func_name :
79
79
# Use torchao scaled grouped mm with dynamic quant for
@@ -148,9 +148,7 @@ def fsdp_pre_all_gather(
148
148
):
149
149
all_gather_inputs = (self ._data ,)
150
150
all_gather_metadata = ()
151
- logger .debug (
152
- f"ScaledGroupedMMTensor fsdp_pre_all_gather: self._data.dtype={ self ._data .dtype } , param_dtype: { mp_policy .param_dtype } "
153
- )
151
+ #logger.info(f"ScaledGroupedMMTensor fsdp_pre_all_gather: self._data.dtype={self._data.dtype}, self._data.shape={self._data.shape}, param_dtype: {mp_policy.param_dtype}")
154
152
return all_gather_inputs , all_gather_metadata
155
153
156
154
def fsdp_post_all_gather (
@@ -162,9 +160,7 @@ def fsdp_post_all_gather(
162
160
out : Optional [torch .Tensor ] = None ,
163
161
):
164
162
(data ,) = all_gather_outputs
165
- logger .debug (
166
- f"ScaledGroupedMMTensor fsdp_post_all_gather: data.dtype={ data .dtype } , param_dtype: { param_dtype } "
167
- )
163
+ #logger.info(f"ScaledGroupedMMTensor fsdp_post_all_gather: data.dtype={data.dtype}, param_dtype: {param_dtype}")
168
164
169
165
if out is not None :
170
166
return
0 commit comments