|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import math |
| 4 | + |
| 5 | +import torch |
| 6 | + |
| 7 | +import helion |
| 8 | +import helion.language as hl |
| 9 | + |
| 10 | + |
| 11 | +@helion.kernel(static_shapes=True) |
| 12 | +def fp8_attention_kernel( |
| 13 | + q: torch.Tensor, # [batch*heads, seq, dim] |
| 14 | + k: torch.Tensor, # [batch*heads, seq, dim] |
| 15 | + v: torch.Tensor, # [batch*heads, dim, seq] - pre-transposed |
| 16 | +) -> torch.Tensor: |
| 17 | + batch_heads = q.size(0) |
| 18 | + seq_len = q.size(1) |
| 19 | + head_dim = q.size(2) |
| 20 | + |
| 21 | + # Output tensor |
| 22 | + out = torch.empty( |
| 23 | + [batch_heads, seq_len, head_dim], dtype=torch.float32, device=q.device |
| 24 | + ) |
| 25 | + |
| 26 | + # Scale factor for attention |
| 27 | + sm_scale = 1.0 / math.sqrt(float(head_dim)) |
| 28 | + # Triton kernel multiplies sm_scale by 1.44269504 (1/log(2)) for exp2 |
| 29 | + sm_scale = sm_scale * 1.44269504 |
| 30 | + |
| 31 | + # Process each batch*head in parallel |
| 32 | + for bh in hl.grid(batch_heads): |
| 33 | + # Process each query position |
| 34 | + for tile_m in hl.tile(seq_len): |
| 35 | + # Initialize for online softmax |
| 36 | + m_i = hl.full([tile_m], float("-inf"), dtype=torch.float32) |
| 37 | + l_i = hl.full([tile_m], 0.0, dtype=torch.float32) |
| 38 | + acc = hl.zeros([tile_m, head_dim], dtype=torch.float32) |
| 39 | + |
| 40 | + # Load query tile - keep in FP8 |
| 41 | + q_tile = q[bh, tile_m, :] # [tile_m, dim] |
| 42 | + |
| 43 | + # Compute attention scores for all keys |
| 44 | + for tile_n in hl.tile(seq_len): |
| 45 | + # Load key tile and transpose for Q @ K^T |
| 46 | + k_tile = k[bh, tile_n, :] # [tile_n, dim] - keep in FP8 |
| 47 | + k_tile_t = k_tile.transpose(0, 1) # [dim, tile_n] |
| 48 | + |
| 49 | + # Compute Q @ K^T with FP8 inputs, result in FP32 |
| 50 | + qk = torch.matmul(q_tile, k_tile_t).to( |
| 51 | + torch.float32 |
| 52 | + ) # [tile_m, tile_n] |
| 53 | + |
| 54 | + # Scale QK scores first |
| 55 | + qk_scaled = qk * sm_scale # [tile_m, tile_n] |
| 56 | + |
| 57 | + # Compute max of scaled scores |
| 58 | + qk_max = torch.amax(qk_scaled, dim=-1) # [tile_m] |
| 59 | + |
| 60 | + # Update global max |
| 61 | + m_new = torch.maximum(m_i, qk_max) |
| 62 | + |
| 63 | + # Shift by max for numerical stability |
| 64 | + qk_shifted = qk_scaled - m_new[:, None] |
| 65 | + |
| 66 | + # Use exp2 to match Triton kernel's implementation |
| 67 | + # Note: Triton kernel already multiplies sm_scale by 1.44269504 |
| 68 | + p = torch.exp2(qk_shifted) # [tile_m, tile_n] |
| 69 | + |
| 70 | + # Sum of exponentials for this block |
| 71 | + l_ij = torch.sum(p, dim=-1) # [tile_m] |
| 72 | + |
| 73 | + # Update accumulators with correction factor |
| 74 | + # Correction factor for previous blocks |
| 75 | + alpha = torch.exp2(m_i - m_new) |
| 76 | + l_i = l_i * alpha + l_ij |
| 77 | + acc = acc * alpha[:, None] |
| 78 | + |
| 79 | + # Load values - V is [dim, seq] |
| 80 | + v_tile = v[bh, :, tile_n] # [dim, tile_n] - keep in FP8 |
| 81 | + |
| 82 | + # Convert p to FP8 for FP8 GEMM |
| 83 | + p_fp8 = p.to(v.dtype) # Convert to same FP8 type as V |
| 84 | + |
| 85 | + # Accumulate attention @ V with FP8 GEMM |
| 86 | + v_t = v_tile.transpose(0, 1) # [tile_n, dim] |
| 87 | + pv = torch.matmul(p_fp8, v_t).to(torch.float32) # [tile_m, dim] |
| 88 | + acc = acc + pv |
| 89 | + |
| 90 | + # Update max tracker |
| 91 | + m_i = m_new |
| 92 | + |
| 93 | + # Final normalization |
| 94 | + acc = acc / l_i[:, None] |
| 95 | + out[bh, tile_m, :] = acc |
| 96 | + |
| 97 | + return out |
| 98 | + |
| 99 | + |
| 100 | +def prepare_fp8_attention_inputs( |
| 101 | + q: torch.Tensor, # [batch, heads, seq, dim] |
| 102 | + k: torch.Tensor, # [batch, heads, seq, dim] |
| 103 | + v: torch.Tensor, # [batch, heads, seq, dim] |
| 104 | +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[int, int, int, int]]: |
| 105 | + """ |
| 106 | + Common preprocessing for FP8 attention implementations. |
| 107 | +
|
| 108 | + Returns: |
| 109 | + q_reshaped_fp8: [batch*heads, seq, dim] - in FP8 e5m2 |
| 110 | + k_reshaped_fp8: [batch*heads, seq, dim] - in FP8 e5m2 |
| 111 | + v_transposed_fp8: [batch*heads, dim, seq] - in FP8 e5m2 |
| 112 | + shape: (batch, heads, seq_len, head_dim) |
| 113 | + """ |
| 114 | + batch, heads, seq_len, head_dim = q.shape |
| 115 | + |
| 116 | + # Reshape to [batch*heads, seq, dim] |
| 117 | + q_reshaped = q.reshape(batch * heads, seq_len, head_dim) |
| 118 | + k_reshaped = k.reshape(batch * heads, seq_len, head_dim) |
| 119 | + |
| 120 | + # Transpose V to [batch, heads, dim, seq] then reshape |
| 121 | + v_transposed = v.permute(0, 1, 3, 2).reshape(batch * heads, head_dim, seq_len) |
| 122 | + |
| 123 | + # Convert to FP8 e5m2 |
| 124 | + q_reshaped_fp8 = q_reshaped.to(torch.float8_e5m2) |
| 125 | + k_reshaped_fp8 = k_reshaped.to(torch.float8_e5m2) |
| 126 | + v_transposed_fp8 = v_transposed.to(torch.float8_e5m2) |
| 127 | + |
| 128 | + return ( |
| 129 | + q_reshaped_fp8, |
| 130 | + k_reshaped_fp8, |
| 131 | + v_transposed_fp8, |
| 132 | + (batch, heads, seq_len, head_dim), |
| 133 | + ) |
| 134 | + |
| 135 | + |
| 136 | +def fp8_attention_tritonbench( |
| 137 | + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor |
| 138 | +) -> torch.Tensor: |
| 139 | + """Wrapper for TritonBench compatibility.""" |
| 140 | + # Common preprocessing with FP8 conversion |
| 141 | + q_fp8, k_fp8, v_fp8, shape = prepare_fp8_attention_inputs(q, k, v) |
| 142 | + batch, heads, seq_len, head_dim = shape |
| 143 | + |
| 144 | + # Call the fused kernel |
| 145 | + out_fused = fp8_attention_kernel(q_fp8, k_fp8, v_fp8) |
| 146 | + |
| 147 | + # Reshape back and convert to FP16 |
| 148 | + out = out_fused.reshape(batch, heads, seq_len, head_dim) |
| 149 | + return out.to(torch.float16) |
| 150 | + |
| 151 | + |
| 152 | +def fp8_attention_pytorch( |
| 153 | + q: torch.Tensor, # [batch, heads, seq, dim] |
| 154 | + k: torch.Tensor, # [batch, heads, seq, dim] |
| 155 | + v: torch.Tensor, # [batch, heads, seq, dim] |
| 156 | +) -> torch.Tensor: |
| 157 | + """ |
| 158 | + Baseline PyTorch implementation of FP8 attention using FP8 e5m2. |
| 159 | + """ |
| 160 | + # Get preprocessed inputs with FP8 conversion |
| 161 | + q_fp8, k_fp8, v_fp8, shape = prepare_fp8_attention_inputs(q, k, v) |
| 162 | + batch, heads, seq_len, head_dim = shape |
| 163 | + |
| 164 | + sm_scale = 1.0 / math.sqrt(float(head_dim)) |
| 165 | + |
| 166 | + outputs = [] |
| 167 | + |
| 168 | + for i in range(batch * heads): |
| 169 | + q_i = q_fp8[i] # [seq, dim] - already FP8 |
| 170 | + k_i = k_fp8[i] # [seq, dim] - already FP8 |
| 171 | + v_i = v_fp8[i] # [dim, seq] - pre-transposed, already FP8 |
| 172 | + |
| 173 | + # For Q @ K^T, we need K^T to be column-major |
| 174 | + kt_fp8 = k_i.t() # column-major [dim, seq] |
| 175 | + |
| 176 | + # Q @ K^T - dequantize and use regular matmul since e5m2 not supported by _scaled_mm |
| 177 | + q_deq = q_i.to(torch.float32) |
| 178 | + kt_deq = kt_fp8.to(torch.float32) |
| 179 | + qk = torch.matmul(q_deq, kt_deq) |
| 180 | + |
| 181 | + # Compute max before scaling |
| 182 | + qk_max = torch.amax(qk, dim=-1, keepdim=True) |
| 183 | + |
| 184 | + # Scale and shift in one operation, then use exp2 |
| 185 | + qk_scaled_shifted = qk * sm_scale - qk_max * sm_scale |
| 186 | + p = torch.exp2(qk_scaled_shifted * 1.44269504) |
| 187 | + |
| 188 | + # Normalize |
| 189 | + p_norm = p / p.sum(dim=-1, keepdim=True) |
| 190 | + |
| 191 | + # Step 2: Attention @ V using FP8 |
| 192 | + # P is [seq, seq], V is [dim, seq] |
| 193 | + # We want P @ V^T = [seq, seq] @ [seq, dim] = [seq, dim] |
| 194 | + p_fp8 = p_norm.to(torch.float8_e5m2) # row-major [seq, seq] |
| 195 | + |
| 196 | + # v_i is [dim, seq], already FP8 |
| 197 | + vt_fp8 = v_i.t() # column-major [seq, dim] |
| 198 | + |
| 199 | + # P @ V^T - dequantize and use regular matmul since e5m2 not supported by torch._scaled_mm |
| 200 | + p_deq = p_fp8.to(torch.float32) |
| 201 | + vt_deq = vt_fp8.to(torch.float32) |
| 202 | + out_i = torch.matmul(p_deq, vt_deq) |
| 203 | + |
| 204 | + outputs.append(out_i) |
| 205 | + |
| 206 | + # Stack and reshape back |
| 207 | + out_stacked = torch.stack(outputs, dim=0) # [batch*heads, seq, dim] |
| 208 | + out = out_stacked.reshape(batch, heads, seq_len, head_dim) |
| 209 | + |
| 210 | + return out.to(torch.float16) |
| 211 | + |
| 212 | + |
| 213 | +def check(batch: int, heads: int, seq_len: int, head_dim: int) -> None: |
| 214 | + torch.manual_seed(42) |
| 215 | + q = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda") |
| 216 | + k = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda") |
| 217 | + v = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda") |
| 218 | + |
| 219 | + from helion._testing import run_example |
| 220 | + |
| 221 | + run_example( |
| 222 | + fp8_attention_tritonbench, fp8_attention_pytorch, (q, k, v), atol=0.1, rtol=0.1 |
| 223 | + ) |
| 224 | + |
| 225 | + |
| 226 | +def main() -> None: |
| 227 | + check(1, 2, 128, 64) |
| 228 | + check(2, 4, 256, 64) |
| 229 | + check(4, 8, 512, 128) |
| 230 | + |
| 231 | + |
| 232 | +if __name__ == "__main__": |
| 233 | + main() |
0 commit comments