Skip to content

Commit a83636d

Browse files
authored
Revert "Fixing aliasing behavior for slice in AQT TensorCoreTiledLayout" (#2175)
Revert "Fixing aliasing behavior for slice in AQT TensorCoreTiledLayout (#2174)" This reverts commit 95119bb.
1 parent 95119bb commit a83636d

File tree

2 files changed

+21
-79
lines changed

2 files changed

+21
-79
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -387,40 +387,6 @@ def test_matmul(self, device, dtype):
387387
# make sure it runs
388388
torch.matmul(x, w.t())
389389

390-
@common_utils.parametrize("device", ["cuda"])
391-
@common_utils.parametrize("dtype", [torch.bfloat16])
392-
@skip_if_no_cuda()
393-
@skip_if_rocm("ROCm enablement in progress")
394-
def test_slice_and_copy_int4wo(self, device, dtype):
395-
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
396-
l.weight = torch.nn.Parameter(
397-
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
398-
)
399-
quantize_(l, Int4WeightOnlyConfig())
400-
param = l.weight
401-
param_data = param.data
402-
param_data = param_data.narrow(0, 0, 512)
403-
assert (
404-
param.data.tensor_impl.packed_weight.data_ptr()
405-
== param_data.tensor_impl.packed_weight.data_ptr()
406-
)
407-
assert (
408-
param.data.tensor_impl.scale_and_zero.data_ptr()
409-
== param_data.tensor_impl.scale_and_zero.data_ptr()
410-
)
411-
assert param.data.dequantize()[0][0] == 0
412-
413-
# dummy_l has random input (shouldn't be 0)
414-
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
415-
quantize_(dummy_l, Int4WeightOnlyConfig())
416-
quantized = dummy_l.weight
417-
quantized = quantized.narrow(0, 0, 512)
418-
419-
param_data.copy_(quantized)
420-
421-
# making sure param.data is updated
422-
assert param.data.dequantize()[0][0] != 0
423-
424390

425391
common_utils.instantiate_parametrized_tests(TestAffineQuantized)
426392
common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic)

torchao/dtypes/uintx/tensor_core_tiled_layout.py

Lines changed: 21 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -350,54 +350,30 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
350350

351351
if func is aten.slice.Tensor:
352352
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
353-
n_by_8, k_by_inner_tiles, _, _ = self.packed_weight.shape
354-
sz_dim1, sz_dim0, _ = self.scale_and_zero.shape
355-
data_len = self.shape[dim]
356-
assert dim in [0, 1], (
357-
f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
358-
)
359-
360-
if dim == 0:
361-
pw_len = n_by_8
362-
sz_len = sz_dim0
353+
if dim in [0, 1]:
354+
int_data, scale, zero_point = self.get_plain()
355+
data_len = int_data.shape[dim]
356+
scale_len = scale.shape[dim]
357+
ratio = data_len / scale_len
358+
start_scale = int(start / ratio)
359+
end_scale = int(end / ratio)
360+
361+
int_data = aten.slice.Tensor(int_data, dim, start, end, step)
362+
scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step)
363+
zero_point = aten.slice.Tensor(
364+
zero_point, dim, start_scale, end_scale, step
365+
)
366+
# this is to handle padding
367+
int_data, scale, zero_point = self._layout.post_process(
368+
int_data, scale, zero_point, self.block_size
369+
)
370+
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
371+
return return_and_correct_aliasing(func, args, kwargs, sliced)
363372
else:
364-
pw_len = k_by_inner_tiles
365-
sz_len = sz_dim1
366-
367-
if pw_len == 0 or sz_len == 0:
368-
return return_and_correct_aliasing(
369-
func,
370-
args,
371-
kwargs,
372-
TensorCoreTiledAQTTensorImpl(
373-
self.packed_weight,
374-
self.scale_and_zero,
375-
self.transposed,
376-
self._layout,
377-
),
373+
raise NotImplementedError(
374+
f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
378375
)
379376

380-
pw_ratio = data_len / pw_len
381-
start_pw = int(start / pw_ratio)
382-
end_pw = int(end / pw_ratio)
383-
384-
sz_ratio = data_len / sz_len
385-
start_sz = int(start / sz_ratio)
386-
end_sz = int(end / sz_ratio)
387-
388-
packed_weight = aten.slice(self.packed_weight, dim, start_pw, end_pw, step)
389-
scale_and_zero = aten.slice(
390-
self.scale_and_zero, 1 - dim, start_sz, end_sz, step
391-
)
392-
return return_and_correct_aliasing(
393-
func,
394-
args,
395-
kwargs,
396-
TensorCoreTiledAQTTensorImpl(
397-
packed_weight, scale_and_zero, self.transposed, self._layout
398-
),
399-
)
400-
401377
raise NotImplementedError(
402378
f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, this is not supported"
403379
)

0 commit comments

Comments
 (0)