diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index bf9446d265..834fe04114 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -39,6 +39,8 @@ class Int4CPULayout(Layout): pass +from torchao.dtypes.uintx.int4_xpu_layout import Int4XPUAQTTensorImpl + @register_layout(Int4CPULayout) class Int4CPUAQTTensorImpl(AQTTensorImpl): """TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only, @@ -148,10 +150,16 @@ def from_plain( def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs["device"] - if not is_device(torch.device(self.device).type, device): + if self.device.type == "xpu": + from torchao.dtypes import Int4XPULayout + int_data, scale, zero_point = self.get_plain() + int_data, scale, zero_point = int_data.to(self.device), scale.to(self.device), zero_point.to(self.device) + return Int4XPUAQTTensorImpl.from_plain(int_data, scale, zero_point, _layout=Int4XPULayout()) + elif not is_device(torch.device(self.device).type, device): raise ValueError( f"Int4CPUAQTTensorImpl does not support conversion from {self.device} to {device}" ) + return self.__class__( self.packed_weight.to(device), self.scale_and_zero.to(device), @@ -241,6 +249,10 @@ def block_size(self): def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros + if self.device.type != "cpu": + self.scale_and_zero = self.scale_and_zero.to("cpu") + self.packed_weight = self.packed_weight.to("cpu") + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) cur_shape = self.shape @@ -249,7 +261,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: eye_shape = original_shape[1] groupsize = int(original_shape[1] / scale.shape[-2]) block_size = (1, groupsize) - device = self.device + device = torch.device("cpu") original_dtype = self.scale_and_zero.dtype target_dtype = torch.int32 quant_min = 0