Skip to content

Commit 4413f75

Browse files
committed
Autoquant
Summary: currently issue where for multiple linear layers, get very slow dynamic quant results on layer linear layers. unclear why. Test Plan: python test/test.py -k "autoquant" <class 'torchao.quantization.autoquant.DefaultLinear'> (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 187.4432 0 AUTOTUNE addmm(65536x3840, 65536x1280, 1280x3840) bias_addmm 2.9764 ms 100.0% triton_mm_1 3.6858 ms 80.8% triton_mm_2 3.7502 ms 79.4% addmm 3.7887 ms 78.6% triton_mm_3 4.1547 ms 71.6% triton_mm_4 4.2022 ms 70.8% triton_mm_0 4.7970 ms 62.0% triton_mm_8 4.9596 ms 60.0% triton_mm_7 5.4343 ms 54.8% triton_mm_10 6.9352 ms 42.9% SingleProcess AUTOTUNE takes 5.6320 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f98800eb760> f(*args, **kwargs) 3.08 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.autoquant.DefaultLinear'> 3.07677136734128 1311.548416 0 <class 'torchao.quantization.subclass.Int8WeightOnlyQuantizedLinearWeight'> (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 1311.548416 0 AUTOTUNE mixed_mm(65536x1280, 1280x3840) fallback_mixed_mm 2.5089 ms 100.0% triton_mm_13 6.4153 ms 39.1% triton_mm_14 6.6832 ms 37.5% triton_mm_12 7.0896 ms 35.4% triton_mm_16 7.5022 ms 33.4% triton_mm_15 7.8426 ms 32.0% triton_mm_19 9.5269 ms 26.3% triton_mm_20 11.2033 ms 22.4% triton_mm_17 13.1675 ms 19.1% triton_mm_18 13.8004 ms 18.2% SingleProcess AUTOTUNE takes 2.4977 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f986ff12050> f(*args, **kwargs) 3.68 ms 1 measurement, 20 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7f986ff27b80> f(*args, **kwargs) 3.10 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.subclass.Int8WeightOnlyQuantizedLinearWeight'> 3.6846738075837493 3.1023880932480097 2144.447488 25 <class 'torchao.quantization.subclass.Int8DynamicallyQuantizedLinearWeight'> (torch.Size([65536, 1280]), torch.Size([3840, 1280]), torch.Size([3840])) 2144.447488 25 AUTOTUNE int_mm(65536x1280, 1280x3840, 65536x3840) triton_mm_43 2.0319 ms 100.0% triton_mm_35 2.8135 ms 72.2% triton_mm_42 3.1552 ms 64.4% triton_mm_36 3.1754 ms 64.0% triton_mm_44 3.3460 ms 60.7% triton_mm_41 3.4036 ms 59.7% triton_mm_37 3.5030 ms 58.0% triton_mm_34 3.6553 ms 55.6% triton_mm_38 3.9232 ms 51.8% triton_mm_40 9.1934 ms 22.1% SingleProcess AUTOTUNE takes 8.1948 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f9892843f40> f(*args, **kwargs) 3.13 ms 1 measurement, 20 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7f986cfd33a0> f(*args, **kwargs) 2.21 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.subclass.Int8DynamicallyQuantizedLinearWeight'> 3.1286065466701984 2.210085652768612 2144.447488 22 <class 'torchao.quantization.autoquant.DefaultLinear'> (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2144.447488 22 AUTOTUNE addmm(65536x1280, 65536x3840, 3840x1280) bias_addmm 2.7966 ms 100.0% addmm 3.0447 ms 91.9% triton_mm_57 3.5612 ms 78.5% triton_mm_58 3.6919 ms 75.7% triton_mm_59 4.1908 ms 66.7% triton_mm_60 4.2350 ms 66.0% triton_mm_56 4.7210 ms 59.2% triton_mm_64 4.9001 ms 57.1% triton_mm_63 5.5218 ms 50.6% triton_mm_66 7.1417 ms 39.2% SingleProcess AUTOTUNE takes 6.3734 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f9888dd2b30> f(*args, **kwargs) 3.33 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.autoquant.DefaultLinear'> 3.329739556647837 2228.913664 39 <class 'torchao.quantization.subclass.Int8WeightOnlyQuantizedLinearWeight'> (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2228.913664 39 AUTOTUNE mixed_mm(65536x3840, 3840x1280) fallback_mixed_mm 2.3987 ms 100.0% triton_mm_70 6.9153 ms 34.7% triton_mm_72 7.1634 ms 33.5% triton_mm_69 7.3164 ms 32.8% triton_mm_68 7.5070 ms 32.0% triton_mm_71 7.5631 ms 31.7% triton_mm_76 10.7759 ms 22.3% triton_mm_75 11.0692 ms 21.7% triton_mm_73 12.8898 ms 18.6% triton_mm_77 13.3715 ms 17.9% SingleProcess AUTOTUNE takes 6.2342 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f9880133fd0> f(*args, **kwargs) 3.48 ms 1 measurement, 20 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7f988175b610> f(*args, **kwargs) 3.22 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.subclass.Int8WeightOnlyQuantizedLinearWeight'> 3.4762858413159847 3.2240213360637426 2228.913664 38 <class 'torchao.quantization.subclass.Int8DynamicallyQuantizedLinearWeight'> (torch.Size([65536, 3840]), torch.Size([1280, 3840]), torch.Size([1280])) 2228.913664 38 AUTOTUNE int_mm(65536x3840, 3840x1280, 65536x1280) triton_mm_99 1.4307 ms 100.0% triton_mm_100 1.9041 ms 75.1% triton_mm_91 2.6079 ms 54.9% triton_mm_98 2.6363 ms 54.3% triton_mm_92 2.6691 ms 53.6% triton_mm_93 3.0178 ms 47.4% triton_mm_97 3.0233 ms 47.3% triton_mm_94 3.1872 ms 44.9% triton_mm_90 3.6072 ms 39.7% triton_mm_96 8.4695 ms 16.9% SingleProcess AUTOTUNE takes 8.1095 seconds <torch.utils.benchmark.utils.common.Measurement object at 0x7f9881782f80> f(*args, **kwargs) 145.38 ms 1 measurement, 20 runs , 1 thread <torch.utils.benchmark.utils.common.Measurement object at 0x7f9892843f70> f(*args, **kwargs) 143.98 ms 1 measurement, 20 runs , 1 thread <class 'torchao.quantization.subclass.Int8DynamicallyQuantizedLinearWeight'> 145.37517526187003 143.98446583654732 2230.364672 79 Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 67787a1 Pull Request resolved: #38
1 parent 969038f commit 4413f75

File tree

6 files changed

+344
-2
lines changed

6 files changed

+344
-2
lines changed

test/test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
change_linear_weights_to_int8_woqtensors,
2525
change_linear_weights_to_int4_woqtensors,
2626
_replace_with_custom_fn_if_matches_filter,
27+
do_autoquant
2728
)
2829
from torchao.quantization.quant_primitives import (
2930
dequantize_per_channel,
@@ -1195,6 +1196,21 @@ def test_on_dummy_distilbert(self):
11951196
print("sqnr_pt_quant", sqnr_pt_quant)
11961197
self.assertTrue(sqnr_sq >= 8.0)
11971198

1199+
class TestAutoQuant(unittest.TestCase):
1200+
def test_auto_quant(self):
1201+
model = torch.nn.Sequential(
1202+
# torch.nn.Linear(5120,1280),
1203+
# torch.nn.ReLU(),
1204+
torch.nn.Linear(1280,3840),
1205+
torch.nn.ReLU(),
1206+
torch.nn.Linear(3840,1280),
1207+
).to("cuda").to(torch.bfloat16)
1208+
example_input = torch.randn(65536,1280, device="cuda", dtype=torch.bfloat16)
1209+
torch._inductor.config.epilogue_fusion = False
1210+
torch._inductor.config.use_mixed_mm = True
1211+
torch._inductor.config.force_fuse_int_mm_with_mul = True
1212+
torch._inductor.config.coordinate_descent_tuning = True
1213+
do_autoquant(model, example_input)
11981214

11991215
if __name__ == "__main__":
12001216
unittest.main()

test/test_autoquant.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# mypy: ignore-errors
8+
import copy
9+
import unittest
10+
11+
import torch
12+
import torch.nn as nn
13+
from torchao.quantization.quant_api import (
14+
change_linears_to_autoquantizable,
15+
change_autoquantizable_to_quantized
16+
)
17+
from torchao.quantization.autoquant import do_autoquant
18+
from torch._dynamo import config
19+
torch.manual_seed(0)
20+
config.cache_size_limit = 100
21+
22+
23+
class AutoquantTests(unittest.TestCase):
24+
def test_autoquant_e2e(self):
25+
model = torch.nn.Sequential(torch.nn.Linear(32,32), torch.nn.ReLU(), torch.nn.Linear(32,32)).cuda().to(torch.bfloat16)
26+
print(model, model[0].weight)
27+
example_input = torch.randn((1,64,32), dtype=torch.bfloat16, device=torch.cuda)
28+
out=model(example_input)
29+
print(out.sum())
30+
do_autoquant(model)
31+
print(model, model[0].weight)
32+
print(model(example_input).sum())
33+
34+
if __name__ == "__main__":
35+
unittest.main()

torchao/quantization/autoquant.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import torch
2+
3+
from .subclass import ( # noqa
4+
Int8DynamicallyQuantizedLinearWeight,
5+
Int8WeightOnlyQuantizedLinearWeight,
6+
QuantizedLinearWeightBase,
7+
)
8+
from torch.utils._python_dispatch import return_and_correct_aliasing
9+
from .utils import benchmark
10+
11+
aten = torch.ops.aten
12+
13+
AUTOQUANT_CACHE = {}
14+
15+
def check_cache(shape, cls):
16+
if shape in AUTOQUANT_CACHE:
17+
return AUTOQUANT_CACHE[shape].get(cls, None)
18+
else:
19+
return None
20+
21+
def update_cache(shape, cls, res):
22+
if not shape in AUTOQUANT_CACHE:
23+
AUTOQUANT_CACHE[shape] = {}
24+
AUTOQUANT_CACHE[shape][cls] = res
25+
26+
class AutoQuantizableLinearWeight(torch.Tensor):
27+
"""
28+
when run, finds best type of quantization for this tensor and swaps itself with that
29+
"""
30+
@staticmethod
31+
def __new__(cls, weight, qtensor_class_list, *args, **kwargs):
32+
kwargs["device"] = weight.device
33+
kwargs["layout"] = (
34+
kwargs.get("layout") if kwargs.get("layout", False) else weight.layout
35+
)
36+
kwargs["dtype"] = (
37+
kwargs.get("dtype") if kwargs.get("dtype", False) else weight.dtype
38+
)
39+
kwargs["requires_grad"] = False
40+
shape = kwargs.pop("shape", weight.shape)
41+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
42+
43+
def __init__(self, weight, qtensor_class_list, *args, **kwargs):
44+
self.weight = weight
45+
self.qtensor_class_list = qtensor_class_list
46+
self.cache_shape = None
47+
48+
def __repr__(self):
49+
return (
50+
f"{self.__class__.__name__}(data={self.weight}, shape={self.shape}, "
51+
f"device={self.device}, dtype={self.dtype}, qtensor_class_list={self.qtensor_class_list})"
52+
)
53+
54+
@staticmethod
55+
def tune_autoquant(act_mat, w_autoquant, bias):
56+
orig_shape = act_mat.shape
57+
act_mat = act_mat.reshape(-1, act_mat.shape[-1])
58+
cache_shape = (act_mat.shape, w_autoquant.shape, None if bias is None else bias.shape)
59+
w_autoquant.cache_shape = cache_shape
60+
for cur_cls in w_autoquant.qtensor_class_list:
61+
if check_cache(cache_shape, cur_cls) is None:
62+
with torch.no_grad():
63+
print(cur_cls, cache_shape)
64+
print(torch.cuda.max_memory_allocated()/1e6, torch.cuda.memory_usage())
65+
res = cur_cls._autoquant_test(act_mat.clone(), w_autoquant.weight.clone(), None if bias is None else bias.clone())
66+
update_cache(cache_shape, cur_cls, res)
67+
print(torch.cuda.max_memory_allocated()/1e6, torch.cuda.memory_usage())
68+
y = torch.mm(act_mat, w_autoquant.weight.t())
69+
y = y.reshape(*orig_shape[:-1], y.shape[-1])
70+
if bias is not None:
71+
y += bias
72+
return y
73+
74+
def to_quantized(self):
75+
if self.cache_shape is None or self.cache_shape not in AUTOQUANT_CACHE:
76+
raise RuntimeError("must run module normally to find best quantization option")
77+
best_time = torch.inf
78+
best_cls = None
79+
for cur_cls in self.qtensor_class_list:
80+
cls_res = AUTOQUANT_CACHE[self.cache_shape].get(cur_cls, torch.inf)
81+
if best_time >= cls_res:
82+
best_time = cls_res
83+
best_cls = cur_cls
84+
# need to handle random cls args/kwargs?
85+
self = best_cls.from_float(self.weight)
86+
return self
87+
88+
def _apply_fn_to_data(self, fn):
89+
return self.__class__(
90+
fn(self.weight), self.qtensor_class_list, dtype=self.dtype
91+
)
92+
93+
def __tensor_flatten__(self):
94+
return ["weight"], [self.qtensor_class_list, self.dtype, self.shape]
95+
96+
@classmethod
97+
def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None):
98+
weight = tensor_data_dict["weight"]
99+
qtensor_class_list, dtype, shape = tensor_attributes[0]
100+
return cls(weight, qtensor_class_list, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride)
101+
102+
@classmethod
103+
def from_float(cls, weight, qtensor_class_list):
104+
return cls(weight, qtensor_class_list)
105+
106+
@classmethod
107+
def __torch_function__(cls, func, types, args=(), kwargs=None):
108+
kwargs = {} if kwargs is None else kwargs
109+
110+
if func is torch.nn.functional.linear:
111+
mat1, w_autoquant, bias = (
112+
args[0],
113+
args[1],
114+
args[2] if len(args)>2 else None
115+
)
116+
return cls.tune_autoquant(mat1, w_autoquant, bias)
117+
118+
try:
119+
with torch._C.DisableTorchFunctionSubclass():
120+
return func(*args, **kwargs)
121+
except:
122+
print(f"ERR: subclass doesn't implement {func}")
123+
124+
@classmethod
125+
def __torch_dispatch__(cls, func, types, args, kwargs):
126+
if func is aten.detach.default:
127+
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))
128+
129+
130+
class DefaultLinear(torch.Tensor):
131+
"""
132+
An class to be used in concert with AutoQuantizableLinearWeight to provide a
133+
default/non-quantized option. Only implements the bare minimum needed to work with the
134+
AutoQuantizableLinearWeight class using the same interfaces that would normally be
135+
used by QTensor subclasses but for a default linear op instead.
136+
"""
137+
def __init__(self):
138+
super().__init__()
139+
140+
@classmethod
141+
def _autoquant_test(cls, act_mat, weight, bias):
142+
w_qtensor = cls.from_float(weight)
143+
q_c_op = torch.compile(cls._quantized_op, mode="max-autotune")
144+
with torch.no_grad():
145+
res=benchmark(q_c_op, act_mat, w_qtensor, bias)
146+
print(cls, res)
147+
return res
148+
149+
@staticmethod
150+
def _quantized_op(act_mat, w_qtensor, bias):
151+
return torch.nn.functional.linear(act_mat, w_qtensor, bias)
152+
153+
@classmethod
154+
def from_float(cls, weight):
155+
return weight
156+
157+
DEFAULT_CLASS_LIST = [
158+
DefaultLinear,
159+
Int8WeightOnlyQuantizedLinearWeight,
160+
Int8DynamicallyQuantizedLinearWeight,
161+
]
162+
163+
if False:
164+
# def _get_to_kwargs(self, *args, **kwargs):
165+
# device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
166+
# device = self.device if device is None else device
167+
# dtype = self.dtype if dtype is None else dtype
168+
# memory_format = (
169+
# memory_format if memory_format is not None else torch.preserve_format
170+
# )
171+
# kwargs = {
172+
# "device": device,
173+
# "dtype": dtype,
174+
# "memory_format": memory_format,
175+
# }
176+
# return kwargs
177+
178+
# def to(self, *args, **kwargs):
179+
# kwargs = self._get_to_kwargs(*args, **kwargs)
180+
# return self.__class__(
181+
# self.int_data.to(kwargs["device"]),
182+
# self.q_scales.to(kwargs["device"]),
183+
# self.transposed,
184+
# self.shape,
185+
# **kwargs,
186+
# )
187+
188+
# def _apply_fn_to_data(self, fn):
189+
# return self.__class__(
190+
# fn(self.int_data), fn(self.q_scales), self.transposed, self.shape, dtype=self.dtype
191+
# )
192+
193+
# def _change_shape(self, shape):
194+
# return self.__class__(
195+
# self.int_data, self.q_scales, self.transposed, shape, dtype=self.dtype
196+
# )
197+
198+
# def half(self):
199+
# return self.to(torch.float16)
200+
pass

