Skip to content

Commit ab3792e

Browse files
authored
Allow different dtype for scales_and_zeros (#1923)
Allow different dtype for scales_and_zeros (#1923) Summary: Pull Request resolved: #1923 D59410096 tried to allow different dtype for scales and zeros but in https://www.internalfb.com/code/fbsource/fbcode/pytorch/ao/torchao/dtypes/uintx/tensor_core_tiled_layout.py?lines=262 where the `pack_tinygemm_scales_and_zeros` is getting called, there is no dtype input. As a result, if the dtype of scales and zeros are not `torch.bfloat16`, it will result in an error in `guard_dtype_size`. This diff is to set the dtype as the same as scales and zeros to avoid this issue. Reviewed By: andrewor14 Differential Revision: D71079504
1 parent 386e219 commit ab3792e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchao/dtypes/uintx/tensor_core_tiled_layout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def from_plain(
264264
zero_point = zero_point.reshape(int_data.shape[0], -1)
265265
from torchao.quantization.utils import pack_tinygemm_scales_and_zeros
266266

267-
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
267+
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype)
268268
return cls(packed_weight, scale_and_zero, False, _layout)
269269

270270
def to(self, *args, **kwargs):

0 commit comments

Comments
 (0)