|
24 | 24 |
|
25 | 25 | __all__ = [
|
26 | 26 | "FP8_DTYPE",
|
| 27 | + "FP8_E4M3_DATA", |
| 28 | + "FP4_E2M1_DATA", |
27 | 29 | "QuantizationType",
|
28 | 30 | "QuantizationStrategy",
|
29 | 31 | "QuantizationArgs",
|
30 | 32 | "round_to_quantized_type",
|
31 | 33 | "ActivationOrdering",
|
32 | 34 | ]
|
33 | 35 |
|
| 36 | + |
| 37 | +class FloatArgs: |
| 38 | + exponent: int |
| 39 | + mantissa: int |
| 40 | + bits: int |
| 41 | + max: float |
| 42 | + min: float |
| 43 | + dtype: Optional[torch.dtype] = None |
| 44 | + |
| 45 | + |
| 46 | +class FP4_E2M1_DATA(FloatArgs): |
| 47 | + exponent = 2 |
| 48 | + mantissa = 1 |
| 49 | + bits = 4 |
| 50 | + max = 6.0 |
| 51 | + min = -6.0 |
| 52 | + |
| 53 | + @staticmethod |
| 54 | + def cast_to_fp4(x): |
| 55 | + sign = torch.sign(x) |
| 56 | + x = torch.abs(x) |
| 57 | + x[(x >= 0.0) & (x <= 0.25)] = 0.0 |
| 58 | + x[(x > 0.25) & (x < 0.75)] = 0.5 |
| 59 | + x[(x >= 0.75) & (x <= 1.25)] = 1.0 |
| 60 | + x[(x > 1.25) & (x < 1.75)] = 1.5 |
| 61 | + x[(x >= 1.75) & (x <= 2.5)] = 2.0 |
| 62 | + x[(x > 2.5) & (x < 3.5)] = 3.0 |
| 63 | + x[(x >= 3.5) & (x <= 5.0)] = 4.0 |
| 64 | + x[x > 5.0] = 6.0 |
| 65 | + return x * sign |
| 66 | + |
| 67 | + |
| 68 | +class FP8_E4M3_DATA(FloatArgs): |
| 69 | + exponent = 4 |
| 70 | + mantissa = 3 |
| 71 | + bits = 8 |
| 72 | + max = torch.finfo(torch.float8_e4m3fn).max |
| 73 | + min = torch.finfo(torch.float8_e4m3fn).min |
| 74 | + dtype = torch.float8_e4m3fn |
| 75 | + |
| 76 | + |
| 77 | +# TODO: Remove soon in favour of a more descriptive FloatArgs |
34 | 78 | FP8_DTYPE = torch.float8_e4m3fn
|
35 | 79 |
|
36 | 80 |
|
@@ -234,7 +278,10 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
|
234 | 278 |
|
235 | 279 | def pytorch_dtype(self) -> torch.dtype:
|
236 | 280 | if self.type == QuantizationType.FLOAT:
|
237 |
| - return FP8_DTYPE |
| 281 | + if self.num_bits == 8: |
| 282 | + return FP8_E4M3_DATA.dtype |
| 283 | + else: |
| 284 | + raise NotImplementedError("Only num_bits in (8) are supported") |
238 | 285 | elif self.type == QuantizationType.INT:
|
239 | 286 | if self.num_bits <= 8:
|
240 | 287 | return torch.int8
|
@@ -263,7 +310,12 @@ def round_to_quantized_type(
|
263 | 310 | """
|
264 | 311 | original_dtype = tensor.dtype
|
265 | 312 | if args.type == QuantizationType.FLOAT:
|
266 |
| - rounded = tensor.to(FP8_DTYPE) |
| 313 | + if args.num_bits == 8: |
| 314 | + rounded = tensor.to(FP8_E4M3_DATA.dtype) |
| 315 | + elif args.num_bits == 4: |
| 316 | + rounded = FP4_E2M1_DATA.cast_to_fp4(tensor) |
| 317 | + else: |
| 318 | + raise NotImplementedError("Only num_bits in (4, 8) are supported") |
267 | 319 | elif args.type == QuantizationType.INT:
|
268 | 320 | rounded = torch.round(tensor)
|
269 | 321 | else:
|
|
0 commit comments