Skip to content

Commit 151bd03

Browse files
authored
Needed for save and load of gemma 1b (#1903)
stack-info: PR: #1903, branch: drisspg/stack/43
1 parent 38b1f45 commit 151bd03

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import unittest
88

99
import torch
10+
import torch.nn as nn
1011
from torch.testing._internal import common_utils
1112
from torch.testing._internal.common_utils import (
1213
TestCase,
@@ -16,6 +17,7 @@
1617
from torchao.core.config import AOBaseConfig
1718
from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout
1819
from torchao.quantization import (
20+
Int8DynamicActivationInt8WeightConfig,
1921
float8_weight_only,
2022
int4_dynamic_activation_int4_weight,
2123
int4_weight_only,
@@ -298,6 +300,13 @@ def test_flatten_unflatten(self, device, dtype):
298300
reconstruct_res = ql(*example_inputs)
299301
self.assertEqual(reconstruct_res, ref)
300302

303+
@common_utils.parametrize("device", COMMON_DEVICES)
304+
@common_utils.parametrize("dtype", COMMON_DTYPES)
305+
def test_alias(self, device, dtype):
306+
dummy = nn.Linear(128, 256, dtype=dtype, device=device)
307+
quantize_(dummy, Int8DynamicActivationInt8WeightConfig())
308+
_ = dummy.weight[...]
309+
301310

302311
common_utils.instantiate_parametrized_tests(TestAffineQuantized)
303312
common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic)

torchao/dtypes/affine_quantized_tensor_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def _(func, types, args, kwargs):
341341
return func(input_tensor, weight_tensor)
342342

343343

344-
@implements(aten.detach.default)
344+
@implements([aten.detach.default, aten.alias.default])
345345
def _(func, types, args, kwargs):
346346
return return_and_correct_aliasing(
347347
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)

torchao/quantization/linear_activation_quantized_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,10 @@ def _(func, types, args, kwargs):
184184
return func(qtensor, original_weight_tensor)
185185

186186

187-
@implements(aten.detach.default)
187+
@implements([aten.detach.default, aten.alias.default])
188188
def _(func, types, args, kwargs):
189189
return return_and_correct_aliasing(
190-
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
190+
func, args, kwargs, args[0]._apply_fn_to_data(func)
191191
)
192192

193193

0 commit comments

Comments
 (0)