1
- import unittest
2
- import functools
3
1
import copy
4
- import torch
5
- import torchao
6
- import os
2
+ import functools
3
+ import unittest
7
4
5
+ import torch
6
+ from torch .distributed ._tensor import DeviceMesh , DTensor , Replicate , Shard
8
7
from torch .testing ._internal import common_utils
9
- from torchao .dtypes import AffineQuantizedTensor
10
- from torchao .dtypes import to_affine_quantized_intx
8
+ from torch .testing ._internal .distributed ._tensor .common_dtensor import (
9
+ DTensorTestBase ,
10
+ with_comms ,
11
+ )
12
+
13
+ import torchao
14
+ from torchao .dtypes import AffineQuantizedTensor , to_affine_quantized_intx
15
+ from torchao .quantization import int8_weight_only , quantize_
11
16
from torchao .quantization .quant_primitives import MappingType
12
- from torchao .quantization import quantize_ , int8_weight_only
13
17
from torchao .utils import TORCH_VERSION_AT_LEAST_2_6
14
18
15
19
"""
@@ -36,10 +40,9 @@ class MyTestCase(TorchAOBasicTestCase):
36
40
unittest.main()
37
41
"""
38
42
43
+
39
44
# copied from https://github.com/pytorch/pytorch/blob/941d094dd1b507dacf06ddc6ed3485a9537e09b7/test/inductor/test_torchinductor.py#L11389
40
- def copy_tests (
41
- my_cls , other_cls , suffix , test_failures = None , xfail_prop = None
42
- ): # noqa: B902
45
+ def copy_tests (my_cls , other_cls , suffix , test_failures = None , xfail_prop = None ): # noqa: B902
43
46
for name , value in my_cls .__dict__ .items ():
44
47
if name .startswith ("test_" ):
45
48
# You cannot copy functions in Python, so we use closures here to
@@ -70,7 +73,6 @@ def new_test(self, value=value):
70
73
setattr (other_cls , f"{ name } _{ suffix } " , new_test )
71
74
72
75
73
-
74
76
class TorchAOBasicTestCase (common_utils .TestCase ):
75
77
COMMON_DEVICES = ["cpu" ] + (["cuda" ] if torch .cuda .is_available () else [])
76
78
COMMON_DTYPES = [torch .float32 , torch .float16 , torch .bfloat16 ]
@@ -90,17 +92,21 @@ def test_flatten_unflatten(self):
90
92
hp_tensor = torch .randn (4 , 128 )
91
93
lp_tensor = self .FACTORY_FN (hp_tensor , ** self .kwargs )
92
94
tensor_data_name_dict , tensor_attributes = lp_tensor .__tensor_flatten__ ()
93
- tensor_data_dict = {name : getattr (lp_tensor , name ) for name in tensor_data_name_dict }
95
+ tensor_data_dict = {
96
+ name : getattr (lp_tensor , name ) for name in tensor_data_name_dict
97
+ }
94
98
outer_size = lp_tensor .size ()
95
99
outer_stride = lp_tensor .stride ()
96
- reconstructed = self .TENSOR_SUBCLASS .__tensor_unflatten__ (tensor_data_dict , tensor_attributes , outer_size , outer_stride )
100
+ reconstructed = self .TENSOR_SUBCLASS .__tensor_unflatten__ (
101
+ tensor_data_dict , tensor_attributes , outer_size , outer_stride
102
+ )
97
103
self .assertEqual (lp_tensor .dequantize (), reconstructed .dequantize ())
98
104
99
105
@common_utils .parametrize ("device" , COMMON_DEVICES )
100
106
@common_utils .parametrize ("dtype" , COMMON_DTYPES )
101
107
def test_hp_tensor_device_dtype (self , device , dtype ):
102
108
hp_tensor = torch .randn (4 , 128 , device = device , dtype = dtype )
103
- lp_tensor = self .FACTORY_FN (hp_tensor , ** self .kwargs )
109
+ self .FACTORY_FN (hp_tensor , ** self .kwargs )
104
110
105
111
@common_utils .parametrize ("device1" , COMMON_DEVICES )
106
112
@common_utils .parametrize ("device2" , COMMON_DEVICES )
@@ -141,7 +147,10 @@ def test_linear(self, device, dtype):
141
147
hp_act_tensor = torch .randn (32 , 128 , device = device , dtype = dtype )
142
148
hp_res = torch .nn .functional .linear (hp_act_tensor , hp_tensor )
143
149
lp_res = torch .nn .functional .linear (hp_act_tensor , lp_tensor )
144
- self .assertGreater (torchao .quantization .utils .compute_error (hp_res , lp_res ), self .LINEAR_MIN_SQNR )
150
+ self .assertGreater (
151
+ torchao .quantization .utils .compute_error (hp_res , lp_res ),
152
+ self .LINEAR_MIN_SQNR ,
153
+ )
145
154
146
155
147
156
class TorchAOCompileTestCase (common_utils .TestCase ):
@@ -165,6 +174,7 @@ class TorchAOCompileTestCase(common_utils.TestCase):
165
174
def test_input_output_tensor_subclass (self , device , dtype ):
166
175
hp_tensor = torch .randn (4 , 128 , device = device , dtype = dtype )
167
176
lp_tensor = self .FACTORY_FN (hp_tensor , ** self .kwargs )
177
+
168
178
def f (tensor ):
169
179
return tensor
170
180
@@ -179,6 +189,7 @@ def f(tensor):
179
189
def test_input_tensor_subclass (self , device , dtype ):
180
190
hp_tensor = torch .randn (4 , 128 , device = device , dtype = dtype )
181
191
lp_tensor = self .FACTORY_FN (hp_tensor , ** self .kwargs )
192
+
182
193
def f (tensor ):
183
194
return tensor .dequantize ()
184
195
@@ -192,6 +203,7 @@ def f(tensor):
192
203
@common_utils .parametrize ("dtype" , COMMON_DTYPES )
193
204
def test_output_tensor_subclass (self , device , dtype ):
194
205
hp_tensor = torch .randn (4 , 128 , device = device , dtype = dtype )
206
+
195
207
def f (hp_tensor ):
196
208
return self .FACTORY_FN (hp_tensor , ** self .kwargs )
197
209
@@ -201,7 +213,12 @@ def f(hp_tensor):
201
213
self .assertTrue (isinstance (f (hp_tensor ), self .TENSOR_SUBCLASS ))
202
214
# bfloat16 seems to result in much larger numerical differences
203
215
if dtype != torch .bfloat16 :
204
- self .assertGreater (torchao .quantization .utils .compute_error (ref .dequantize (), compiled .dequantize ()), self .COMPILE_MIN_SQNR )
216
+ self .assertGreater (
217
+ torchao .quantization .utils .compute_error (
218
+ ref .dequantize (), compiled .dequantize ()
219
+ ),
220
+ self .COMPILE_MIN_SQNR ,
221
+ )
205
222
206
223
@common_utils .parametrize ("device" , COMMON_DEVICES )
207
224
@common_utils .parametrize ("dtype" , COMMON_DTYPES )
@@ -211,22 +228,18 @@ def test_linear_compile(self, device, dtype):
211
228
212
229
hp_act_tensor = torch .randn (32 , 128 , device = device , dtype = dtype )
213
230
hp_res = torch .nn .functional .linear (hp_act_tensor , hp_tensor )
214
- l = torch .nn .Linear (128 , 4 , bias = False , device = device , dtype = dtype )
215
- l .weight = torch .nn .Parameter (lp_tensor )
216
- lp_res = torch .compile (l )(hp_act_tensor )
217
- self .assertGreater (torchao .quantization .utils .compute_error (hp_res , lp_res ), self .LINEAR_MIN_SQNR )
231
+ linear = torch .nn .Linear (128 , 4 , bias = False , device = device , dtype = dtype )
232
+ linear .weight = torch .nn .Parameter (lp_tensor )
233
+ lp_res = torch .compile (linear )(hp_act_tensor )
234
+ self .assertGreater (
235
+ torchao .quantization .utils .compute_error (hp_res , lp_res ),
236
+ self .LINEAR_MIN_SQNR ,
237
+ )
218
238
219
- import torch .distributed as dist
220
- from torch .distributed ._tensor import DTensor , Replicate , Shard , DeviceMesh
221
- from torch .testing ._internal .distributed ._tensor .common_dtensor import (
222
- DTensorTestBase ,
223
- with_comms ,
224
- NUM_DEVICES ,
225
- )
226
239
227
240
class TorchAOTensorParallelTestCase (DTensorTestBase ):
228
- """Basic test case for tensor subclasses
229
- """
241
+ """Basic test case for tensor subclasses"""
242
+
230
243
COMMON_DTYPES = [torch .float32 , torch .float16 , torch .bfloat16 ]
231
244
232
245
TENSOR_SUBCLASS = AffineQuantizedTensor
@@ -247,9 +260,7 @@ def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
247
260
# Construct DTensor from local shard
248
261
dtensor = DTensor .from_local (local_shard , mesh , [Shard (0 )])
249
262
# Replace parameter in module
250
- m .linear .weight = torch .nn .Parameter (
251
- dtensor , requires_grad = False
252
- )
263
+ m .linear .weight = torch .nn .Parameter (dtensor , requires_grad = False )
253
264
return m
254
265
255
266
@staticmethod
@@ -266,9 +277,7 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
266
277
# Construct DTensor from local shard
267
278
dtensor = DTensor .from_local (local_shard , mesh , [Shard (1 )])
268
279
# Replace parameter in module
269
- m .linear .weight = torch .nn .Parameter (
270
- dtensor , requires_grad = False
271
- )
280
+ m .linear .weight = torch .nn .Parameter (dtensor , requires_grad = False )
272
281
return m
273
282
274
283
def quantize (self , m : torch .nn .Module ) -> torch .nn .Module :
@@ -289,7 +298,9 @@ def test_tp(self, dtype):
289
298
class M (torch .nn .Module ):
290
299
def __init__ (self , in_features , out_features , ** kwargs ) -> None :
291
300
super ().__init__ (** kwargs )
292
- self .linear = torch .nn .Linear (in_features , out_features , bias = False , device = "cuda" )
301
+ self .linear = torch .nn .Linear (
302
+ in_features , out_features , bias = False , device = "cuda"
303
+ )
293
304
294
305
def forward (self , x : torch .Tensor ) -> torch .Tensor :
295
306
return self .linear (x )
@@ -301,12 +312,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
301
312
proj_up = M (1024 , 2048 ).to (device ).to (dtype )
302
313
proj_dn = M (2048 , 1024 ).to (device ).to (dtype )
303
314
example_input = 100 * torch .randn (128 , 1024 , device = device , dtype = dtype )
304
- y = proj_dn (proj_up (example_input ))
315
+ proj_dn (proj_up (example_input ))
305
316
306
317
# Quantize the model
307
318
up_quant = self .quantize (proj_up )
308
319
dn_quant = self .quantize (proj_dn )
309
- y_q = dn_quant (up_quant (example_input ))
320
+ dn_quant (up_quant (example_input ))
310
321
311
322
mesh = self .build_device_mesh ()
312
323
mesh .device_type = "cuda"
@@ -316,11 +327,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
316
327
dn_dist = self .rowwise_shard (dn_quant , mesh )
317
328
318
329
# We need to turn inputs into DTensor form as well -- just a format change
319
- input_dtensor = DTensor .from_local (
320
- example_input , mesh , [Replicate ()]
321
- )
330
+ input_dtensor = DTensor .from_local (example_input , mesh , [Replicate ()])
322
331
323
- y_d = dn_dist (up_dist (input_dtensor ))
332
+ dn_dist (up_dist (input_dtensor ))
324
333
325
334
if not TORCH_VERSION_AT_LEAST_2_6 :
326
335
# Need torch 2.6 to support compiled tensor parallelism
@@ -329,7 +338,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
329
338
up_compiled = torch .compile (up_dist )
330
339
y_up = up_compiled (input_dtensor )
331
340
dn_compiled = torch .compile (dn_dist )
332
- y_dn = dn_compiled (y_up )
341
+ dn_compiled (y_up )
342
+
333
343
334
344
common_utils .instantiate_parametrized_tests (TorchAOBasicTestCase )
335
345
common_utils .instantiate_parametrized_tests (TorchAOCompileTestCase )
0 commit comments