Skip to content

Commit ea04d03

Browse files
committed
Add fp8_attention example and unit test
1 parent 8f5068c commit ea04d03

File tree

3 files changed

+362
-0
lines changed

3 files changed

+362
-0
lines changed

examples/fp8_attention.py

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
from __future__ import annotations
2+
3+
import math
4+
5+
import torch
6+
7+
import helion
8+
import helion.language as hl
9+
10+
11+
@helion.kernel(static_shapes=True)
12+
def fp8_attention_kernel(
13+
q: torch.Tensor, # [batch*heads, seq, dim]
14+
k: torch.Tensor, # [batch*heads, seq, dim]
15+
v: torch.Tensor, # [batch*heads, dim, seq] - pre-transposed
16+
) -> torch.Tensor:
17+
batch_heads = q.size(0)
18+
seq_len = q.size(1)
19+
head_dim = q.size(2)
20+
21+
# Output tensor
22+
out = torch.empty(
23+
[batch_heads, seq_len, head_dim], dtype=torch.float32, device=q.device
24+
)
25+
26+
# Scale factor for attention
27+
sm_scale = 1.0 / math.sqrt(float(head_dim))
28+
# Triton kernel multiplies sm_scale by 1.44269504 (1/log(2)) for exp2
29+
sm_scale = sm_scale * 1.44269504
30+
31+
# Process each batch*head in parallel
32+
for bh in hl.grid(batch_heads):
33+
# Process each query position
34+
for tile_m in hl.tile(seq_len):
35+
# Initialize for online softmax
36+
m_i = hl.full([tile_m], float("-inf"), dtype=torch.float32)
37+
l_i = hl.full([tile_m], 0.0, dtype=torch.float32)
38+
acc = hl.zeros([tile_m, head_dim], dtype=torch.float32)
39+
40+
# Load query tile - keep in FP8
41+
q_tile = q[bh, tile_m, :] # [tile_m, dim]
42+
43+
# Compute attention scores for all keys
44+
for tile_n in hl.tile(seq_len):
45+
# Load key tile and transpose for Q @ K^T
46+
k_tile = k[bh, tile_n, :] # [tile_n, dim] - keep in FP8
47+
k_tile_t = k_tile.transpose(0, 1) # [dim, tile_n]
48+
49+
# Compute Q @ K^T with FP8 inputs, result in FP32
50+
qk = torch.matmul(q_tile, k_tile_t).to(
51+
torch.float32
52+
) # [tile_m, tile_n]
53+
54+
# Scale QK scores first
55+
qk_scaled = qk * sm_scale # [tile_m, tile_n]
56+
57+
# Compute max of scaled scores
58+
qk_max = torch.amax(qk_scaled, dim=-1) # [tile_m]
59+
60+
# Update global max
61+
m_new = torch.maximum(m_i, qk_max)
62+
63+
# Shift by max for numerical stability
64+
qk_shifted = qk_scaled - m_new[:, None]
65+
66+
# Use exp2 to match Triton kernel's implementation
67+
# Note: Triton kernel already multiplies sm_scale by 1.44269504
68+
p = torch.exp2(qk_shifted) # [tile_m, tile_n]
69+
70+
# Sum of exponentials for this block
71+
l_ij = torch.sum(p, dim=-1) # [tile_m]
72+
73+
# Update accumulators with correction factor
74+
# Correction factor for previous blocks
75+
alpha = torch.exp2(m_i - m_new)
76+
l_i = l_i * alpha + l_ij
77+
acc = acc * alpha[:, None]
78+
79+
# Load values - V is [dim, seq]
80+
v_tile = v[bh, :, tile_n] # [dim, tile_n] - keep in FP8
81+
82+
# Convert p to FP8 for FP8 GEMM
83+
p_fp8 = p.to(v.dtype) # Convert to same FP8 type as V
84+
85+
# Accumulate attention @ V with FP8 GEMM
86+
v_t = v_tile.transpose(0, 1) # [tile_n, dim]
87+
pv = torch.matmul(p_fp8, v_t).to(torch.float32) # [tile_m, dim]
88+
acc = acc + pv
89+
90+
# Update max tracker
91+
m_i = m_new
92+
93+
# Final normalization
94+
acc = acc / l_i[:, None]
95+
out[bh, tile_m, :] = acc
96+
97+
return out
98+
99+
100+
def prepare_fp8_attention_inputs(
101+
q: torch.Tensor, # [batch, heads, seq, dim]
102+
k: torch.Tensor, # [batch, heads, seq, dim]
103+
v: torch.Tensor, # [batch, heads, seq, dim]
104+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[int, int, int, int]]:
105+
"""
106+
Common preprocessing for FP8 attention implementations.
107+
108+
Returns:
109+
q_reshaped_fp8: [batch*heads, seq, dim] - in FP8 e5m2
110+
k_reshaped_fp8: [batch*heads, seq, dim] - in FP8 e5m2
111+
v_transposed_fp8: [batch*heads, dim, seq] - in FP8 e5m2
112+
shape: (batch, heads, seq_len, head_dim)
113+
"""
114+
batch, heads, seq_len, head_dim = q.shape
115+
116+
# Reshape to [batch*heads, seq, dim]
117+
q_reshaped = q.reshape(batch * heads, seq_len, head_dim)
118+
k_reshaped = k.reshape(batch * heads, seq_len, head_dim)
119+
120+
# Transpose V to [batch, heads, dim, seq] then reshape
121+
v_transposed = v.permute(0, 1, 3, 2).reshape(batch * heads, head_dim, seq_len)
122+
123+
# Convert to FP8 e5m2
124+
q_reshaped_fp8 = q_reshaped.to(torch.float8_e5m2)
125+
k_reshaped_fp8 = k_reshaped.to(torch.float8_e5m2)
126+
v_transposed_fp8 = v_transposed.to(torch.float8_e5m2)
127+
128+
return (
129+
q_reshaped_fp8,
130+
k_reshaped_fp8,
131+
v_transposed_fp8,
132+
(batch, heads, seq_len, head_dim),
133+
)
134+
135+
136+
def fp8_attention_tritonbench(
137+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
138+
) -> torch.Tensor:
139+
"""Wrapper for TritonBench compatibility."""
140+
# Common preprocessing with FP8 conversion
141+
q_fp8, k_fp8, v_fp8, shape = prepare_fp8_attention_inputs(q, k, v)
142+
batch, heads, seq_len, head_dim = shape
143+
144+
# Call the fused kernel
145+
out_fused = fp8_attention_kernel(q_fp8, k_fp8, v_fp8)
146+
147+
# Reshape back and convert to FP16
148+
out = out_fused.reshape(batch, heads, seq_len, head_dim)
149+
return out.to(torch.float16)
150+
151+
152+
def fp8_attention_pytorch(
153+
q: torch.Tensor, # [batch, heads, seq, dim]
154+
k: torch.Tensor, # [batch, heads, seq, dim]
155+
v: torch.Tensor, # [batch, heads, seq, dim]
156+
) -> torch.Tensor:
157+
"""
158+
Baseline PyTorch implementation of FP8 attention using FP8 e5m2.
159+
"""
160+
# Get preprocessed inputs with FP8 conversion
161+
q_fp8, k_fp8, v_fp8, shape = prepare_fp8_attention_inputs(q, k, v)
162+
batch, heads, seq_len, head_dim = shape
163+
164+
sm_scale = 1.0 / math.sqrt(float(head_dim))
165+
166+
outputs = []
167+
168+
for i in range(batch * heads):
169+
q_i = q_fp8[i] # [seq, dim] - already FP8
170+
k_i = k_fp8[i] # [seq, dim] - already FP8
171+
v_i = v_fp8[i] # [dim, seq] - pre-transposed, already FP8
172+
173+
# For Q @ K^T, we need K^T to be column-major
174+
kt_fp8 = k_i.t() # column-major [dim, seq]
175+
176+
# Q @ K^T - dequantize and use regular matmul since e5m2 not supported by _scaled_mm
177+
q_deq = q_i.to(torch.float32)
178+
kt_deq = kt_fp8.to(torch.float32)
179+
qk = torch.matmul(q_deq, kt_deq)
180+
181+
# Compute max before scaling
182+
qk_max = torch.amax(qk, dim=-1, keepdim=True)
183+
184+
# Scale and shift in one operation, then use exp2
185+
qk_scaled_shifted = qk * sm_scale - qk_max * sm_scale
186+
p = torch.exp2(qk_scaled_shifted * 1.44269504)
187+
188+
# Normalize
189+
p_norm = p / p.sum(dim=-1, keepdim=True)
190+
191+
# Step 2: Attention @ V using FP8
192+
# P is [seq, seq], V is [dim, seq]
193+
# We want P @ V^T = [seq, seq] @ [seq, dim] = [seq, dim]
194+
p_fp8 = p_norm.to(torch.float8_e5m2) # row-major [seq, seq]
195+
196+
# v_i is [dim, seq], already FP8
197+
vt_fp8 = v_i.t() # column-major [seq, dim]
198+
199+
# P @ V^T - dequantize and use regular matmul since e5m2 not supported by torch._scaled_mm
200+
p_deq = p_fp8.to(torch.float32)
201+
vt_deq = vt_fp8.to(torch.float32)
202+
out_i = torch.matmul(p_deq, vt_deq)
203+
204+
outputs.append(out_i)
205+
206+
# Stack and reshape back
207+
out_stacked = torch.stack(outputs, dim=0) # [batch*heads, seq, dim]
208+
out = out_stacked.reshape(batch, heads, seq_len, head_dim)
209+
210+
return out.to(torch.float16)
211+
212+
213+
def check(batch: int, heads: int, seq_len: int, head_dim: int) -> None:
214+
torch.manual_seed(42)
215+
q = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda")
216+
k = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda")
217+
v = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda")
218+
219+
from helion._testing import run_example
220+
221+
run_example(
222+
fp8_attention_tritonbench, fp8_attention_pytorch, (q, k, v), atol=0.1, rtol=0.1
223+
)
224+
225+
226+
def main() -> None:
227+
check(1, 2, 128, 64)
228+
check(2, 4, 256, 64)
229+
check(4, 8, 512, 128)
230+
231+
232+
if __name__ == "__main__":
233+
main()

