20
20
LayoutType ,
21
21
PlainLayoutType ,
22
22
)
23
- from torchao .utils import TorchAOBaseTensor , _register_layout_cls , _get_layout_tensor_constructor
23
+ from torchao .utils import TorchAOBaseTensor
24
24
25
25
aten = torch .ops .aten
26
26
@@ -191,12 +191,8 @@ def _apply_fn_to_data(self, fn):
191
191
# LayoutType and Layout Tensor Subclass Registration #
192
192
######################################################
193
193
194
- def register_layout_cls (layout_type_class : type (LayoutType )):
195
- return _register_layout_cls (MyDTypeTensor , layout_type_class )
196
-
197
- def get_layout_tensor_constructor (layout_type_class : type (LayoutType )):
198
- return _get_layout_tensor_constructor (MyDTypeTensor , layout_type_class )
199
-
194
+ register_layout_cls = MyDTypeTensor .register_layout_cls
195
+ get_layout_tensor_constructor = MyDTypeTensor .get_layout_tensor_constructor
200
196
201
197
@register_layout_cls (PlainLayoutType )
202
198
class PlainMyDTypeLayout (MyDTypeLayout ):
@@ -343,12 +339,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
343
339
344
340
for _ in range (NUM_WARMUPS ):
345
341
m (* example_inputs )
346
- print ("before quantization:" , benchmark_model (m , NUM_RUNS , example_inputs [ 0 ] ))
342
+ print ("before quantization:" , benchmark_model (m , NUM_RUNS , example_inputs ))
347
343
348
344
compiled = torch .compile (m , mode = "max-autotune" )
349
345
for _ in range (NUM_WARMUPS ):
350
346
compiled (* example_inputs )
351
- print ("after compile:" , benchmark_model (compiled , NUM_RUNS , example_inputs [ 0 ] ))
347
+ print ("after compile:" , benchmark_model (compiled , NUM_RUNS , example_inputs ))
352
348
353
349
# convert weights to quantized weights
354
350
m .linear .weight = torch .nn .Parameter (
@@ -358,7 +354,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
358
354
for _ in range (NUM_WARMUPS ):
359
355
m (* example_inputs )
360
356
361
- print ("after quantization:" , benchmark_model (m , NUM_RUNS , example_inputs [ 0 ] ))
357
+ print ("after quantization:" , benchmark_model (m , NUM_RUNS , example_inputs ))
362
358
363
359
m = torch .compile (m , mode = "max-autotune" )
364
360
@@ -367,4 +363,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
367
363
368
364
# NOTE: currently there is no speedup because we just dequantize the weight in the _quantized_linear op
369
365
# we plan to add custom op example in the future and that will help us to get speedup
370
- print ("after quantization and compile:" , benchmark_model (m , NUM_RUNS , example_inputs [ 0 ] ))
366
+ print ("after quantization and compile:" , benchmark_model (m , NUM_RUNS , example_inputs ))
0 commit comments