Skip to content

Commit 0a06448

Browse files
authored
change use_optimum_format=True and add bias (#1431)
Signed-off-by: Xin He <xin3.he@intel.com>
1 parent 87b3b18 commit 0a06448

File tree

9 files changed

+129
-113
lines changed

9 files changed

+129
-113
lines changed

docs/source/quantization_weight_only.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,18 +93,19 @@ To support low memory inference, Neural Compressor implemented WeightOnlyLinear,
9393
**Export arguments**
9494
| export args | default value | comments |
9595
|:----------:|:-------------:|:-------------------------------------------------------------------:|
96-
| qweight_config_path | None | If need to export model with fp32_model and json file, set the path of qconfig.json |
96+
| use_optimum_format | True | Whether to use the popular format used in [Optimum](https://github.com/huggingface/optimum/blob/e0927976d06d163ed09fe5bd80d013e1cfa0c463/docs/source/llm_quantization/usage_guides/quantization.mdx#L5) |
9797
| sym_full_range | False | Whether to leverage the full compression range under symmetric quantization |
98-
| compression_dtype | torch.int32 | Data type for compressed dtype, select from [torch.int8\|16\|32\|64] |
99-
| compression_dim | 1 | 0 means output channel while 1 means input channel |
100-
| scale_dtype | torch.float32 | Data type for scale and bias |
101-
| use_hf_format | False | Whether to use the popular format present on HuggingFace hub |
98+
| compression_dtype | torch.int32 | Data type for compressed dtype, select from [torch.int8\|16\|32\|64]. It's torch.int32 when use_optimum_format=True |
99+
| compression_dim | 1 | 0 means output channel while 1 means input channel. It's 1 for weight and 0 for zero-point when use_optimum_format=True |
100+
| scale_dtype | torch.float32 | Data type for scale and bias. It's torch.float16 when use_optimum_format=True |
101+
| qweight_config_path | None | set the path of qconfig.json if you want to export model with json file |
102+
| gptq_config_path | None | If need to export model with fp32_model and json file, set the path of gptq_config.json for GPTQ quantized model|
102103

103-
**Note:** HuggingFace format is quite special, the main differences are as follows:
104+
**Note:** The format used in Optimum is acceptable for transformers, which makes it easy to use. However, this format is rather special, the main differences are as follows:
104105

105106
> 1: Compression Dimension: weight = 1, zero = 0 and both are transposed.
106107
> 2: Zero Point: zero_point-= 1 before compression. zero_point is always required even for sym.
107-
> 3: Group Index: Use the same number for a group instead of recording channel order.
108+
> 3: Group Index: Use the same number for a group instead of recording channel order.
108109
109110

110111
### **User Code Example**

neural_compressor/adaptor/pytorch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4582,10 +4582,12 @@ def rtn_quantize(self, model, tune_cfg):
45824582
enable_full_range = self.recipes["rtn_args"].get("enable_full_range", False)
45834583
enable_mse_search = self.recipes["rtn_args"].get("enable_mse_search", False)
45844584
group_dim = self.recipes["rtn_args"].get("group_dim", 1)
4585+
return_int = self.recipes["rtn_args"].get("return_int", False)
45854586
else: # pragma: no cover
45864587
enable_full_range = False
45874588
enable_mse_search = False
45884589
group_dim = 1
4590+
return_int = False
45894591
from .torch_utils.util import fetch_module, set_module
45904592
from .torch_utils.weight_only import rtn_quantize
45914593

@@ -4623,7 +4625,7 @@ def rtn_quantize(self, model, tune_cfg):
46234625
num_bits,
46244626
group_size,
46254627
scheme,
4626-
return_int=False,
4628+
return_int=return_int,
46274629
data_type=dtype,
46284630
enable_full_range=enable_full_range,
46294631
enable_mse_search=enable_mse_search,

neural_compressor/adaptor/torch_utils/model_wrapper.py

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,10 @@ def __init__(
217217
compression_dim=1,
218218
g_idx=False,
219219
device="cpu",
220-
use_hf_format=False,
220+
use_optimum_format=True,
221221
):
222222
super().__init__()
223-
self.use_hf_format = use_hf_format
223+
self.use_optimum_format = use_optimum_format
224224
self.dtype = dtype
225225
if "int" not in self.dtype: # for nf4, fp4
226226
from neural_compressor.adaptor.torch_utils.weight_only import FLOAT_MAPPING, INT_MAPPING
@@ -245,13 +245,13 @@ def __init__(
245245
dtype_bits_mapping = {torch.int8: 8, torch.int16: 16, torch.int32: 32, torch.int64: 64}
246246
self.compress_bits = dtype_bits_mapping[compression_dtype]
247247
self.n_pack = self.compress_bits // self.bits
248-
self.compressed_dtype = compression_dtype
249-
self.float_type = scale_dtype
250248
# K is input channel, N is output channel
251249
assert compression_dim in [0, 1], (
252250
"Only support 0 or 1 as compression dimension, " + "0 is output channel, 1 is input channel."
253251
)
254-
if self.use_hf_format:
252+
if self.use_optimum_format:
253+
self.float_type = torch.float16
254+
self.compressed_dtype = torch.int32
255255
self.register_buffer(
256256
"scales",
257257
torch.zeros(
@@ -276,7 +276,10 @@ def __init__(
276276
).to(device),
277277
)
278278
self.qzeros = self.qzeros.T
279+
self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device))
279280
else:
281+
self.compressed_dtype = compression_dtype
282+
self.float_type = scale_dtype
280283
self.register_buffer(
281284
"scales",
282285
torch.zeros(
@@ -316,18 +319,18 @@ def __init__(
316319
dtype=self.compressed_dtype,
317320
).to(device),
318321
)
322+
if bias:
323+
self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device))
324+
else:
325+
self.bias = None
319326
if g_idx:
320327
self.register_buffer("g_idx", torch.zeros(in_features, dtype=torch.int32).to(device))
321328
else:
322329
self.g_idx = None
323-
if bias:
324-
self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device))
325-
else:
326-
self.bias = None
327330

