25
25
except :
26
26
gemlite = None
27
27
28
-
29
28
aten = torch .ops .aten
30
29
31
30
@@ -35,7 +34,12 @@ def _same_metadata(
35
34
) -> bool :
36
35
kwargs_match = len (self .gemlite_kwargs ) == len (src .gemlite_kwargs )
37
36
for k , v in self .gemlite_kwargs .items ():
38
- if k != "scale_activations" :
37
+ if k in [
38
+ "in_features" ,
39
+ "out_features" ,
40
+ "packing_bitwidth" ,
41
+ "elements_per_sample" ,
42
+ ]:
39
43
kwargs_match = kwargs_match and (v == src .gemlite_kwargs [k ])
40
44
41
45
return (
@@ -80,6 +84,7 @@ def get_gemlite_aqt_kwargs(
80
84
weight ,
81
85
group_size = 64 ,
82
86
bit_width = 4 ,
87
+ packing_bitwidth = None ,
83
88
use_hqq = True ,
84
89
):
85
90
if gemlite is None :
@@ -99,6 +104,9 @@ def get_gemlite_aqt_kwargs(
99
104
assert group_size is None or bit_width != 8 , (
100
105
"gemlite only works with group_size=None for bit_width=8"
101
106
)
107
+ assert packing_bitwidth in [8 , 16 , 32 , None ], (
108
+ f"Invalid packing bitwidth, got { packing_bitwidth } "
109
+ )
102
110
103
111
out_features , in_features = weight .shape
104
112
group_size = in_features if group_size is None else group_size
@@ -107,15 +115,17 @@ def get_gemlite_aqt_kwargs(
107
115
aqt_kwargs ["_layout" ] = GemlitePackedLayout (
108
116
group_size = group_size ,
109
117
bit_width = bit_width ,
118
+ packing_bitwidth = packing_bitwidth ,
110
119
)
111
120
aqt_kwargs ["use_hqq" ] = use_hqq
112
121
return aqt_kwargs
113
122
114
123
115
124
@dataclass (frozen = True )
116
125
class GemlitePackedLayout (Layout ):
117
- group_size : Optional [int ] = 64
126
+ group_size : Optional [int ] = 128
118
127
bit_width : int = 4
128
+ packing_bitwidth : Optional [int ] = None
119
129
120
130
121
131
@register_layout (GemlitePackedLayout )
@@ -191,24 +201,36 @@ def from_plain(
191
201
192
202
group_size , bit_width = _layout .group_size , _layout .bit_width
193
203
out_features , in_features = int_data .shape
204
+ packing_bitwidth = _layout .packing_bitwidth
194
205
195
206
if bit_width == 8 and group_size == in_features :
196
207
gemlite_linear = gemlite .helper .A16W8 (device = int_data .device ).from_weights (
197
208
int_data , scales = scale , bias = None
198
209
)
199
210
else :
200
- gemlite_linear = gemlite .helper .A16Wn (device = int_data .device ).from_weights (
211
+ gemlite_linear = gemlite .helper .A16Wn (
212
+ device = int_data .device , packing_bitwidth = packing_bitwidth
213
+ ).from_weights (
201
214
int_data , scale , zero_point , bit_width , group_size , bias = None
202
215
)
203
216
217
+ meta_args = gemlite_linear .get_meta_args ()
204
218
gemlite_kwargs = {
205
219
"in_features" : in_features ,
206
220
"out_features" : out_features ,
207
- "meta_args" : gemlite_linear .get_meta_args (),
221
+ "packing_bitwidth" : packing_bitwidth ,
222
+ "data_contiguous" : gemlite_linear .data_contiguous ,
223
+ "elements_per_sample" : gemlite_linear .elements_per_sample ,
224
+ "W_group_mode" : gemlite_linear .W_group_mode ,
225
+ "meta_args" : meta_args ,
208
226
}
209
227
210
228
packed_weight , scale , zero_point = gemlite_linear .get_tensor_args ()
211
229
packed_weight = packed_weight .to (device )
230
+ if zero_point is None :
231
+ zero_point = torch .tensor (
232
+ [[]], device = packed_weight .device , dtype = torch .int32
233
+ )
212
234
213
235
return cls (packed_weight , scale , zero_point , gemlite_kwargs , _layout )
214
236
@@ -235,18 +257,39 @@ def _apply_fn_to_data(self, fn):
235
257
def get_plain (self ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
236
258
device = self .packed_weight .device
237
259
int_data = (
238
- gemlite .bitpack .unpack_over_rows (
239
- self .packed_weight .cuda (),
240
- W_nbits = self ._layout .bit_width ,
241
- num_output_rows = self .gemlite_kwargs ["out_features" ],
242
- dtype = torch .uint8 ,
260
+ (
261
+ gemlite .bitpack .unpack_over_rows (
262
+ self .packed_weight .cuda (),
263
+ W_nbits = self ._layout .bit_width ,
264
+ num_output_rows = self .gemlite_kwargs ["in_features" ],
265
+ dtype = torch .uint8 ,
266
+ )
243
267
)
268
+ .to (device )
244
269
.t ()
245
- .contiguous ()
246
- ).to (device )
270
+ )
271
+
272
+ # Preserve col-row major layout
273
+ if self .gemlite_kwargs ["data_contiguous" ]:
274
+ int_data = int_data .contiguous ()
275
+
276
+ # Handle FMA mode: W_q * s + z -> (W_q - z) * s
277
+ if self .gemlite_kwargs ["W_group_mode" ] == 4 :
278
+ scale_min_val = 1e-8
279
+ scale = self .scale .clone ().float ()
280
+ scale [torch .logical_and (scale >= 0 , scale .abs () <= scale_min_val )] = (
281
+ scale_min_val
282
+ )
283
+ scale [
284
+ torch .logical_and (scale < 0 , scale .abs () <= scale_min_val )
285
+ ] = - scale_min_val
286
+ zero_point = (- self .zero_point .float () / scale ).clamp_ (- 100 , 100 )
287
+ zero_point = zero_point .to (self .scale .dtype )
288
+ else :
289
+ zero_point = self .zero_point
247
290
248
291
scale = self .scale .t ().contiguous ()
249
- zero_point = self . zero_point .t ().contiguous ()
292
+ zero_point = zero_point .t ().contiguous ()
250
293
251
294
return int_data , scale , zero_point
252
295
@@ -274,30 +317,47 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
274
317
assert step == 1 , "Only step == 1 is supported in slicing right now"
275
318
276
319
if dim in [0 , 1 ]:
277
- int_data , scale , zero_point = self .get_plain ()
278
- data_len = int_data .shape [dim ]
320
+ # data in self is transposed, meaning forward() performs x @ W_deq not x @ W_deq.T
321
+ dim = 1 - dim
322
+ packed_weight = self .packed_weight
323
+ scale = self .scale
324
+ zero_point = self .zero_point
325
+
326
+ gemlite_kwargs = self .gemlite_kwargs .copy ()
327
+ orig_shape = [
328
+ gemlite_kwargs ["in_features" ],
329
+ gemlite_kwargs ["out_features" ],
330
+ ]
331
+ elements_per_sample = gemlite_kwargs ["elements_per_sample" ]
332
+ data_len = orig_shape [dim ]
279
333
scale_len = scale .shape [dim ]
280
334
ratio = data_len / scale_len
281
335
start_scale = int (start / ratio )
282
336
end_scale = int (end / ratio )
283
337
284
- int_data = aten .slice .Tensor (int_data , dim , start , end , step )
338
+ # For packing only the K dimension. This should be flipped for N-dim packing.
339
+ div = elements_per_sample if dim == 0 else 1
340
+ packed_weight = aten .slice .Tensor (
341
+ packed_weight , dim , start // div , end // div , step
342
+ )
343
+
344
+ # Update in_features/out_features
345
+ gemlite_kwargs ["in_features" ] = (
346
+ packed_weight .shape [0 ] * elements_per_sample
347
+ )
348
+ gemlite_kwargs ["out_features" ] = packed_weight .shape [1 ]
349
+
285
350
scale = aten .slice .Tensor (scale , dim , start_scale , end_scale , step )
286
351
if zero_point is not None and zero_point .numel () > 0 :
287
352
zero_point = aten .slice .Tensor (
288
353
zero_point , dim , start_scale , end_scale , step
289
354
)
290
355
else :
291
356
zero_point = None
292
- # this is to handle padding
293
- int_data , scale , zero_point = self ._layout .post_process (
294
- int_data , scale , zero_point , self .block_size
295
- )
296
-
297
- sliced = self .from_plain (
298
- int_data , scale , zero_point , self ._layout
299
- ) # Will be transposed again
300
357
358
+ sliced = GemliteAQTTensorImpl (
359
+ packed_weight , scale , zero_point , gemlite_kwargs , self ._layout
360
+ )
301
361
return return_and_correct_aliasing (func , args , kwargs , sliced )
302
362
303
363
else :
@@ -308,10 +368,24 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
308
368
elif func is aten .copy_ .default :
309
369
self = args [0 ]
310
370
src = args [1 ]
371
+
372
+ # Handle zero_point = None with symmetric quant
373
+ if self .zero_point is None :
374
+ self .zero_point = torch .tensor (
375
+ [[]], device = self .packed_weight .device , dtype = torch .int32
376
+ )
377
+
378
+ if src .zero_point is None :
379
+ src .zero_point = torch .tensor (
380
+ [[]], device = src .packed_weight .device , dtype = torch .int32
381
+ )
382
+
311
383
if _same_metadata (self , src ):
312
384
self_tensors = self .__tensor_flatten__ ()[0 ]
313
385
for tensor_name in self_tensors :
314
386
getattr (self , tensor_name ).copy_ (getattr (src , tensor_name ))
387
+ for key in self .gemlite_kwargs :
388
+ self .gemlite_kwargs [key ] = src .gemlite_kwargs [key ]
315
389
return
316
390
raise ValueError (
317
391
f"Not supported args for copy_ due to metadata mistach: { args [0 ], args [1 ]} "
0 commit comments