Skip to content

Commit 3175785

Browse files
authored
Add fp8_attention example and unit test (#318)
1 parent 14110be commit 3175785

File tree

4 files changed

+362
-2
lines changed

4 files changed

+362
-2
lines changed

examples/fp8_attention.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
from __future__ import annotations
2+
3+
import math
4+
from typing import Callable
5+
6+
import torch
7+
8+
import helion
9+
import helion.language as hl
10+
11+
12+
@helion.kernel(static_shapes=True)
13+
def fp8_attention_kernel(
14+
q: torch.Tensor, # [batch*heads, seq, dim]
15+
k: torch.Tensor, # [batch*heads, seq, dim]
16+
v: torch.Tensor, # [batch*heads, dim, seq] - pre-transposed
17+
batch: int,
18+
heads: int,
19+
) -> torch.Tensor:
20+
batch_heads = q.size(0)
21+
seq_len = q.size(1)
22+
head_dim = q.size(2)
23+
24+
# Output tensor with 4D shape in FP8 format
25+
out = torch.empty(
26+
[batch, heads, seq_len, head_dim], dtype=torch.float8_e5m2, device=q.device
27+
)
28+
29+
# Scale factor for attention
30+
sm_scale = 1.0 / math.sqrt(float(head_dim))
31+
# Triton kernel multiplies sm_scale by 1.44269504 (1/log(2)) for exp2
32+
sm_scale = sm_scale * 1.44269504
33+
34+
# Process each batch*head in parallel
35+
for bh in hl.grid(batch_heads):
36+
# Calculate batch and head indices
37+
b = bh // heads
38+
h = bh % heads
39+
40+
# Process each query position
41+
for tile_m in hl.tile(seq_len):
42+
# Initialize for online softmax
43+
m_i = hl.full([tile_m], float("-inf"), dtype=torch.float32)
44+
l_i = hl.full([tile_m], 0.0, dtype=torch.float32)
45+
acc = hl.zeros([tile_m, head_dim], dtype=torch.float32)
46+
47+
# Load query tile - keep in FP8
48+
q_tile = q[bh, tile_m, :] # [tile_m, dim]
49+
50+
# Compute attention scores for all keys
51+
for tile_n in hl.tile(seq_len):
52+
# Load key tile and transpose for Q @ K^T
53+
k_tile = k[bh, tile_n, :] # [tile_n, dim] - keep in FP8
54+
k_tile_t = k_tile.transpose(0, 1) # [dim, tile_n]
55+
56+
# Compute Q @ K^T with FP8 inputs, result in FP32
57+
qk = torch.matmul(q_tile, k_tile_t).to(
58+
torch.float32
59+
) # [tile_m, tile_n]
60+
61+
# Scale QK scores first
62+
qk_scaled = qk * sm_scale # [tile_m, tile_n]
63+
64+
# Compute max of scaled scores
65+
qk_max = torch.amax(qk_scaled, dim=-1) # [tile_m]
66+
67+
# Update global max
68+
m_new = torch.maximum(m_i, qk_max)
69+
70+
# Shift by max for numerical stability
71+
qk_shifted = qk_scaled - m_new[:, None]
72+
73+
# Use exp2 to match Triton kernel's implementation
74+
# Note: Triton kernel already multiplies sm_scale by 1.44269504
75+
p = torch.exp2(qk_shifted) # [tile_m, tile_n]
76+
77+
# Sum of exponentials for this block
78+
l_ij = torch.sum(p, dim=-1) # [tile_m]
79+
80+
# Update accumulators with correction factor
81+
# Correction factor for previous blocks
82+
alpha = torch.exp2(m_i - m_new)
83+
l_i = l_i * alpha + l_ij
84+
acc = acc * alpha[:, None]
85+
86+
# Load values - V is [dim, seq]
87+
v_tile = v[bh, :, tile_n] # [dim, tile_n] - keep in FP8
88+
89+
# Convert p to FP8 for FP8 GEMM
90+
p_fp8 = p.to(v.dtype) # Convert to same FP8 type as V
91+
92+
# Accumulate attention @ V with FP8 GEMM
93+
v_t = v_tile.transpose(0, 1) # [tile_n, dim]
94+
pv = torch.matmul(p_fp8, v_t).to(torch.float32) # [tile_m, dim]
95+
acc = acc + pv
96+
97+
# Update max tracker
98+
m_i = m_new
99+
100+
# Final normalization
101+
acc = acc / l_i[:, None]
102+
# Convert to FP8 before writing to output
103+
out[b, h, tile_m, :] = acc.to(torch.float8_e5m2)
104+
105+
return out
106+
107+
108+
def preprocess_fp8_attention_inputs(
109+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
110+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
111+
q_fp8 = q.to(torch.float8_e5m2)
112+
k_fp8 = k.to(torch.float8_e5m2)
113+
v = v.permute(0, 1, 3, 2)
114+
v_fp8 = v.to(torch.float8_e5m2)
115+
batch, heads, seq_len, head_dim = q.shape
116+
q_fp8_reshaped = q_fp8.reshape(batch * heads, seq_len, head_dim)
117+
k_fp8_reshaped = k_fp8.reshape(batch * heads, seq_len, head_dim)
118+
v_fp8_reshaped = v_fp8.reshape(batch * heads, head_dim, seq_len)
119+
return q_fp8_reshaped, k_fp8_reshaped, v_fp8_reshaped
120+
121+
122+
def fp8_attention_tritonbench(
123+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
124+
) -> Callable[[], torch.Tensor]:
125+
batch, heads, seq_len, head_dim = q.shape
126+
q_fp8, k_fp8, v_fp8 = preprocess_fp8_attention_inputs(q, k, v)
127+
# Return lambda that calls the kernel - preprocessing is done outside.
128+
# This matches the tritonbench kernel timing measurement setup.
129+
return lambda: fp8_attention_kernel(q_fp8, k_fp8, v_fp8, batch, heads)
130+
131+
132+
def _fp8_attention_pytorch_impl(
133+
q_fp8: torch.Tensor,
134+
k_fp8: torch.Tensor,
135+
v_fp8: torch.Tensor,
136+
batch: int,
137+
heads: int,
138+
seq_len: int,
139+
head_dim: int,
140+
) -> torch.Tensor:
141+
sm_scale = 1.0 / math.sqrt(float(head_dim))
142+
143+
outputs = []
144+
145+
for i in range(batch * heads):
146+
q_i = q_fp8[i] # [seq, dim] - already FP8
147+
k_i = k_fp8[i] # [seq, dim] - already FP8
148+
v_i = v_fp8[i] # [dim, seq] - pre-transposed, already FP8
149+
150+
# For Q @ K^T, we need K^T to be column-major
151+
kt_fp8 = k_i.t() # column-major [dim, seq]
152+
153+
# Q @ K^T - dequantize and use regular matmul since e5m2 not supported by _scaled_mm
154+
q_deq = q_i.to(torch.float32)
155+
kt_deq = kt_fp8.to(torch.float32)
156+
qk = torch.matmul(q_deq, kt_deq)
157+
158+
# Compute max before scaling
159+
qk_max = torch.amax(qk, dim=-1, keepdim=True)
160+
161+
# Scale and shift in one operation, then use exp2
162+
qk_scaled_shifted = qk * sm_scale - qk_max * sm_scale
163+
p = torch.exp2(qk_scaled_shifted * 1.44269504)
164+
165+
# Normalize
166+
p_norm = p / p.sum(dim=-1, keepdim=True)
167+
168+
# Step 2: Attention @ V using FP8
169+
# P is [seq, seq], V is [dim, seq]
170+
# We want P @ V^T = [seq, seq] @ [seq, dim] = [seq, dim]
171+
p_fp8 = p_norm.to(torch.float8_e5m2) # row-major [seq, seq]
172+
173+
# v_i is [dim, seq], already FP8
174+
vt_fp8 = v_i.t() # column-major [seq, dim]
175+
176+
# P @ V^T - dequantize and use regular matmul since e5m2 not supported by torch._scaled_mm
177+
p_deq = p_fp8.to(torch.float32)
178+
vt_deq = vt_fp8.to(torch.float32)
179+
out_i = torch.matmul(p_deq, vt_deq)
180+
out_i = out_i.to(torch.float8_e5m2) # convert back to FP8
181+
182+
outputs.append(out_i)
183+
184+
# Stack and reshape back
185+
out_stacked = torch.stack(outputs, dim=0) # [batch*heads, seq, dim]
186+
return out_stacked.reshape(batch, heads, seq_len, head_dim)
187+
188+
189+
def fp8_attention_pytorch(
190+
q: torch.Tensor, # [batch, heads, seq, dim]
191+
k: torch.Tensor, # [batch, heads, seq, dim]
192+
v: torch.Tensor, # [batch, heads, seq, dim]
193+
) -> Callable[[], torch.Tensor]:
194+
"""
195+
Baseline PyTorch implementation of FP8 attention using FP8 e5m2.
196+
"""
197+
batch, heads, seq_len, head_dim = q.shape
198+
q_fp8, k_fp8, v_fp8 = preprocess_fp8_attention_inputs(q, k, v)
199+
# Return lambda that calls the kernel - preprocessing is done outside.
200+
# This matches the Helion kernel timing measurement setup.
201+
return lambda: _fp8_attention_pytorch_impl(
202+
q_fp8, k_fp8, v_fp8, batch, heads, seq_len, head_dim
203+
)
204+
205+
206+
def check(batch: int, heads: int, seq_len: int, head_dim: int) -> None:
207+
torch.manual_seed(42)
208+
q = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda")
209+
k = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda")
210+
v = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda")
211+
212+
from helion._testing import run_example
213+
214+
helion_fn = fp8_attention_tritonbench(q, k, v)
215+
pytorch_fn = fp8_attention_pytorch(q, k, v)
216+
run_example(
217+
helion_fn,
218+
pytorch_fn,
219+
(),
220+
atol=0.1,
221+
rtol=0.1,
222+
)
223+
224+
225+
def main() -> None:
226+
check(1, 2, 128, 64)
227+
check(2, 4, 256, 64)
228+
check(4, 8, 512, 128)
229+
230+
231+
if __name__ == "__main__":
232+
main()

helion/_testing.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,12 @@ def run_example(
9898
for name, func in {**kernels, **baselines}.items():
9999
if name != first_baseline_name:
100100
print(f"Testing {name} correctness...", file=sys.stderr)
101-
torch.testing.assert_close(func(*args), expected, rtol=rtol, atol=atol)
101+
torch.testing.assert_close(
102+
func(*args).to(torch.float32),
103+
expected.to(torch.float32),
104+
rtol=rtol,
105+
atol=atol,
106+
)
102107

103108
# Benchmark all functions
104109
all_times = {
@@ -145,7 +150,12 @@ def check_example(
145150
args,
146151
**kwargs,
147152
)
148-
skip_accuracy or torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-2) # pyright: ignore[reportUnusedExpression]
153+
skip_accuracy or torch.testing.assert_close(
154+
result.to(torch.float32), # pyright: ignore[reportAttributeAccessIssue]
155+
expected.to(torch.float32),
156+
atol=1e-1,
157+
rtol=1e-2,
158+
) # pyright: ignore[reportUnusedExpression]
149159
return code
150160

151161

test/test_examples.expected

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,83 @@ def embedding(x: torch.Tensor, weight: torch.Tensor, *, _launcher=_default_launc
574574
_launcher(_embedding_kernel, (x_flat.size(0) * triton.cdiv(embedding_dim, _BLOCK_SIZE_1),), x_flat, weight, out, x_flat.size(0), out.stride(0), out.stride(1), weight.stride(0), weight.stride(1), x_flat.stride(0), embedding_dim, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
575575
return out.view(*x.size(), embedding_dim)
576576

577+
--- assertExpectedJournal(TestExamples.test_fp8_attention)
578+
from __future__ import annotations
579+
580+
import math
581+
import torch
582+
import triton
583+
import triton.language as tl
584+
from torch._inductor.runtime import triton_helpers
585+
from torch._inductor.runtime.triton_compat import libdevice
586+
from helion.runtime import default_launcher as _default_launcher
587+
588+
@triton.jit
589+
def _fp8_attention_kernel_kernel(q, k, v, out, out_stride_0, heads, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
590+
pid_0 = tl.program_id(0)
591+
offset_0 = pid_0
592+
indices_5 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
593+
for offset_4 in tl.range(0, 256, _BLOCK_SIZE_1):
594+
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
595+
m_i = tl.full([_BLOCK_SIZE_1], float('-inf'), tl.float32)
596+
l_i = tl.full([_BLOCK_SIZE_1], 0.0, tl.float32)
597+
acc = tl.full([_BLOCK_SIZE_1, 64], 0.0, tl.float32)
598+
q_tile = tl.load(q + (offset_0 * 16384 + indices_4[:, None] * 64 + indices_5[None, :] * 1), None)
599+
for offset_2 in tl.range(0, 256, _BLOCK_SIZE_3):
600+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
601+
q_tile_copy = q_tile
602+
m_i_copy = m_i
603+
l_i_copy = l_i
604+
acc_copy = acc
605+
q_tile_copy_0 = q_tile_copy
606+
m_i_copy_0 = m_i_copy
607+
l_i_copy_0 = l_i_copy
608+
acc_copy_0 = acc_copy
609+
k_tile = tl.load(k + (offset_0 * 16384 + indices_2[:, None] * 64 + indices_5[None, :] * 1), None)
610+
k_tile_t = tl.permute(k_tile, [1, 0])
611+
mm = tl.dot(q_tile_copy_0, k_tile_t)
612+
v_0 = mm.to(tl.float32)
613+
v_1 = 0.18033688
614+
v_2 = v_0 * v_1
615+
qk_max = tl.max(v_2, 1)
616+
v_3 = triton_helpers.maximum(m_i_copy_0, qk_max)
617+
subscript = v_3[:, None]
618+
v_4 = v_2 - subscript
619+
v_5 = libdevice.exp2(v_4)
620+
l_ij = tl.sum(v_5, 1)
621+
v_6 = m_i_copy_0 - v_3
622+
v_7 = libdevice.exp2(v_6)
623+
v_8 = l_i_copy_0 * v_7
624+
l_i = v_8 + l_ij
625+
subscript_1 = v_7[:, None]
626+
v_10 = acc_copy_0 * subscript_1
627+
v_tile = tl.load(v + (offset_0 * 16384 + indices_5[:, None] * 1 + indices_2[None, :] * 64), None)
628+
v_11 = v_5.to(tl.float8e5)
629+
v_t = tl.permute(v_tile, [1, 0])
630+
mm_1 = tl.dot(v_11, v_t)
631+
v_12 = mm_1.to(tl.float32)
632+
acc = v_10 + v_12
633+
m_i = v_3
634+
subscript_2 = l_i[:, None]
635+
v_14 = acc / subscript_2
636+
v_15 = v_14.to(tl.float8e5)
637+
symnode_0 = triton_helpers.div_floor_integer(offset_0, heads)
638+
symnode_1 = triton_helpers.remainder_integer(offset_0, heads)
639+
tl.store(out + (symnode_0 * out_stride_0 + symnode_1 * 16384 + indices_4[:, None] * 64 + indices_5[None, :] * 1), v_15, None)
640+
641+
def fp8_attention_kernel(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, batch: int, heads: int, *, _launcher=_default_launcher):
642+
batch_heads = q.size(0)
643+
seq_len = q.size(1)
644+
head_dim = q.size(2)
645+
out = torch.empty([batch, heads, seq_len, head_dim], dtype=torch.float8_e5m2, device=q.device)
646+
sm_scale = 1.0 / math.sqrt(float(head_dim))
647+
sm_scale = sm_scale * 1.44269504
648+
_RDIM_SIZE_2 = 64
649+
_BLOCK_SIZE_1 = 64
650+
_BLOCK_SIZE_3 = 64
651+
_launcher(_fp8_attention_kernel_kernel, (8,), q, k, v, out, out.stride(0), heads, _RDIM_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
652+
return out
653+
577654
--- assertExpectedJournal(TestExamples.test_fp8_gemm)
578655
from __future__ import annotations
579656

test/test_examples.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,47 @@ def test_attention_persistent_interleaved_l2_grouping(self):
557557
)
558558
)
559559

560+
@unittest.skipIf(
561+
not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9,
562+
"FP8 requires GPU with compute capability >= 9.0 (e.g., H100)",
563+
)
564+
def test_fp8_attention(self):
565+
batch = 2
566+
heads = 4
567+
seq_len = 256
568+
head_dim = 64
569+
570+
# Create FP16 tensors
571+
q = torch.randn(
572+
batch, heads, seq_len, head_dim, dtype=torch.float16, device=DEVICE
573+
)
574+
k = torch.randn(
575+
batch, heads, seq_len, head_dim, dtype=torch.float16, device=DEVICE
576+
)
577+
v = torch.randn(
578+
batch, heads, seq_len, head_dim, dtype=torch.float16, device=DEVICE
579+
)
580+
581+
# Import the module
582+
mod = import_path(EXAMPLES_DIR / "fp8_attention.py")
583+
584+
# Prepare FP8 inputs using the module's preprocessing function
585+
q_fp8, k_fp8, v_fp8 = mod.preprocess_fp8_attention_inputs(q, k, v)
586+
args = (q_fp8, k_fp8, v_fp8, batch, heads)
587+
588+
# Get expected output from kernel
589+
expected = mod.fp8_attention_kernel(*args)
590+
591+
self.assertExpectedJournal(
592+
check_example(
593+
"fp8_attention",
594+
args,
595+
expected,
596+
fn_name="fp8_attention_kernel",
597+
block_sizes=[64, 64],
598+
)
599+
)
600+
560601

561602
if __name__ == "__main__":
562603
unittest.main()

0 commit comments

Comments
 (0)