File tree Expand file tree Collapse file tree 1 file changed +11
-2
lines changed
torchao/prototype/moe_training Expand file tree Collapse file tree 1 file changed +11
-2
lines changed Original file line number Diff line number Diff line change @@ -47,7 +47,6 @@ def __new__(
47
47
cls ,
48
48
tensor : torch .Tensor ,
49
49
):
50
- # logger.info(f"ScaledGroupedMMTensor __new__: tensor.dtype={tensor.dtype}, dtype: {dtype}, shape: {tensor.shape}")
51
50
return torch .Tensor ._make_wrapper_subclass (
52
51
cls ,
53
52
tensor .size (),
@@ -156,8 +155,18 @@ def fsdp_post_all_gather(
156
155
(data ,) = all_gather_outputs
157
156
158
157
if out is not None :
158
+ assert isinstance (out , ScaledGroupedMMTensor ), f"{ type (out )} "
159
+ if data .dtype == param_dtype :
160
+ assert (
161
+ data .untyped_storage ().data_ptr ()
162
+ == out ._data .untyped_storage ().data_ptr ()
163
+ )
164
+ else :
165
+ assert out ._data .dtype == param_dtype , (
166
+ f"{ out ._data .dtype } { param_dtype } "
167
+ )
168
+ out ._data .copy_ (data )
159
169
return
160
-
161
170
output = ScaledGroupedMMTensor (data )
162
171
inner_tensors = (data ,)
163
172
return output , inner_tensors
You can’t perform that action at this time.
0 commit comments