Skip to content

Commit 95119bb

Browse files
authored
Fixing aliasing behavior for slice in AQT TensorCoreTiledLayout (#2174)
* Fixing aliasing behavior for slice in AQT TensorCoreTiledLayout Summary: slice op is supposed to preserve aliasing (output of slice should alias the input), but this is not true for TensorCoreTiledLayout (used by int4wo), and some others like gemlite Reason is that we do unpacking, pading and prepacking right now, which creates new tensors. We fixes it in this PR by doing slicing on the packed inner Tensor directly, specifically packed_weight and scale_and_zero in TensorCoreTiledLayout. Test Plan: python test/dtypes/test_affine_quantized.py -k test_slice_and_copy_int4wo Reviewers: Subscribers: Tasks: Tags: * simplify code * add check for data_ptr * format * avoid div by zero * format
1 parent aa95bb4 commit 95119bb

File tree

2 files changed

+79
-21
lines changed

2 files changed

+79
-21
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,40 @@ def test_matmul(self, device, dtype):
387387
# make sure it runs
388388
torch.matmul(x, w.t())
389389

390+
@common_utils.parametrize("device", ["cuda"])
391+
@common_utils.parametrize("dtype", [torch.bfloat16])
392+
@skip_if_no_cuda()
393+
@skip_if_rocm("ROCm enablement in progress")
394+
def test_slice_and_copy_int4wo(self, device, dtype):
395+
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
396+
l.weight = torch.nn.Parameter(
397+
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
398+
)
399+
quantize_(l, Int4WeightOnlyConfig())
400+
param = l.weight
401+
param_data = param.data
402+
param_data = param_data.narrow(0, 0, 512)
403+
assert (
404+
param.data.tensor_impl.packed_weight.data_ptr()
405+
== param_data.tensor_impl.packed_weight.data_ptr()
406+
)
407+
assert (
408+
param.data.tensor_impl.scale_and_zero.data_ptr()
409+
== param_data.tensor_impl.scale_and_zero.data_ptr()
410+
)
411+
assert param.data.dequantize()[0][0] == 0
412+
413+
# dummy_l has random input (shouldn't be 0)
414+
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
415+
quantize_(dummy_l, Int4WeightOnlyConfig())
416+
quantized = dummy_l.weight
417+
quantized = quantized.narrow(0, 0, 512)
418+
419+
param_data.copy_(quantized)
420+
421+
# making sure param.data is updated
422+
assert param.data.dequantize()[0][0] != 0
423+
390424

391425
common_utils.instantiate_parametrized_tests(TestAffineQuantized)
392426
common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic)

torchao/dtypes/uintx/tensor_core_tiled_layout.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -350,30 +350,54 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
350350

351351
if func is aten.slice.Tensor:
352352
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
353-
if dim in [0, 1]:
354-
int_data, scale, zero_point = self.get_plain()
355-
data_len = int_data.shape[dim]
356-
scale_len = scale.shape[dim]
357-
ratio = data_len / scale_len
358-
start_scale = int(start / ratio)
359-
end_scale = int(end / ratio)
360-
361-
int_data = aten.slice.Tensor(int_data, dim, start, end, step)
362-
scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step)
363-
zero_point = aten.slice.Tensor(
364-
zero_point, dim, start_scale, end_scale, step
365-
)
366-
# this is to handle padding
367-
int_data, scale, zero_point = self._layout.post_process(
368-
int_data, scale, zero_point, self.block_size
369-
)
370-
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
371-
return return_and_correct_aliasing(func, args, kwargs, sliced)
353+
n_by_8, k_by_inner_tiles, _, _ = self.packed_weight.shape
354+
sz_dim1, sz_dim0, _ = self.scale_and_zero.shape
355+
data_len = self.shape[dim]
356+
assert dim in [0, 1], (
357+
f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
358+
)
359+
360+
if dim == 0:
361+
pw_len = n_by_8
362+
sz_len = sz_dim0
372363
else:
373-
raise NotImplementedError(
374-
f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
364+
pw_len = k_by_inner_tiles
365+
sz_len = sz_dim1
366+
367+
if pw_len == 0 or sz_len == 0:
368+
return return_and_correct_aliasing(
369+
func,
370+
args,
371+
kwargs,
372+
TensorCoreTiledAQTTensorImpl(
373+
self.packed_weight,
374+
self.scale_and_zero,
375+
self.transposed,
376+
self._layout,
377+
),
375378
)
376379

380+
pw_ratio = data_len / pw_len
381+
start_pw = int(start / pw_ratio)
382+
end_pw = int(end / pw_ratio)
383+
384+
sz_ratio = data_len / sz_len
385+
start_sz = int(start / sz_ratio)
386+
end_sz = int(end / sz_ratio)
387+
388+
packed_weight = aten.slice(self.packed_weight, dim, start_pw, end_pw, step)
389+
scale_and_zero = aten.slice(
390+
self.scale_and_zero, 1 - dim, start_sz, end_sz, step
391+
)
392+
return return_and_correct_aliasing(
393+
func,
394+
args,
395+
kwargs,
396+
TensorCoreTiledAQTTensorImpl(
397+
packed_weight, scale_and_zero, self.transposed, self._layout
398+
),
399+
)
400+
377401
raise NotImplementedError(
378402
f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, this is not supported"
379403
)

0 commit comments

Comments
 (0)