328331
def pack(self, int_weight, scale, zp, bias, g_idx=None):
329332
int_weight = int_weight.to(self.device)
330-
if self.use_hf_format and zp is None:
333+
if self.use_optimum_format and zp is None:
331334
# to avoid overflow
332335
int_weight = int_weight.type(torch.int32)
333336
shift_bias = 2 ** (self.bits - 1)
@@ -339,13 +342,13 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
339342
if g_idx is not None:
340343
assert hasattr(self, "g_idx"), "g_idx is not set when initializing."
341344
self.g_idx = g_idx.type(torch.int32).to(self.device)
342-
if self.use_hf_format:
345+
if self.use_optimum_format:
343346
invperm = torch.argsort(self.g_idx)
344347
self.g_idx = invperm // self.groupsize
345348
self.g_idx = self.g_idx.type(torch.int32).to(self.device)
346349
assert scale.shape == self.scales.shape, "Scale shape is mismatched."
347350
self.scales = scale.type(self.float_type).to(self.device)
348-
if not self.use_hf_format and self.compression_dim == 0:
351+
if not self.use_optimum_format and self.compression_dim == 0:
349352
int_weight = int_weight.T
350353
self.qweight = self.qweight.T
351354
origin_shape = int_weight.shape
@@ -362,14 +365,14 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
362365
tmp[:, e] &= mask
363366
tmp[:, e] = tmp[:, e] << (self.bits * e)
364367
self.qweight[:, j] |= tmp[:, e]
365-
if not self.use_hf_format and self.compression_dim == 0:
368+
if not self.use_optimum_format and self.compression_dim == 0:
366369
self.qweight = self.qweight.T
367370

