Skip to content

Commit 90c8cbd

Browse files
authored
Fix inference_mode (#885)
Summary: Fixes: #875 Test Plan: Test locally with tutorials/quantize_vit/run_vit_b_quant.py with: ``` with torch.inference_mode(): benchmark_model(model, 20, inputs) ``` but can't repro the issue in unit tests Reviewers: Subscribers: Tasks: Tags:
1 parent 3fa38aa commit 90c8cbd

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1483,7 +1483,7 @@ def _register_aqt_quantized_linear_dispatches():
14831483

14841484
_register_aqt_quantized_linear_dispatches()
14851485

1486-
@implements(torch.nn.functional.linear)
1486+
@implements([torch.nn.functional.linear, aten.linear.default])
14871487
def _(func, types, args, kwargs):
14881488
input_tensor, weight_tensor, bias = (
14891489
args[0],

torchao/quantization/linear_activation_quantized_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def to(self, *args, **kwargs):
9191

9292
implements = LinearActivationQuantizedTensor.implements
9393

94-
@implements(torch.nn.functional.linear)
94+
@implements([torch.nn.functional.linear, aten.linear.default])
9595
def _(func, types, args, kwargs):
9696
input_tensor, weight_tensor, bias = (
9797
args[0],

0 commit comments

Comments
 (0)