Skip to content

Commit 38fee18

Browse files
committed
Add fp8_attention example and unit test
stack-info: PR: #318, branch: yf225/stack/26
1 parent d884774 commit 38fee18

File tree

3 files changed

+342
-0
lines changed

3 files changed

+342
-0
lines changed

examples/fp8_attention.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
from __future__ import annotations
2+
3+
import math
4+
from typing import Callable
5+
6+
import torch
7+
8+
import helion
9+
import helion.language as hl
10+
11+
12+
@helion.kernel(static_shapes=True)
13+
def fp8_attention_kernel(
14+
q: torch.Tensor, # [batch*heads, seq, dim]
15+
k: torch.Tensor, # [batch*heads, seq, dim]
16+
v: torch.Tensor, # [batch*heads, dim, seq] - pre-transposed
17+
) -> torch.Tensor:
18+
batch_heads = q.size(0)
19+
seq_len = q.size(1)
20+
head_dim = q.size(2)
21+
22+
# Output tensor
23+
out = torch.empty(
24+
[batch_heads, seq_len, head_dim], dtype=torch.float32, device=q.device
25+
)
26+
27+
# Scale factor for attention
28+
sm_scale = 1.0 / math.sqrt(float(head_dim))
29+
# Triton kernel multiplies sm_scale by 1.44269504 (1/log(2)) for exp2
30+
sm_scale = sm_scale * 1.44269504
31+
32+
# Process each batch*head in parallel
33+
for bh in hl.grid(batch_heads):
34+
# Process each query position
35+
for tile_m in hl.tile(seq_len):
36+
# Initialize for online softmax
37+
m_i = hl.full([tile_m], float("-inf"), dtype=torch.float32)
38+
l_i = hl.full([tile_m], 0.0, dtype=torch.float32)
39+
acc = hl.zeros([tile_m, head_dim], dtype=torch.float32)
40+
41+
# Load query tile - keep in FP8
42+
q_tile = q[bh, tile_m, :] # [tile_m, dim]
43+
44+
# Compute attention scores for all keys
45+
for tile_n in hl.tile(seq_len):
46+
# Load key tile and transpose for Q @ K^T
47+
k_tile = k[bh, tile_n, :] # [tile_n, dim] - keep in FP8
48+
k_tile_t = k_tile.transpose(0, 1) # [dim, tile_n]
49+
50+
# Compute Q @ K^T with FP8 inputs, result in FP32
51+
qk = torch.matmul(q_tile, k_tile_t).to(
52+
torch.float32
53+
) # [tile_m, tile_n]
54+
55+
# Scale QK scores first
56+
qk_scaled = qk * sm_scale # [tile_m, tile_n]
57+
58+
# Compute max of scaled scores
59+
qk_max = torch.amax(qk_scaled, dim=-1) # [tile_m]
60+
61+
# Update global max
62+
m_new = torch.maximum(m_i, qk_max)
63+
64+
# Shift by max for numerical stability
65+
qk_shifted = qk_scaled - m_new[:, None]
66+
67+
# Use exp2 to match Triton kernel's implementation
68+
# Note: Triton kernel already multiplies sm_scale by 1.44269504
69+
p = torch.exp2(qk_shifted) # [tile_m, tile_n]
70+
71+
# Sum of exponentials for this block
72+
l_ij = torch.sum(p, dim=-1) # [tile_m]
73+
74+
# Update accumulators with correction factor
75+
# Correction factor for previous blocks
76+
alpha = torch.exp2(m_i - m_new)
77+
l_i = l_i * alpha + l_ij
78+
acc = acc * alpha[:, None]
79+
80+
# Load values - V is [dim, seq]
81+
v_tile = v[bh, :, tile_n] # [dim, tile_n] - keep in FP8
82+
83+
# Convert p to FP8 for FP8 GEMM
84+
p_fp8 = p.to(v.dtype) # Convert to same FP8 type as V
85+
86+
# Accumulate attention @ V with FP8 GEMM
87+
v_t = v_tile.transpose(0, 1) # [tile_n, dim]
88+
pv = torch.matmul(p_fp8, v_t).to(torch.float32) # [tile_m, dim]
89+
acc = acc + pv
90+
91+
# Update max tracker
92+
m_i = m_new
93+
94+
# Final normalization
95+
acc = acc / l_i[:, None]
96+
out[bh, tile_m, :] = acc
97+
98+
return out
99+
100+
101+
def preprocess_fp8_attention_inputs(
102+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
103+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
104+
q = q.to(torch.float8_e5m2)
105+
k = k.to(torch.float8_e5m2)
106+
v = v.permute(0, 1, 3, 2)
107+
v = v.to(torch.float8_e5m2)
108+
return q, k, v
109+
110+
111+
def fp8_attention_tritonbench(
112+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
113+
) -> Callable[[], torch.Tensor]:
114+
q_fp8, k_fp8, v_fp8 = preprocess_fp8_attention_inputs(q, k, v)
115+
116+
# Get shape info
117+
batch, heads, seq_len, head_dim = q.shape
118+
119+
# Reshape Q and K to [batch*heads, seq, dim]
120+
q_reshaped = q_fp8.reshape(batch * heads, seq_len, head_dim)
121+
k_reshaped = k_fp8.reshape(batch * heads, seq_len, head_dim)
122+
123+
# V is already transposed, reshape to [batch*heads, dim, seq]
124+
v_reshaped = v_fp8.reshape(batch * heads, head_dim, seq_len)
125+
126+
# Return lambda that calls the kernel - preprocessing is done outside
127+
return (
128+
lambda: fp8_attention_kernel(q_reshaped, k_reshaped, v_reshaped)
129+
.reshape(batch, heads, seq_len, head_dim)
130+
.to(torch.float8_e5m2)
131+
)
132+
133+
134+
def fp8_attention_pytorch(
135+
q: torch.Tensor, # [batch, heads, seq, dim]
136+
k: torch.Tensor, # [batch, heads, seq, dim]
137+
v: torch.Tensor, # [batch, heads, seq, dim]
138+
) -> torch.Tensor:
139+
"""
140+
Baseline PyTorch implementation of FP8 attention using FP8 e5m2.
141+
"""
142+
batch, heads, seq_len, head_dim = q.shape
143+
144+
q_fp8, k_fp8, v_fp8 = preprocess_fp8_attention_inputs(q, k, v)
145+
146+
# Reshape Q and K to [batch*heads, seq, dim]
147+
q_fp8 = q_fp8.reshape(batch * heads, seq_len, head_dim)
148+
k_fp8 = k_fp8.reshape(batch * heads, seq_len, head_dim)
149+
150+
# V is already transposed by preprocess_fp8_attention_inputs, reshape to [batch*heads, dim, seq]
151+
v_fp8 = v_fp8.reshape(batch * heads, head_dim, seq_len)
152+
153+
sm_scale = 1.0 / math.sqrt(float(head_dim))
154+
155+
outputs = []
156+
157+
for i in range(batch * heads):
158+
q_i = q_fp8[i] # [seq, dim] - already FP8
159+
k_i = k_fp8[i] # [seq, dim] - already FP8
160+
v_i = v_fp8[i] # [dim, seq] - pre-transposed, already FP8
161+
162+
# For Q @ K^T, we need K^T to be column-major
163+
kt_fp8 = k_i.t() # column-major [dim, seq]
164+
165+
# Q @ K^T - dequantize and use regular matmul since e5m2 not supported by _scaled_mm
166+
q_deq = q_i.to(torch.float32)
167+
kt_deq = kt_fp8.to(torch.float32)
168+
qk = torch.matmul(q_deq, kt_deq)
169+
170+
# Compute max before scaling
171+
qk_max = torch.amax(qk, dim=-1, keepdim=True)
172+
173+
# Scale and shift in one operation, then use exp2
174+
qk_scaled_shifted = qk * sm_scale - qk_max * sm_scale
175+
p = torch.exp2(qk_scaled_shifted * 1.44269504)
176+
177+
# Normalize
178+
p_norm = p / p.sum(dim=-1, keepdim=True)
179+
180+
# Step 2: Attention @ V using FP8
181+
# P is [seq, seq], V is [dim, seq]
182+
# We want P @ V^T = [seq, seq] @ [seq, dim] = [seq, dim]
183+
p_fp8 = p_norm.to(torch.float8_e5m2) # row-major [seq, seq]
184+
185+
# v_i is [dim, seq], already FP8
186+
vt_fp8 = v_i.t() # column-major [seq, dim]
187+
188+
# P @ V^T - dequantize and use regular matmul since e5m2 not supported by torch._scaled_mm
189+
p_deq = p_fp8.to(torch.float32)
190+
vt_deq = vt_fp8.to(torch.float32)
191+
out_i = torch.matmul(p_deq, vt_deq)
192+
193+
outputs.append(out_i)
194+
195+
# Stack and reshape back
196+
out_stacked = torch.stack(outputs, dim=0) # [batch*heads, seq, dim]
197+
out = out_stacked.reshape(batch, heads, seq_len, head_dim)
198+
199+
return out.to(torch.float16)
200+
201+
202+
def check(batch: int, heads: int, seq_len: int, head_dim: int) -> None:
203+
torch.manual_seed(42)
204+
q = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda")
205+
k = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda")
206+
v = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda")
207+
208+
from helion._testing import run_example
209+
210+
helion_fn = fp8_attention_tritonbench(q, k, v)
211+
run_example(
212+
lambda q, k, v: helion_fn().to(torch.float16),
213+
fp8_attention_pytorch,
214+
(q, k, v),
215+
atol=0.1,
216+
rtol=0.1,
217+
)
218+
219+
220+
def main() -> None:
221+
check(1, 2, 128, 64)
222+
check(2, 4, 256, 64)
223+
check(4, 8, 512, 128)
224+
225+
226+
if __name__ == "__main__":
227+
main()

test/test_examples.expected

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,80 @@ def embedding(x: torch.Tensor, weight: torch.Tensor, *, _launcher=_default_launc
574574
_launcher(_embedding_kernel, (x_flat.size(0) * triton.cdiv(embedding_dim, _BLOCK_SIZE_1),), x_flat, weight, out, x_flat.size(0), out.stride(0), out.stride(1), weight.stride(0), weight.stride(1), x_flat.stride(0), embedding_dim, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
575575
return out.view(*x.size(), embedding_dim)
576576

577+
--- assertExpectedJournal(TestExamples.test_fp8_attention)
578+
from __future__ import annotations
579+
580+
import math
581+
import torch
582+
import triton
583+
import triton.language as tl
584+
from torch._inductor.runtime import triton_helpers
585+
from torch._inductor.runtime.triton_compat import libdevice
586+
from helion.runtime import default_launcher as _default_launcher
587+
588+
@triton.jit
589+
def _fp8_attention_kernel_kernel(q, k, v, out, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
590+
pid_0 = tl.program_id(0)
591+
offset_0 = pid_0
592+
indices_5 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
593+
for offset_4 in tl.range(0, 256, _BLOCK_SIZE_1):
594+
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
595+
m_i = tl.full([_BLOCK_SIZE_1], float('-inf'), tl.float32)
596+
l_i = tl.full([_BLOCK_SIZE_1], 0.0, tl.float32)
597+
acc = tl.full([_BLOCK_SIZE_1, 64], 0.0, tl.float32)
598+
q_tile = tl.load(q + (offset_0 * 16384 + indices_4[:, None] * 64 + indices_5[None, :] * 1), None)
599+
for offset_2 in tl.range(0, 256, _BLOCK_SIZE_3):
600+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
601+
q_tile_copy = q_tile
602+
m_i_copy = m_i
603+
l_i_copy = l_i
604+
acc_copy = acc
605+
q_tile_copy_0 = q_tile_copy
606+
m_i_copy_0 = m_i_copy
607+
l_i_copy_0 = l_i_copy
608+
acc_copy_0 = acc_copy
609+
k_tile = tl.load(k + (offset_0 * 16384 + indices_2[:, None] * 64 + indices_5[None, :] * 1), None)
610+
k_tile_t = tl.permute(k_tile, [1, 0])
611+
mm = tl.dot(q_tile_copy_0, k_tile_t)
612+
v_0 = mm.to(tl.float32)
613+
v_1 = 0.18033688
614+
v_2 = v_0 * v_1
615+
qk_max = tl.max(v_2, 1)
616+
v_3 = triton_helpers.maximum(m_i_copy_0, qk_max)
617+
subscript = v_3[:, None]
618+
v_4 = v_2 - subscript
619+
v_5 = libdevice.exp2(v_4)
620+
l_ij = tl.sum(v_5, 1)
621+
v_6 = m_i_copy_0 - v_3
622+
v_7 = libdevice.exp2(v_6)
623+
v_8 = l_i_copy_0 * v_7
624+
l_i = v_8 + l_ij
625+
subscript_1 = v_7[:, None]
626+
v_10 = acc_copy_0 * subscript_1
627+
v_tile = tl.load(v + (offset_0 * 16384 + indices_5[:, None] * 1 + indices_2[None, :] * 64), None)
628+
v_11 = v_5.to(tl.float8e5)
629+
v_t = tl.permute(v_tile, [1, 0])
630+
mm_1 = tl.dot(v_11, v_t)
631+
v_12 = mm_1.to(tl.float32)
632+
acc = v_10 + v_12
633+
m_i = v_3
634+
subscript_2 = l_i[:, None]
635+
v_14 = acc / subscript_2
636+
tl.store(out + (offset_0 * 16384 + indices_4[:, None] * 64 + indices_5[None, :] * 1), v_14, None)
637+
638+
def fp8_attention_kernel(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *, _launcher=_default_launcher):
639+
batch_heads = q.size(0)
640+
seq_len = q.size(1)
641+
head_dim = q.size(2)
642+
out = torch.empty([batch_heads, seq_len, head_dim], dtype=torch.float32, device=q.device)
643+
sm_scale = 1.0 / math.sqrt(float(head_dim))
644+
sm_scale = sm_scale * 1.44269504
645+
_RDIM_SIZE_2 = 64
646+
_BLOCK_SIZE_1 = 64
647+
_BLOCK_SIZE_3 = 64
648+
_launcher(_fp8_attention_kernel_kernel, (8,), q, k, v, out, _RDIM_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
649+
return out
650+
577651
--- assertExpectedJournal(TestExamples.test_fp8_gemm)
578652
from __future__ import annotations
579653

test/test_examples.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,47 @@ def test_attention_persistent_interleaved_l2_grouping(self):
557557
)
558558
)
559559

560+
@unittest.skipIf(
561+
not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9,
562+
"FP8 requires GPU with compute capability >= 9.0 (e.g., H100)",
563+
)
564+
def test_fp8_attention(self):
565+
batch = 2
566+
heads = 4
567+
seq_len = 256
568+
head_dim = 64
569+
570+
# Create FP16 tensors
571+
q = torch.randn(
572+
batch, heads, seq_len, head_dim, dtype=torch.float16, device=DEVICE
573+
)
574+
k = torch.randn(
575+
batch, heads, seq_len, head_dim, dtype=torch.float16, device=DEVICE
576+
)
577+
v = torch.randn(
578+
batch, heads, seq_len, head_dim, dtype=torch.float16, device=DEVICE
579+
)
580+
581+
# Import the module
582+
mod = import_path(EXAMPLES_DIR / "fp8_attention.py")
583+
584+
# Prepare FP8 inputs using the module's preprocessing function
585+
q_fp8, k_fp8, v_fp8, shape = mod.prepare_fp8_attention_inputs(q, k, v)
586+
args = (q_fp8, k_fp8, v_fp8)
587+
588+
# Get expected output from kernel
589+
expected = mod.fp8_attention_kernel(*args)
590+
591+
self.assertExpectedJournal(
592+
check_example(
593+
"fp8_attention",
594+
args,
595+
expected,
596+
fn_name="fp8_attention_kernel",
597+
block_sizes=[64, 64],
598+
)
599+
)
600+
560601

561602
if __name__ == "__main__":
562603
unittest.main()

0 commit comments

Comments
 (0)