Skip to content

Commit 56a5ff9

Browse files
authored
[DataType] Rename FP8 dtypes (#3155)
Following recent renaming in TVM, we now rename all FP8 dtypes: * `e4m3_float8` is renamed to `float8_e4m3fn`, * `e5m2_float8` is renamed to `float8_e5m2`. This aligns with dtype names in PyTorch and ml_dtypes. The delivered HuggingFace FP8 model repos might need update as well to align with the rename.
1 parent 8d59826 commit 56a5ff9

File tree

7 files changed

+30
-28
lines changed

7 files changed

+30
-28
lines changed

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
--find-links https://mlc.ai/wheels
22
fastapi
3+
ml_dtypes>=0.5.1
34
mlc-ai-nightly
45
openai
56
prompt_toolkit

python/mlc_llm/op/cutlass.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ def group_gemm(
5050
out_dtype = out_dtype if out_dtype else x.dtype
5151
weight_dtype = weight_dtype if weight_dtype else weight.dtype
5252

53-
if x.dtype == "e5m2_float8" and weight_dtype == "e5m2_float8" and out_dtype == "float16":
53+
if x.dtype == "float8_e5m2" and weight_dtype == "float8_e5m2" and out_dtype == "float16":
5454
func_name = "cutlass.group_gemm_e5m2_e5m2_fp16"
55-
elif x.dtype == "e4m3_float8" and weight_dtype == "e5m2_float8" and out_dtype == "float16":
55+
elif x.dtype == "float8_e4m3fn" and weight_dtype == "float8_e5m2" and out_dtype == "float16":
5656
func_name = "cutlass.group_gemm_e4m3_e5m2_fp16"
57-
elif x.dtype == "e4m3_float8" and weight_dtype == "e4m3_float8" and out_dtype == "float16":
57+
elif x.dtype == "float8_e4m3fn" and weight_dtype == "float8_e4m3fn" and out_dtype == "float16":
5858
func_name = "cutlass.group_gemm_e4m3_e4m3_fp16"
5959
elif x.dtype == "float16" and weight_dtype == "float16" and out_dtype == "float16":
6060
func_name = "cutlass.group_gemm_fp16_sm90"
@@ -113,11 +113,11 @@ def fp8_gemm(
113113
out_dtype = out_dtype if out_dtype else x.dtype
114114
weight_dtype = weight_dtype if weight_dtype else weight.dtype
115115

116-
if x.dtype == "e5m2_float8" and weight_dtype == "e5m2_float8" and out_dtype == "float16":
116+
if x.dtype == "float8_e5m2" and weight_dtype == "float8_e5m2" and out_dtype == "float16":
117117
func_name = "cutlass.gemm_e5m2_e5m2_fp16"
118-
elif x.dtype == "e4m3_float8" and weight_dtype == "e5m2_float8" and out_dtype == "float16":
118+
elif x.dtype == "float8_e4m3fn" and weight_dtype == "float8_e5m2" and out_dtype == "float16":
119119
func_name = "cutlass.gemm_e5m2_e4m3_fp16"
120-
elif x.dtype == "e4m3_float8" and weight_dtype == "e4m3_float8" and out_dtype == "float16":
120+
elif x.dtype == "float8_e4m3fn" and weight_dtype == "float8_e4m3fn" and out_dtype == "float16":
121121
func_name = "cutlass.gemm_e4m3_e4m3_fp16"
122122
else:
123123
raise NotImplementedError(

python/mlc_llm/op/moe_matmul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def dequantize_float8_gemv(
182182
w: Tensor,
183183
scale: Optional[Tensor],
184184
indptr: Tensor,
185-
quantize_dtype: Literal["e5m2_float8", "e4m3_float8"],
185+
quantize_dtype: Literal["float8_e5m2", "float8_e4m3fn"],
186186
) -> Tensor:
187187
"""GEMV for project-in (e1-e3) or project-out (e2) in MLP but the weight is quantized in
188188
fp8 e5m2 or e4m3. It needs to be dequantized before the GEMV computation.
@@ -204,8 +204,8 @@ def dequantize_float8_gemv(
204204
The index pointer tensor of shape (1, experts_per_tok), where `experts_per_tok` is the
205205
number of activated experts per token.
206206
207-
quantize_dtype : Literal["e5m2_float8", "e4m3_float8"]
208-
The quantize dtype of the weight tensor, which is either e5m2_float8 or e4m3_float8.
207+
quantize_dtype : Literal["float8_e5m2", "float8_e4m3fn"]
208+
The quantize dtype of the weight tensor, which is either float8_e5m2 or float8_e4m3fn.
209209
"""
210210
(x_leading_dim, in_features), model_dtype = x.shape, x.dtype
211211
(local_experts, out_features, _), storage_dtype = w.shape, w.dtype

python/mlc_llm/quantization/per_tensor_quantization.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ class PerTensorQuantize: # pylint: disable=too-many-instance-attributes
3232

3333
name: str
3434
kind: str
35-
activation_dtype: Literal["e4m3_float8", "e5m2_float8"]
36-
weight_dtype: Literal["e4m3_float8", "e5m2_float8"]
37-
storage_dtype: Literal["uint32", "e4m3_float8", "e5m2_float8"]
35+
activation_dtype: Literal["float8_e4m3fn", "float8_e5m2"]
36+
weight_dtype: Literal["float8_e4m3fn", "float8_e5m2"]
37+
storage_dtype: Literal["uint32", "float8_e4m3fn", "float8_e5m2"]
3838
model_dtype: Literal["float16"]
3939
quantize_embedding: bool = True
4040
quantize_final_fc: bool = True
@@ -184,8 +184,8 @@ def quantize_weight(self, weight) -> List[NDArray]:
184184

185185
def _create_quantize_func() -> IRModule:
186186
if DataType(self.weight_dtype).type_code in [
187-
DataTypeCode.E4M3Float,
188-
DataTypeCode.E5M2Float,
187+
DataTypeCode.Float8E4M3FN,
188+
DataTypeCode.Float8E5M2,
189189
]:
190190
quantize_func = functools.partial(
191191
self.quantize_float8,
@@ -288,8 +288,8 @@ def _dequantize(
288288
if self.use_scale:
289289
assert scale is not None
290290
if DataType(self.weight_dtype).type_code in [
291-
DataTypeCode.E4M3Float,
292-
DataTypeCode.E5M2Float,
291+
DataTypeCode.Float8E4M3FN,
292+
DataTypeCode.Float8E5M2,
293293
]:
294294
return self.dequantize_float8(q_weight, scale, self.weight_dtype, out_shape)
295295
raise NotImplementedError()
@@ -655,8 +655,8 @@ def from_mixtral_experts(
655655
The per-tensor quantized MixtralExperts layer
656656
"""
657657
if DataType(config.weight_dtype).type_code in [
658-
DataTypeCode.E4M3Float,
659-
DataTypeCode.E5M2Float,
658+
DataTypeCode.Float8E4M3FN,
659+
DataTypeCode.Float8E5M2,
660660
]:
661661
return PerTensorQuantizeMixtralExperts._IMPL["fp8"].from_mixtral_experts(
662662
src, config, name

python/mlc_llm/quantization/quantization.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr
122122
"e5m2_e5m2_f16": PerTensorQuantize(
123123
name="e5m2_e5m2_f16",
124124
kind="per-tensor-quant",
125-
activation_dtype="e5m2_float8",
126-
weight_dtype="e5m2_float8",
127-
storage_dtype="e5m2_float8",
125+
activation_dtype="float8_e5m2",
126+
weight_dtype="float8_e5m2",
127+
storage_dtype="float8_e5m2",
128128
model_dtype="float16",
129129
quantize_final_fc=False,
130130
quantize_embedding=False,
@@ -134,9 +134,9 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr
134134
"e4m3_e4m3_f16": PerTensorQuantize(
135135
name="e4m3_e4m3_f16",
136136
kind="per-tensor-quant",
137-
activation_dtype="e4m3_float8",
138-
weight_dtype="e4m3_float8",
139-
storage_dtype="e4m3_float8",
137+
activation_dtype="float8_e4m3fn",
138+
weight_dtype="float8_e4m3fn",
139+
storage_dtype="float8_e4m3fn",
140140
model_dtype="float16",
141141
quantize_final_fc=False,
142142
quantize_embedding=False,
@@ -147,9 +147,9 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr
147147
"e4m3_e4m3_f16_max_calibrate": PerTensorQuantize(
148148
name="e4m3_e4m3_f16_max_calibrate",
149149
kind="per-tensor-quant",
150-
activation_dtype="e4m3_float8",
151-
weight_dtype="e4m3_float8",
152-
storage_dtype="e4m3_float8",
150+
activation_dtype="float8_e4m3fn",
151+
weight_dtype="float8_e4m3fn",
152+
storage_dtype="float8_e4m3fn",
153153
model_dtype="float16",
154154
quantize_final_fc=False,
155155
quantize_embedding=False,

python/mlc_llm/quantization/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def convert_uint_packed_fp8_to_float( # pylint: disable=too-many-arguments
104104
out_shape: Optional[Sequence[tir.PrimExpr]] = None,
105105
) -> te.Tensor:
106106
"""Unpack a fp8 value from the storage dtype and convert to float."""
107-
assert quant_dtype in ["e4m3_float8", "e5m2_float8"]
107+
assert quant_dtype in ["float8_e4m3fn", "float8_e5m2"]
108108
assert DataType(storage_dtype).type_code == DataTypeCode.UINT
109109
bits = DataType(quant_dtype).bits
110110
elem_storage_dtype = DataType(f"uint{bits}")

python/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def main():
111111
"transformers",
112112
"pandas",
113113
"datasets",
114+
"ml_dtypes>=0.5.1",
114115
"flashinfer-python==0.2.2",
115116
],
116117
distclass=BinaryDistribution,

0 commit comments

Comments
 (0)