Skip to content

Add fp8_attention example and unit test #318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 233 additions & 0 deletions examples/fp8_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
from __future__ import annotations

import math

import torch

import helion
import helion.language as hl


@helion.kernel(static_shapes=True)
def fp8_attention_kernel(
q: torch.Tensor, # [batch*heads, seq, dim]
k: torch.Tensor, # [batch*heads, seq, dim]
v: torch.Tensor, # [batch*heads, dim, seq] - pre-transposed
) -> torch.Tensor:
batch_heads = q.size(0)
seq_len = q.size(1)
head_dim = q.size(2)

# Output tensor
out = torch.empty(
[batch_heads, seq_len, head_dim], dtype=torch.float32, device=q.device
)

# Scale factor for attention
sm_scale = 1.0 / math.sqrt(float(head_dim))
# Triton kernel multiplies sm_scale by 1.44269504 (1/log(2)) for exp2
sm_scale = sm_scale * 1.44269504

# Process each batch*head in parallel
for bh in hl.grid(batch_heads):
# Process each query position
for tile_m in hl.tile(seq_len):
# Initialize for online softmax
m_i = hl.full([tile_m], float("-inf"), dtype=torch.float32)
l_i = hl.full([tile_m], 0.0, dtype=torch.float32)
acc = hl.zeros([tile_m, head_dim], dtype=torch.float32)

# Load query tile - keep in FP8
q_tile = q[bh, tile_m, :] # [tile_m, dim]

# Compute attention scores for all keys
for tile_n in hl.tile(seq_len):
# Load key tile and transpose for Q @ K^T
k_tile = k[bh, tile_n, :] # [tile_n, dim] - keep in FP8
k_tile_t = k_tile.transpose(0, 1) # [dim, tile_n]

# Compute Q @ K^T with FP8 inputs, result in FP32
qk = torch.matmul(q_tile, k_tile_t).to(
torch.float32
) # [tile_m, tile_n]

# Scale QK scores first
qk_scaled = qk * sm_scale # [tile_m, tile_n]

# Compute max of scaled scores
qk_max = torch.amax(qk_scaled, dim=-1) # [tile_m]

# Update global max
m_new = torch.maximum(m_i, qk_max)

# Shift by max for numerical stability
qk_shifted = qk_scaled - m_new[:, None]

# Use exp2 to match Triton kernel's implementation
# Note: Triton kernel already multiplies sm_scale by 1.44269504
p = torch.exp2(qk_shifted) # [tile_m, tile_n]

# Sum of exponentials for this block
l_ij = torch.sum(p, dim=-1) # [tile_m]

# Update accumulators with correction factor
# Correction factor for previous blocks
alpha = torch.exp2(m_i - m_new)
l_i = l_i * alpha + l_ij
acc = acc * alpha[:, None]

# Load values - V is [dim, seq]
v_tile = v[bh, :, tile_n] # [dim, tile_n] - keep in FP8

# Convert p to FP8 for FP8 GEMM
p_fp8 = p.to(v.dtype) # Convert to same FP8 type as V

# Accumulate attention @ V with FP8 GEMM
v_t = v_tile.transpose(0, 1) # [tile_n, dim]
pv = torch.matmul(p_fp8, v_t).to(torch.float32) # [tile_m, dim]
acc = acc + pv

# Update max tracker
m_i = m_new

# Final normalization
acc = acc / l_i[:, None]
out[bh, tile_m, :] = acc

return out


