Skip to content

Commit 7d0b5eb

Browse files
committed
add model_opt compressor
1 parent a2cbcf6 commit 7d0b5eb

File tree

5 files changed

+224
-13
lines changed

5 files changed

+224
-13
lines changed

src/compressed_tensors/compressors/compress_to_fp4.py

+49-13
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,61 @@
1-
import torch
2-
import numpy
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy
16+
import torch
17+
318

419
x = torch.Tensor(
5-
[[-0.5000, -6.0000, -0.5000, -1.5000, -1.0000, 6.0000, -0.0000, -0.0000],
6-
[-1.0000, -6.0000, -0.5000, -0.0000, 0.5000, 0.5000, -0.0000, -0.0000],
7-
[-3.0000, -6.0000, -0.5000, -2.0000, -0.5000, -1.5000, -0.0000, -0.0000],
8-
[ 1.5000, 6.0000, -0.0000, -0.5000, 1.0000, 1.0000, -0.0000, 0.0000]]
20+
[
21+
[-0.5000, -6.0000, -0.5000, -1.5000, -1.0000, 6.0000, -0.0000, -0.0000],
22+
[-1.0000, -6.0000, -0.5000, -0.0000, 0.5000, 0.5000, -0.0000, -0.0000],
23+
[-3.0000, -6.0000, -0.5000, -2.0000, -0.5000, -1.5000, -0.0000, -0.0000],
24+
[1.5000, 6.0000, -0.0000, -0.5000, 1.0000, 1.0000, -0.0000, 0.0000],
25+
]
926
)
1027

1128
m, n = x.shape
1229

13-
FLOAT_TO_E2M1 = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0]
30+
FLOAT_TO_E2M1 = [
31+
0.0,
32+
0.5,
33+
1.0,
34+
1.5,
35+
2.0,
36+
3.0,
37+
4.0,
38+
6.0,
39+
0.0,
40+
-0.5,
41+
-1.0,
42+
-1.5,
43+
-2.0,
44+
-3.0,
45+
-4.0,
46+
-6.0,
47+
]
1448
conversion_dict = {}
1549

16-
# Dictionary between fp4 value and index
50+
# Dictionary between fp4 value and index
1751
for i in range(len(FLOAT_TO_E2M1)):
18-
conversion_dict[FLOAT_TO_E2M1[i]] = i
52+
conversion_dict[FLOAT_TO_E2M1[i]] = i
1953

2054

2155
x_numpy = x.to("cpu").numpy()
22-
x_index = numpy.array([[conversion_dict[i] for i in row] for row in x_numpy], dtype=numpy.uint8)
56+
x_index = numpy.array(
57+
[[conversion_dict[i] for i in row] for row in x_numpy], dtype=numpy.uint8
58+
)
2359
x_index_bits = numpy.unpackbits(x_index)
2460