368371
if zp is not None:
369372
zp = zp.to(self.device)
370-
if self.use_hf_format:
373+
if self.use_optimum_format:
371374
zp -= 1
372-
if self.use_hf_format or self.compression_dim == 0:
375+
if self.use_optimum_format or self.compression_dim == 0:
373376
zp = zp.T
374377
self.qzeros = self.qzeros.T
375378
assert hasattr(self, "qzeros"), "zp is not set when initializing."
@@ -382,23 +385,19 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
382385
tmp[:, e] &= mask
383386
tmp[:, e] = tmp[:, e] << (self.bits * e)
384387
self.qzeros[:, j] |= tmp[:, e]
385-
if self.use_hf_format or self.compression_dim == 0:
388+
if self.use_optimum_format or self.compression_dim == 0:
386389
self.qzeros = self.qzeros.T
387-
if self.use_hf_format:
390+
if self.use_optimum_format:
388391
self.scales = self.scales.T
389392
self.qweight = self.qweight.T
390-
self.g_idx = self.g_idx
391393
self.qzeros = self.qzeros.T
392394

393395
def recover(self):
394396
logger.debug(f"Recovering {self} weight")
395-
if self.use_hf_format:
396-
# Prevent broken id links of self.scales and self.scales
397-
self.scales = self.scales.T
398-
self.qweight = self.qweight.T
399-
self.g_idx = self.g_idx
400-
self.qzeros = self.qzeros.T
401-
device = self.scales.device
397+
scales = self.scales.T if self.use_optimum_format else self.scales
398+
qweight = self.qweight.T if self.use_optimum_format else self.qweight
399+
400+
device = scales.device
402401
fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device)
403402
if self.g_idx is None:
404403
# used for recovering fp32_weight
@@ -410,8 +409,7 @@ def recover(self):
410409
weight_dtype = torch.int8
411410
# unpack weight
412411
weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device)
413-
qweight = self.qweight
414-
if not self.use_hf_format and self.compression_dim == 0:
412+
if not self.use_optimum_format and self.compression_dim == 0:
415413
weight = weight.T
416414
qweight = qweight.T
417415
origin_shape = weight.shape
@@ -427,7 +425,7 @@ def recover(self):
427425
if weight_dtype == torch.uint8:
428426
tmp &= mask # remove sign bit
429427
weight[:, index] = tmp.type(weight_dtype)
430-
if not self.use_hf_format and self.compression_dim == 0:
428+
if not self.use_optimum_format and self.compression_dim == 0:
431429
weight = weight.T
432430
if "int" not in self.dtype:
433431
new_weight = torch.zeros(self.out_features, self.in_features).to(device)
@@ -437,9 +435,9 @@ def recover(self):
437435
# unpack zero_point
438436
if hasattr(self, "qzeros"):
439437
zp_dtype = self.compressed_dtype # to avoid overflow when weight-zp
440-
zp = torch.zeros(self.scales.shape, dtype=zp_dtype).to(device)
441-
qzeros = self.qzeros
442-
if self.use_hf_format or self.compression_dim == 0:
438+
zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device)
439+
qzeros = self.qzeros.T if self.use_optimum_format else self.qzeros
440+
if self.use_optimum_format or self.compression_dim == 0:
443441
zp = zp.T
444442
qzeros = qzeros.T
445443
origin_shape = zp.shape
@@ -454,30 +452,34 @@ def recover(self):
454452
tmp = tmp >> self.compress_bits - self.bits
455453
tmp &= mask
456454
zp[:, index] = tmp.type(zp_dtype)
457-
if self.use_hf_format or self.compression_dim == 0:
455+
if self.use_optimum_format or self.compression_dim == 0:
458456
zp = zp.T
459-
if self.use_hf_format:
457+
if self.use_optimum_format:
460458
# zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
461459
zp += 1
462460
zp = torch.where(zp > (2**self.bits - 1), 0, zp)
463461
# recover fp32 weight with int_weight, scale, and zero_point
464462
for idx in range(self.in_features):
465-
fp32_weight[:, idx] = (weight[:, idx] - zp[:, self.g_idx[idx]]) * self.scales[:, self.g_idx[idx]]
463+
fp32_weight[:, idx] = (weight[:, idx] - zp[:, self.g_idx[idx]]) * scales[:, self.g_idx[idx]]
466464
else:
467465
# recover fp32 weight with int_weight, scale
468466
for idx in range(self.in_features):
469-
fp32_weight[:, idx] = weight[:, idx] * self.scales[:, self.g_idx[idx]]
467+
fp32_weight[:, idx] = weight[:, idx] * scales[:, self.g_idx[idx]]
470468
return fp32_weight
471469

