Skip to content

Commit 4a4925f

Browse files
authored
Revert "Add support for copy_ for plain layout and tensor core tiled layout" (#1803)
Revert "Add support for copy_ for plain layout and tensor core tiled layout (…" This reverts commit 79e3366.
1 parent 3219318 commit 4a4925f

File tree

7 files changed

+1
-210
lines changed

7 files changed

+1
-210
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -209,53 +209,6 @@ def test_print_quantized_module(self, apply_quant):
209209
ql = apply_quant(linear)
210210
assert "AffineQuantizedTensor" in str(ql)
211211

212-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
213-
@common_utils.parametrize(
214-
"apply_quant", get_quantization_functions(False, True, "cuda", False)
215-
)
216-
def test_copy_(self, apply_quant):
217-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
218-
linear2 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
219-
220-
if isinstance(apply_quant, AOBaseConfig):
221-
quantize_(linear, apply_quant)
222-
ql = linear
223-
quantize_(linear2, apply_quant)
224-
ql2 = linear2
225-
else:
226-
ql = apply_quant(linear)
227-
ql2 = apply_quant(linear2)
228-
229-
example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")
230-
output = ql(example_input)
231-
ql2.weight.copy_(ql.weight)
232-
ql2.bias = ql.bias
233-
output2 = ql2(example_input)
234-
self.assertEqual(output, output2)
235-
236-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
237-
@common_utils.parametrize(
238-
"apply_quant", get_quantization_functions(False, True, "cuda", False)
239-
)
240-
def test_copy__mismatch_metadata(self, apply_quant):
241-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
242-
linear2 = torch.nn.Linear(128, 512, dtype=torch.bfloat16, device="cuda")
243-
244-
if isinstance(apply_quant, AOBaseConfig):
245-
quantize_(linear, apply_quant)
246-
ql = linear
247-
quantize_(linear2, apply_quant)
248-
ql2 = linear2
249-
else:
250-
ql = apply_quant(linear)
251-
ql2 = apply_quant(linear2)
252-
253-
# copy should fail due to shape mismatch
254-
with self.assertRaisesRegex(
255-
ValueError, "Not supported args for copy_ due to metadata mistach:"
256-
):
257-
ql2.weight.copy_(ql.weight)
258-
259212

260213
class TestAffineQuantizedBasic(TestCase):
261214
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])

torchao/dtypes/affine_quantized_tensor_ops.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -97,27 +97,6 @@ def deregister_aqt_quantized_linear_dispatch(dispatch_condition):
9797
)
9898

9999

100-
def _same_metadata(self: AffineQuantizedTensor, src: AffineQuantizedTensor):
101-
return (
102-
isinstance(self, AffineQuantizedTensor)
103-
and isinstance(src, AffineQuantizedTensor)
104-
and all(
105-
[
106-
getattr(self, attr) == getattr(src, attr)
107-
for attr in [
108-
"block_size",
109-
"shape",
110-
"quant_min",
111-
"quant_max",
112-
"zero_point_domain",
113-
"dtype",
114-
]
115-
]
116-
)
117-
and type(self.tensor_impl) == type(src.tensor_impl)
118-
)
119-
120-
121100
class QuantizedLinearNotImplementedError(NotImplementedError):
122101
"""Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table"""
123102

@@ -352,20 +331,6 @@ def _(func, types, args, kwargs):
352331
)
353332

354333

355-
@implements(aten.copy_.default)
356-
def _(func, types, args, kwargs):
357-
self = args[0]
358-
src = args[1]
359-
if _same_metadata(self, src):
360-
self_tensors = self.__tensor_flatten__()[0]
361-
for tensor_name in self_tensors:
362-
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
363-
return
364-
raise ValueError(
365-
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
366-
)
367-
368-
369334
@implements(aten.t.default)
370335
def _(func, types, args, kwargs):
371336
block_size = args[0].block_size

torchao/dtypes/floatx/float8_layout.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,6 @@
2323
aten = torch.ops.aten
2424

2525

26-
def _same_metadata(self: "Float8AQTTensorImpl", src: "Float8AQTTensorImpl") -> bool:
27-
return (
28-
isinstance(self, Float8AQTTensorImpl)
29-
and isinstance(src, Float8AQTTensorImpl)
30-
and self.shape == src.shape
31-
and self.float8_data.shape == src.float8_data.shape
32-
and self.scale.shape == src.scale.shape
33-
and self.transposed == src.transposed
34-
and type(self._layout) == type(src._layout)
35-
)
36-
37-
3826
@dataclass(frozen=True)
3927
class Float8Layout(Layout):
4028
"""Represents the layout configuration for Float8 affine quantized tensors.
@@ -138,17 +126,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
138126
"""
139127
args[0].transposed = not args[0].transposed
140128
return return_and_correct_aliasing(func, args, kwargs, args[0])
141-
elif func is aten.copy_.default:
142-
self = args[0]
143-
src = args[1]
144-
if _same_metadata(self, src):
145-
self_tensors = self.__tensor_flatten__()[0]
146-
for tensor_name in self_tensors:
147-
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
148-
return
149-
raise ValueError(
150-
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
151-
)
152129
elif func is aten.slice.Tensor:
153130
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
154131
if dim == 0:

torchao/dtypes/uintx/cutlass_int4_packed_layout.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,6 @@ def _aqt_is_int4(aqt):
2828
)
2929

3030