torchao/quantization/quant_api.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .weight_only import (
2929
WeightOnlyInt8QuantLinear,
3030
)
31+
from .autoquant import AutoQuantizableLinearWeight, DEFAULT_CLASS_LIST
3132

3233
__all__ = [
3334
"apply_weight_only_int8_quant",
@@ -95,9 +96,11 @@ def apply_dynamic_quant(model, filter_fn=None):
9596

9697

9798
def _get_subclass_inserter(cls, **kwargs):
99+
method = kwargs.pop("method", "from_float")
98100
def insert_subclass(lin):
99101
lin.weight = torch.nn.Parameter(
100-
cls.from_float(lin.weight, **kwargs), requires_grad=False
102+
# cls.from_float(...)
103+
getattr(cls, method)(lin.weight, **kwargs), requires_grad=False
101104
)
102105
return lin
103106

@@ -153,6 +156,39 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs):
153156
filter_fn,
154157
)
155158

159+
160+
def change_linears_to_autoquantizable(model, **kwargs):
161+
filter_fn = kwargs.pop("filter_fn", _is_linear)
162+
_replace_with_custom_fn_if_matches_filter(
163+
model,
164+
_get_subclass_inserter(AutoQuantizableLinearWeight, **kwargs),
165+
filter_fn if filter_fn is not None else _is_linear,
166+
)
167+
168+
def change_autoquantizable_to_quantized(model, **kwargs):
169+
filter_fn = kwargs.pop(
170+
"filter_fn",
171+
lambda mod, *args:
172+
_is_linear(mod, *args) and
173+
isinstance(mod.weight, AutoQuantizableLinearWeight)
174+
)
175+
_replace_with_custom_fn_if_matches_filter(
176+
model,
177+
_get_subclass_inserter(
178+
AutoQuantizableLinearWeight, method="to_quantized", **kwargs
179+
),
180+
filter_fn,
181+
)
182+
183+
@torch.no_grad()
184+
def do_autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=_is_linear):
185+
change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list)
186+
if not isinstance(example_input, (tuple, list)):
187+
example_input = [example_input]
188+
model(*example_input)
189+
change_autoquantizable_to_quantized(model)
190+
return model
191+
156192
def swap_conv2d_1x1_to_linear(model, filter_fn=None):
157193
"""
158194
Changes all conv2d 1x1 modules to equivalent linear modules so that they can then be quantized.

torchao/quantization/subclass.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
groupwise_affine_quantize_tensor,
1414
quant_int8_dynamic_per_token_linear,
1515
unpack_tinygemm_scales_and_zeros,
16+
quantize_activation_per_token_absmax,
17+
quant_int8_per_token_matmul,
18+
safe_int_mm,
1619
)
17-
from .utils import find_multiple
20+
from .utils import find_multiple, benchmark
1821
import warnings
1922

2023

@@ -197,6 +200,25 @@ def _quantized_op(act_mat, w_qtensor, bias):
197200
act_mat, w_qtensor.int_data, w_qtensor.q_scales, bias, act_mat.dtype
198201
)
199202

203+
@classmethod
204+
def _autoquant_test(cls, act_mat, weight, bias):
205+
w_qtensor = cls.from_float(weight)
206+
q_c_op = torch.compile(cls._quantized_op, mode="max-autotune")
207+
with torch.no_grad():
208+
res=benchmark(q_c_op, act_mat, w_qtensor, bias)
209+
x_vals_int8, x_scales = quantize_activation_per_token_absmax(
210+
act_mat.reshape(-1, act_mat.shape[-1])
211+
)
212+
quantized_matmul = (
213+
lambda x_vals_int8, x_scales, w_vals_int8:
214+
safe_int_mm(x_vals_int8, w_vals_int8) * x_scales
215+
)
216+
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune")
217+
with torch.no_grad():
218+
res2=benchmark(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data)
219+
print(cls, res, res2)
220+
return (res+res2)/2
221+
200222
def dequantize(self, dtype=None):
201223
"""
202224
Obtain the dequantized version of the quantized tensor subclass
@@ -292,6 +314,26 @@ def _quantized_op(act_mat, w_qtensor, bias):
292314
y += bias
293315
return y.to(orig_dtype)
294316

317+
@classmethod
318+
def _autoquant_test(cls, act_mat, weight, bias):
319+
w_qtensor = cls.from_float(weight)
320+
q_c_op = torch.compile(cls._quantized_op, mode="max-autotune")
321+
with torch.no_grad():
322+
res=benchmark(q_c_op, act_mat, w_qtensor, bias)
323+
324+
quantized_matmul = (
325+
lambda act_mat, w_vals_int8:
326+
torch.mm(act_mat, w_vals_int8.to(act_mat.dtype))
327+
)
328+
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune")
329+
with torch.no_grad():
330+
res2=benchmark(
331+
q_c_matmul,
332+
act_mat.reshape(-1, act_mat.shape[-1]),
333+
w_qtensor.int_data)
334+
print(cls, res, res2
335+
)
336+
return (res+res2)/2
295337

296338
class Int4WeightOnlyQuantizedLinearWeight(QuantizedLinearWeightBase):
297339
"""

0 commit comments

Comments
 (0)