From 0a2a04e565959d5fdd02b9f38c9151f0f216701b Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 14 Jul 2025 12:55:04 -0700 Subject: [PATCH] Add fp8_attention example and unit test stack-info: PR: https://github.com/pytorch-labs/helion/pull/318, branch: yf225/stack/26 --- examples/fp8_attention.py | 213 ++++++++++++++++++++++++++++++++++++ test/test_examples.expected | 74 +++++++++++++ test/test_examples.py | 41 +++++++ 3 files changed, 328 insertions(+) create mode 100644 examples/fp8_attention.py diff --git a/examples/fp8_attention.py b/examples/fp8_attention.py new file mode 100644 index 00000000..915a3f92 --- /dev/null +++ b/examples/fp8_attention.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +import math +from typing import Callable + +import torch + +import helion +import helion.language as hl + + +@helion.kernel(static_shapes=True) +def fp8_attention_kernel( + q: torch.Tensor, # [batch*heads, seq, dim] + k: torch.Tensor, # [batch*heads, seq, dim] + v: torch.Tensor, # [batch*heads, dim, seq] - pre-transposed +) -> torch.Tensor: + batch_heads = q.size(0) + seq_len = q.size(1) + head_dim = q.size(2) + + # Output tensor + out = torch.empty( + [batch_heads, seq_len, head_dim], dtype=torch.float32, device=q.device + ) + + # Scale factor for attention + sm_scale = 1.0 / math.sqrt(float(head_dim)) + # Triton kernel multiplies sm_scale by 1.44269504 (1/log(2)) for exp2 + sm_scale = sm_scale * 1.44269504 + + # Process each batch*head in parallel + for bh in hl.grid(batch_heads): + # Process each query position + for tile_m in hl.tile(seq_len): + # Initialize for online softmax + m_i = hl.full([tile_m], float("-inf"), dtype=torch.float32) + l_i = hl.full([tile_m], 0.0, dtype=torch.float32) + acc = hl.zeros([tile_m, head_dim], dtype=torch.float32) + + # Load query tile - keep in FP8 + q_tile = q[bh, tile_m, :] # [tile_m, dim] + + # Compute attention scores for all keys + for tile_n in hl.tile(seq_len): + # Load key tile and transpose for Q @ K^T + k_tile = k[bh, tile_n, :] # [tile_n, dim] - keep in FP8 + k_tile_t = k_tile.transpose(0, 1) # [dim, tile_n] + + # Compute Q @ K^T with FP8 inputs, result in FP32 + qk = torch.matmul(q_tile, k_tile_t).to( + torch.float32 + ) # [tile_m, tile_n] + + # Scale QK scores first + qk_scaled = qk * sm_scale # [tile_m, tile_n] + + # Compute max of scaled scores + qk_max = torch.amax(qk_scaled, dim=-1) # [tile_m] + + # Update global max + m_new = torch.maximum(m_i, qk_max) + + # Shift by max for numerical stability + qk_shifted = qk_scaled - m_new[:, None] + + # Use exp2 to match Triton kernel's implementation + # Note: Triton kernel already multiplies sm_scale by 1.44269504 + p = torch.exp2(qk_shifted) # [tile_m, tile_n] + + # Sum of exponentials for this block + l_ij = torch.sum(p, dim=-1) # [tile_m] + + # Update accumulators with correction factor + # Correction factor for previous blocks + alpha = torch.exp2(m_i - m_new) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + + # Load values - V is [dim, seq] + v_tile = v[bh, :, tile_n] # [dim, tile_n] - keep in FP8 + + # Convert p to FP8 for FP8 GEMM + p_fp8 = p.to(v.dtype) # Convert to same FP8 type as V + + # Accumulate attention @ V with FP8 GEMM + v_t = v_tile.transpose(0, 1) # [tile_n, dim] + pv = torch.matmul(p_fp8, v_t).to(torch.float32) # [tile_m, dim] + acc = acc + pv + + # Update max tracker + m_i = m_new + + # Final normalization + acc = acc / l_i[:, None] + out[bh, tile_m, :] = acc + + return out + + +def preprocess_fp8_attention_inputs( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q_fp8 = q.to(torch.float8_e5m2) + k_fp8 = k.to(torch.float8_e5m2) + v = v.permute(0, 1, 3, 2) + v_fp8 = v.to(torch.float8_e5m2) + batch, heads, seq_len, head_dim = q.shape + q_fp8_reshaped = q_fp8.reshape(batch * heads, seq_len, head_dim) + k_fp8_reshaped = k_fp8.reshape(batch * heads, seq_len, head_dim) + v_fp8_reshaped = v_fp8.reshape(batch * heads, head_dim, seq_len) + return q_fp8_reshaped, k_fp8_reshaped, v_fp8_reshaped + + +def fp8_attention_tritonbench( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor +) -> Callable[[], torch.Tensor]: + batch, heads, seq_len, head_dim = q.shape + q_fp8, k_fp8, v_fp8 = preprocess_fp8_attention_inputs(q, k, v) + # Return lambda that calls the kernel - preprocessing is done outside + return ( + lambda: fp8_attention_kernel(q_fp8, k_fp8, v_fp8) + .reshape(batch, heads, seq_len, head_dim) + .to(torch.float8_e5m2) + ) + + +def fp8_attention_pytorch( + q: torch.Tensor, # [batch, heads, seq, dim] + k: torch.Tensor, # [batch, heads, seq, dim] + v: torch.Tensor, # [batch, heads, seq, dim] +) -> torch.Tensor: + """ + Baseline PyTorch implementation of FP8 attention using FP8 e5m2. + """ + batch, heads, seq_len, head_dim = q.shape + q_fp8, k_fp8, v_fp8 = preprocess_fp8_attention_inputs(q, k, v) + + sm_scale = 1.0 / math.sqrt(float(head_dim)) + + outputs = [] + + for i in range(batch * heads): + q_i = q_fp8[i] # [seq, dim] - already FP8 + k_i = k_fp8[i] # [seq, dim] - already FP8 + v_i = v_fp8[i] # [dim, seq] - pre-transposed, already FP8 + + # For Q @ K^T, we need K^T to be column-major + kt_fp8 = k_i.t() # column-major [dim, seq] + + # Q @ K^T - dequantize and use regular matmul since e5m2 not supported by _scaled_mm + q_deq = q_i.to(torch.float32) + kt_deq = kt_fp8.to(torch.float32) + qk = torch.matmul(q_deq, kt_deq) + + # Compute max before scaling + qk_max = torch.amax(qk, dim=-1, keepdim=True) + + # Scale and shift in one operation, then use exp2 + qk_scaled_shifted = qk * sm_scale - qk_max * sm_scale + p = torch.exp2(qk_scaled_shifted * 1.44269504) + + # Normalize + p_norm = p / p.sum(dim=-1, keepdim=True) + + # Step 2: Attention @ V using FP8 + # P is [seq, seq], V is [dim, seq] + # We want P @ V^T = [seq, seq] @ [seq, dim] = [seq, dim] + p_fp8 = p_norm.to(torch.float8_e5m2) # row-major [seq, seq] + + # v_i is [dim, seq], already FP8 + vt_fp8 = v_i.t() # column-major [seq, dim] + + # P @ V^T - dequantize and use regular matmul since e5m2 not supported by torch._scaled_mm + p_deq = p_fp8.to(torch.float32) + vt_deq = vt_fp8.to(torch.float32) + out_i = torch.matmul(p_deq, vt_deq) + + outputs.append(out_i) + + # Stack and reshape back + out_stacked = torch.stack(outputs, dim=0) # [batch*heads, seq, dim] + out = out_stacked.reshape(batch, heads, seq_len, head_dim) + + return out.to(torch.float16) + + +def check(batch: int, heads: int, seq_len: int, head_dim: int) -> None: + torch.manual_seed(42) + q = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda") + k = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda") + v = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda") + + from helion._testing import run_example + + helion_fn = fp8_attention_tritonbench(q, k, v) + run_example( + lambda q, k, v: helion_fn().to(torch.float16), + fp8_attention_pytorch, + (q, k, v), + atol=0.1, + rtol=0.1, + ) + + +def main() -> None: + check(1, 2, 128, 64) + check(2, 4, 256, 64) + check(4, 8, 512, 128) + + +if __name__ == "__main__": + main() diff --git a/test/test_examples.expected b/test/test_examples.expected index cc597ddd..2ef07287 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -574,6 +574,80 @@ def embedding(x: torch.Tensor, weight: torch.Tensor, *, _launcher=_default_launc _launcher(_embedding_kernel, (x_flat.size(0) * triton.cdiv(embedding_dim, _BLOCK_SIZE_1),), x_flat, weight, out, x_flat.size(0), out.stride(0), out.stride(1), weight.stride(0), weight.stride(1), x_flat.stride(0), embedding_dim, _BLOCK_SIZE_1, num_warps=4, num_stages=3) return out.view(*x.size(), embedding_dim) +--- assertExpectedJournal(TestExamples.test_fp8_attention) +from __future__ import annotations + +import math +import torch +import triton +import triton.language as tl +from torch._inductor.runtime import triton_helpers +from torch._inductor.runtime.triton_compat import libdevice +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _fp8_attention_kernel_kernel(q, k, v, out, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + indices_5 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + for offset_4 in tl.range(0, 256, _BLOCK_SIZE_1): + indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + m_i = tl.full([_BLOCK_SIZE_1], float('-inf'), tl.float32) + l_i = tl.full([_BLOCK_SIZE_1], 0.0, tl.float32) + acc = tl.full([_BLOCK_SIZE_1, 64], 0.0, tl.float32) + q_tile = tl.load(q + (offset_0 * 16384 + indices_4[:, None] * 64 + indices_5[None, :] * 1), None) + for offset_2 in tl.range(0, 256, _BLOCK_SIZE_3): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) + q_tile_copy = q_tile + m_i_copy = m_i + l_i_copy = l_i + acc_copy = acc + q_tile_copy_0 = q_tile_copy + m_i_copy_0 = m_i_copy + l_i_copy_0 = l_i_copy + acc_copy_0 = acc_copy + k_tile = tl.load(k + (offset_0 * 16384 + indices_2[:, None] * 64 + indices_5[None, :] * 1), None) + k_tile_t = tl.permute(k_tile, [1, 0]) + mm = tl.dot(q_tile_copy_0, k_tile_t) + v_0 = mm.to(tl.float32) + v_1 = 0.18033688 + v_2 = v_0 * v_1 + qk_max = tl.max(v_2, 1) + v_3 = triton_helpers.maximum(m_i_copy_0, qk_max) + subscript = v_3[:, None] + v_4 = v_2 - subscript + v_5 = libdevice.exp2(v_4) + l_ij = tl.sum(v_5, 1) + v_6 = m_i_copy_0 - v_3 + v_7 = libdevice.exp2(v_6) + v_8 = l_i_copy_0 * v_7 + l_i = v_8 + l_ij + subscript_1 = v_7[:, None] + v_10 = acc_copy_0 * subscript_1 + v_tile = tl.load(v + (offset_0 * 16384 + indices_5[:, None] * 1 + indices_2[None, :] * 64), None) + v_11 = v_5.to(tl.float8e5) + v_t = tl.permute(v_tile, [1, 0]) + mm_1 = tl.dot(v_11, v_t) + v_12 = mm_1.to(tl.float32) + acc = v_10 + v_12 + m_i = v_3 + subscript_2 = l_i[:, None] + v_14 = acc / subscript_2 + tl.store(out + (offset_0 * 16384 + indices_4[:, None] * 64 + indices_5[None, :] * 1), v_14, None) + +def fp8_attention_kernel(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *, _launcher=_default_launcher): + batch_heads = q.size(0) + seq_len = q.size(1) + head_dim = q.size(2) + out = torch.empty([batch_heads, seq_len, head_dim], dtype=torch.float32, device=q.device) + sm_scale = 1.0 / math.sqrt(float(head_dim)) + sm_scale = sm_scale * 1.44269504 + _RDIM_SIZE_2 = 64 + _BLOCK_SIZE_1 = 64 + _BLOCK_SIZE_3 = 64 + _launcher(_fp8_attention_kernel_kernel, (8,), q, k, v, out, _RDIM_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3) + return out + --- assertExpectedJournal(TestExamples.test_fp8_gemm) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index 16bf049d..cda17e6f 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -557,6 +557,47 @@ def test_attention_persistent_interleaved_l2_grouping(self): ) ) + @unittest.skipIf( + not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9, + "FP8 requires GPU with compute capability >= 9.0 (e.g., H100)", + ) + def test_fp8_attention(self): + batch = 2 + heads = 4 + seq_len = 256 + head_dim = 64 + + # Create FP16 tensors + q = torch.randn( + batch, heads, seq_len, head_dim, dtype=torch.float16, device=DEVICE + ) + k = torch.randn( + batch, heads, seq_len, head_dim, dtype=torch.float16, device=DEVICE + ) + v = torch.randn( + batch, heads, seq_len, head_dim, dtype=torch.float16, device=DEVICE + ) + + # Import the module + mod = import_path(EXAMPLES_DIR / "fp8_attention.py") + + # Prepare FP8 inputs using the module's preprocessing function + q_fp8, k_fp8, v_fp8 = mod.preprocess_fp8_attention_inputs(q, k, v) + args = (q_fp8, k_fp8, v_fp8) + + # Get expected output from kernel + expected = mod.fp8_attention_kernel(*args) + + self.assertExpectedJournal( + check_example( + "fp8_attention", + args, + expected, + fn_name="fp8_attention_kernel", + block_sizes=[64, 64], + ) + ) + if __name__ == "__main__": unittest.main()