Skip to content

Add triton_tutorial_flash_v2_on_host_tma_ws_oss_blackwell #279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions tritonbench/kernels/triton_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1952,6 +1952,7 @@ def _attn_fwd_tma_ws_persistent( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
]


# on-device TMA
@triton.autotune(list(filter(keep, configsCutlassBlackwell)), key=["N_CTX"])
@triton.jit
def _attn_fwd_tma_ws_persistent_with_dp( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
Expand Down Expand Up @@ -2110,6 +2111,146 @@ def _attn_fwd_tma_ws_persistent_with_dp( # Q, V, desc_k, desc_v, sm_scale, M, O
tile_idx += num_progs


@triton.jit
def _attn_fwd_subtile(q, k, offs_m, start_n, offs_n, qk_scale, l_i, m_i, acc, v, dtype: tl.constexpr, STAGE: tl.constexpr):
qk = tl.dot(q, k)
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
else:
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
qk = qk * qk_scale - m_ij[:, None]
p = tl.math.exp2(qk)
# -- compute correction factor
alpha = tl.math.exp2(m_i - m_ij)
l_ij = tl.sum(p, 1)

# -- update output accumulator --
BM: tl.constexpr = acc.shape[0]
BN: tl.constexpr = acc.shape[1]

acc0, acc1 = acc.reshape([BM, 2, BN//2]).permute(0, 2, 1).split()
acc0 = acc0 * alpha[:, None]
acc1 = acc1 * alpha[:, None]
acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN])

# prepare p and v for the dot
p = p.to(dtype)
# note that this non transposed v for FP8 is only supported on Blackwell
acc = tl.dot(p, v, acc)
# update m_i and l_i
# place this at the end of the loop to reduce register pressure
l_i = l_i * alpha + l_ij
m_i = m_ij

return l_i, m_i, acc


@triton.jit
def _attn_fwd_inner_oss_dp(acc0, acc1, l_i0, l_i1, m_i0, m_i1, q0, q1, #
desc_k, desc_v, #
offset_y, dtype: tl.constexpr, start_m, qk_scale, #
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, #
STAGE: tl.constexpr, offs_m0: tl.constexpr, offs_m1: tl.constexpr, #
offs_n: tl.constexpr, #
N_CTX: tl.constexpr, warp_specialize: tl.constexpr):
# range of values handled by this stage
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
# causal = False
else:
lo, hi = 0, N_CTX
offsetkv_y = offset_y + lo

# loop over k, v and update accumulator
for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize, disallow_acc_multi_buffer=True):
start_n = tl.multiple_of(start_n, BLOCK_N)

k = desc_k.load([offsetkv_y, 0]).T
v = desc_v.load([offsetkv_y, 0])

l_i0, m_i0, acc0 = _attn_fwd_subtile(q0, k, offs_m0, start_n, offs_n, qk_scale, l_i0, m_i0, acc0, v, dtype, STAGE)
l_i1, m_i1, acc1 = _attn_fwd_subtile(q1, k, offs_m1, start_n, offs_n, qk_scale, l_i1, m_i1, acc1, v, dtype, STAGE)

offsetkv_y += BLOCK_N

return acc0, acc1, l_i0, l_i1, m_i0, m_i1


#@triton.autotune(configs=list(filter(keep_tma, configs_tma_dp)),
# key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"])
@triton.jit
def _attn_fwd_tma_oss_dp(sm_scale, M, #
Z, H, desc_q, desc_k, desc_v, desc_o, N_CTX, #
HEAD_DIM: tl.constexpr, #
BLOCK_M: tl.constexpr, #
BLOCK_N: tl.constexpr, #
FP8_OUTPUT: tl.constexpr, #
STAGE: tl.constexpr, #
warp_specialize: tl.constexpr, #
ENABLE_TMA: tl.constexpr,
):
dtype = tl.float8e5 if FP8_OUTPUT else tl.bfloat16
tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H

offset_y = off_z + off_h * N_CTX
qo_offset_y = offset_y + start_m * BLOCK_M
# initialize offsets
offs_m0 = start_m * BLOCK_M + tl.arange(0, BLOCK_M//2)
offs_m1 = start_m * BLOCK_M + tl.arange(BLOCK_M//2, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)

m_i0 = tl.zeros([BLOCK_M//2], dtype=tl.float32) - float("inf")
l_i0 = tl.zeros([BLOCK_M//2], dtype=tl.float32) + 1.0
acc0 = tl.zeros([BLOCK_M//2, HEAD_DIM], dtype=tl.float32)

m_i1 = tl.zeros([BLOCK_M//2], dtype=tl.float32) - float("inf")
l_i1 = tl.zeros([BLOCK_M//2], dtype=tl.float32) + 1.0
acc1 = tl.zeros([BLOCK_M//2, HEAD_DIM], dtype=tl.float32)

qk_scale = sm_scale
qk_scale *= 1.44269504 # 1/log(2)

q0 = desc_q.load([qo_offset_y, 0])
q1 = desc_q.load([qo_offset_y + BLOCK_M//2, 0])

if STAGE & 1:
acc0, acc1, l_i0, l_i1, m_i0, m_i1 = _attn_fwd_inner_oss_dp(acc0, acc1, l_i0, l_i1, m_i0, m_i1, q0, q1, #
desc_k, desc_v, #
offset_y, dtype, start_m, qk_scale, #
BLOCK_M, HEAD_DIM, BLOCK_N, #
4 - STAGE, offs_m0, offs_m1, offs_n, N_CTX, #
warp_specialize)
if STAGE & 2:
acc0, acc1, l_i0, l_i1, m_i0, m_i1 = _attn_fwd_inner_oss_dp(acc0, acc1, l_i0, l_i1, m_i0, m_i1, q0, q1, #
desc_k, desc_v, #
offset_y, dtype, start_m, qk_scale, #
BLOCK_M, HEAD_DIM, BLOCK_N, #
2, offs_m0, offs_m1, offs_n, N_CTX, #
warp_specialize)

m_i0 += tl.math.log2(l_i0)
acc0 = acc0 / l_i0[:, None]
m_ptrs0 = M + off_hz * N_CTX + offs_m0
tl.store(m_ptrs0, m_i0)
desc_o.store([qo_offset_y, 0], acc0.to(dtype))

m_i1 += tl.math.log2(l_i1)
acc1 = acc1 / l_i1[:, None]
m_ptrs1 = M + off_hz * N_CTX + offs_m1
tl.store(m_ptrs1, m_i1)
desc_o.store([qo_offset_y+BLOCK_M//2, 0], acc1.to(dtype))


@triton.jit
def _attn_bwd_preprocess(
O,
Expand Down Expand Up @@ -2472,6 +2613,7 @@ def forward(ctx, q, k, v, causal, sm_scale, baseVariant):

# no autotune with fixed BLOCK_N
if HAS_TMA_DESC is True and torch.version.hip is None:
# Legacy on-host grid constant TMA
desc_helper = TmaAutoTuneHelper()
desc_helper.init_tma_descriptor("k")
desc_helper.init_tma_descriptor("v")
Expand Down Expand Up @@ -2628,6 +2770,17 @@ def grid_tma_persistent(META):
desc_v = desc_helper.get_tma_descriptor_kernel_param("v")
desc_o = desc_helper.get_tma_descriptor_kernel_param("o")

# For variants using new on-host TMA
if baseVariant == "on_host_tma_ws_oss":
from triton.tools.tensor_descriptor import TensorDescriptor
y_dim = q.shape[0] * q.shape[1] * q.shape[2]
BLOCK_M = 256
BLOCK_N = 128
desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[BLOCK_M//2, HEAD_DIM_K])
desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[BLOCK_N, HEAD_DIM_K])
desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[BLOCK_N, HEAD_DIM_K])
desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[BLOCK_M//2, HEAD_DIM_K])

M = torch.empty(
(q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
)
Expand Down Expand Up @@ -2810,6 +2963,24 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
ENABLE_WS=True,
**extra_kern_args,
)
elif baseVariant == "on_host_tma_ws_oss":
BLOCK_M = 256
BLOCK_N = 128
_attn_fwd_tma_oss_dp[grid_tma](
sm_scale, M, #
q.shape[0], q.shape[1], #
desc_q, desc_k, desc_v, desc_o, #
N_CTX=q.shape[2], #
HEAD_DIM=HEAD_DIM_K, #
FP8_OUTPUT=q.dtype == torch.float8_e5m2, #
STAGE=stage, #
warp_specialize=True, #
ENABLE_TMA=True,
BLOCK_N=BLOCK_N, BLOCK_M=BLOCK_M, #
num_warps=4,
num_stages=2,
#maxnreg=64,
**extra_kern_args)

ctx.save_for_backward(q, k, v, o, M)
ctx.grid = grid_tma
Expand Down
24 changes: 24 additions & 0 deletions tritonbench/operators/blackwell_attentions/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def causal_mask(b, h, q_idx, kv_idx):

return lambda: flex_attention(q, k, v, block_mask=block_mask)

# use Meta's warpspec + on device TMA + persistent
@register_benchmark(enabled=False)
def triton_tutorial_flash_v2_tma_ws_persistent_blackwell(
self,
Expand All @@ -293,6 +294,29 @@ def triton_tutorial_flash_v2_tma_ws_persistent_blackwell(
q, k, v, self.causal, self.sm_scale, "tma_ws_persistent_blackwell"
)

@register_benchmark(enabled=False)
def triton_tutorial_flash_v2_blackwell(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) -> Callable:
return lambda: triton_tutorial_FA2_opt(
q, k, v, self.causal, self.sm_scale, "base_opt"
)

# use OSS warpspec + on host TMA
@register_benchmark(enabled=False)
def triton_tutorial_flash_v2_on_host_tma_ws_oss_blackwell(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) -> Callable:
return lambda: triton_tutorial_FA2_opt(
q, k, v, self.causal, self.sm_scale, "on_host_tma_ws_oss"
)

# Only works with triton main, forward only.
@register_benchmark(enabled=False)
def gluon_blackwell_tutorial_fwd(
Expand Down
Loading