From 1c9f2fe134052538a3eb9ced19ae47b744d6ca63 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 25 Jun 2025 15:05:32 +0000 Subject: [PATCH 1/5] enable cpu to xpu Signed-off-by: jiqing-feng --- torchao/dtypes/uintx/int4_cpu_layout.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index bf9446d265..452a1a55e9 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -39,6 +39,8 @@ class Int4CPULayout(Layout): pass +from torchao.dtypes.uintx.int4_xpu_layout import Int4XPUAQTTensorImpl + @register_layout(Int4CPULayout) class Int4CPUAQTTensorImpl(AQTTensorImpl): """TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only, @@ -148,10 +150,15 @@ def from_plain( def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs["device"] - if not is_device(torch.device(self.device).type, device): + if self.device.type == "xpu": + from torchao.dtypes import Int4XPULayout + int_data, scale, zero_point = self.get_plain() + return Int4XPUAQTTensorImpl.from_plain(int_data.to(device), scale.to(device), zero_point.to(device), _layout=Int4XPULayout()) + elif not is_device(torch.device(self.device).type, device): raise ValueError( f"Int4CPUAQTTensorImpl does not support conversion from {self.device} to {device}" ) + return self.__class__( self.packed_weight.to(device), self.scale_and_zero.to(device), @@ -241,6 +248,10 @@ def block_size(self): def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros + if self.device.type != "cpu": + self.scale_and_zero = self.scale_and_zero.to("cpu") + self.packed_weight = self.packed_weight.to("cpu") + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) cur_shape = self.shape @@ -249,7 +260,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: eye_shape = original_shape[1] groupsize = int(original_shape[1] / scale.shape[-2]) block_size = (1, groupsize) - device = self.device + device = torch.device("cpu") original_dtype = self.scale_and_zero.dtype target_dtype = torch.int32 quant_min = 0 From affe779683b521c381439dcc266ef40f7d59a8cc Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 25 Jun 2025 15:12:27 +0000 Subject: [PATCH 2/5] fix format Signed-off-by: jiqing-feng --- torchao/dtypes/uintx/int4_cpu_layout.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index 452a1a55e9..834fe04114 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -153,7 +153,8 @@ def to(self, *args, **kwargs): if self.device.type == "xpu": from torchao.dtypes import Int4XPULayout int_data, scale, zero_point = self.get_plain() - return Int4XPUAQTTensorImpl.from_plain(int_data.to(device), scale.to(device), zero_point.to(device), _layout=Int4XPULayout()) + int_data, scale, zero_point = int_data.to(self.device), scale.to(self.device), zero_point.to(self.device) + return Int4XPUAQTTensorImpl.from_plain(int_data, scale, zero_point, _layout=Int4XPULayout()) elif not is_device(torch.device(self.device).type, device): raise ValueError( f"Int4CPUAQTTensorImpl does not support conversion from {self.device} to {device}" From eb56d7893745fef63f99d9bad2e26812b3be759a Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 24 Jul 2025 16:35:11 +0000 Subject: [PATCH 3/5] enable int4 device convert Signed-off-by: jiqing-feng --- torchao/dtypes/uintx/int4_cpu_layout.py | 18 ++++--------- torchao/dtypes/uintx/int4_xpu_layout.py | 12 ++++----- .../dtypes/uintx/tensor_core_tiled_layout.py | 15 +++++------ torchao/dtypes/utils.py | 27 +++++++++++++++++++ 4 files changed, 44 insertions(+), 28 deletions(-) diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index 46bf28aeb6..46fa6a3bd1 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -16,7 +16,7 @@ AffineQuantizedTensor, register_layout, ) -from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device +from torchao.dtypes.utils import Int4AQTTensorImpl, Layout, is_device from torchao.quantization.quant_primitives import ( ZeroPointDomain, _quantize_affine_tinygemm, @@ -39,10 +39,8 @@ class Int4CPULayout(Layout): pass -from torchao.dtypes.uintx.int4_xpu_layout import Int4XPUAQTTensorImpl - @register_layout(Int4CPULayout) -class Int4CPUAQTTensorImpl(AQTTensorImpl): +class Int4CPUAQTTensorImpl(Int4AQTTensorImpl): """TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only, used by tinygemm kernels `_weight_int4pack_mm_for_cpu` It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of @@ -150,15 +148,9 @@ def from_plain( def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs["device"] - if self.device.type == "xpu": - from torchao.dtypes import Int4XPULayout - int_data, scale, zero_point = self.get_plain() - int_data, scale, zero_point = int_data.to(self.device), scale.to(self.device), zero_point.to(self.device) - return Int4XPUAQTTensorImpl.from_plain(int_data, scale, zero_point, _layout=Int4XPULayout()) - elif not is_device(torch.device(self.device).type, device): - raise ValueError( - f"{self.__class__.__name__} does not support conversion from {self.device} to {device}" - ) + if torch.device(device).type != "cpu": + # Convert CPU tensor implementation to other devices. + return super().to(*args, **kwargs) return self.__class__( self.packed_weight.to(device), diff --git a/torchao/dtypes/uintx/int4_xpu_layout.py b/torchao/dtypes/uintx/int4_xpu_layout.py index 955a7a8610..9dce8121cd 100644 --- a/torchao/dtypes/uintx/int4_xpu_layout.py +++ b/torchao/dtypes/uintx/int4_xpu_layout.py @@ -17,7 +17,7 @@ AffineQuantizedTensor, register_layout, ) -from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device +from torchao.dtypes.utils import Int4AQTTensorImpl, Layout, is_device from torchao.quantization.quant_primitives import ZeroPointDomain from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_8, @@ -158,7 +158,7 @@ class Int4XPULayout(Layout): @register_layout(Int4XPULayout) -class Int4XPUAQTTensorImpl(AQTTensorImpl): +class Int4XPUAQTTensorImpl(Int4AQTTensorImpl): """ TensorImpl for int4 XPU layout for affine quantized tensor, this is for int4 only, used by tinygemm kernels `_weight_int4pack_mm_xpu` and `_weight_int4pack_mm_with_zeros_and_scales` (TBD) @@ -281,10 +281,10 @@ def from_plain( def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs["device"] - if not is_device(torch.device(self.device).type, device): - raise ValueError( - f"Int4XPUAQTTensorImpl does not support conversion from {self.device} to {device}" - ) + if torch.device(device).type != "xpu": + # Convert XPU tensor implementation to other devices. + return super().to(*args, **kwargs) + return self.__class__( self.packed_weight.to(device), self.scale_and_zero.to(device) if self.scale_and_zero is not None else None, diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index 591d9a9be1..f216c29337 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -17,7 +17,7 @@ AffineQuantizedTensor, register_layout, ) -from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device +from torchao.dtypes.utils import Int4AQTTensorImpl, Layout, is_device from torchao.quantization.quant_primitives import ( ZeroPointDomain, _get_reduction_params, @@ -190,7 +190,7 @@ def extra_repr(self): @register_layout(TensorCoreTiledLayout) -class TensorCoreTiledAQTTensorImpl(AQTTensorImpl): +class TensorCoreTiledAQTTensorImpl(Int4AQTTensorImpl): """TensorImpl for tensor_core_tiled layout for affine quantized tensor, this is for int4 only, used by tinygemm kernels `_weight_int4pack_mm` @@ -315,13 +315,10 @@ def quant_2d(int_data_2d): def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs["device"] - # tensor core tiled layout supports both cpu and cuda but does not support the conversion - # between these two devices, in the future we should not use the same layout for - # cpu and cuda device: https://github.com/pytorch/ao/issues/1117 - if not is_device(torch.device(self.device).type, device): - logging.warning( - f"TensorCoreTiledAQTTensorImpl does not support conversion from {self.device} to {device}" - ) + if torch.device(device).type != "cuda": + # Convert CUDA tensor implementation to other devices. + return super().to(*args, **kwargs) + return self.__class__( self.packed_weight.to(device), self.scale_and_zero.to(device), diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index a07188a18d..abbde72d99 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -133,3 +133,30 @@ def __repr__(self): data, scale, zero_point = self.get_plain() _layout = self.get_layout() return f"{self.__class__.__name__}(data={str(data)}... , scale={str(scale)}... , zero_point={str(zero_point)}... , _layout={_layout})" + + +class Int4AQTTensorImpl(AQTTensorImpl): + """ + Base class for the tensor impl for `AffineQuantizedTensor`. This is for int4 only. + + Note: This is not a user facing API, it's used by AffineQuantizedTensor to construct + the underlying implementation of a AQT based on layout. + """ + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs["device"] + if torch.device(device).type == "cuda": + from torchao.dtypes.uintx.tensor_core_tiled_layout import TensorCoreTiledLayout, TensorCoreTiledAQTTensorImpl + int_data, scale, zero_point = self.get_plain() + int_data, scale, zero_point = int_data.to(device), scale.to(device), zero_point.to(device) + return TensorCoreTiledAQTTensorImpl.from_plain(int_data, scale, zero_point, _layout=TensorCoreTiledLayout()) + elif torch.device(device).type == "xpu": + from torchao.dtypes.uintx.int4_xpu_layout import Int4XPULayout, Int4XPUAQTTensorImpl + int_data, scale, zero_point = self.get_plain() + int_data, scale, zero_point = int_data.to(device), scale.to(device), zero_point.to(device) + return Int4XPUAQTTensorImpl.from_plain(int_data, scale, zero_point, _layout=Int4XPULayout()) + elif torch.device(device).type == "cpu": + from torchao.dtypes.uintx.int4_cpu_layout import Int4CPULayout, Int4CPUAQTTensorImpl + int_data, scale, zero_point = self.get_plain() + int_data, scale, zero_point = int_data.to(device), scale.to(device), zero_point.to(device) + return Int4CPUAQTTensorImpl.from_plain(int_data, scale, zero_point, _layout=Int4CPULayout()) From 333223e2de798dca1656bf5acac000c1b9c402e4 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 24 Jul 2025 16:41:26 +0000 Subject: [PATCH 4/5] fix device extract Signed-off-by: jiqing-feng --- torchao/dtypes/uintx/int4_cpu_layout.py | 3 +-- torchao/dtypes/uintx/int4_xpu_layout.py | 3 +-- torchao/dtypes/uintx/tensor_core_tiled_layout.py | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index 46fa6a3bd1..a354748ab8 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -146,8 +146,7 @@ def from_plain( return cls(packed_weight, scale_and_zero, False, _layout) def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - device = kwargs["device"] + device = self._get_to_kwargs(*args, **kwargs)["device"] if torch.device(device).type != "cpu": # Convert CPU tensor implementation to other devices. return super().to(*args, **kwargs) diff --git a/torchao/dtypes/uintx/int4_xpu_layout.py b/torchao/dtypes/uintx/int4_xpu_layout.py index 9dce8121cd..b044090001 100644 --- a/torchao/dtypes/uintx/int4_xpu_layout.py +++ b/torchao/dtypes/uintx/int4_xpu_layout.py @@ -279,8 +279,7 @@ def from_plain( ) def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - device = kwargs["device"] + device = self._get_to_kwargs(*args, **kwargs)["device"] if torch.device(device).type != "xpu": # Convert XPU tensor implementation to other devices. return super().to(*args, **kwargs) diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index f216c29337..1bba667e33 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -313,8 +313,7 @@ def quant_2d(int_data_2d): return cls(packed_weight, scale_and_zero, False, _layout) def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - device = kwargs["device"] + device = self._get_to_kwargs(*args, **kwargs)["device"] if torch.device(device).type != "cuda": # Convert CUDA tensor implementation to other devices. return super().to(*args, **kwargs) From a9a68a3c4fc12e8366fab8dc31b3b7e1fceb6fc6 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 24 Jul 2025 16:47:12 +0000 Subject: [PATCH 5/5] fix format Signed-off-by: jiqing-feng --- torchao/dtypes/utils.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index abbde72d99..71ccaa1617 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -143,20 +143,15 @@ class Int4AQTTensorImpl(AQTTensorImpl): the underlying implementation of a AQT based on layout. """ def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - device = kwargs["device"] + device = self._get_to_kwargs(*args, **kwargs)["device"] + int_data, scale, zero_point = self.get_plain() + int_data, scale, zero_point = int_data.to(device), scale.to(device), zero_point.to(device) if torch.device(device).type == "cuda": from torchao.dtypes.uintx.tensor_core_tiled_layout import TensorCoreTiledLayout, TensorCoreTiledAQTTensorImpl - int_data, scale, zero_point = self.get_plain() - int_data, scale, zero_point = int_data.to(device), scale.to(device), zero_point.to(device) return TensorCoreTiledAQTTensorImpl.from_plain(int_data, scale, zero_point, _layout=TensorCoreTiledLayout()) elif torch.device(device).type == "xpu": from torchao.dtypes.uintx.int4_xpu_layout import Int4XPULayout, Int4XPUAQTTensorImpl - int_data, scale, zero_point = self.get_plain() - int_data, scale, zero_point = int_data.to(device), scale.to(device), zero_point.to(device) return Int4XPUAQTTensorImpl.from_plain(int_data, scale, zero_point, _layout=Int4XPULayout()) elif torch.device(device).type == "cpu": from torchao.dtypes.uintx.int4_cpu_layout import Int4CPULayout, Int4CPUAQTTensorImpl - int_data, scale, zero_point = self.get_plain() - int_data, scale, zero_point = int_data.to(device), scale.to(device), zero_point.to(device) return Int4CPUAQTTensorImpl.from_plain(int_data, scale, zero_point, _layout=Int4CPULayout())