@@ -156,48 +156,109 @@ def from_plain(
156
156
zero_point : Optional [torch .Tensor ],
157
157
layout : Layout ,
158
158
bias : Optional [torch .Tensor ] = None ,
159
+ * ,
160
+ validate_inputs : bool = True ,
159
161
):
160
162
assert isinstance (layout , PackedLinearInt8DynamicActivationIntxWeightLayout )
161
- assert layout .has_params_set (), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain"
162
163
assert layout .target in [
163
164
t for t , _ in _TARGET_AND_STR
164
165
], f"Unexpected target: { layout .target } "
166
+ assert layout .has_params_set (), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain"
167
+
168
+ if layout .target != Target .ATEN :
169
+ _check_torchao_ops_loaded ()
170
+ else :
171
+ assert (
172
+ TORCH_VERSION_AT_LEAST_2_6
173
+ ), "aten target is requires torch version > 2.6.0"
174
+ assert (
175
+ torch .backends .kleidiai .is_available ()
176
+ ), "ATEN target requires torch.backends.kleidiai.is_available()"
177
+ layout .bit_width == 4 , "ATEN target only supports torch.int4"
178
+ assert not layout .has_weight_zeros , "ATEN target does not support zeros"
179
+
180
+ data_dtype = getattr (torch , f"int{ layout .bit_width } " )
181
+ qmin , qmax = _DTYPE_TO_QVALUE_BOUNDS [data_dtype ]
165
182
166
183
int_types = [torch .int8 , torch .int16 , torch .int32 , torch .int64 ]
167
184
185
+ # Check int_data
186
+ assert int_data .device == torch .device ("cpu" )
187
+ assert int_data .dtype in int_types
168
188
n , k = int_data .shape
169
- assert int_data .dtype in int_types , f"int_data.dtype must be { int_types } "
170
189
assert k % layout .group_size == 0 , "k must be divisible by group_size"
190
+ if validate_inputs :
191
+ assert int_data .min ().item () >= qmin
192
+ assert int_data .max ().item () <= qmax
171
193
int_data = int_data .to (torch .int8 )
172
194
173
- assert scale .dtype == torch .float32 , "scale must be float32"
195
+ # Check scale
196
+ assert scale .device == torch .device ("cpu" )
197
+ if scale .dtype != torch .float32 :
198
+ logging .info (f"scale has dtype { scale .dtype } , converting to torch.float32" )
199
+ scale = scale .to (torch .float32 )
200
+ n_ , _ = scale .shape
201
+ assert n_ == n
174
202
assert (
175
203
scale .numel () * layout .group_size == int_data .numel ()
176
204
), "must have 1 scale per group"
177
-
178
- assert (zero_point is not None ) == (
179
- layout .has_weight_zeros
180
- ), "zero_point being None must be consistent with layout.has_weight_zeros"
181
- if zero_point is not None :
205
+ if validate_inputs :
206
+ assert scale .min ().item () > 0
207
+ # Some targets round scales to bfloat16, give warning if scales are at higher precision
208
+ scale_is_rounded_to_bf16 = torch .allclose (
209
+ scale , scale .to (torch .bfloat16 ).to (torch .float32 )
210
+ )
211
+ if not scale_is_rounded_to_bf16 :
212
+ if layout .target == Target .ATEN and (layout .group_size < k ):
213
+ logging .warning (
214
+ "When using Target.ATEN with group_size < k, scales will be rounded to bfloat16"
215
+ )
216
+ if layout .target in [Target .AUTO , Target .KLEIDIAI ]:
217
+ logging .warning (
218
+ "When using [Target.AUTO, Target.KLEIDIAI], scales will be rounded to bfloat16"
219
+ )
220
+
221
+ # Check zero_point
222
+ if zero_point is None :
182
223
assert (
183
- zero_point .dtype in int_types
184
- ), f"zero_point.dtype must be { int_types } "
224
+ not layout .has_weight_zeros
225
+ ), "zero_point must be provided if has_weight_zeros=True"
226
+ else :
227
+ assert zero_point .device == torch .device ("cpu" )
228
+ assert zero_point .shape == scale .shape
229
+ assert zero_point .dtype in int_types
185
230
assert (
186
231
zero_point .numel () * layout .group_size == int_data .numel ()
187
232
), "must have 1 zero_point per group"
233
+ if validate_inputs :
234
+ zero_point_min = zero_point .min ().item ()
235
+ zero_point_max = zero_point .max ().item ()
236
+ assert zero_point .min ().item () >= qmin
237
+ assert zero_point .max ().item () <= qmax
238
+ has_weight_zeros = True
239
+ if zero_point_min == 0 and zero_point_max == 0 :
240
+ has_weight_zeros = False
241
+ assert (
242
+ has_weight_zeros == layout .has_weight_zeros
243
+ ), "zero_point being all zeros must be consistent with layout.has_weight_zeros"
188
244
zero_point = zero_point .to (torch .int8 )
189
245
190
- assert (bias is not None ) == (
191
- layout .has_bias
246
+ # Check bias
247
+ has_bias = bias is not None
248
+ assert (
249
+ has_bias == layout .has_bias
192
250
), "bias being None must be consistent with layout.has_bias"
193
- if bias is not None :
194
- assert bias .dtype == torch .float32 , "bias.dtype must be float32"
195
- assert bias .shape == (n ,), "bias must have shape n"
251
+ if has_bias :
252
+ assert bias .device == torch .device ("cpu" )
253
+ if bias .dtype != torch .float32 :
254
+ logging .info (
255
+ f"bias has dtype { bias .dtype } , converting to torch.float32"
256
+ )
257
+ bias = bias .to (torch .float32 )
258
+ assert bias .shape == (n ,)
196
259
260
+ # Construct packed_weight
197
261
if layout .target == Target .ATEN :
198
- assert (
199
- TORCH_VERSION_AT_LEAST_2_6
200
- ), "aten target is requires torch version > 2.6.0"
201
262
int_data = int_data .add (8 )
202
263
int_data = (int_data [::, 1 ::2 ] << 4 | int_data [::, ::2 ]).to (torch .uint8 )
203
264
@@ -213,12 +274,11 @@ def from_plain(
213
274
args = [
214
275
int_data ,
215
276
scale .reshape (- 1 ),
216
- zero_point .reshape (- 1 ) if zero_point is not None else None ,
277
+ zero_point .reshape (- 1 ) if layout . has_weight_zeros else None ,
217
278
layout .group_size ,
218
279
bias ,
219
280
target_to_str (layout .target ) if layout .target != Target .AUTO else None ,
220
281
]
221
-
222
282
packed_weight = getattr (
223
283
torch .ops .torchao ,
224
284
f"_pack_8bit_act_{ layout .bit_width } bit_weight" ,
@@ -358,79 +418,35 @@ def make_packed_linear_int8_dynamic_activation_intx_weight_tensor(
358
418
assert TORCH_VERSION_AT_LEAST_2_6 , "Using PackedLinearInt8DynamicActivationIntxWeightLayout requires torch version > 2.6.0"
359
419
360
420
layout = PackedLinearInt8DynamicActivationIntxWeightLayout (target = target )
361
- if layout .target != Target .ATEN :
362
- _check_torchao_ops_loaded ()
363
- else :
364
- assert (
365
- torch .backends .kleidiai .is_available ()
366
- ), "ATEN target requires torch.backends.kleidiai.is_available()"
367
- assert data_dtype == torch .int4 , "ATEN target only supports torch.int4"
368
- assert zero_point is None , "ATEN target does not support zeros"
369
421
370
- assert data_dtype in [getattr (torch , f"int{ x } " ) for x in range (1 , 9 )]
371
- qmin , qmax = _DTYPE_TO_QVALUE_BOUNDS [data_dtype ]
372
422
bit_width = _DTYPE_TO_BIT_WIDTH [data_dtype ]
423
+ qmin , qmax = _DTYPE_TO_QVALUE_BOUNDS [data_dtype ]
373
424
374
- int_types = [torch .int8 , torch .int16 , torch .int32 , torch .int64 ]
375
-
376
- # Check int_data
377
- assert int_data .device == torch .device ("cpu" )
378
- assert int_data .dtype in int_types
379
425
n , k = int_data .shape
380
- if validate_inputs :
381
- assert int_data .min ().item () >= qmin
382
- assert int_data .max ().item () <= qmax
383
-
384
- # Check scale
385
- assert scale .device == torch .device ("cpu" )
386
- if scale .dtype != torch .float32 :
387
- logging .info (f"scale has dtype { scale .dtype } , converting to torch.float32" )
388
- scale = scale .to (torch .float32 )
389
426
n_ , groups_per_k = scale .shape
390
- assert n_ == n
391
427
assert k % groups_per_k == 0
392
428
group_size = k // groups_per_k
393
- if validate_inputs :
394
- assert scale .min ().item () > 0
395
429
396
- if validate_inputs :
397
- # Some targets round scales to bfloat16, give warning if scales are at higher precision
398
- scale_is_rounded_to_bf16 = torch .allclose (
399
- scale , scale .to (torch .bfloat16 ).to (torch .float32 )
400
- )
401
- if not scale_is_rounded_to_bf16 :
402
- if layout .target == Target .ATEN and (group_size < k ):
403
- logging .warning (
404
- "When using Target.ATEN with group_size < k, scales will be rounded to bfloat16"
405
- )
406
- if layout .target in [Target .AUTO , Target .KLEIDIAI ]:
407
- logging .warning (
408
- "When using [Target.AUTO, Target.KLEIDIAI], scales will be rounded to bfloat16"
409
- )
410
-
411
- # Check zero_point
412
- has_weight_zeros = zero_point is not None
413
- if has_weight_zeros :
414
- assert zero_point .device == torch .device ("cpu" )
415
- assert zero_point .shape == scale .shape
416
- assert zero_point .dtype in int_types
417
- if validate_inputs :
418
- assert zero_point .min ().item () >= qmin
419
- assert zero_point .max ().item () <= qmax
430
+ has_weight_zeros = True
431
+ if zero_point is None :
432
+ has_weight_zeros = False
433
+ else :
434
+ zero_point_min = zero_point .min ().item ()
435
+ zero_point_max = zero_point .max ().item ()
436
+ if zero_point_min == 0 and zero_point_max == 0 :
437
+ has_weight_zeros = False
420
438
421
- # Check bias
422
439
has_bias = bias is not None
423
- if has_bias :
424
- assert bias .device == torch .device ("cpu" )
425
- if bias .dtype != torch .float32 :
426
- logging .info (f"bias has dtype { bias .dtype } , converting to torch.float32" )
427
- bias = bias .to (torch .float32 )
428
- assert bias .shape == (n ,)
429
440
430
441
layout .set_params (bit_width , group_size , has_weight_zeros , has_bias )
431
442
assert layout .has_params_set ()
432
443
tensor_impl = PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl .from_plain (
433
- int_data , scale , zero_point , layout , bias
444
+ int_data ,
445
+ scale ,
446
+ zero_point ,
447
+ layout ,
448
+ bias ,
449
+ validate_inputs = validate_inputs ,
434
450
)
435
451
436
452
return AffineQuantizedTensor (
@@ -439,7 +455,5 @@ def make_packed_linear_int8_dynamic_activation_intx_weight_tensor(
439
455
shape = int_data .shape ,
440
456
quant_min = qmin ,
441
457
quant_max = qmax ,
442
- zero_point_domain = ZeroPointDomain .INT
443
- if has_weight_zeros
444
- else ZeroPointDomain .NONE ,
458
+ zero_point_domain = ZeroPointDomain .INT ,
445
459
)
0 commit comments