Skip to content

Commit d128e59

Browse files
[DataType] BF16 Support (#3158)
Allows BF16 for model datatypes in group quantization and adds Quantization settings for BF16. Corresponding PR in TVM is apache/tvm#17670
1 parent bbd5e9c commit d128e59

File tree

3 files changed

+30
-2
lines changed

3 files changed

+30
-2
lines changed

ci/task/test_unittest.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ if [[ -n ${MLC_CI_SETUP_DEPS:-} ]]; then
77
echo "MLC_CI_SETUP_DEPS=1 start setup deps.."
88
# Install dependency
99
pip install --force-reinstall wheels/*.whl
10+
pip install "ml_dtypes>=0.5.1" --no-binary ml_dtypes
1011
pip install --quiet pytest
1112
pip install --pre -U --no-index -f https://mlc.ai/wheels mlc-ai-nightly-cu123
1213
export LD_LIBRARY_PATH=/usr/local/cuda/compat/:$LD_LIBRARY_PATH

python/mlc_llm/quantization/group_quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class GroupQuantize: # pylint: disable=too-many-instance-attributes
3333
group_size: int
3434
quantize_dtype: Literal["int3", "int4", "int8"]
3535
storage_dtype: Literal["uint32"]
36-
model_dtype: Literal["float16", "float32"]
36+
model_dtype: Literal["float16", "float32", "bfloat16"]
3737
linear_weight_layout: Literal["KN", "NK"]
3838
quantize_embedding: bool = True
3939
quantize_final_fc: bool = True
@@ -50,7 +50,7 @@ def __post_init__(self):
5050
model_dtype = DataType(self.model_dtype)
5151
assert quantize_dtype.type_code == DataTypeCode.INT
5252
assert storage_dtype.type_code == DataTypeCode.UINT
53-
assert model_dtype.type_code == DataTypeCode.FLOAT
53+
assert model_dtype.type_code in (DataTypeCode.FLOAT, DataTypeCode.BFLOAT)
5454
if storage_dtype.bits < quantize_dtype.bits:
5555
raise ValueError("Storage unit should be greater or equal to quantized element")
5656

python/mlc_llm/quantization/quantization.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr
3333
kind="no-quant",
3434
model_dtype="float16",
3535
),
36+
"q0bf16": NoQuantize(
37+
name="q0bf16",
38+
kind="no-quant",
39+
model_dtype="bfloat16",
40+
),
3641
"q0f32": NoQuantize(
3742
name="q0f32",
3843
kind="no-quant",
@@ -82,6 +87,28 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr
8287
quantize_embedding=True,
8388
quantize_final_fc=True,
8489
),
90+
"q4bf16_0": GroupQuantize(
91+
name="q4bf16_0",
92+
kind="group-quant",
93+
group_size=32,
94+
quantize_dtype="int4",
95+
storage_dtype="uint32",
96+
model_dtype="bfloat16",
97+
linear_weight_layout="KN",
98+
quantize_embedding=True,
99+
quantize_final_fc=True,
100+
),
101+
"q4bf16_1": GroupQuantize(
102+
name="q4bf16_1",
103+
kind="group-quant",
104+
group_size=32,
105+
quantize_dtype="int4",
106+
storage_dtype="uint32",
107+
model_dtype="bfloat16",
108+
linear_weight_layout="NK",
109+
quantize_embedding=True,
110+
quantize_final_fc=True,
111+
),
85112
"q4f32_1": GroupQuantize(
86113
name="q4f32_1",
87114
kind="group-quant",

0 commit comments

Comments
 (0)