472470
def forward(self, input):
471+
weight = self.recover()
472+
device = self.scales.device
473+
if weight.dtype == torch.float16 and device.type == "cpu":
474+
weight = weight.float()
475+
self.bias = self.bias.float() if self.bias is not None else None
473476
if level == DEBUG:
474477
if not hasattr(self, "weight"):
475-
self.weight = self.recover()
478+
self.weight = weight
476479
input = input.type(self.weight.dtype)
477480
logger.debug(f"Calculating {self}")
478481
return F.linear(input, self.weight, self.bias)
479482
else:
480-
weight = self.recover()
481483
input = input.type(weight.dtype)
482484
return F.linear(input, weight, self.bias)
483485

@@ -489,8 +491,8 @@ def extra_repr(self) -> str:
489491
self.groupsize,
490492
self.bias is not None,
491493
)
492-
if self.use_hf_format:
493-
tmp_str += ", use_hf_format=True"
494+
if self.use_optimum_format:
495+
tmp_str += ", use_optimum_format=True"
494496
return tmp_str
495497

496498

neural_compressor/adaptor/torch_utils/weight_only.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def rtn_quantize(
396396
compression_dim = kwargs.get("compression_dim", 1)
397397
scale_dtype = kwargs.get("scale_dtype", torch.float32)
398398
device = kwargs.get("device", "cpu")
399-
use_hf_format = kwargs.get("use_hf_format", False)
399+
use_optimum_format = kwargs.get("use_optimum_format", True)
400400
for name, m in model.named_modules():
401401
if m.__class__.__name__ not in supported_layers:
402402
continue
@@ -452,7 +452,7 @@ def rtn_quantize(
452452
compression_dim=compression_dim,
453453
scale_dtype=scale_dtype,
454454
device=device,
455-
use_hf_format=use_hf_format,
455+
use_optimum_format=use_optimum_format,
456456
)
457457
new_module.pack(int_weight, scale, zp, m.bias)
458458
if name == "":

neural_compressor/model/torch_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def export_compressed_model(
459459
scale_dtype=torch.float32,
460460
gptq_config_path=None,
461461
device="cpu",
462-
use_hf_format=False,
462+
use_optimum_format=True,
463463
):
464464
"""Convert Linear to WeightOnlyLinear for low memory inference.
465465
@@ -475,7 +475,7 @@ def export_compressed_model(
475475
Defaults to torch.float32.
476476
gptq_config_path (str, optional): Path of gptq_config.json. Defaults to None.
477477
device (str, optional): choose device for compression. Defaults to cpu.
478-
use_hf_format (bool, optional): use the popular huggingface compression format.
478+
use_optimum_format (bool, optional): use the popular huggingface compression format.
479479
1: compression_dim: weight = 1, zeros = 0 and both are transposed.
480480
2: zeros -= 1 before compression. Why we need it?
481481
3: g_idx: use same number for one group instead of recording the channel order.
@@ -520,7 +520,7 @@ def export_compressed_model(
520520
compression_dim=compression_dim,
521521
scale_dtype=scale_dtype,
522522
device=device,
523-
use_hf_format=use_hf_format,
523+
use_optimum_format=use_optimum_format,
524524
)
525525
set_module(self.model, k, new_module)
526526
continue
@@ -551,7 +551,7 @@ def export_compressed_model(
551551
compression_dim=compression_dim,
552552
scale_dtype=scale_dtype,
553553
device=device,
554-
use_hf_format=use_hf_format,
554+
use_optimum_format=use_optimum_format,
555555
)
556556
new_module.pack(int_weight, gptq_scale, gptq_zp, m.bias, gptq_perm)
557557
set_module(self.model, k, new_module)
@@ -578,7 +578,7 @@ def export_compressed_model(
578578
compression_dim=compression_dim,
579579
scale_dtype=scale_dtype,
580580
device=device,
581-
use_hf_format=use_hf_format,
581+
use_optimum_format=use_optimum_format,
582582
)
583583
set_module(self.model, k, mod)
584584
return self.model

0 commit comments

Comments
 (0)