From 1c9f2fe134052538a3eb9ced19ae47b744d6ca63 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 25 Jun 2025 15:05:32 +0000 Subject: [PATCH 1/2] enable cpu to xpu Signed-off-by: jiqing-feng --- torchao/dtypes/uintx/int4_cpu_layout.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index bf9446d265..452a1a55e9 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -39,6 +39,8 @@ class Int4CPULayout(Layout): pass +from torchao.dtypes.uintx.int4_xpu_layout import Int4XPUAQTTensorImpl + @register_layout(Int4CPULayout) class Int4CPUAQTTensorImpl(AQTTensorImpl): """TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only, @@ -148,10 +150,15 @@ def from_plain( def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs["device"] - if not is_device(torch.device(self.device).type, device): + if self.device.type == "xpu": + from torchao.dtypes import Int4XPULayout + int_data, scale, zero_point = self.get_plain() + return Int4XPUAQTTensorImpl.from_plain(int_data.to(device), scale.to(device), zero_point.to(device), _layout=Int4XPULayout()) + elif not is_device(torch.device(self.device).type, device): raise ValueError( f"Int4CPUAQTTensorImpl does not support conversion from {self.device} to {device}" ) + return self.__class__( self.packed_weight.to(device), self.scale_and_zero.to(device), @@ -241,6 +248,10 @@ def block_size(self): def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros + if self.device.type != "cpu": + self.scale_and_zero = self.scale_and_zero.to("cpu") + self.packed_weight = self.packed_weight.to("cpu") + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) cur_shape = self.shape @@ -249,7 +260,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: eye_shape = original_shape[1] groupsize = int(original_shape[1] / scale.shape[-2]) block_size = (1, groupsize) - device = self.device + device = torch.device("cpu") original_dtype = self.scale_and_zero.dtype target_dtype = torch.int32 quant_min = 0 From affe779683b521c381439dcc266ef40f7d59a8cc Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 25 Jun 2025 15:12:27 +0000 Subject: [PATCH 2/2] fix format Signed-off-by: jiqing-feng --- torchao/dtypes/uintx/int4_cpu_layout.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index 452a1a55e9..834fe04114 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -153,7 +153,8 @@ def to(self, *args, **kwargs): if self.device.type == "xpu": from torchao.dtypes import Int4XPULayout int_data, scale, zero_point = self.get_plain() - return Int4XPUAQTTensorImpl.from_plain(int_data.to(device), scale.to(device), zero_point.to(device), _layout=Int4XPULayout()) + int_data, scale, zero_point = int_data.to(self.device), scale.to(self.device), zero_point.to(self.device) + return Int4XPUAQTTensorImpl.from_plain(int_data, scale, zero_point, _layout=Int4XPULayout()) elif not is_device(torch.device(self.device).type, device): raise ValueError( f"Int4CPUAQTTensorImpl does not support conversion from {self.device} to {device}"