test/test_examples.expected

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,94 @@ def _embedding_make_precompiler(x: torch.Tensor, weight: torch.Tensor):
535535
from helion.runtime.precompile_shim import make_precompiler
536536
return make_precompiler(_embedding_kernel)(x_flat, weight, out, x_flat.size(0), out.stride(0), out.stride(1), weight.stride(0), weight.stride(1), x_flat.stride(0), embedding_dim, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
537537

538+
--- assertExpectedJournal(TestExamples.test_fp8_attention)
539+
from __future__ import annotations
540+
541+
import math
542+
import torch
543+
import triton
544+
import triton.language as tl
545+
from torch._inductor.runtime import triton_helpers
546+
from torch._inductor.runtime.triton_compat import libdevice
547+
548+
@triton.jit
549+
def _fp8_attention_kernel_kernel(q, k, v, out, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
550+
pid_0 = tl.program_id(0)
551+
offset_0 = pid_0
552+
indices_5 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
553+
for offset_4 in tl.range(0, 256, _BLOCK_SIZE_1):
554+
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
555+
m_i = tl.full([_BLOCK_SIZE_1], float('-inf'), tl.float32)
556+
l_i = tl.full([_BLOCK_SIZE_1], 0.0, tl.float32)
557+
acc = tl.full([_BLOCK_SIZE_1, 64], 0.0, tl.float32)
558+
q_tile = tl.load(q + (offset_0 * 16384 + indices_4[:, None] * 64 + indices_5[None, :] * 1), None)
559+
for offset_2 in tl.range(0, 256, _BLOCK_SIZE_3):
560+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
561+
q_tile_copy = q_tile
562+
m_i_copy = m_i
563+
l_i_copy = l_i
564+
acc_copy = acc
565+
q_tile_copy_0 = q_tile_copy
566+
m_i_copy_0 = m_i_copy
567+
l_i_copy_0 = l_i_copy
568+
acc_copy_0 = acc_copy
569+
k_tile = tl.load(k + (offset_0 * 16384 + indices_2[:, None] * 64 + indices_5[None, :] * 1), None)
570+
k_tile_t = tl.permute(k_tile, [1, 0])
571+
mm = tl.dot(q_tile_copy_0, k_tile_t)
572+
v_0 = mm.to(tl.float32)
573+
v_1 = 0.18033688
574+
v_2 = v_0 * v_1
575+
qk_max = tl.max(v_2, 1)
576+
v_3 = triton_helpers.maximum(m_i_copy_0, qk_max)
577+
subscript = v_3[:, None]
578+
v_4 = v_2 - subscript
579+
v_5 = libdevice.exp2(v_4)
580+
l_ij = tl.sum(v_5, 1)
581+
v_6 = m_i_copy_0 - v_3
582+
v_7 = libdevice.exp2(v_6)
583+
v_8 = l_i_copy_0 * v_7
584+
l_i = v_8 + l_ij
585+
subscript_1 = v_7[:, None]
586+
v_10 = acc_copy_0 * subscript_1
587+
v_tile = tl.load(v + (offset_0 * 16384 + indices_5[:, None] * 1 + indices_2[None, :] * 64), None)
588+
v_11 = v_5.to(tl.float8e5)
589+
v_t = tl.permute(v_tile, [1, 0])
590+
mm_1 = tl.dot(v_11, v_t)
591+
v_12 = mm_1.to(tl.float32)
592+
acc = v_10 + v_12
593+
m_i = v_3
594+
subscript_2 = l_i[:, None]
595+
v_14 = acc / subscript_2
596+
tl.store(out + (offset_0 * 16384 + indices_4[:, None] * 64 + indices_5[None, :] * 1), v_14, None)
597+
598+
def fp8_attention_kernel(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
599+
"""FP8 attention kernel processing batch*heads in parallel."""
600+
batch_heads = q.size(0)
601+
seq_len = q.size(1)
602+
head_dim = q.size(2)
603+
out = torch.empty([batch_heads, seq_len, head_dim], dtype=torch.float32, device=q.device)
604+
sm_scale = 1.0 / math.sqrt(float(head_dim))
605+
sm_scale = sm_scale * 1.44269504
606+
_RDIM_SIZE_2 = 64
607+
_BLOCK_SIZE_1 = 64
608+
_BLOCK_SIZE_3 = 64
609+
_fp8_attention_kernel_kernel[8,](q, k, v, out, _RDIM_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
610+
return out
611+
612+
def _fp8_attention_kernel_make_precompiler(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
613+
"""FP8 attention kernel processing batch*heads in parallel."""
614+
batch_heads = q.size(0)
615+
seq_len = q.size(1)
616+
head_dim = q.size(2)
617+
out = torch.empty([batch_heads, seq_len, head_dim], dtype=torch.float32, device=q.device)
618+
sm_scale = 1.0 / math.sqrt(float(head_dim))
619+
sm_scale = sm_scale * 1.44269504
620+
_RDIM_SIZE_2 = 64
621+
_BLOCK_SIZE_1 = 64
622+
_BLOCK_SIZE_3 = 64
623+
from helion.runtime.precompile_shim import make_precompiler
624+
return make_precompiler(_fp8_attention_kernel_kernel)(q, k, v, out, _RDIM_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
625+
538626
--- assertExpectedJournal(TestExamples.test_fp8_gemm)
539627
from __future__ import annotations
540628

test/test_examples.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,47 @@ def test_segment_reduction(self):
523523
)
524524
)
525525

526+
@unittest.skipIf(
527+
not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9,
528+
"FP8 requires GPU with compute capability >= 9.0 (e.g., H100)",
529+
)
530+
def test_fp8_attention(self):
531+
batch = 2
532+
heads = 4
533+
seq_len = 256
534+
head_dim = 64
535+
536+
# Create FP16 tensors
537+
q = torch.randn(
538+
batch, heads, seq_len, head_dim, dtype=torch.float16, device=DEVICE
539+
)
540+
k = torch.randn(
541+
batch, heads, seq_len, head_dim, dtype=torch.float16, device=DEVICE
542+
)
543+
v = torch.randn(
544+
batch, heads, seq_len, head_dim, dtype=torch.float16, device=DEVICE
545+
)
546+
547+
# Import the module
548+
mod = import_path(EXAMPLES_DIR / "fp8_attention.py")
549+
550+
# Prepare FP8 inputs using the module's preprocessing function
551+
q_fp8, k_fp8, v_fp8, shape = mod.prepare_fp8_attention_inputs(q, k, v)
552+
args = (q_fp8, k_fp8, v_fp8)
553+
554+
# Get expected output from kernel
555+
expected = mod.fp8_attention_kernel(*args)
556+
557+
self.assertExpectedJournal(
558+
check_example(
559+
"fp8_attention",
560+
args,
561+
expected,
562+
fn_name="fp8_attention_kernel",
563+
block_sizes=[64, 64],
564+
)
565+
)
566+
526567

527568
if __name__ == "__main__":
528569
unittest.main()

0 commit comments

Comments
 (0)