def prepare_fp8_attention_inputs(
q: torch.Tensor, # [batch, heads, seq, dim]
k: torch.Tensor, # [batch, heads, seq, dim]
v: torch.Tensor, # [batch, heads, seq, dim]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[int, int, int, int]]:
"""
Common preprocessing for FP8 attention implementations.

Returns:
q_reshaped_fp8: [batch*heads, seq, dim] - in FP8 e5m2
k_reshaped_fp8: [batch*heads, seq, dim] - in FP8 e5m2
v_transposed_fp8: [batch*heads, dim, seq] - in FP8 e5m2
shape: (batch, heads, seq_len, head_dim)
"""
batch, heads, seq_len, head_dim = q.shape

# Reshape to [batch*heads, seq, dim]
q_reshaped = q.reshape(batch * heads, seq_len, head_dim)
k_reshaped = k.reshape(batch * heads, seq_len, head_dim)

# Transpose V to [batch, heads, dim, seq] then reshape
v_transposed = v.permute(0, 1, 3, 2).reshape(batch * heads, head_dim, seq_len)

# Convert to FP8 e5m2
q_reshaped_fp8 = q_reshaped.to(torch.float8_e5m2)
k_reshaped_fp8 = k_reshaped.to(torch.float8_e5m2)
v_transposed_fp8 = v_transposed.to(torch.float8_e5m2)

return (
q_reshaped_fp8,
k_reshaped_fp8,
v_transposed_fp8,
(batch, heads, seq_len, head_dim),
)


def fp8_attention_tritonbench(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
"""Wrapper for TritonBench compatibility."""
# Common preprocessing with FP8 conversion
q_fp8, k_fp8, v_fp8, shape = prepare_fp8_attention_inputs(q, k, v)
batch, heads, seq_len, head_dim = shape
Comment on lines +141 to +142
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to benchmark this? The PyTorch input prepare ins't really helion work.


# Call the fused kernel
out_fused = fp8_attention_kernel(q_fp8, k_fp8, v_fp8)

# Reshape back and convert to FP16
out = out_fused.reshape(batch, heads, seq_len, head_dim)
return out.to(torch.float16)


def fp8_attention_pytorch(
q: torch.Tensor, # [batch, heads, seq, dim]
k: torch.Tensor, # [batch, heads, seq, dim]
v: torch.Tensor, # [batch, heads, seq, dim]
) -> torch.Tensor:
"""
Baseline PyTorch implementation of FP8 attention using FP8 e5m2.
"""
# Get preprocessed inputs with FP8 conversion
q_fp8, k_fp8, v_fp8, shape = prepare_fp8_attention_inputs(q, k, v)
batch, heads, seq_len, head_dim = shape

sm_scale = 1.0 / math.sqrt(float(head_dim))

outputs = []

for i in range(batch * heads):
q_i = q_fp8[i] # [seq, dim] - already FP8
k_i = k_fp8[i] # [seq, dim] - already FP8
v_i = v_fp8[i] # [dim, seq] - pre-transposed, already FP8

# For Q @ K^T, we need K^T to be column-major
kt_fp8 = k_i.t() # column-major [dim, seq]

# Q @ K^T - dequantize and use regular matmul since e5m2 not supported by _scaled_mm
q_deq = q_i.to(torch.float32)
kt_deq = kt_fp8.to(torch.float32)
qk = torch.matmul(q_deq, kt_deq)

# Compute max before scaling
qk_max = torch.amax(qk, dim=-1, keepdim=True)

# Scale and shift in one operation, then use exp2
qk_scaled_shifted = qk * sm_scale - qk_max * sm_scale
p = torch.exp2(qk_scaled_shifted * 1.44269504)

# Normalize
p_norm = p / p.sum(dim=-1, keepdim=True)

# Step 2: Attention @ V using FP8
# P is [seq, seq], V is [dim, seq]
# We want P @ V^T = [seq, seq] @ [seq, dim] = [seq, dim]
p_fp8 = p_norm.to(torch.float8_e5m2) # row-major [seq, seq]

# v_i is [dim, seq], already FP8
vt_fp8 = v_i.t() # column-major [seq, dim]

# P @ V^T - dequantize and use regular matmul since e5m2 not supported by torch._scaled_mm
p_deq = p_fp8.to(torch.float32)
vt_deq = vt_fp8.to(torch.float32)
out_i = torch.matmul(p_deq, vt_deq)

outputs.append(out_i)

# Stack and reshape back
out_stacked = torch.stack(outputs, dim=0) # [batch*heads, seq, dim]
out = out_stacked.reshape(batch, heads, seq_len, head_dim)

return out.to(torch.float16)


def check(batch: int, heads: int, seq_len: int, head_dim: int) -> None:
torch.manual_seed(42)
q = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda")
k = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda")
v = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda")

from helion._testing import run_example

run_example(
fp8_attention_tritonbench, fp8_attention_pytorch, (q, k, v), atol=0.1, rtol=0.1
)


def main() -> None:
check(1, 2, 128, 64)
check(2, 4, 256, 64)
check(4, 8, 512, 128)


if __name__ == "__main__":
main()
74 changes: 74 additions & 0 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,80 @@ def embedding(x: torch.Tensor, weight: torch.Tensor, *, _launcher=_default_launc
_launcher(_embedding_kernel, (x_flat.size(0) * triton.cdiv(embedding_dim, _BLOCK_SIZE_1),), x_flat, weight, out, x_flat.size(0), out.stride(0), out.stride(1), weight.stride(0), weight.stride(1), x_flat.stride(0), embedding_dim, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return out.view(*x.size(), embedding_dim)

--- assertExpectedJournal(TestExamples.test_fp8_attention)
from __future__ import annotations

import math
import torch
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers
from torch._inductor.runtime.triton_compat import libdevice
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _fp8_attention_kernel_kernel(q, k, v, out, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0
indices_5 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
for offset_4 in tl.range(0, 256, _BLOCK_SIZE_1):
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
m_i = tl.full([_BLOCK_SIZE_1], float('-inf'), tl.float32)
l_i = tl.full([_BLOCK_SIZE_1], 0.0, tl.float32)
acc = tl.full([_BLOCK_SIZE_1, 64], 0.0, tl.float32)
q_tile = tl.load(q + (offset_0 * 16384 + indices_4[:, None] * 64 + indices_5[None, :] * 1), None)
for offset_2 in tl.range(0, 256, _BLOCK_SIZE_3):
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
q_tile_copy = q_tile
m_i_copy = m_i
l_i_copy = l_i
acc_copy = acc
q_tile_copy_0 = q_tile_copy
m_i_copy_0 = m_i_copy
l_i_copy_0 = l_i_copy
acc_copy_0 = acc_copy
k_tile = tl.load(k + (offset_0 * 16384 + indices_2[:, None] * 64 + indices_5[None, :] * 1), None)
k_tile_t = tl.permute(k_tile, [1, 0])
mm = tl.dot(q_tile_copy_0, k_tile_t)
v_0 = mm.to(tl.float32)
v_1 = 0.18033688
v_2 = v_0 * v_1
qk_max = tl.max(v_2, 1)
v_3 = triton_helpers.maximum(m_i_copy_0, qk_max)
subscript = v_3[:, None]
v_4 = v_2 - subscript
v_5 = libdevice.exp2(v_4)
l_ij = tl.sum(v_5, 1)
v_6 = m_i_copy_0 - v_3
v_7 = libdevice.exp2(v_6)
v_8 = l_i_copy_0 * v_7
l_i = v_8 + l_ij
subscript_1 = v_7[:, None]
v_10 = acc_copy_0 * subscript_1
v_tile = tl.load(v + (offset_0 * 16384 + indices_5[:, None] * 1 + indices_2[None, :] * 64), None)
v_11 = v_5.to(tl.float8e5)
v_t = tl.permute(v_tile, [1, 0])
mm_1 = tl.dot(v_11, v_t)
v_12 = mm_1.to(tl.float32)
acc = v_10 + v_12
m_i = v_3
subscript_2 = l_i[:, None]
v_14 = acc / subscript_2
tl.store(out + (offset_0 * 16384 + indices_4[:, None] * 64 + indices_5[None, :] * 1), v_14, None)

def fp8_attention_kernel(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *, _launcher=_default_launcher):
batch_heads = q.size(0)
seq_len = q.size(1)
head_dim = q.size(2)
out = torch.empty([batch_heads, seq_len, head_dim], dtype=torch.float32, device=q.device)
sm_scale = 1.0 / math.sqrt(float(head_dim))
sm_scale = sm_scale * 1.44269504
_RDIM_SIZE_2 = 64
_BLOCK_SIZE_1 = 64
_BLOCK_SIZE_3 = 64
_launcher(_fp8_attention_kernel_kernel, (8,), q, k, v, out, _RDIM_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestExamples.test_fp8_gemm)
from __future__ import annotations

Expand Down
41 changes: 41 additions & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,47 @@ def test_attention_persistent_interleaved_l2_grouping(self):
)
)

@unittest.skipIf(
not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9,
"FP8 requires GPU with compute capability >= 9.0 (e.g., H100)",
)
def test_fp8_attention(self):
batch = 2
heads = 4
seq_len = 256
head_dim = 64

# Create FP16 tensors
q = torch.randn(
batch, heads, seq_len, head_dim, dtype=torch.float16, device=DEVICE
)
k = torch.randn(
batch, heads, seq_len, head_dim, dtype=torch.float16, device=DEVICE
)
v = torch.randn(
batch, heads, seq_len, head_dim, dtype=torch.float16, device=DEVICE
)

# Import the module
mod = import_path(EXAMPLES_DIR / "fp8_attention.py")

# Prepare FP8 inputs using the module's preprocessing function
q_fp8, k_fp8, v_fp8, shape = mod.prepare_fp8_attention_inputs(q, k, v)
args = (q_fp8, k_fp8, v_fp8)

# Get expected output from kernel
expected = mod.fp8_attention_kernel(*args)

self.assertExpectedJournal(
check_example(
"fp8_attention",
args,
expected,
fn_name="fp8_attention_kernel",
block_sizes=[64, 64],
)
)


if __name__ == "__main__":
unittest.main()
Loading