22
22
23
23
try :
24
24
import gemlite
25
- from gemlite .core import GemLiteLinearTriton
26
25
except :
27
26
gemlite = None
28
27
@@ -51,18 +50,6 @@ def _same_metadata(
51
50
)
52
51
53
52
54
- def scale_activations_no_scaling (x ):
55
- return x , None
56
-
57
-
58
- def scale_activations_int8 (x ):
59
- x_shape = x .shape
60
- out_x = x .view (- 1 , x .shape [- 1 ])
61
- scaled_x = torch .abs (out_x ).amax (axis = 1 , keepdim = True ) / 127
62
- out_x = torch .round (out_x / scaled_x ).to (dtype = torch .int8 )
63
- return out_x .view (x_shape ), scaled_x
64
-
65
-
66
53
def get_gemlite_quant_kwargs (bit_width , group_size , dtype ):
67
54
from torchao .quantization .quant_primitives import MappingType , ZeroPointDomain
68
55
@@ -93,8 +80,6 @@ def get_gemlite_aqt_kwargs(
93
80
weight ,
94
81
group_size = 64 ,
95
82
bit_width = 4 ,
96
- packing_bitwidth = 32 ,
97
- contiguous = None ,
98
83
use_hqq = True ,
99
84
):
100
85
if gemlite is None :
@@ -106,12 +91,7 @@ def get_gemlite_aqt_kwargs(
106
91
4 ,
107
92
8 ,
108
93
], f"gemlite only works with bit_width 4,8 but got { bit_width } "
109
- assert packing_bitwidth in [
110
- 8 ,
111
- 16 ,
112
- 32 ,
113
- None ,
114
- ], f"gemlite needs packing_bitwidth in [8, 16, 32] but got { packing_bitwidth } "
94
+
115
95
assert weight .dtype in [torch .float16 , torch .bfloat16 ], (
116
96
f"gemlite only works with dtype torch.float16 or torch.bfloat16 but got { weight .dtype } "
117
97
)
@@ -127,8 +107,6 @@ def get_gemlite_aqt_kwargs(
127
107
aqt_kwargs ["_layout" ] = GemlitePackedLayout (
128
108
group_size = group_size ,
129
109
bit_width = bit_width ,
130
- packing_bitwidth = packing_bitwidth ,
131
- contiguous = contiguous ,
132
110
)
133
111
aqt_kwargs ["use_hqq" ] = use_hqq
134
112
return aqt_kwargs
@@ -138,8 +116,6 @@ def get_gemlite_aqt_kwargs(
138
116
class GemlitePackedLayout (Layout ):
139
117
group_size : Optional [int ] = 64
140
118
bit_width : int = 4
141
- packing_bitwidth : int = None
142
- contiguous : bool = None
143
119
144
120
145
121
@register_layout (GemlitePackedLayout )
@@ -216,13 +192,18 @@ def from_plain(
216
192
group_size , bit_width = _layout .group_size , _layout .bit_width
217
193
out_features , in_features = int_data .shape
218
194
219
- gemlite_linear = gemlite .helper .A16Wn (device = int_data .device ).from_weights (
220
- int_data , scale , zero_point , bit_width , group_size , bias = None
221
- )
195
+ if bit_width == 8 and group_size == in_features :
196
+ gemlite_linear = gemlite .helper .A16W8 (device = int_data .device ).from_weights (
197
+ int_data , scales = scale , bias = None
198
+ )
199
+ else :
200
+ gemlite_linear = gemlite .helper .A16Wn (device = int_data .device ).from_weights (
201
+ int_data , scale , zero_point , bit_width , group_size , bias = None
202
+ )
222
203
223
204
gemlite_kwargs = {
205
+ "in_features" : in_features ,
224
206
"out_features" : out_features ,
225
- "scaled_activations" : gemlite_linear .scaled_activations ,
226
207
"meta_args" : gemlite_linear .get_meta_args (),
227
208
}
228
209
@@ -253,20 +234,17 @@ def _apply_fn_to_data(self, fn):
253
234
254
235
def get_plain (self ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
255
236
device = self .packed_weight .device
256
- elements_per_sample = self ._layout .packing_bitwidth // self ._layout .bit_width
257
- in_features = (
258
- self .packed_weight .numel () * elements_per_sample
259
- ) // self .gemlite_kwargs ["out_features" ]
260
237
int_data = (
261
238
gemlite .bitpack .unpack_over_rows (
262
239
self .packed_weight .cuda (),
263
240
W_nbits = self ._layout .bit_width ,
264
- num_output_rows = in_features ,
241
+ num_output_rows = self . gemlite_kwargs [ "out_features" ] ,
265
242
dtype = torch .uint8 ,
266
243
)
267
244
.t ()
268
245
.contiguous ()
269
246
).to (device )
247
+
270
248
scale = self .scale .t ().contiguous ()
271
249
zero_point = self .zero_point .t ().contiguous ()
272
250
@@ -353,42 +331,21 @@ def block_size(self):
353
331
return (1 , self ._layout .group_size )
354
332
355
333
356
- # logic taken from gemlite's core.py
357
- def _matmul_type_fn (batch_size : int , bit_width : int ) -> str :
358
- if batch_size > 64 :
359
- return "GEMM"
360
- elif batch_size > 1 :
361
- return "GEMM_SPLITK"
362
- else :
363
- return gemlite .core .get_default_gemv (bit_width )
364
-
365
-
366
334
def _linear_fp_act_int4_weight_gemlite_impl (input_tensor , weight_tensor , bias = None ):
367
335
if hasattr (weight_tensor , "tensor_impl" ):
368
336
weight_impl = weight_tensor .tensor_impl
369
337
else :
370
338
weight_impl = weight_tensor
371
339
372
- batch_size = input_tensor .view (- 1 , input_tensor .shape [- 1 ]).shape [0 ]
373
- matmul_type = _matmul_type_fn (batch_size , weight_impl ._layout .bit_width )
374
-
375
- if weight_impl .gemlite_kwargs ["scaled_activations" ]:
376
- scale_activations = scale_activations_int8
377
- else :
378
- scale_activations = scale_activations_no_scaling
379
-
380
- return GemLiteLinearTriton .forward_functional (
340
+ return gemlite .core .forward_functional (
381
341
x = input_tensor ,
382
342
bias = bias ,
383
- matmul_type = matmul_type ,
384
- out_features = weight_impl .gemlite_kwargs ["out_features" ],
385
- scale_activations = scale_activations ,
386
- meta_args = weight_impl .gemlite_kwargs ["meta_args" ],
387
343
tensor_args = (
388
344
weight_impl .packed_weight ,
389
345
weight_impl .scale ,
390
346
weight_impl .zero_point ,
391
347
),
348
+ meta_args = weight_impl .gemlite_kwargs ["meta_args" ],
392
349
)
393
350
394
351
0 commit comments