diff --git a/fbgemm_gpu/experimental/gen_ai/gen_ai/two_simplical_attn/two_simplical_attn_bwd.py b/fbgemm_gpu/experimental/gen_ai/gen_ai/two_simplical_attn/two_simplical_attn_bwd.py new file mode 100644 index 0000000000..0105bd1bef --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/gen_ai/two_simplical_attn/two_simplical_attn_bwd.py @@ -0,0 +1,652 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-unsafe + +import os + +import triton # @manual # @manual + +import triton.language as tl # @manual # @manual + +from triton import Config + +TWO_SIMPLICAL_AUTOTUNE = os.getenv("TWO_SIMPLICAL_AUTOTUNE", "0") == "1" + + +def get_configs(): + if TWO_SIMPLICAL_AUTOTUNE: + return [ + Config( + { + "BLOCK_SIZE_Q": BLOCK_SIZE_Q, + "BLOCK_SIZE_KV": BLOCK_SIZE_KV, + "num_stages": num_stages, + }, + num_warps=num_warps, + ) + for BLOCK_SIZE_Q in [32, 64, 128, 256] + for BLOCK_SIZE_KV in [32, 64, 128, 256] + for num_stages in [1, 2, 3, 4] + for num_warps in [4, 8] + ] + return [ + Config( + { + "BLOCK_SIZE_Q": BLOCK_SIZE_Q, + "BLOCK_SIZE_KV": BLOCK_SIZE_KV, + "num_stages": num_stages, + }, + num_warps=num_warps, + ) + for BLOCK_SIZE_Q in [32] + for BLOCK_SIZE_KV in [64] + for num_stages in [3] + for num_warps in [4] + ] + + +# kv1 bwd kernel with tiles [kv1, q, 1] +@triton.autotune( + configs=get_configs(), + key=[ + "HEAD_DIM", + "seq_len", + "w1", + "w2", + "is_flipped", + ], +) +@triton.jit +def two_simplical_attn_bwd_kv1_kernel( + Q_ptr, # [b, s, k, h] + K1_ptr, # [b, s, k, h] + K2_ptr, # [b, s, k, h] + V1_ptr, # [b, s, k, h] + V2_ptr, # [b, s, k, h] + dO_ptr, # [b, s, k, h] + M_ptr, # [b, k, s] + D_ptr, # [b, k, s] + dQ_ptr, # [b, s, k, h] + dK1_ptr, # [b, s, k, h] + dV1_ptr, # [b, s, k, h] + bs, + seq_len, + num_heads, + w1, # Q[i]: KV1(i-w1,i] + w2, # Q[i]: KV2(i-w2,i] + q_stride_b, + q_stride_s, + q_stride_k, + q_stride_h, + k1_stride_b, + k1_stride_s, + k1_stride_k, + k1_stride_h, + k2_stride_b, + k2_stride_s, + k2_stride_k, + k2_stride_h, + v1_stride_b, + v1_stride_s, + v1_stride_k, + v1_stride_h, + v2_stride_b, + v2_stride_s, + v2_stride_k, + v2_stride_h, + dO_stride_b, + dO_stride_s, + dO_stride_k, + dO_stride_h, + m_stride_b, + m_stride_k, + m_stride_s, + d_stride_b, + d_stride_k, + d_stride_s, + dq_stride_b, + dq_stride_s, + dq_stride_k, + dq_stride_h, + dk1_stride_b, + dk1_stride_s, + dk1_stride_k, + dk1_stride_h, + dv1_stride_b, + dv1_stride_s, + dv1_stride_k, + dv1_stride_h, + BLOCK_SIZE_Q: tl.constexpr, + BLOCK_SIZE_KV: tl.constexpr, + HEAD_DIM: tl.constexpr, + SM_SCALE: tl.constexpr, +): + data_dtype = tl.bfloat16 + compute_dtype = tl.float32 + gemm_dtype = tl.bfloat16 + + kv1_start = tl.program_id(0) * BLOCK_SIZE_KV + kv1_end = kv1_start + BLOCK_SIZE_KV + bk = tl.program_id(1) + offs_b = bk // num_heads + offs_k = bk % num_heads + + qkv_offs_bk = offs_b * q_stride_b + offs_k * q_stride_k + Q_ptr += qkv_offs_bk + K1_ptr += qkv_offs_bk + K2_ptr += qkv_offs_bk + V1_ptr += qkv_offs_bk + V2_ptr += qkv_offs_bk + + dO_ptr += offs_b * dO_stride_b + offs_k * dO_stride_k + M_ptr += offs_b * m_stride_b + offs_k * m_stride_k + D_ptr += offs_b * d_stride_b + offs_k * d_stride_k + dK1_ptr += offs_b * dk1_stride_b + offs_k * dk1_stride_k + dV1_ptr += offs_b * dv1_stride_b + offs_k * dv1_stride_k + + softmax_scale = tl.cast(SM_SCALE, gemm_dtype) + qkv_offs_h = tl.arange(0, HEAD_DIM) + + kv1_offs_s = kv1_start + tl.arange(0, BLOCK_SIZE_KV) + + k1_offs = kv1_offs_s[:, None] * k1_stride_s + qkv_offs_h[None, :] * k1_stride_h + kv1_mask = kv1_offs_s[:, None] < seq_len + k1_tile = tl.load(K1_ptr + k1_offs, mask=kv1_mask).to( + gemm_dtype + ) # [BLOCK_SIZE_KV, HEAD_DIM] + v1_offs = kv1_offs_s[:, None] * v1_stride_s + qkv_offs_h[None, :] * v1_stride_h + v1_tile = tl.load(V1_ptr + v1_offs, mask=kv1_mask).to( + gemm_dtype + ) # [BLOCK_SIZE_KV, HEAD_DIM] + dv1 = tl.zeros((BLOCK_SIZE_KV, HEAD_DIM), compute_dtype) + dk1 = tl.zeros((BLOCK_SIZE_KV, HEAD_DIM), compute_dtype) + # for kv2_idx in tl.range(0, seq_len): + # kv1 - w2 < kv2 <= kv1 + w1 + for kv2_idx in tl.range( + tl.maximum(0, kv1_start - w2), tl.minimum(seq_len, kv1_end + w1) + ): + k2_offs = kv2_idx * k2_stride_s + qkv_offs_h * k2_stride_h + k2_tile = (tl.load(K2_ptr + k2_offs).to(gemm_dtype))[None, :] # [1, HEAD_DIM] + v2_offs = kv2_idx * v2_stride_s + qkv_offs_h * v2_stride_h + v2_tile = (tl.load(V2_ptr + v2_offs).to(gemm_dtype))[None, :] # [1, HEAD_DIM] + k1k2 = k1_tile * k2_tile # [BLOCK_SIZE_KV, HEAD_DIM] + v1v2 = v1_tile * v2_tile # [BLOCK_SIZE_KV, HEAD_DIM] + k1k2_scaled = k1k2 * softmax_scale + # kv1 <= q < kv1 + w1 + # kv2 <= q < kv2 + w2 + q_start = tl.maximum(kv1_start, kv2_idx) + q_end = tl.minimum(seq_len, tl.minimum(kv1_end + w1, kv2_idx + w2)) + # FIXME: Triton kernel compilation fails when pipelining is enabled in this specific case: P1828952934. + # So the pipelining is disabled for now. + for q_idx in tl.range(q_start, q_end, BLOCK_SIZE_Q, num_stages=1): + # Load qt, m, d, dO + q_offs_s = q_idx + tl.arange(0, BLOCK_SIZE_Q) + q_offs = q_offs_s[None, :] * q_stride_s + qkv_offs_h[:, None] * q_stride_h + q_mask_s = q_offs_s < seq_len + qt_tile = tl.load( + Q_ptr + q_offs, mask=q_mask_s[None, :] + ) # [HEAD_DIM, BLOCK_SIZE_Q] + m_offs = q_offs_s * m_stride_s + m_tile = tl.load(M_ptr + m_offs, mask=q_mask_s)[ + None, : + ] # [1, BLOCK_SIZE_Q] + d_offs = q_offs_s * d_stride_s + d_tile = tl.load(D_ptr + d_offs, mask=q_mask_s)[ + None, : + ] # [1, BLOCK_SIZE_Q] + dO_offs = ( + q_offs_s[:, None] * dO_stride_s + qkv_offs_h[None, :] * dO_stride_h + ) + dO_tile = tl.load( + dO_ptr + dO_offs, mask=q_mask_s[:, None] + ) # [BLOCK_SIZE_Q, HEAD_DIM] + + # Compute dv1. + # [KV, D] @ [D, Q] => [KV, Q] + qkkT = tl.dot( + k1k2_scaled, qt_tile, out_dtype=tl.float32 + ) # [BLOCK_SIZE_KV, BLOCK_SIZE_Q] + + # Mask qkkT to -inf. + kv1_local_mask = ((q_offs_s[None, :] - w1) < kv1_offs_s[:, None]) & ( + kv1_offs_s[:, None] <= q_offs_s[None, :] + ) + kv2_local_mask = ((q_offs_s - w2) < kv2_idx) & (kv2_idx <= q_offs_s) + local_mask = ( + kv1_local_mask & kv2_local_mask[None, :] + ) # [BLOCK_SIZE_KV, BLOCK_SIZE_Q] + + valid = tl.sum(local_mask) > 0 + if valid: + pT = tl.exp(qkkT - m_tile) # [BLOCK_SIZE_KV, BLOCK_SIZE_Q] + pT = tl.where(local_mask, pT, 0.0) + # dv1[kv1, d] = p[kv1, q] @ dO[q, d] * v2[kv2, d] + dOv2 = dO_tile * v2_tile # [BLOCK_SIZE_Q, HEAD_DIM] + dv1 += tl.dot( + pT.to(gemm_dtype), dOv2.to(gemm_dtype), out_dtype=tl.float32 + ) # [BLOCK_SIZE_KV, HEAD_DIM] + + # dpT[kv1, q] = v1v2[kv1, d] @ dO.T[d, q] + dpT = tl.dot( + v1v2, tl.trans(dO_tile.to(gemm_dtype)), out_dtype=tl.float32 + ) # [BLOCK_SIZE_KV, BLOCK_SIZE_Q] + dsT = tl.fma(pT, dpT, -pT * d_tile) # [BLOCK_SIZE_KV, BLOCK_SIZE_Q] + # dsT = tl.where(local_mask, dsT, 0.0) + dsT_scaled = dsT.to(gemm_dtype) * softmax_scale + # qk2[q, d] = qt.T[q, d] * k2[1, d] + + qt_tile_T = tl.trans(qt_tile) * k2_tile + + dk1 += tl.dot(dsT_scaled, qt_tile_T, out_dtype=tl.float32) + dv1_offs = kv1_offs_s[:, None] * dv1_stride_s + qkv_offs_h[None, :] * dv1_stride_h + dk1_offs = kv1_offs_s[:, None] * dk1_stride_s + qkv_offs_h[None, :] * dk1_stride_h + tl.store(dV1_ptr + dv1_offs, dv1.to(data_dtype), mask=kv1_mask) + tl.store(dK1_ptr + dk1_offs, dk1.to(data_dtype), mask=kv1_mask) + + +def two_simplical_attn_bwd_kv1_triton( + q, + k1, + k2, + v1, + v2, + dO, + m, + d, + w1, + w2, + dk1, + dv1, + dq, + sm_scale, +): + """Helper function to get bwd dk1 and dv1.""" + bs, seq_len, num_heads, head_dim = q.shape + sm_scale *= 1.0 + + # if w2 > w1: + # # NOTE: The total number of inner loop iterations is: + # # Total iterations = (w1 + BLOCK_SIZE_KV) * (w2 / BLOCK_SIZE_Q) + # # = (w1 * w2) / BLOCK_SIZE_Q + (BLOCK_SIZE_KV * w2) / BLOCK_SIZE_Q + # # When w1 != w2, assigning the smaller window size to w2 minimizes total iterations. + # # This is because w2 appears in both terms, while w1 only affects the first term. + # w1, w2 = w2, w1 + # k1, k2 = k2, k1 + # v1, v2 = v2, v1 + + grid = lambda args: (triton.cdiv(seq_len, args["BLOCK_SIZE_KV"]), bs * num_heads) # noqa: E731 + two_simplical_attn_bwd_kv1_kernel[grid]( + q, + k1, + k2, + v1, + v2, + dO, + m, + d, + dq, + dk1, + dv1, + bs, + seq_len, + num_heads, + w1, + w2, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k1.stride(0), + k1.stride(1), + k1.stride(2), + k1.stride(3), + k2.stride(0), + k2.stride(1), + k2.stride(2), + k2.stride(3), + v1.stride(0), + v1.stride(1), + v1.stride(2), + v1.stride(3), + v2.stride(0), + v2.stride(1), + v2.stride(2), + v2.stride(3), + dO.stride(0), + dO.stride(1), + dO.stride(2), + dO.stride(3), + m.stride(0), + m.stride(1), + m.stride(2), + d.stride(0), + d.stride(1), + d.stride(2), + dq.stride(0) if dq is not None else 0, + dq.stride(1) if dq is not None else 0, + dq.stride(2) if dq is not None else 0, + dq.stride(3) if dq is not None else 0, + dk1.stride(0), + dk1.stride(1), + dk1.stride(2), + dk1.stride(3), + dv1.stride(0), + dv1.stride(1), + dv1.stride(2), + dv1.stride(3), + HEAD_DIM=head_dim, + SM_SCALE=sm_scale, + ) + + +# "Single" pass kv2q kernel without atomics. +# Only works for tile_KV2 == 2 * tile_q == w2 +# Outer loop over q % 2 == 0. And go over KV2[q_start - w2, q_end] +# Second outer looper q % 2 == 1. Go over KV2[q_start - w2, q_end] and inplace add. +@triton.autotune( + configs=[ + Config( + { + "BLOCK_SIZE_Q": 32, + "BLOCK_SIZE_KV2": 64, + "num_stages": 1, + }, + num_warps=4, + ) + ], + key=["HEAD_DIM"], +) +@triton.jit +def two_simplical_attn_bwd_kv2q_kernel( + Q_ptr, # [b, s, k, h] + K1_ptr, # [b, s, k, h] + K2_ptr, # [b, s, k, h] + V1_ptr, # [b, s, k, h] + V2_ptr, # [b, s, k, h] + dO_ptr, # [b, s, k, h] + M_ptr, # [b, k, s] + D_ptr, # [b, k, s] + dQ_ptr, # [b, s, k, h] + dK2_ptr, # [b, s, k, h] + dV2_ptr, # [b, s, k, h] + bs, + seq_len, + num_heads, + head_dim, + w1, # Q[i]: KV1(i-w1,i] + w2, # Q[i]: KV2(i-w2,i] + q_stride_b, + q_stride_s, + q_stride_k, + q_stride_h, + k1_stride_b, + k1_stride_s, + k1_stride_k, + k1_stride_h, + k2_stride_b, + k2_stride_s, + k2_stride_k, + k2_stride_h, + v1_stride_b, + v1_stride_s, + v1_stride_k, + v1_stride_h, + v2_stride_b, + v2_stride_s, + v2_stride_k, + v2_stride_h, + dO_stride_b, + dO_stride_s, + dO_stride_k, + dO_stride_h, + m_stride_b, + m_stride_k, + m_stride_s, + d_stride_b, + d_stride_k, + d_stride_s, + dq_stride_b, + dq_stride_s, + dq_stride_k, + dq_stride_h, + dk2_stride_b, + dk2_stride_s, + dk2_stride_k, + dk2_stride_h, + dv2_stride_b, + dv2_stride_s, + dv2_stride_k, + dv2_stride_h, + BLOCK_SIZE_Q: tl.constexpr, + BLOCK_SIZE_KV2: tl.constexpr, + HEAD_DIM: tl.constexpr, + SM_SCALE: tl.constexpr, + num_stages: tl.constexpr, + IS_SECOND_PASS: tl.constexpr, +): + assert BLOCK_SIZE_KV2 == BLOCK_SIZE_Q + w2 + compute_dtype = tl.float32 + gemm_dtype = tl.bfloat16 + + # First pass does even tiles, second pass does odd tiles. + q_start = tl.program_id(0) * BLOCK_SIZE_KV2 + if IS_SECOND_PASS: + q_start += BLOCK_SIZE_Q + q_end = q_start + BLOCK_SIZE_Q + kv2_start = q_start - w2 + + bk = tl.program_id(1) + offs_b = bk // num_heads + offs_k = bk % num_heads + + qkv_offs_bk = offs_b * q_stride_b + offs_k * q_stride_k + Q_ptr += qkv_offs_bk + K1_ptr += qkv_offs_bk + K2_ptr += qkv_offs_bk + V1_ptr += qkv_offs_bk + V2_ptr += qkv_offs_bk + + dO_ptr += offs_b * dO_stride_b + offs_k * dO_stride_k + M_ptr += offs_b * m_stride_b + offs_k * m_stride_k + D_ptr += offs_b * d_stride_b + offs_k * d_stride_k + dQ_ptr += offs_b * dq_stride_b + offs_k * dq_stride_k + dK2_ptr += offs_b * dk2_stride_b + offs_k * dk2_stride_k + dV2_ptr += offs_b * dv2_stride_b + offs_k * dv2_stride_k + + softmax_scale = tl.cast(SM_SCALE, gemm_dtype) + qkv_offs_h = tl.arange(0, HEAD_DIM) + qkv_mask_h = qkv_offs_h < head_dim + + q_offs_s = q_start + tl.arange(0, BLOCK_SIZE_Q) + kv2_offs_s = kv2_start + tl.arange(0, BLOCK_SIZE_KV2) + q_offs = q_offs_s[:, None] * q_stride_s + qkv_offs_h[None, :] * q_stride_h + kv2_offs = kv2_offs_s[:, None] * k2_stride_s + qkv_offs_h[None, :] * k2_stride_h + m_offs = q_offs_s * m_stride_s + d_offs = q_offs_s * d_stride_s + dO_offs = q_offs_s[:, None] * dO_stride_s + qkv_offs_h[None, :] * dO_stride_h + q_mask_s = q_offs_s < seq_len + q_mask = q_mask_s[:, None] & qkv_mask_h[None, :] + kv2_mask_s = 0 <= kv2_offs_s and kv2_offs_s < seq_len + kv2_mask = kv2_mask_s[:, None] & qkv_mask_h[None, :] + + q_tile = tl.load(Q_ptr + q_offs, mask=q_mask).to( + compute_dtype + ) # [BLOCK_SIZE_Q, HEAD_DIM] + k2_tile = tl.load(K2_ptr + kv2_offs, mask=kv2_mask).to( + gemm_dtype + ) # [KV2, HEAD_DIM] + v2_tile = tl.load(V2_ptr + kv2_offs, mask=kv2_mask).to( + gemm_dtype + ) # [KV2, HEAD_DIM] + m_tile = tl.load(M_ptr + m_offs, mask=q_mask_s).to(gemm_dtype) # [BLOCK_SIZE_Q] + d_tile = tl.load(D_ptr + d_offs, mask=q_mask_s).to(gemm_dtype) # [BLOCK_SIZE_Q] + dO_tile = tl.load(dO_ptr + dO_offs, mask=q_mask).to( + gemm_dtype + ) # [BLOCK_SIZE_Q, HEAD_DIM] + + dq = tl.zeros((BLOCK_SIZE_Q, HEAD_DIM), tl.float32) + # dqT = tl.zeros((HEAD_DIM, BLOCK_SIZE_Q), tl.float32) + dk2 = tl.zeros((BLOCK_SIZE_KV2, HEAD_DIM), tl.float32) + dv2 = tl.zeros((BLOCK_SIZE_KV2, HEAD_DIM), tl.float32) + + kv1_start = tl.maximum(0, q_start - w1) + kv1_end = tl.minimum(seq_len, q_end) + for kv1_idx in tl.range(kv1_start, kv1_end, num_stages=num_stages): + k1_offs = kv1_idx * k1_stride_s + qkv_offs_h * k1_stride_h + v1_offs = kv1_idx * v1_stride_s + qkv_offs_h * v1_stride_h + k1_tile = tl.load(K1_ptr + k1_offs, mask=qkv_mask_h).to( + compute_dtype + ) # [HEAD_DIM] + + v1_tile = tl.load(V1_ptr + v1_offs, mask=qkv_mask_h).to( + compute_dtype + ) # [HEAD_DIM] + + qk1_s = q_tile * (k1_tile[None, :] * softmax_scale) # [Q, D] + qk1_s = qk1_s.to(gemm_dtype) + # k2[KV, Q] @ qk1_s.T[Q, D] => [KV2, Q] + qkkT = tl.dot(k2_tile, qk1_s.T, out_dtype=tl.float32) # [KV2, Q] + + qkT_mask = kv2_mask_s[:, None] & q_mask_s[None, :] + kv1_local_mask = ((q_offs_s[None, :] - w1) < kv1_idx) & ( + kv1_idx <= q_offs_s[None, :] + ) # [KV2, Q] + kv2_local_mask = ((q_offs_s[None, :] - w2) < kv2_offs_s[:, None]) & ( + kv2_offs_s[:, None] <= q_offs_s[None, :] + ) # [KV2, Q] + local_mask = kv1_local_mask & kv2_local_mask # [BLOCK_SIZE_KV, BLOCK_SIZE_Q] + qkT_mask &= kv1_local_mask & kv2_local_mask + + pT = tl.exp(qkkT - m_tile[None, :]) # [KV2, Q] + pT = tl.where(qkT_mask, pT, 0.0) + + qkkT = tl.where(local_mask, qkkT, -1.0e38) + + dOv1 = dO_tile * v1_tile[None, :] # [Q, D] + dOv1 = dOv1.to(gemm_dtype) + # pT[KV2, Q] @ dOv1[Q, D] => [KV2, D] + dv2 += tl.dot(pT.to(gemm_dtype), dOv1, out_dtype=tl.float32) + + # v2[KV2, D] @ dOv1.T[D, Q] => dpT[KV2, Q] + dpT = tl.dot(v2_tile, dOv1.T, out_dtype=tl.float32) + dsT = pT * (dpT - d_tile[None, :]) # [KV2, Q] + dsT = tl.where(qkT_mask, dsT, 0.0) + dsT = dsT.to(gemm_dtype) # [KV2, Q] + + # dsT[KV2, Q] @ qk1[Q, D] => dk2[KV2, D] + dk2 += tl.dot(dsT, qk1_s, out_dtype=tl.float32) + + k1k2 = k1_tile[None, :] * k2_tile # [KV2, D] + k1k2 = k1k2.to(gemm_dtype) + # k1k2T = k1_tile[:, None] * k2_tile.T + # k1k2T = k1k2T.to(gemm_dtype) + + # Normal. + # dsT.T[Q, KV2] @ [KV2, D] => dq[Q, D] + dq += tl.dot(dsT.T, k1k2) # * softmax scale at the end. + + # End. update gradients. + if IS_SECOND_PASS: + # load, add. + prev_dk2 = tl.load(dK2_ptr + kv2_offs, kv2_mask) + prev_dv2 = tl.load(dV2_ptr + kv2_offs, kv2_mask) + dk2 += prev_dk2 + dv2 += prev_dv2 + + dq *= softmax_scale + tl.store(dK2_ptr + kv2_offs, dk2, kv2_mask) + tl.store(dV2_ptr + kv2_offs, dv2, kv2_mask) + tl.store(dQ_ptr + q_offs, dq, q_mask) + + +def two_simplical_attn_bwd_kv2q_triton( + q, + k1, + k2, + v1, + v2, + dO, + m, + d, + w1, + w2, + dk2, + dv2, + dq, + sm_scale, +): + bs, seq_len, num_heads, head_dim = q.shape + sm_scale *= 1.0 # TODO math.log2(math.exp(1)) is faster. + + # assert w2 == 32 + # TODO replace with grid assert, w2 + BLOCK_SIZE_Q == BLOCK_SIZE_KV2 + def grid(args): + return (triton.cdiv(seq_len, args["BLOCK_SIZE_KV2"]), bs * num_heads) + + for is_second_pass in [False, True]: + two_simplical_attn_bwd_kv2q_kernel[grid]( + q, + k1, + k2, + v1, + v2, + dO, + m, + d, + dq, + dk2, + dv2, + bs, + seq_len, + num_heads, + head_dim, + w1, + w2, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k1.stride(0), + k1.stride(1), + k1.stride(2), + k1.stride(3), + k2.stride(0), + k2.stride(1), + k2.stride(2), + k2.stride(3), + v1.stride(0), + v1.stride(1), + v1.stride(2), + v1.stride(3), + v2.stride(0), + v2.stride(1), + v2.stride(2), + v2.stride(3), + dO.stride(0), + dO.stride(1), + dO.stride(2), + dO.stride(3), + m.stride(0), + m.stride(1), + m.stride(2), + d.stride(0), + d.stride(1), + d.stride(2), + dq.stride(0), + dq.stride(1), + dq.stride(2), + dq.stride(3), + dk2.stride(0), + dk2.stride(1), + dk2.stride(2), + dk2.stride(3), + dv2.stride(0), + dv2.stride(1), + dv2.stride(2), + dv2.stride(3), + # BLOCK_SIZE_Q=block_size_q, + # BLOCK_SIZE_KV=block_size_kv, + HEAD_DIM=triton.next_power_of_2(head_dim), + SM_SCALE=sm_scale, + IS_SECOND_PASS=is_second_pass, + ) diff --git a/fbgemm_gpu/experimental/gen_ai/gen_ai/two_simplical_attn/two_simplical_attn_fwd.py b/fbgemm_gpu/experimental/gen_ai/gen_ai/two_simplical_attn/two_simplical_attn_fwd.py new file mode 100644 index 0000000000..cd07ef1235 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/gen_ai/two_simplical_attn/two_simplical_attn_fwd.py @@ -0,0 +1,369 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +import os + +import torch +import triton +import triton.language as tl +from triton import Config + +TWO_SIMPLICAL_AUTOTUNE = os.getenv("TWO_SIMPLICAL_AUTOTUNE", "0") == "1" + + +def get_configs(): + if TWO_SIMPLICAL_AUTOTUNE: + return [ + Config( + { + "BLOCK_SIZE_KV": BLOCK_SIZE_KV, + "num_stages_0": num_stages_0, + "num_stages_1": num_stages_1, + }, + num_warps=num_warps, + ) + for BLOCK_SIZE_KV in [32, 64, 128, 256] + for num_stages_0 in [8, 16, 32] + for num_stages_1 in [1, 2, 3, 4] + for num_warps in [4, 8] + ] + return [ + Config( + { + "BLOCK_SIZE_KV": BLOCK_SIZE_KV, + "num_stages_0": num_stages_0, + "num_stages_1": num_stages_1, + }, + num_warps=num_warps, + ) + for BLOCK_SIZE_KV in [64] + for num_stages_0 in [16] + for num_stages_1 in [1] + for num_warps in [4] + ] + + +@triton.jit +def _gqa_pack_fwd_inner( + K2_ptr, + V2_ptr, + k2_stride_s, + k2_stride_h, + v2_stride_s, + v2_stride_h, + qkv_offs_h, + q_idx, + kv2_idx, + qk1, + v1_tile, + acc, + m_i, + l_i, + gemm_dtype: tl.constexpr, + w2: tl.constexpr, + BLOCK_SIZE_KV: tl.constexpr, + IS_MASK_LOAD: tl.constexpr, +): + # [BLOCK_SIZE_KV, HEAD_DIM] + kv2_offs_s = kv2_idx + tl.arange(0, BLOCK_SIZE_KV) + k2_offs = kv2_offs_s[None, :] * k2_stride_s + qkv_offs_h[:, None] * k2_stride_h + v2_offs = kv2_offs_s[:, None] * v2_stride_s + qkv_offs_h[None, :] * v2_stride_h + + if IS_MASK_LOAD: + kv2_mask_s = q_idx - w2 < kv2_offs_s and kv2_offs_s <= q_idx + k2t_tile = tl.load( + K2_ptr + k2_offs, mask=kv2_mask_s[None, :] + ) # [HEAD_DIM, BLOCK_SIZE_KV] + v2_tile = tl.load( + V2_ptr + v2_offs, mask=kv2_mask_s[:, None] + ) # [BLOCK_SIZE_KV, HEAD_DIM] + else: + kv2_mask_s = None + k2t_tile = tl.load(K2_ptr + k2_offs) # [HEAD_DIM, BLOCK_SIZE_KV] + v2_tile = tl.load(V2_ptr + v2_offs) # [BLOCK_SIZE_KV, HEAD_DIM] + + # k2 @ qk1.T: [kv2, d] @ [d, q] -> [kv2, q] + # qkkT [kv2, q] + qk = tl.dot( + qk1, # * softmax_scale, + k2t_tile, + input_precision="tf32", # INPUT_PRECISION, + out_dtype=tl.float32, + ) # [BLOCK_SIZE_Q, BLOCK_SIZE_KV] + + # Mask for q_idx - w1 < kv1_idx <= q_idx + # and q_idx - w2 < kv2_offs_s <= q_idx + + if IS_MASK_LOAD: + qk_mask = kv2_mask_s[None, :] + # TODO Triton nan's with -inf, but float max is probably sufficient. + qk += tl.where(qk_mask, 0, -1.0e6) + + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.math.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + v12_tile = v1_tile * v2_tile # [BLOCK_SIZE_KV, HEAD_DIM] + + # v12T[d, kv2] @ pT[kv2, q]: accT[d, q] + acc += tl.dot( + p.to(gemm_dtype), + v12_tile, + input_precision="ieee", # INPUT_PRECISION, + out_dtype=tl.float32, + ) + + m_i = m_ij + + return acc, m_i, l_i + + +# Without TMA. +@triton.autotune( + configs=get_configs(), + key=["HEAD_DIM", "w1", "w2", "seq_len"], +) +@triton.jit +def _gqa_pack_fwd_kernel( + Q_ptr, # [b, s, k, h] + K1_ptr, # [b, s, 1, h] + K2_ptr, # [b, s, 1, h] + V1_ptr, # [b, s, 1, h] + V2_ptr, # [b, s, 1, h] + O_ptr, # [b, s, k, h] + M_ptr, # [b, k, s] + bs, + seq_len, + num_heads, + w1: tl.constexpr, + w2: tl.constexpr, + q_stride_b, + q_stride_s, + q_stride_k, + q_stride_h, + k1_stride_b, + k1_stride_s, + k1_stride_k, + k1_stride_h, + k2_stride_b, + k2_stride_s, + k2_stride_k, + k2_stride_h, + v1_stride_b, + v1_stride_s, + v1_stride_k, + v1_stride_h, + v2_stride_b, + v2_stride_s, + v2_stride_k, + v2_stride_h, + out_stride_b, + out_stride_s, + out_stride_k, + out_stride_h, + m_stride_b, + m_stride_k, + m_stride_s, + BLOCK_SIZE_Q: tl.constexpr, + BLOCK_SIZE_KV: tl.constexpr, + HEAD_DIM: tl.constexpr, + INPUT_PRECISION: tl.constexpr, + SM_SCALE: tl.constexpr, + K2_BIAS: tl.constexpr, + V2_BIAS: tl.constexpr, + num_stages_0: tl.constexpr, + num_stages_1: tl.constexpr, +): + """GQA two simplical attention fwd kernel. Assume TP=num_kv so kernel is called per kv_group.""" + data_dtype = tl.bfloat16 + compute_dtype = tl.float32 + gemm_dtype = tl.bfloat16 + + q_idx = tl.program_id(0) + offs_b = tl.program_id(1) + + q_offs_b = offs_b * q_stride_b + kv_offs_b = offs_b * k1_stride_b + + Q_ptr += q_offs_b + q_idx * q_stride_s + K1_ptr += kv_offs_b + K2_ptr += kv_offs_b + V1_ptr += kv_offs_b + V2_ptr += kv_offs_b + O_ptr += offs_b * out_stride_b + q_idx * out_stride_s + M_ptr += offs_b * m_stride_b + q_idx * m_stride_s + + m_i = tl.zeros((BLOCK_SIZE_Q,), dtype=compute_dtype) - float("inf") + # TODO why does triton impl initialize this as 1, but paper uses 0. + l_i = tl.zeros((BLOCK_SIZE_Q,), dtype=compute_dtype) + acc = tl.zeros((BLOCK_SIZE_Q, HEAD_DIM), dtype=compute_dtype) + + qkv_offs_h = tl.arange(0, HEAD_DIM) + q_offs_k = tl.arange(0, BLOCK_SIZE_Q) + q_offs = q_offs_k[:, None] * q_stride_k + qkv_offs_h[None, :] * q_stride_h + + q_tile = tl.load(Q_ptr + q_offs) # [BLOCK_SIZE_Q, HEAD_DIM] + softmax_scale = tl.cast(SM_SCALE, gemm_dtype) + + q_tile = q_tile * softmax_scale + + kv2_start = tl.maximum(0, q_idx - w2 + 1) + kv2_end = tl.minimum(seq_len, q_idx + 1) + has_mask_load = (kv2_end - kv2_start) % BLOCK_SIZE_KV > 0 + num_n_trips = tl.cdiv(kv2_end - kv2_start, BLOCK_SIZE_KV) + num_n_trips_inner = num_n_trips - 1 if has_mask_load else num_n_trips + + kv1_start = tl.maximum(0, q_idx - w1 + 1) + kv1_end = tl.minimum(seq_len, q_idx + 1) + for kv1_idx in tl.range(kv1_start, kv1_end, num_stages=num_stages_0): + k1_offs = kv1_idx * k1_stride_s + qkv_offs_h * k1_stride_h + k1_tile = tl.load(K1_ptr + k1_offs)[None, :] # [1, HEAD_DIM] + qk1_tile = q_tile * k1_tile # [BLOCK_SIZE_Q, HEAD_DIM] + + v1_offs = kv1_idx * v1_stride_s + qkv_offs_h * v1_stride_h + v1_tile = tl.load(V1_ptr + v1_offs)[None,] # [1, HEAD_DIM] + + for _inner_idx in tl.range( + num_n_trips_inner, + num_stages=num_stages_1, + ): + kv2_idx = kv2_start + _inner_idx * BLOCK_SIZE_KV + acc, m_i, l_i = _gqa_pack_fwd_inner( + K2_ptr, + V2_ptr, + k2_stride_s, + k2_stride_h, + v2_stride_s, + v2_stride_h, + qkv_offs_h, + q_idx, + kv2_idx, + qk1_tile, + v1_tile, + acc, + m_i, + l_i, + gemm_dtype, + w2, + BLOCK_SIZE_KV, + IS_MASK_LOAD=False, + ) + if has_mask_load: + kv2_idx = kv2_start + num_n_trips_inner * BLOCK_SIZE_KV + acc, m_i, l_i = _gqa_pack_fwd_inner( + K2_ptr, + V2_ptr, + k2_stride_s, + k2_stride_h, + v2_stride_s, + v2_stride_h, + qkv_offs_h, + q_idx, + kv2_idx, + qk1_tile, + v1_tile, + acc, + m_i, + l_i, + gemm_dtype, + w2, + BLOCK_SIZE_KV, + IS_MASK_LOAD=True, + ) + + acc = acc / l_i[:, None] + acc = acc.to(data_dtype) + out_offs = q_offs_k[:, None] * out_stride_k + qkv_offs_h[None, :] * out_stride_h + tl.store(O_ptr + out_offs, acc) + + m = m_i + tl.log(l_i) + + m_offs = q_offs_k * m_stride_k + tl.store(M_ptr + m_offs, m) + + +def two_simplical_attn_fwd(q, k1, k2, v1, v2, w1, w2, k2_bias=None, v2_bias=None): + """Two simplical attention kernel with GQA packing. + L = q @ k1 X k2 + P = softmax(L, axis=[-1, -2]) + O = P @ v1 X v2 + """ + bs, seq_len, num_heads, head_dim = q.shape + _, seq_len1, _, _ = k1.shape + _, seq_len2, _, _ = k2.shape + assert ( + seq_len == seq_len1 and seq_len1 == seq_len2 + ), "input seq lens must match, sliding window is done within kernel" + assert w1 > 0 and w2 > 0, "block local windows must be positive" + output = torch.zeros_like(q, memory_format=torch.contiguous_format).to( + torch.bfloat16 + ) + m = torch.zeros((bs, num_heads, seq_len), dtype=torch.float32, device=q.device) + # INPUT_PRECISION = "ieee" + INPUT_PRECISION = "tf32" + # e^x = 2^(x * log2(e)), so we multiply x by log2(e) to use faster exp2 in kernel. + sm_scale = 1.44269504 # math.log2(math.exp(1)) + sm_scale *= head_dim**-0.5 + if not k2_bias: + k2_bias = 1.0 / head_dim + if not v2_bias: + v2_bias = 1.0 + + # NOTE: to optimize performance, we always make sure w1 is the smaller window size, when + # w1 and w2 are not equal. + if w1 > w2: + k1, k2 = k2, k1 + v1, v2 = v2, v1 + w1, w2 = w2, w1 + + grid = lambda args: (seq_len, bs) # noqa: E731 + + _gqa_pack_fwd_kernel[grid]( + q, + k1, + k2, + v1, + v2, + output, + m, + bs, + seq_len, + num_heads, + w1, + w2, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k1.stride(0), + k1.stride(1), + k1.stride(2), + k1.stride(3), + k2.stride(0), + k2.stride(1), + k2.stride(2), + k2.stride(3), + v1.stride(0), + v1.stride(1), + v1.stride(2), + v1.stride(3), + v2.stride(0), + v2.stride(1), + v2.stride(2), + v2.stride(3), + output.stride(0), + output.stride(1), + output.stride(2), + output.stride(3), + m.stride(0), + m.stride(1), + m.stride(2), + HEAD_DIM=head_dim, + INPUT_PRECISION=INPUT_PRECISION, + SM_SCALE=sm_scale, + K2_BIAS=k2_bias, + V2_BIAS=v2_bias, + BLOCK_SIZE_Q=num_heads, + ) + return output, m