9
9
10
10
import torch
11
11
import torch .utils ._pytree as pytree
12
+ from torch import nn
12
13
from torch ._prims_common import suggest_memory_format
14
+ from torch .distributed .device_mesh import DeviceMesh
15
+ from torch .distributed .fsdp import MixedPrecisionPolicy
16
+ from torch .autograd .grad_mode import _unsafe_preserve_version_counter
13
17
14
18
from torchao .prototype .moe_training import _scaled_grouped_mm
15
19
@@ -69,7 +73,6 @@ def __init__(
69
73
70
74
@classmethod
71
75
def __torch_function__ (cls , func , types , args , kwargs = {}):
72
- logger .info (f"{ func .__name__ } , args: { args } , kwargs: { kwargs } " )
73
76
# override the grouped mm op to use the differentiable _scaled_grouped_mm
74
77
if func .__name__ == cls .grouped_mm_func_name :
75
78
# Use torchao scaled grouped mm with dynamic quant for
@@ -142,9 +145,18 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
142
145
flatten_spec ["_dtype" ],
143
146
)
144
147
145
- def fsdp_pre_all_gather (self , mesh ):
146
- all_gather_inputs = (self ._data ,)
148
+ # fsdp hooks based on https://github.com/pytorch/pytorch/blob/20e40492b046b9287726d3ec656117e4dc38f0e2/test/distributed/_composable/fsdp/test_fully_shard_extensions.py#L81
149
+ def fsdp_pre_all_gather (
150
+ self ,
151
+ mesh : DeviceMesh ,
152
+ outer_size : torch .Size ,
153
+ outer_stride : tuple [int , ...],
154
+ module : nn .Module ,
155
+ mp_policy : MixedPrecisionPolicy ,
156
+ ):
157
+ all_gather_inputs = (self ._data .to (mp_policy .param_dtype ),)
147
158
all_gather_metadata = ()
159
+ logger .debug (f"fsdp_pre_all_gather: self._data.dtype={ self ._data .dtype } , param_dtype: { mp_policy .param_dtype } " )
148
160
return all_gather_inputs , all_gather_metadata
149
161
150
162
def fsdp_post_all_gather (
@@ -156,6 +168,15 @@ def fsdp_post_all_gather(
156
168
out : Optional [torch .Tensor ] = None ,
157
169
):
158
170
(data ,) = all_gather_outputs
159
- output = ScaledGroupedMMTensor (data , param_dtype )
160
- inner_tensors = (data ,)
171
+ logger .debug (f"fsdp_post_all_gather: data.dtype={ data .dtype } , param_dtype: { param_dtype } " )
172
+
173
+ if out is not None :
174
+ #with _unsafe_preserve_version_counter(out):
175
+ with torch .no_grad ():
176
+ out .copy_ (data )
177
+ return
178
+
179
+ upcast_data = data .to (param_dtype )
180
+ output = ScaledGroupedMMTensor (upcast_data , param_dtype )
181
+ inner_tensors = (upcast_data ,)
161
182
return output , inner_tensors
0 commit comments