|
16 | 16 | """
|
17 | 17 |
|
18 | 18 | import torch
|
| 19 | +import torch.nn.functional as F |
| 20 | +import torch.nn as nn |
19 | 21 | from .dynamic_quant import (
|
20 | 22 | DynamicallyPerAxisQuantizedLinear,
|
21 | 23 | )
|
|
28 | 30 | from .weight_only import (
|
29 | 31 | WeightOnlyInt8QuantLinear,
|
30 | 32 | )
|
31 |
| -from typing import Dict |
| 33 | +from .quant_primitives import ( |
| 34 | + get_group_qparams_symmetric, |
| 35 | + per_token_dynamic_quant, |
| 36 | +) |
| 37 | +from typing import Dict, Tuple |
32 | 38 |
|
33 | 39 | __all__ = [
|
34 | 40 | "apply_weight_only_int8_quant",
|
@@ -382,7 +388,7 @@ def quantize(
|
382 | 388 | self.pad_calibration_inputs,
|
383 | 389 | )
|
384 | 390 | model = self._convert_for_runtime(model)
|
385 |
| - model.load_state_dict(state_dict) |
| 391 | + model.load_state_dict(state_dict, strict=False) |
386 | 392 | return model
|
387 | 393 |
|
388 | 394 |
|
@@ -465,11 +471,49 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
|
465 | 471 | self.precision,
|
466 | 472 | )
|
467 | 473 |
|
| 474 | + from math import gcd |
| 475 | + from functools import reduce |
| 476 | + |
| 477 | + |
| 478 | + def find_multiple(n: int, *args: Tuple[int]) -> int: |
| 479 | + # TODO: this change is reverted right now in gpt-fast |
| 480 | + k: int = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) # type: ignore[9] |
| 481 | + if n % k == 0: |
| 482 | + return n |
| 483 | + return n + k - (n % k) |
| 484 | + |
468 | 485 |
|
469 | 486 | def _check_linear_int4_k(k, group_size=1):
|
470 | 487 | return k % group_size == 0
|
471 | 488 |
|
472 | 489 |
|
| 490 | + def _calc_padded_size_linear_int4(k, groupsize=1): |
| 491 | + return find_multiple(k, groupsize) |
| 492 | + |
| 493 | + |
| 494 | + def pack_scales_and_zeros(scales, zeros, precision=torch.float32): |
| 495 | + assert scales.shape == zeros.shape |
| 496 | + assert scales.dtype == precision |
| 497 | + assert zeros.dtype == precision |
| 498 | + return ( |
| 499 | + torch.cat( |
| 500 | + [ |
| 501 | + scales.reshape(scales.size(0), scales.size(1), 1), |
| 502 | + zeros.reshape(zeros.size(0), zeros.size(1), 1), |
| 503 | + ], |
| 504 | + 2, |
| 505 | + ) |
| 506 | + .transpose(0, 1) |
| 507 | + .contiguous() |
| 508 | + ) |
| 509 | + |
| 510 | + |
| 511 | + def unpack_scales_and_zeros(scales_and_zeros): |
| 512 | + assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 |
| 513 | + assert scales_and_zeros.dtype == torch.float |
| 514 | + return torch.split(scales_and_zeros.transpose(0, 1), 1, 2) |
| 515 | + |
| 516 | + |
473 | 517 | def replace_linear_8da4w(
|
474 | 518 | module,
|
475 | 519 | group_size,
|
@@ -554,25 +598,27 @@ def __init__(
|
554 | 598 | ]
|
555 | 599 | # skip unless padding_allowed=True or its correctly sized
|
556 | 600 | self.skip_layer_func = lambda linear_weight: not (
|
557 |
| - _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) |
| 601 | + _check_linear_int4_k(linear_weight.shape[-1], groupsize) |
558 | 602 | or padding_allowed
|
559 | 603 | )
|
560 | 604 |
|
561 | 605 | # we need to do the padding here, both for q and the qparams if necessary
|
562 | 606 | def make_names_and_values_dict_func(q, qparams):
|
563 | 607 | k = q.shape[1]
|
564 |
| - new_k = _calc_padded_size_linear_int4(k, groupsize, inner_k_tiles) |
| 608 | + new_k = _calc_padded_size_linear_int4(k, groupsize) |
565 | 609 | # how much we need to pad the weight
|
566 | 610 | delta_k = new_k - q.shape[1]
|
567 | 611 | final_q = F.pad(q, pad=(0, delta_k))
|
568 |
| - scales_and_zeros = pack_scales_and_zeros(*qparams, precision=self.precision) |
| 612 | + scales = qparams[0].to(self.precision) |
| 613 | + zeros = qparams[1].to(self.precision) |
| 614 | + # scales_and_zeros = pack_scales_and_zeros(*qparams, precision=self.precision) |
569 | 615 | # how many new groups we need for padded weight
|
570 |
| - delta_groups = new_k // groupsize - scales_and_zeros.shape[0] |
| 616 | + # delta_groups = new_k // groupsize - scales_and_zeros.shape[0] |
571 | 617 | # TODO: split scales and zero_points
|
572 |
| - final_s_and_z = F.pad( |
573 |
| - scales_and_zeros, pad=(0, 0, 0, 0, 0, delta_groups), value=1 |
574 |
| - ) |
575 |
| - return {"weight": final_q, "scales_and_zeros": final_s_and_z} |
| 618 | + # final_s_and_z = F.pad( |
| 619 | + # scales_and_zeros, pad=(0, 0, 0, 0, 0, delta_groups), value=1 |
| 620 | + # ) |
| 621 | + return {"weight": final_q, "scales": scales, "zeros": zeros} |
576 | 622 |
|
577 | 623 | self.make_names_and_values_dict_func = make_names_and_values_dict_func
|
578 | 624 | super().__init__()
|
|
0 commit comments