@@ -250,3 +250,165 @@ def test_actorder_reload_match(actorder, tmp_path, mock_per_group_calibration):
250
250
assert torch .equal (fake_quant_dummy , reconstructed_dense ["dummy.weight" ])
251
251
252
252
shutil .rmtree (tmp_path )
253
+
254
+
255
+ @pytest .mark .parametrize (
256
+ "num_bits,values,expected_values" ,
257
+ [
258
+ (
259
+ 4 ,
260
+ torch .tensor ([[1 ]]),
261
+ torch .tensor ([[9 ]], dtype = torch .int32 ),
262
+ ),
263
+ (
264
+ 8 ,
265
+ torch .tensor ([[1 ]]),
266
+ torch .tensor ([[129 ]], dtype = torch .int32 ),
267
+ ),
268
+ # 0000 0000 0000 0000 1100 1011 1010 1001
269
+ (4 , torch .tensor ([[1 , 2 , 3 , 4 ]]), torch .tensor ([[52137 ]], dtype = torch .int32 )),
270
+ # 0111 0110 0101 0100 0011 0010 0001 0000
271
+ (
272
+ 4 ,
273
+ torch .tensor ([[- 8 , - 7 , - 6 , - 5 , - 4 , - 3 , - 2 , - 1 ]]),
274
+ torch .tensor ([[1985229328 ]], dtype = torch .int32 ),
275
+ ),
276
+ # 10000100 10000011 10000010 10000001
277
+ (
278
+ 8 ,
279
+ torch .tensor ([[1 , 2 , 3 , 4 ]]),
280
+ torch .tensor ([[- 2071756159 ]], dtype = torch .int32 ),
281
+ ),
282
+ # 00000011 00000010 00000001 00000000
283
+ (
284
+ 8 ,
285
+ torch .tensor ([[- 128 , - 127 , - 126 , - 125 ]]),
286
+ torch .tensor ([[50462976 ]], dtype = torch .int32 ),
287
+ ),
288
+ (
289
+ 4 ,
290
+ torch .tensor ([[- 8 , - 7 , - 6 , - 5 , - 4 , - 3 , - 2 , - 1 , 1 , 2 , 3 , 4 ]]),
291
+ torch .tensor ([[1985229328 , 52137 ]], dtype = torch .int32 ),
292
+ ),
293
+ (
294
+ 4 ,
295
+ torch .tensor (
296
+ [
297
+ [- 8 , - 7 , - 6 , - 5 , - 4 , - 3 , - 2 , - 1 , 1 , 2 , 3 , 4 , - 8 , - 8 , - 8 , - 8 ],
298
+ [1 , 2 , 3 , 4 , - 8 , - 8 , - 8 , - 8 , - 8 , - 7 , - 6 , - 5 , - 4 , - 3 , - 2 , - 1 ],
299
+ ]
300
+ ),
301
+ torch .tensor ([[1985229328 , 52137 ], [52137 , 1985229328 ]], dtype = torch .int32 ),
302
+ ),
303
+ (
304
+ 8 ,
305
+ torch .tensor (
306
+ [
307
+ [1 , 2 , 3 , 4 ],
308
+ [- 128 , - 127 , - 126 , - 125 ],
309
+ ]
310
+ ),
311
+ torch .tensor ([[- 2071756159 ], [50462976 ]], dtype = torch .int32 ),
312
+ ),
313
+ (
314
+ 8 ,
315
+ torch .tensor (
316
+ [
317
+ [1 , 2 , 3 , 4 , - 128 , - 127 , - 126 , - 125 ],
318
+ [- 128 , - 127 , - 126 , - 125 , 1 , 2 , 3 , 4 ],
319
+ ]
320
+ ),
321
+ torch .tensor (
322
+ [[- 2071756159 , 50462976 ], [50462976 , - 2071756159 ]], dtype = torch .int32
323
+ ),
324
+ ),
325
+ ],
326
+ )
327
+ def test_pack_to_int32 (num_bits , values , expected_values ):
328
+ values = values .to (torch .int8 )
329
+ packed_values = pack_to_int32 (values , num_bits )
330
+ assert torch .equal (packed_values , expected_values )
331
+ assert packed_values .dtype == expected_values .dtype
332
+
333
+
334
+ @pytest .mark .parametrize (
335
+ "num_bits,values,expected_tensor" ,
336
+ [
337
+ (
338
+ 4 ,
339
+ torch .tensor ([[9 ]], dtype = torch .int32 ),
340
+ torch .tensor ([[1 ]], dtype = torch .int8 ),
341
+ ),
342
+ (
343
+ 8 ,
344
+ torch .tensor ([[129 ]], dtype = torch .int32 ),
345
+ torch .tensor ([[1 ]], dtype = torch .int8 ),
346
+ ),
347
+ (
348
+ 4 ,
349
+ torch .tensor ([[52137 ]], dtype = torch .int32 ),
350
+ torch .tensor ([[1 , 2 , 3 , 4 ]], dtype = torch .int8 ),
351
+ ),
352
+ (
353
+ 4 ,
354
+ torch .tensor ([[1985229328 ]], dtype = torch .int32 ),
355
+ torch .tensor ([[- 8 , - 7 , - 6 , - 5 , - 4 , - 3 , - 2 , - 1 ]], dtype = torch .int8 ),
356
+ ),
357
+ (
358
+ 8 ,
359
+ torch .tensor ([[- 2071756159 ]], dtype = torch .int32 ),
360
+ torch .tensor ([[1 , 2 , 3 , 4 ]], dtype = torch .int8 ),
361
+ ),
362
+ (
363
+ 8 ,
364
+ torch .tensor ([[50462976 ]], dtype = torch .int32 ),
365
+ torch .tensor ([[- 128 , - 127 , - 126 , - 125 ]], dtype = torch .int8 ),
366
+ ),
367
+ (
368
+ 4 ,
369
+ torch .tensor ([[1985229328 , 52137 ]], dtype = torch .int32 ),
370
+ torch .tensor (
371
+ [[- 8 , - 7 , - 6 , - 5 , - 4 , - 3 , - 2 , - 1 , 1 , 2 , 3 , 4 ]], dtype = torch .int8
372
+ ),
373
+ ),
374
+ (
375
+ 4 ,
376
+ torch .tensor ([[1985229328 , 52137 ], [52137 , 1985229328 ]], dtype = torch .int32 ),
377
+ torch .tensor (
378
+ [
379
+ [- 8 , - 7 , - 6 , - 5 , - 4 , - 3 , - 2 , - 1 , 1 , 2 , 3 , 4 , - 8 , - 8 , - 8 , - 8 ],
380
+ [1 , 2 , 3 , 4 , - 8 , - 8 , - 8 , - 8 , - 8 , - 7 , - 6 , - 5 , - 4 , - 3 , - 2 , - 1 ],
381
+ ],
382
+ dtype = torch .int8 ,
383
+ ),
384
+ ),
385
+ (
386
+ 8 ,
387
+ torch .tensor ([[- 2071756159 ], [50462976 ]], dtype = torch .int32 ),
388
+ torch .tensor (
389
+ [
390
+ [1 , 2 , 3 , 4 ],
391
+ [- 128 , - 127 , - 126 , - 125 ],
392
+ ],
393
+ dtype = torch .int8 ,
394
+ ),
395
+ ),
396
+ (
397
+ 8 ,
398
+ torch .tensor (
399
+ [[- 2071756159 , 50462976 ], [50462976 , - 2071756159 ]], dtype = torch .int32
400
+ ),
401
+ torch .tensor (
402
+ [
403
+ [1 , 2 , 3 , 4 , - 128 , - 127 , - 126 , - 125 ],
404
+ [- 128 , - 127 , - 126 , - 125 , 1 , 2 , 3 , 4 ],
405
+ ],
406
+ dtype = torch .int8 ,
407
+ ),
408
+ ),
409
+ ],
410
+ )
411
+ def test_unpack_from_int32 (num_bits , values , expected_tensor ):
412
+ unpacked_tensor = unpack_from_int32 (values , num_bits , expected_tensor .shape )
413
+ assert torch .equal (unpacked_tensor , unpacked_tensor )
414
+ assert unpacked_tensor .dtype == unpacked_tensor .dtype
0 commit comments