Skip to content

Commit 8df3fbb

Browse files
handle out != None
1 parent 7fdba52 commit 8df3fbb

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

torchao/prototype/moe_training/tensor.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def __new__(
4747
cls,
4848
tensor: torch.Tensor,
4949
):
50-
# logger.info(f"ScaledGroupedMMTensor __new__: tensor.dtype={tensor.dtype}, dtype: {dtype}, shape: {tensor.shape}")
5150
return torch.Tensor._make_wrapper_subclass(
5251
cls,
5352
tensor.size(),
@@ -156,8 +155,18 @@ def fsdp_post_all_gather(
156155
(data,) = all_gather_outputs
157156

158157
if out is not None:
158+
assert isinstance(out, ScaledGroupedMMTensor), f"{type(out)}"
159+
if data.dtype == param_dtype:
160+
assert (
161+
data.untyped_storage().data_ptr()
162+
== out._data.untyped_storage().data_ptr()
163+
)
164+
else:
165+
assert out._data.dtype == param_dtype, (
166+
f"{out._data.dtype} {param_dtype}"
167+
)
168+
out._data.copy_(data)
159169
return
160-
161170
output = ScaledGroupedMMTensor(data)
162171
inner_tensors = (data,)
163172
return output, inner_tensors

0 commit comments

Comments
 (0)