31-
def _same_metadata(self: "Int4PackedTensorImpl", src: "Int4PackedTensorImpl") -> bool:
32-
return (
33-
isinstance(self, Int4PackedTensorImpl)
34-
and isinstance(src, Int4PackedTensorImpl)
35-
and self.shape == src.shape
36-
and self.int_data.shape == src.int_data.shape
37-
and self.scale.shape == src.scale.shape
38-
and type(self._layout) == type(src._layout)
39-
)
40-
41-
4231
@dataclass(frozen=True)
4332
class CutlassInt4PackedLayout(Layout):
4433
"""Layout class for int4 packed layout for affine quantized tensor, for cutlass kernel."""
@@ -88,18 +77,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
8877
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
8978
)
9079

91-
elif func is aten.copy_.default:
92-
self = args[0]
93-
src = args[1]
94-
if _same_metadata(self, src):
95-
self_tensors = self.__tensor_flatten__()[0]
96-
for tensor_name in self_tensors:
97-
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
98-
return
99-
raise ValueError(
100-
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
101-
)
102-
10380
raise NotImplementedError(
10481
f"Int4PackedTensorImpl dispatch: attempting to run {func}, this is not supported"
10582
)

torchao/dtypes/uintx/plain_layout.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,6 @@
2222
aten = torch.ops.aten
2323

2424

25-
def _same_metadata(self: "PlainAQTTensorImpl", src: "PlainAQTTensorImpl") -> bool:
26-
return (
27-
isinstance(self, PlainAQTTensorImpl)
28-
and isinstance(src, PlainAQTTensorImpl)
29-
and self.shape == src.shape
30-
and self.int_data.shape == src.int_data.shape
31-
and self.scale.shape == src.scale.shape
32-
and (self.zero_point is None and src.zero_point is None)
33-
or (
34-
self.zero_point is not None
35-
and src.zero_point is not None
36-
and self.zero_point.shape == src.zero_point.shape
37-
)
38-
and type(self._layout) == type(src._layout)
39-
)
40-
41-
4225
@register_layout(PlainLayout)
4326
class PlainAQTTensorImpl(AQTTensorImpl):
4427
"""
@@ -125,23 +108,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
125108
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
126109
)
127110

128-
elif func is aten.clone.default:
111+
if func is aten.clone.default:
129112
return return_and_correct_aliasing(
130113
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
131114
)
132115

133-
elif func is aten.copy_.default:
134-
self = args[0]
135-
src = args[1]
136-
if _same_metadata(self, src):
137-
self_tensors = self.__tensor_flatten__()[0]
138-
for tensor_name in self_tensors:
139-
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
140-
return
141-
raise ValueError(
142-
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
143-
)
144-
145116
elif func is aten.t.default:
146117
tensor = args[0]
147118
new = tensor.__class__(

torchao/dtypes/uintx/tensor_core_tiled_layout.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,6 @@ def _aqt_is_tensor_core_tile_uint4(aqt):
3232
)
3333

3434

35-
def _same_metadata(
36-
self: "TensorCoreTiledAQTTensorImpl", src: "TensorCoreTiledAQTTensorImpl"
37-
) -> bool:
38-
return (
39-
isinstance(self, TensorCoreTiledAQTTensorImpl)
40-
and isinstance(src, TensorCoreTiledAQTTensorImpl)
41-
and self.shape == src.shape
42-
and self.packed_weight.shape == src.packed_weight.shape
43-
and self.scale_and_zero.shape == src.scale_and_zero.shape
44-
and self.transposed == src.transposed
45-
and type(self._layout) == type(src._layout)
46-
)
47-
48-
4935
def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias):
5036
return (
5137
# input is native bfloat16 tensor
@@ -304,18 +290,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
304290
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
305291
)
306292

307-
if func is aten.copy_.default:
308-
self = args[0]
309-
src = args[1]
310-
if _same_metadata(self, src):
311-
self_tensors = self.__tensor_flatten__()[0]
312-
for tensor_name in self_tensors:
313-
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
314-
return
315-
raise ValueError(
316-
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
317-
)
318-
319293
if func is aten.t.default:
320294
"""we don't need to repack the weight and just rely on external
321295
shape being changed and record the status of transpose/no-transpose

torchao/quantization/linear_activation_quantized_tensor.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -112,18 +112,6 @@ def to(self, *args, **kwargs):
112112
)
113113

114114

115-
def _same_metadata(
116-
self: LinearActivationQuantizedTensor, src: LinearActivationQuantizedTensor
117-
):
118-
return (
119-
isinstance(self, LinearActivationQuantizedTensor)
120-
and isinstance(src, LinearActivationQuantizedTensor)
121-
and self.shape == src.shape
122-
and self.input_quant_func == src.input_quant_func
123-
and self.quant_kwargs == src.quant_kwargs
124-
)
125-
126-
127115
implements = LinearActivationQuantizedTensor.implements
128116

129117

@@ -203,20 +191,6 @@ def _(func, types, args, kwargs):
203191
)
204192

205193

206-
@implements(aten.copy_.default)
207-
def _(func, types, args, kwargs):
208-
self = args[0]
209-
src = args[1]
210-
if _same_metadata(self, src):
211-
self_tensors = self.__tensor_flatten__()[0]
212-
for tensor_name in self_tensors:
213-
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
214-
return
215-
raise ValueError(
216-
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
217-
)
218-
219-
220194
@implements(aten.t.default)
221195
def _(func, types, args, kwargs):
222196
return return_and_correct_aliasing(

0 commit comments

Comments
 (0)