@@ -227,31 +227,42 @@ def _process_quantization(
227
227
perm = torch .argsort (g_idx )
228
228
x = safe_permute (x , perm , dim = 1 )
229
229
230
- # TODO: experiment with vectorizing for loop for performance
231
- end = 0
232
- for index , group_count in enumerate (group_sizes ):
233
- sc = scale [:, index ].view (- 1 , 1 )
234
- zp = zero_point [:, index ].view (- 1 , 1 ) if zero_point is not None else None
235
-
236
- start = end
237
- end = start + group_count
238
- if do_quantize :
239
- output [:, start :end ] = _quantize (
240
- x = x [:, start :end ],
241
- scale = sc ,
242
- zero_point = zp ,
243
- q_min = q_min ,
244
- q_max = q_max ,
245
- args = args ,
246
- dtype = dtype ,
247
- global_scale = global_scale ,
248
- )
230
+ x = torch .reshape (
231
+ x ,
232
+ (
233
+ x .shape [0 ],
234
+ ceil (x .shape [1 ] / group_size ),
235
+ group_size ,
236
+ ),
237
+ )
249
238
250
- if do_dequantize :
251
- input = output [:, start :end ] if do_quantize else x [:, start :end ]
252
- output [:, start :end ] = _dequantize (
253
- x_q = input , scale = sc , zero_point = zp , global_scale = global_scale
254
- )
239
+ if do_quantize :
240
+ output = _quantize (
241
+ x = x ,
242
+ scale = scale .unsqueeze (- 1 ),
243
+ zero_point = zero_point .unsqueeze (- 1 ) if zero_point is not None else None ,
244
+ dtype = dtype ,
245
+ global_scale = global_scale ,
246
+ q_min = q_min ,
247
+ q_max = q_max ,
248
+ args = args ,
249
+ )
250
+
251
+ if do_dequantize :
252
+ input = output if do_quantize else x
253
+ output = _dequantize (
254
+ x_q = input ,
255
+ scale = scale .unsqueeze (- 1 ),
256
+ zero_point = zero_point .unsqueeze (- 1 ) if zero_point is not None else None ,
257
+ global_scale = global_scale ,
258
+ )
259
+
260
+ output = torch .reshape (
261
+ output ,
262
+ (output .shape [0 ], output .shape [1 ] * output .shape [2 ]),
263
+ )
264
+
265
+ output = output .to (output_dtype )
255
266
256
267
if not is_column_order :
257
268
output = safe_permute (output , torch .argsort (perm ), dim = 1 )
0 commit comments