2561
packed_shape = numpy.zeros([x_index_bits.shape[0] // 2], numpy.uint8)
@@ -32,9 +68,9 @@
3268
subset = x_index_bits[start:end]
3369
subset_a = subset[4:8]
3470
subset_b = subset[12:16]
35-
packed_shape[i+4:i+8] = subset_a
36-
packed_shape[i:i+4] = subset_b
37-
start = end
71+
packed_shape[i + 4 : i + 8] = subset_a
72+
packed_shape[i : i + 4] = subset_b
73+
start = end
3874
end = start + 16
3975
i += 8
4076

src/compressed_tensors/compressors/quantized_compressors/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@
1414
# flake8: noqa
1515

1616
from .base import *
17+
from .modelopt_quantized import *
1718
from .naive_quantized import *
1819
from .pack_quantized import *

src/compressed_tensors/compressors/quantized_compressors/base.py

+4
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ def compress(
113113
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
114114
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
115115
g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
116+
global_scale = model_state.get(
117+
merge_names(prefix, "weight_global_scale"), None
118+
)
116119
if scale is not None:
117120
# weight is quantized, compress it
118121
if isinstance(names_to_scheme[prefix], tuple):
@@ -125,6 +128,7 @@ def compress(
125128
scale=scale,
126129
zero_point=zp,
127130
g_idx=g_idx,
131+
global_scale=global_scale,
128132
quantization_args=quant_args,
129133
device="cpu",
130134
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from typing import Dict, Optional, Tuple
17+
18+
import numpy
19+
import torch
20+
from compressed_tensors.compressors.base import BaseCompressor
21+
from compressed_tensors.compressors.quantized_compressors.base import (
22+
BaseQuantizationCompressor,
23+
)
24+
from compressed_tensors.config import CompressionFormat
25+
from compressed_tensors.quantization import QuantizationArgs
26+
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
27+
from torch import Tensor
28+
29+
30+
FLOAT_TO_E2M1 = [
31+
0.0,
32+
0.5,
33+
1.0,
34+
1.5,
35+
2.0,
36+
3.0,
37+
4.0,
38+
6.0,
39+
-0.0,
40+
-0.5,
41+
-1.0,
42+
-1.5,
43+
-2.0,
44+
-3.0,
45+
-4.0,
46+
-6.0,
47+
]
48+
conversion_dict = {
49+
0.0: 0,
50+
0.5: 1,
51+
1.0: 2,
52+
1.5: 3,
53+
2.0: 4,
54+
3.0: 5,
55+
4.0: 6,
56+
6.0: 7,
57+
-0.0: 8,
58+
-0.5: 9,
59+
-1.0: 10,
60+
-1.5: 11,
61+
-2.0: 12,
62+
-3.0: 13,
63+
-4.0: 14,
64+
-6.0: 15,
65+
}
66+
67+
68+
@BaseCompressor.register(name=CompressionFormat.modelopt_quantized.value)
69+
class ModelOptCompressor(BaseQuantizationCompressor):
70+
"""
71+
Implements naive compression for quantized models. Weight of each
72+
quantized layer is converted from its original float type to the closest Pytorch
73+
type to the type specified by the layer's QuantizationArgs.
74+
"""
75+
76+
@property
77+
def compression_param_names(self) -> Tuple[str]:
78+
"""
79+
Returns a tuple of compression parameter names introduced by
80+
the compressor during compression
81+
"""
82+
return (
83+
"weight_packed",
84+
"weight_scale",
85+
"weight_zero_point",
86+
"weight_global_scale",
87+
)
88+
89+
def compress_weight(
90+
self,
91+
weight: Tensor,
92+
scale: Tensor,
93+
global_scale: Tensor,
94+
quantization_args: QuantizationArgs,
95+
device: Optional[torch.device] = None,
96+
zero_point: Optional[torch.Tensor] = None,
97+
g_idx: Optional[torch.Tensor] = None,
98+
) -> Dict[str, torch.Tensor]:
99+
100+
quantized_weight = quantize(
101+
x=weight,
102+
scale=scale,
103+
global_scale=global_scale,
104+
zero_point=zero_point,
105+
args=quantization_args,
106+
)
107+
compressed_dict = {}
108+
weight_packed = pack_fp4_to_uint8(quantized_weight)
109+
compressed_dict["weight_packed"] = weight_packed
110+
return compressed_dict
111+
112+
def decompress_weight(
113+
self,
114+
compressed_data: Dict[str, Tensor],
115+
quantization_args: Optional[QuantizationArgs] = None,
116+
) -> torch.Tensor:
117+
118+
weight = compressed_data["weight_packed"]
119+
scale = compressed_data["weight_scale"]
120+
global_scale = compressed_data["weight_global_scale"]
121+
m, n = weight.shape
122+
unpacked = unpack_fp4_from_uint8(weight, m, n * 2)
123+
decompressed_weight = dequantize(
124+
x_q=unpacked, scale=scale, global_scale=global_scale
125+
)
126+
127+
return decompressed_weight
128+
129+
130+
def pack_fp4_to_uint8(x: torch.Tensor):
131+
m, n = x.shape
132+
133+
# convert to bits
134+
x_array = x.cpu().to(torch.float32).numpy()
135+
x_index = numpy.array(
136+
[[conversion_dict[i] for i in row] for row in x_array], dtype=numpy.uint8
137+
)
138+
x_index_bits = numpy.unpackbits(x_index)
139+
140+
# unpack
141+
packed_shape = numpy.zeros([x_index_bits.shape[0] // 2], numpy.uint8)
142+
start = 0
143+
end = 16
144+
i = 0
145+
146+
# janky bit manipulation
147+
while end < len(x_index_bits):
148+
packed_shape[i + 4 : i + 8] = x_index_bits[start:end][4:8]
149+
packed_shape[i : i + 4] = x_index_bits[start:end][12:16]
150+
start = end
151+
end = start + 16
152+
i += 8
153+
154+
# pack
155+
packed = numpy.packbits(packed_shape)
156+
packed = torch.from_numpy(packed).to(torch.uint8)
157+
# reshape
158+
packed = packed.reshape(m, n // 2)
159+
return packed
160+
161+
162+
# from vLLM
163+
def unpack_fp4_from_uint8(x: torch.Tensor, m: int, n: int):
164+
v_2nd = x & 0xF
165+
v_1st = (x >> 4) & 0xF
166+
c = torch.stack((v_2nd, v_1st), dim=-1)
167+
out = torch.tensor([FLOAT_TO_E2M1[x] for x in c.flatten()])
168+
out = out.reshape(m, n).to(torch.float32)
169+
return out

src/compressed_tensors/config/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class CompressionFormat(Enum):
3232
naive_quantized = "naive-quantized"
3333
pack_quantized = "pack-quantized"
3434
marlin_24 = "marlin-24"
35+
modelopt_quantized = "modelopt-quantized"
3536

3637

3738
@unique

0 commit comments

Comments
 (0)