|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +import torch |
| 4 | + |
| 5 | +BFLOAT16_EXP_BIAS = 127 |
| 6 | +BFLOAT16_MANTISSA_BITS = 7 |
| 7 | +BFLOAT16_EXP_BITS = 8 |
| 8 | + |
| 9 | +FLOAT16_EXP_BIAS = 15 |
| 10 | +FLOAT16_MANTISSA_BITS = 10 |
| 11 | +FLOAT16_EXP_BITS = 5 |
| 12 | + |
| 13 | +FLOAT8_E8M0_MAX_EXP = 127 |
| 14 | +FLOAT4_EXP_BIAS = 1 |
| 15 | +FLOAT4_MANTISSA_BITS = 1 |
| 16 | + |
| 17 | +FLOAT16_VAL_TO_ADD = (1 << (FLOAT16_MANTISSA_BITS - FLOAT4_MANTISSA_BITS - 1)) |
| 18 | +FLOAT16_SIGN_EXPONENT_MASK = (( |
| 19 | + (1 << (FLOAT16_EXP_BITS + 1)) - 1) << FLOAT16_MANTISSA_BITS) |
| 20 | + |
| 21 | +BFLOAT16_VAL_TO_ADD = (1 << |
| 22 | + (BFLOAT16_MANTISSA_BITS - FLOAT4_MANTISSA_BITS - 1)) |
| 23 | +BFLOAT16_SIGN_EXPONENT_MASK = (( |
| 24 | + (1 << (BFLOAT16_EXP_BITS + 1)) - 1) << BFLOAT16_MANTISSA_BITS) |
| 25 | + |
| 26 | + |
| 27 | +def e8m0_to_half(scale, half_dtype: torch.dtype): |
| 28 | + assert scale.dtype == torch.uint8 |
| 29 | + |
| 30 | + scale_exp = scale.to(torch.int16) - 127 |
| 31 | + |
| 32 | + # This can be implemented with bitwise operations in a proper kernel. |
| 33 | + scale_half = 2.0**(scale_exp.to(torch.float)) |
| 34 | + |
| 35 | + return scale_half.to(half_dtype) |
| 36 | + |
| 37 | + |
| 38 | +def upcast_fp4_to_fp16_or_bf16(val, float_dtype: torch.dtype, |
| 39 | + half_exp_bias: int, half_mantissa_bits: int): |
| 40 | + assert val.dtype == torch.uint8 |
| 41 | + |
| 42 | + unpacked = torch.zeros(*val.shape[:-1], |
| 43 | + val.shape[-1] * 2, |
| 44 | + dtype=torch.uint8, |
| 45 | + device=val.device) |
| 46 | + unpacked[..., 1::2] = (val >> 4) & 0x0F # Extract high 4 bits. |
| 47 | + unpacked[..., ::2] = val & 0x0F # Extract low 4 bits. |
| 48 | + |
| 49 | + # Takes one float4 values represented as b0000xxxx, |
| 50 | + # and converts it to the corresponding float16 value. |
| 51 | + |
| 52 | + sign = unpacked >> 3 |
| 53 | + |
| 54 | + exp = (unpacked >> 1) & 3 |
| 55 | + new_mantissa = unpacked & 1 |
| 56 | + |
| 57 | + # if exp == 0 and new_mantissa == 0: |
| 58 | + # new_exp = 0 |
| 59 | + # else: |
| 60 | + # new_exp = exp - FLOAT4_EXP_BIAS + FLOAT16_EXP_BIAS |
| 61 | + |
| 62 | + # int8_t works with float16, but may overflow with bfloat16. |
| 63 | + new_exp = exp - FLOAT4_EXP_BIAS + half_exp_bias |
| 64 | + |
| 65 | + # Cast b0000 to 0. in fp16/bf16. |
| 66 | + new_exp = new_exp * torch.logical_or(exp > 0, new_mantissa > 0) |
| 67 | + |
| 68 | + # Cast b0001 to 0.5 in fp16/bf16. |
| 69 | + new_mantissa = torch.logical_and(new_mantissa, exp > 0) |
| 70 | + |
| 71 | + new_mantissa = new_mantissa.to(torch.int32) |
| 72 | + new_exp = new_exp.to(torch.int32) |
| 73 | + sign = sign.to(torch.int32) |
| 74 | + |
| 75 | + qdq_val = (sign << 15) + (new_exp << half_mantissa_bits) + ( |
| 76 | + new_mantissa << (half_mantissa_bits - 1)) |
| 77 | + |
| 78 | + assert qdq_val.max() <= 65535 |
| 79 | + assert qdq_val.min() >= 0 |
| 80 | + qdq_val = qdq_val.to(torch.uint16) |
| 81 | + |
| 82 | + result = qdq_val.view(float_dtype) |
| 83 | + |
| 84 | + return result |
| 85 | + |
| 86 | + |
| 87 | +def dq_mxfp4_torch(x: torch.Tensor, scale: torch.Tensor, |
| 88 | + float_dtype: torch.dtype) -> torch.Tensor: |
| 89 | + assert x.dtype == torch.uint8 |
| 90 | + assert scale.dtype == torch.uint8 |
| 91 | + |
| 92 | + if float_dtype == torch.float16: |
| 93 | + half_exp_bias = FLOAT16_EXP_BIAS |
| 94 | + half_mantissa_bits = FLOAT16_MANTISSA_BITS |
| 95 | + elif float_dtype == torch.bfloat16: |
| 96 | + half_exp_bias = BFLOAT16_EXP_BIAS |
| 97 | + half_mantissa_bits = BFLOAT16_MANTISSA_BITS |
| 98 | + |
| 99 | + scale_half = e8m0_to_half(scale, half_dtype=float_dtype) |
| 100 | + |
| 101 | + x_half = upcast_fp4_to_fp16_or_bf16(x, |
| 102 | + float_dtype=float_dtype, |
| 103 | + half_exp_bias=half_exp_bias, |
| 104 | + half_mantissa_bits=half_mantissa_bits) |
| 105 | + |
| 106 | + x_half = x_half.reshape(*x_half.shape[:-1], -1, 32) |
| 107 | + x_half = x_half * scale_half[..., None] |
| 108 | + x_half = x_half.reshape(*x_half.shape[:-2], -1) |
| 109 | + |
| 110 | + return x_half |
| 111 | + |
| 112 | + |
| 113 | +def fp16_to_fp4_simulate(val, half_mantissa_bits: int, half_exp_bits: int, |
| 114 | + half_exp_bias: int): |
| 115 | + # Casts an fp16/bf16 input to the restricted values of float4_e2m1, |
| 116 | + # that is to say [0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, |
| 117 | + # -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0]. |
| 118 | + |
| 119 | + float_type = val.dtype |
| 120 | + |
| 121 | + # "rshift_cuda" not implemented for 'UInt16' |
| 122 | + val_view = val.view(torch.int16) #.to(torch.int32) |
| 123 | + |
| 124 | + exp = val_view >> half_mantissa_bits |
| 125 | + exp = exp & ((1 << half_exp_bits) - 1) |
| 126 | + |
| 127 | + exp = exp.view(torch.uint16).to(torch.int32) |
| 128 | + |
| 129 | + sign = (val_view >> (half_mantissa_bits + half_exp_bits)) & 1 |
| 130 | + |
| 131 | + mantissa_last = (val_view >> (half_mantissa_bits - 1)) & 1 |
| 132 | + |
| 133 | + exp_unbias = exp - half_exp_bias |
| 134 | + new_exp = exp_unbias + FLOAT4_EXP_BIAS |
| 135 | + |
| 136 | + exp_shift = (new_exp <= 0) * (1 - new_exp) |
| 137 | + |
| 138 | + # Typically 9. |
| 139 | + # Take the min to prevent overflow on `uint16_t half`. This is the case for |
| 140 | + # very small values, correctly mapped to `round_close`. |
| 141 | + tail_bits = half_mantissa_bits - FLOAT4_MANTISSA_BITS + exp_shift |
| 142 | + tail_bits[tail_bits >= 16] = 16 |
| 143 | + |
| 144 | + mantissa_plus_one = val_view & ((1 << (half_mantissa_bits + 1)) - 1) |
| 145 | + |
| 146 | + half = 1 << (tail_bits - 1) |
| 147 | + |
| 148 | + tail = mantissa_plus_one & ((1 << tail_bits) - 1) |
| 149 | + |
| 150 | + round_close = (tail < half) # round towards 0 |
| 151 | + round_away = (tail > half) # round away from 0 |
| 152 | + tie = tail == half |
| 153 | + |
| 154 | + new_mantissa_close = torch.zeros(val.shape, |
| 155 | + device=val.device, |
| 156 | + dtype=torch.bool) |
| 157 | + new_exp_close = torch.zeros(val.shape, |
| 158 | + device=val.device, |
| 159 | + dtype=torch.uint16) |
| 160 | + |
| 161 | + new_mantissa_away = torch.zeros(val.shape, |
| 162 | + device=val.device, |
| 163 | + dtype=torch.bool) |
| 164 | + new_exp_away = torch.zeros(val.shape, |
| 165 | + device=val.device, |
| 166 | + dtype=torch.uint16) |
| 167 | + |
| 168 | + new_exp_tie = torch.zeros(val.shape, device=val.device, dtype=torch.uint16) |
| 169 | + |
| 170 | + # 1. round down |
| 171 | + # if new_exp == 0: # case [0.5, 0.749999] |
| 172 | + # new_mantissa = 0 |
| 173 | + # elif new_exp < 0: # case [0, 0.24999] |
| 174 | + # new_mantissa = 0 |
| 175 | + # else: |
| 176 | + # new_mantissa = mantissa_last |
| 177 | + |
| 178 | + new_mantissa_close = (new_exp > 0) * mantissa_last |
| 179 | + new_exp_close = exp |
| 180 | + |
| 181 | + # # 2. round up |
| 182 | + # if new_exp <= 0: # case [0.250001, 0.499999] and [0.75001, 0.99999] |
| 183 | + # new_mantissa = 0 |
| 184 | + # new_exp += 1 |
| 185 | + # elif mantissa_last == 0: |
| 186 | + # new_mantissa = 1 |
| 187 | + # else: |
| 188 | + # new_mantissa = 0 |
| 189 | + # new_exp += 1 |
| 190 | + |
| 191 | + new_mantissa_away = torch.logical_and(new_exp > 0, mantissa_last == 0) |
| 192 | + new_exp_away = exp + torch.logical_or(new_exp <= 0, mantissa_last == 1) |
| 193 | + |
| 194 | + # # 3. tie |
| 195 | + # 0.25 -> 0. (handled by `exp > (half_exp_bias - 2)`) |
| 196 | + # 0.75 -> 1. |
| 197 | + # 1.25 -> 1. |
| 198 | + # 1.75 -> 2. |
| 199 | + # 2.5 -> 2. |
| 200 | + # 3.5 -> 4. |
| 201 | + # 5. -> 4. |
| 202 | + new_exp_tie = (exp > (half_exp_bias - 2)) * (exp + (mantissa_last == 1)) |
| 203 | + |
| 204 | + # Gather round up, round down and tie. |
| 205 | + new_exp = round_away * new_exp_away \ |
| 206 | + + round_close * new_exp_close \ |
| 207 | + + tie * new_exp_tie |
| 208 | + |
| 209 | + new_mantissa = round_away * new_mantissa_away \ |
| 210 | + + round_close * new_mantissa_close |
| 211 | + |
| 212 | + # if new_exp > 3: |
| 213 | + # new_mantissa = 1 |
| 214 | + new_mantissa = new_mantissa + (new_exp > |
| 215 | + (2 + half_exp_bias)) * (new_mantissa == 0) |
| 216 | + |
| 217 | + # Clamp the exponent to acceptable values. |
| 218 | + new_exp = (new_exp >= (half_exp_bias - 2)) * torch.clamp( |
| 219 | + new_exp, half_exp_bias - 2, half_exp_bias + 2) |
| 220 | + |
| 221 | + sign = sign.to(torch.int32) |
| 222 | + new_mantissa = new_mantissa.to(torch.int32) |
| 223 | + |
| 224 | + qdq_val = (sign << 15) + (new_exp << half_mantissa_bits) + ( |
| 225 | + new_mantissa << (half_mantissa_bits - 1)) |
| 226 | + |
| 227 | + assert qdq_val.max() <= 65535 |
| 228 | + assert qdq_val.min() >= 0 |
| 229 | + assert qdq_val.dtype == torch.int32 |
| 230 | + qdq_val = qdq_val.to(torch.uint16) |
| 231 | + |
| 232 | + result = qdq_val.view(float_type) |
| 233 | + return result |
| 234 | + |
| 235 | + |
| 236 | +def qdq_mxfp4_torch(x: torch.Tensor, |
| 237 | + scale_calculation_mode: str = "even") -> torch.Tensor: |
| 238 | + half_dtype = x.dtype |
| 239 | + |
| 240 | + if half_dtype == torch.float16: |
| 241 | + half_mantissa_bits = FLOAT16_MANTISSA_BITS |
| 242 | + half_exp_bits = FLOAT16_EXP_BITS |
| 243 | + half_exp_bias = FLOAT16_EXP_BIAS |
| 244 | + val_to_add = FLOAT16_VAL_TO_ADD |
| 245 | + sign_exponent_mask = FLOAT16_SIGN_EXPONENT_MASK |
| 246 | + elif half_dtype == torch.bfloat16: |
| 247 | + half_mantissa_bits = BFLOAT16_MANTISSA_BITS |
| 248 | + half_exp_bits = BFLOAT16_EXP_BITS |
| 249 | + half_exp_bias = BFLOAT16_EXP_BIAS |
| 250 | + val_to_add = BFLOAT16_VAL_TO_ADD |
| 251 | + sign_exponent_mask = BFLOAT16_SIGN_EXPONENT_MASK |
| 252 | + else: |
| 253 | + raise ValueError("not implemented") |
| 254 | + |
| 255 | + x = x.reshape(*x.shape[:-1], -1, 32) |
| 256 | + |
| 257 | + block_max = torch.max(torch.abs(x), dim=-1).values |
| 258 | + |
| 259 | + block_max = block_max.view(torch.uint16).to(torch.int32) |
| 260 | + |
| 261 | + block_max_uint = torch.bitwise_and(block_max + val_to_add, |
| 262 | + sign_exponent_mask) |
| 263 | + |
| 264 | + assert block_max_uint.max() <= 65535 |
| 265 | + assert block_max_uint.min() >= 0 |
| 266 | + assert block_max_uint.dtype == torch.int32 |
| 267 | + block_max_uint = block_max_uint.to(torch.uint16) |
| 268 | + |
| 269 | + block_max = block_max_uint.view(half_dtype) |
| 270 | + |
| 271 | + scale_exp = FLOAT8_E8M0_MAX_EXP + torch.floor(torch.log2(block_max)).to( |
| 272 | + torch.int32) - 2 |
| 273 | + |
| 274 | + scale_exp = torch.clamp(scale_exp, 0, 2 * FLOAT8_E8M0_MAX_EXP) |
| 275 | + |
| 276 | + scale = 2.0**(scale_exp - FLOAT8_E8M0_MAX_EXP) |
| 277 | + scale = scale.to(half_dtype) |
| 278 | + |
| 279 | + x = x / scale[..., None] |
| 280 | + |
| 281 | + x_fp4 = fp16_to_fp4_simulate(x, |
| 282 | + half_exp_bits=half_exp_bits, |
| 283 | + half_mantissa_bits=half_mantissa_bits, |
| 284 | + half_exp_bias=half_exp_bias) |
| 285 | + |
| 286 | + x_fp4 = x_fp4 * scale[..., None] |
| 287 | + return x_fp4.reshape(*x_fp4.shape[:-2], -1) |
0 commit comments