diff --git a/tritonbench/kernels/triton_fused_attention.py b/tritonbench/kernels/triton_fused_attention.py index d49f6856..16d069d5 100644 --- a/tritonbench/kernels/triton_fused_attention.py +++ b/tritonbench/kernels/triton_fused_attention.py @@ -1952,6 +1952,7 @@ def _attn_fwd_tma_ws_persistent( # Q, V, desc_k, desc_v, sm_scale, M, Out, # ] +# on-device TMA @triton.autotune(list(filter(keep, configsCutlassBlackwell)), key=["N_CTX"]) @triton.jit def _attn_fwd_tma_ws_persistent_with_dp( # Q, V, desc_k, desc_v, sm_scale, M, Out, # @@ -2110,6 +2111,146 @@ def _attn_fwd_tma_ws_persistent_with_dp( # Q, V, desc_k, desc_v, sm_scale, M, O tile_idx += num_progs +@triton.jit +def _attn_fwd_subtile(q, k, offs_m, start_n, offs_n, qk_scale, l_i, m_i, acc, v, dtype: tl.constexpr, STAGE: tl.constexpr): + qk = tl.dot(q, k) + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] + p = tl.math.exp2(qk) + # -- compute correction factor + alpha = tl.math.exp2(m_i - m_ij) + l_ij = tl.sum(p, 1) + + # -- update output accumulator -- + BM: tl.constexpr = acc.shape[0] + BN: tl.constexpr = acc.shape[1] + + acc0, acc1 = acc.reshape([BM, 2, BN//2]).permute(0, 2, 1).split() + acc0 = acc0 * alpha[:, None] + acc1 = acc1 * alpha[:, None] + acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN]) + + # prepare p and v for the dot + p = p.to(dtype) + # note that this non transposed v for FP8 is only supported on Blackwell + acc = tl.dot(p, v, acc) + # update m_i and l_i + # place this at the end of the loop to reduce register pressure + l_i = l_i * alpha + l_ij + m_i = m_ij + + return l_i, m_i, acc + + +@triton.jit +def _attn_fwd_inner_oss_dp(acc0, acc1, l_i0, l_i1, m_i0, m_i1, q0, q1, # + desc_k, desc_v, # + offset_y, dtype: tl.constexpr, start_m, qk_scale, # + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr, offs_m0: tl.constexpr, offs_m1: tl.constexpr, # + offs_n: tl.constexpr, # + N_CTX: tl.constexpr, warp_specialize: tl.constexpr): + # range of values handled by this stage + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + # causal = False + else: + lo, hi = 0, N_CTX + offsetkv_y = offset_y + lo + + # loop over k, v and update accumulator + for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize, disallow_acc_multi_buffer=True): + start_n = tl.multiple_of(start_n, BLOCK_N) + + k = desc_k.load([offsetkv_y, 0]).T + v = desc_v.load([offsetkv_y, 0]) + + l_i0, m_i0, acc0 = _attn_fwd_subtile(q0, k, offs_m0, start_n, offs_n, qk_scale, l_i0, m_i0, acc0, v, dtype, STAGE) + l_i1, m_i1, acc1 = _attn_fwd_subtile(q1, k, offs_m1, start_n, offs_n, qk_scale, l_i1, m_i1, acc1, v, dtype, STAGE) + + offsetkv_y += BLOCK_N + + return acc0, acc1, l_i0, l_i1, m_i0, m_i1 + + +#@triton.autotune(configs=list(filter(keep_tma, configs_tma_dp)), +# key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"]) +@triton.jit +def _attn_fwd_tma_oss_dp(sm_scale, M, # + Z, H, desc_q, desc_k, desc_v, desc_o, N_CTX, # + HEAD_DIM: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_N: tl.constexpr, # + FP8_OUTPUT: tl.constexpr, # + STAGE: tl.constexpr, # + warp_specialize: tl.constexpr, # + ENABLE_TMA: tl.constexpr, + ): + dtype = tl.float8e5 if FP8_OUTPUT else tl.bfloat16 + tl.static_assert(BLOCK_N <= HEAD_DIM) + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_z = off_hz // H + off_h = off_hz % H + + offset_y = off_z + off_h * N_CTX + qo_offset_y = offset_y + start_m * BLOCK_M + # initialize offsets + offs_m0 = start_m * BLOCK_M + tl.arange(0, BLOCK_M//2) + offs_m1 = start_m * BLOCK_M + tl.arange(BLOCK_M//2, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + m_i0 = tl.zeros([BLOCK_M//2], dtype=tl.float32) - float("inf") + l_i0 = tl.zeros([BLOCK_M//2], dtype=tl.float32) + 1.0 + acc0 = tl.zeros([BLOCK_M//2, HEAD_DIM], dtype=tl.float32) + + m_i1 = tl.zeros([BLOCK_M//2], dtype=tl.float32) - float("inf") + l_i1 = tl.zeros([BLOCK_M//2], dtype=tl.float32) + 1.0 + acc1 = tl.zeros([BLOCK_M//2, HEAD_DIM], dtype=tl.float32) + + qk_scale = sm_scale + qk_scale *= 1.44269504 # 1/log(2) + + q0 = desc_q.load([qo_offset_y, 0]) + q1 = desc_q.load([qo_offset_y + BLOCK_M//2, 0]) + + if STAGE & 1: + acc0, acc1, l_i0, l_i1, m_i0, m_i1 = _attn_fwd_inner_oss_dp(acc0, acc1, l_i0, l_i1, m_i0, m_i1, q0, q1, # + desc_k, desc_v, # + offset_y, dtype, start_m, qk_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 4 - STAGE, offs_m0, offs_m1, offs_n, N_CTX, # + warp_specialize) + if STAGE & 2: + acc0, acc1, l_i0, l_i1, m_i0, m_i1 = _attn_fwd_inner_oss_dp(acc0, acc1, l_i0, l_i1, m_i0, m_i1, q0, q1, # + desc_k, desc_v, # + offset_y, dtype, start_m, qk_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 2, offs_m0, offs_m1, offs_n, N_CTX, # + warp_specialize) + + m_i0 += tl.math.log2(l_i0) + acc0 = acc0 / l_i0[:, None] + m_ptrs0 = M + off_hz * N_CTX + offs_m0 + tl.store(m_ptrs0, m_i0) + desc_o.store([qo_offset_y, 0], acc0.to(dtype)) + + m_i1 += tl.math.log2(l_i1) + acc1 = acc1 / l_i1[:, None] + m_ptrs1 = M + off_hz * N_CTX + offs_m1 + tl.store(m_ptrs1, m_i1) + desc_o.store([qo_offset_y+BLOCK_M//2, 0], acc1.to(dtype)) + + @triton.jit def _attn_bwd_preprocess( O, @@ -2472,6 +2613,7 @@ def forward(ctx, q, k, v, causal, sm_scale, baseVariant): # no autotune with fixed BLOCK_N if HAS_TMA_DESC is True and torch.version.hip is None: + # Legacy on-host grid constant TMA desc_helper = TmaAutoTuneHelper() desc_helper.init_tma_descriptor("k") desc_helper.init_tma_descriptor("v") @@ -2628,6 +2770,17 @@ def grid_tma_persistent(META): desc_v = desc_helper.get_tma_descriptor_kernel_param("v") desc_o = desc_helper.get_tma_descriptor_kernel_param("o") + # For variants using new on-host TMA + if baseVariant == "on_host_tma_ws_oss": + from triton.tools.tensor_descriptor import TensorDescriptor + y_dim = q.shape[0] * q.shape[1] * q.shape[2] + BLOCK_M = 256 + BLOCK_N = 128 + desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[BLOCK_M//2, HEAD_DIM_K]) + desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[BLOCK_N, HEAD_DIM_K]) + desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[BLOCK_N, HEAD_DIM_K]) + desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[BLOCK_M//2, HEAD_DIM_K]) + M = torch.empty( (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 ) @@ -2810,6 +2963,24 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): ENABLE_WS=True, **extra_kern_args, ) + elif baseVariant == "on_host_tma_ws_oss": + BLOCK_M = 256 + BLOCK_N = 128 + _attn_fwd_tma_oss_dp[grid_tma]( + sm_scale, M, # + q.shape[0], q.shape[1], # + desc_q, desc_k, desc_v, desc_o, # + N_CTX=q.shape[2], # + HEAD_DIM=HEAD_DIM_K, # + FP8_OUTPUT=q.dtype == torch.float8_e5m2, # + STAGE=stage, # + warp_specialize=True, # + ENABLE_TMA=True, + BLOCK_N=BLOCK_N, BLOCK_M=BLOCK_M, # + num_warps=4, + num_stages=2, + #maxnreg=64, + **extra_kern_args) ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid_tma diff --git a/tritonbench/operators/blackwell_attentions/operator.py b/tritonbench/operators/blackwell_attentions/operator.py index ed3710f2..4d123dff 100644 --- a/tritonbench/operators/blackwell_attentions/operator.py +++ b/tritonbench/operators/blackwell_attentions/operator.py @@ -282,6 +282,7 @@ def causal_mask(b, h, q_idx, kv_idx): return lambda: flex_attention(q, k, v, block_mask=block_mask) + # use Meta's warpspec + on device TMA + persistent @register_benchmark(enabled=False) def triton_tutorial_flash_v2_tma_ws_persistent_blackwell( self, @@ -293,6 +294,29 @@ def triton_tutorial_flash_v2_tma_ws_persistent_blackwell( q, k, v, self.causal, self.sm_scale, "tma_ws_persistent_blackwell" ) + @register_benchmark(enabled=False) + def triton_tutorial_flash_v2_blackwell( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ) -> Callable: + return lambda: triton_tutorial_FA2_opt( + q, k, v, self.causal, self.sm_scale, "base_opt" + ) + + # use OSS warpspec + on host TMA + @register_benchmark(enabled=False) + def triton_tutorial_flash_v2_on_host_tma_ws_oss_blackwell( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ) -> Callable: + return lambda: triton_tutorial_FA2_opt( + q, k, v, self.causal, self.sm_scale, "on_host_tma_ws_oss" + ) + # Only works with triton main, forward only. @register_benchmark(enabled=False) def gluon_blackwell_tutorial_fwd(