Skip to content

Commit f82071d

Browse files
authored
Fix tutorial example for my dtype tensor subclass (#865)
* Fix tutorial example for my dtyep tensor subclass Summary: att Test Plan: python tutorials/developer_api_guide/my_dtype_tensor_subclass.py Reviewers: Subscribers: Tasks: Tags: * fix function not defined error
1 parent 85d03de commit f82071d

File tree

1 file changed

+7
-11
lines changed

1 file changed

+7
-11
lines changed

tutorials/developer_api_guide/my_dtype_tensor_subclass.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
LayoutType,
2121
PlainLayoutType,
2222
)
23-
from torchao.utils import TorchAOBaseTensor, _register_layout_cls, _get_layout_tensor_constructor
23+
from torchao.utils import TorchAOBaseTensor
2424

2525
aten = torch.ops.aten
2626

@@ -191,12 +191,8 @@ def _apply_fn_to_data(self, fn):
191191
# LayoutType and Layout Tensor Subclass Registration #
192192
######################################################
193193

194-
def register_layout_cls(layout_type_class: type(LayoutType)):
195-
return _register_layout_cls(MyDTypeTensor, layout_type_class)
196-
197-
def get_layout_tensor_constructor(layout_type_class: type(LayoutType)):
198-
return _get_layout_tensor_constructor(MyDTypeTensor, layout_type_class)
199-
194+
register_layout_cls = MyDTypeTensor.register_layout_cls
195+
get_layout_tensor_constructor = MyDTypeTensor.get_layout_tensor_constructor
200196

201197
@register_layout_cls(PlainLayoutType)
202198
class PlainMyDTypeLayout(MyDTypeLayout):
@@ -343,12 +339,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
343339

344340
for _ in range(NUM_WARMUPS):
345341
m(*example_inputs)
346-
print("before quantization:", benchmark_model(m, NUM_RUNS, example_inputs[0]))
342+
print("before quantization:", benchmark_model(m, NUM_RUNS, example_inputs))
347343

348344
compiled = torch.compile(m, mode="max-autotune")
349345
for _ in range(NUM_WARMUPS):
350346
compiled(*example_inputs)
351-
print("after compile:", benchmark_model(compiled, NUM_RUNS, example_inputs[0]))
347+
print("after compile:", benchmark_model(compiled, NUM_RUNS, example_inputs))
352348

353349
# convert weights to quantized weights
354350
m.linear.weight = torch.nn.Parameter(
@@ -358,7 +354,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
358354
for _ in range(NUM_WARMUPS):
359355
m(*example_inputs)
360356

361-
print("after quantization:", benchmark_model(m, NUM_RUNS, example_inputs[0]))
357+
print("after quantization:", benchmark_model(m, NUM_RUNS, example_inputs))
362358

363359
m = torch.compile(m, mode="max-autotune")
364360

@@ -367,4 +363,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
367363

368364
# NOTE: currently there is no speedup because we just dequantize the weight in the _quantized_linear op
369365
# we plan to add custom op example in the future and that will help us to get speedup
370-
print("after quantization and compile:", benchmark_model(m, NUM_RUNS, example_inputs[0]))
366+
print("after quantization and compile:", benchmark_model(m, NUM_RUNS, example_inputs))

0 commit comments

Comments
 (0)