Skip to content

Commit 3eceba2

Browse files
committed
make a copy for ws version
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent e01efe6 commit 3eceba2

File tree

1 file changed

+324
-8
lines changed

1 file changed

+324
-8
lines changed

tritonbench/kernels/triton_fused_attention.py

Lines changed: 324 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,6 +1497,131 @@ def _attn_bwd_dkdv(
14971497
start_m,
14981498
num_steps, #
14991499
MASK: tl.constexpr,
1500+
):
1501+
offs_m = start_m + tl.arange(0, BLOCK_M1)
1502+
offs_n = start_n + tl.arange(0, BLOCK_N1)
1503+
offs_k = tl.arange(0, HEAD_DIM)
1504+
qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
1505+
do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
1506+
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
1507+
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
1508+
curr_m = start_m
1509+
step_m = BLOCK_M1
1510+
for blk_idx in range(num_steps):
1511+
qT = tl.load(qT_ptrs)
1512+
# Load m before computing qk to reduce pipeline stall.
1513+
offs_m = curr_m + tl.arange(0, BLOCK_M1)
1514+
m = tl.load(M + offs_m)
1515+
qkT = tl.dot(k, qT)
1516+
#dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
1517+
pT = tl.math.exp2(qkT - m[None, :])
1518+
# Autoregressive masking.
1519+
if MASK:
1520+
mask = offs_m[None, :] >= offs_n[:, None]
1521+
pT = tl.where(mask, pT, 0.0)
1522+
do = tl.load(do_ptrs)
1523+
# Compute dV.
1524+
ppT = pT
1525+
ppT = ppT.to(tl.bfloat16)
1526+
dv += tl.dot(ppT, do)
1527+
# D (= delta) is pre-divided by ds_scale.
1528+
Di = tl.load(D + offs_m)
1529+
# Compute dP and dS.
1530+
dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
1531+
dsT = pT * (dpT - Di[None, :])
1532+
dsT = dsT.to(tl.bfloat16)
1533+
dk += tl.dot(dsT, tl.trans(qT))
1534+
# Increment pointers.
1535+
curr_m += step_m
1536+
qT_ptrs += step_m * stride_tok
1537+
do_ptrs += step_m * stride_tok
1538+
return dk, dv
1539+
1540+
1541+
# the main inner-loop logic for computing dQ
1542+
@triton.jit
1543+
def _attn_bwd_dq(
1544+
dq,
1545+
q,
1546+
K,
1547+
V, #
1548+
do,
1549+
m,
1550+
D,
1551+
# shared by Q/K/V/DO.
1552+
stride_tok,
1553+
stride_d, #
1554+
H,
1555+
N_CTX, #
1556+
BLOCK_M2: tl.constexpr, #
1557+
BLOCK_N2: tl.constexpr, #
1558+
HEAD_DIM: tl.constexpr,
1559+
# Filled in by the wrapper.
1560+
start_m,
1561+
start_n,
1562+
num_steps, #
1563+
MASK: tl.constexpr,
1564+
):
1565+
offs_m = start_m + tl.arange(0, BLOCK_M2)
1566+
offs_n = start_n + tl.arange(0, BLOCK_N2)
1567+
offs_k = tl.arange(0, HEAD_DIM)
1568+
kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
1569+
vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
1570+
# D (= delta) is pre-divided by ds_scale.
1571+
Di = tl.load(D + offs_m)
1572+
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
1573+
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
1574+
curr_n = start_n
1575+
step_n = BLOCK_N2
1576+
for blk_idx in range(num_steps):
1577+
kT = tl.load(kT_ptrs)
1578+
vT = tl.load(vT_ptrs)
1579+
qk = tl.dot(q, kT)
1580+
p = tl.math.exp2(qk - m)
1581+
# Autoregressive masking.
1582+
if MASK:
1583+
offs_n = curr_n + tl.arange(0, BLOCK_N2)
1584+
mask = offs_m[:, None] >= offs_n[None, :]
1585+
p = tl.where(mask, p, 0.0)
1586+
# Compute dP and dS.
1587+
dp = tl.dot(do, vT).to(tl.float32)
1588+
ds = p * (dp - Di[:, None])
1589+
ds = ds.to(tl.bfloat16)
1590+
# Compute dQ.
1591+
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
1592+
dq += tl.dot(ds, tl.trans(kT))
1593+
# Increment pointers.
1594+
curr_n += step_n
1595+
kT_ptrs += step_n * stride_tok
1596+
vT_ptrs += step_n * stride_tok
1597+
return dq
1598+
1599+
1600+
# The main inner-loop logic for computing dK and dV.
1601+
@triton.jit
1602+
def _attn_bwd_dkdv_ws(
1603+
dk,
1604+
dv, #
1605+
Q,
1606+
k,
1607+
v,
1608+
sm_scale, #
1609+
DO, #
1610+
M,
1611+
D, #
1612+
# shared by Q/K/V/DO.
1613+
stride_tok,
1614+
stride_d, #
1615+
H,
1616+
N_CTX,
1617+
BLOCK_M1: tl.constexpr, #
1618+
BLOCK_N1: tl.constexpr, #
1619+
HEAD_DIM: tl.constexpr, #
1620+
# Filled in by the wrapper.
1621+
start_n,
1622+
start_m,
1623+
num_steps, #
1624+
MASK: tl.constexpr,
15001625
):
15011626
offs_m = start_m + tl.arange(0, BLOCK_M1)
15021627
offs_n = start_n + tl.arange(0, BLOCK_N1)
@@ -1546,7 +1671,7 @@ def _attn_bwd_dkdv(
15461671

15471672
# the main inner-loop logic for computing dQ
15481673
@triton.jit
1549-
def _attn_bwd_dq(
1674+
def _attn_bwd_dq_ws(
15501675
dq,
15511676
q,
15521677
K,
@@ -1734,9 +1859,8 @@ def _attn_bwd_compute(
17341859
dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
17351860

17361861
# load K and V: they stay in SRAM throughout the inner loop.
1737-
with tl.async_task([0]):
1738-
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
1739-
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
1862+
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
1863+
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
17401864

17411865
num_steps = BLOCK_N1 // MASK_BLOCK_M1
17421866

@@ -1805,9 +1929,8 @@ def _attn_bwd_compute(
18051929
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
18061930
offs_m = start_m + tl.arange(0, BLOCK_M2)
18071931

1808-
with tl.async_task([0]):
1809-
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
1810-
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
1932+
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
1933+
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
18111934
dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
18121935

18131936
m = tl.load(M + offs_m)
@@ -1868,6 +1991,199 @@ def _attn_bwd_compute(
18681991
tl.store(dq_ptrs, dq)
18691992

18701993

1994+
@triton.jit
1995+
def _attn_bwd_compute_ws(
1996+
Q,
1997+
K,
1998+
V,
1999+
sm_scale, #
2000+
DO, #
2001+
DQ,
2002+
DK,
2003+
DV, #
2004+
M,
2005+
D,
2006+
# shared by Q/K/V/DO.
2007+
stride_z,
2008+
stride_h,
2009+
stride_tok,
2010+
stride_d, #
2011+
H,
2012+
N_CTX, #
2013+
BLOCK_M1: tl.constexpr, #
2014+
BLOCK_N1: tl.constexpr, #
2015+
BLOCK_M2: tl.constexpr, #
2016+
BLOCK_N2: tl.constexpr, #
2017+
BLK_SLICE_FACTOR: tl.constexpr, #
2018+
HEAD_DIM: tl.constexpr,
2019+
):
2020+
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
2021+
2022+
bhid = tl.program_id(2)
2023+
off_chz = (bhid * N_CTX).to(tl.int64)
2024+
adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
2025+
pid = tl.program_id(0)
2026+
2027+
# offset pointers for batch/head
2028+
Q += adj
2029+
K += adj
2030+
V += adj
2031+
DO += adj
2032+
DQ += adj
2033+
DK += adj
2034+
DV += adj
2035+
M += off_chz
2036+
D += off_chz
2037+
2038+
# load scales
2039+
offs_k = tl.arange(0, HEAD_DIM)
2040+
2041+
start_n = pid * BLOCK_N1
2042+
start_m = start_n
2043+
2044+
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
2045+
offs_n = start_n + tl.arange(0, BLOCK_N1)
2046+
2047+
dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
2048+
dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
2049+
2050+
# load K and V: they stay in SRAM throughout the inner loop.
2051+
with tl.async_task([0]):
2052+
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
2053+
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
2054+
2055+
num_steps = BLOCK_N1 // MASK_BLOCK_M1
2056+
2057+
dk, dv = _attn_bwd_dkdv_ws(
2058+
dk,
2059+
dv, #
2060+
Q,
2061+
k,
2062+
v,
2063+
sm_scale, #
2064+
DO, #
2065+
M,
2066+
D, #
2067+
stride_tok,
2068+
stride_d, #
2069+
H,
2070+
N_CTX, #
2071+
MASK_BLOCK_M1,
2072+
BLOCK_N1,
2073+
HEAD_DIM, #
2074+
start_n,
2075+
start_m,
2076+
num_steps, #
2077+
MASK=True, #
2078+
)
2079+
2080+
start_m += num_steps * MASK_BLOCK_M1
2081+
num_steps = (N_CTX - start_m) // BLOCK_M1
2082+
2083+
# Compute dK and dV for non-masked blocks.
2084+
dk, dv = _attn_bwd_dkdv_ws( #
2085+
dk,
2086+
dv, #
2087+
Q,
2088+
k,
2089+
v,
2090+
sm_scale, #
2091+
DO, #
2092+
M,
2093+
D, #
2094+
stride_tok,
2095+
stride_d, #
2096+
H,
2097+
N_CTX, #
2098+
BLOCK_M1,
2099+
BLOCK_N1,
2100+
HEAD_DIM, #
2101+
start_n,
2102+
start_m,
2103+
num_steps, #
2104+
MASK=False, #
2105+
)
2106+
2107+
with tl.async_task([1, 2]):
2108+
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
2109+
tl.store(dv_ptrs, dv)
2110+
2111+
# Write back dK.
2112+
dk *= sm_scale
2113+
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
2114+
tl.store(dk_ptrs, dk)
2115+
2116+
# THIS BLOCK DOES DQ:
2117+
start_m = pid * BLOCK_M2
2118+
end_n = start_m + BLOCK_M2
2119+
2120+
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
2121+
offs_m = start_m + tl.arange(0, BLOCK_M2)
2122+
2123+
with tl.async_task([0]):
2124+
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
2125+
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
2126+
dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
2127+
2128+
m = tl.load(M + offs_m)
2129+
m = m[:, None]
2130+
2131+
# Compute dQ for masked (diagonal) blocks.
2132+
# NOTE: This code scans each row of QK^T backward (from right to left,
2133+
# but inside each call to _attn_bwd_dq, from left to right), but that's
2134+
# not due to anything important. I just wanted to reuse the loop
2135+
# structure for dK & dV above as much as possible.
2136+
num_steps = BLOCK_M2 // MASK_BLOCK_N2
2137+
dq = _attn_bwd_dq_ws(
2138+
dq,
2139+
q,
2140+
K,
2141+
V, #
2142+
do,
2143+
m,
2144+
D, #
2145+
stride_tok,
2146+
stride_d, #
2147+
H,
2148+
N_CTX, #
2149+
BLOCK_M2,
2150+
MASK_BLOCK_N2,
2151+
HEAD_DIM, #
2152+
start_m,
2153+
end_n - num_steps * MASK_BLOCK_N2,
2154+
num_steps, #
2155+
MASK=True, #
2156+
)
2157+
end_n -= num_steps * MASK_BLOCK_N2
2158+
# stage 2
2159+
num_steps = end_n // BLOCK_N2
2160+
dq = _attn_bwd_dq_ws(
2161+
dq,
2162+
q,
2163+
K,
2164+
V, #
2165+
do,
2166+
m,
2167+
D, #
2168+
stride_tok,
2169+
stride_d, #
2170+
H,
2171+
N_CTX, #
2172+
BLOCK_M2,
2173+
BLOCK_N2,
2174+
HEAD_DIM, #
2175+
start_m,
2176+
end_n - num_steps * BLOCK_N2,
2177+
num_steps, #
2178+
MASK=False, #
2179+
)
2180+
# Write back dQ.
2181+
with tl.async_task([1, 2]):
2182+
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
2183+
dq *= LN2
2184+
tl.store(dq_ptrs, dq)
2185+
2186+
18712187
@triton.autotune(list(filter(keep2, configsBwd)), key=["N_CTX"])
18722188
@triton.jit
18732189
def _attn_bwd(
@@ -1949,7 +2265,7 @@ def _attn_bwd_ws(
19492265
BLK_SLICE_FACTOR: tl.constexpr, #
19502266
HEAD_DIM: tl.constexpr,
19512267
):
1952-
_attn_bwd_compute(
2268+
_attn_bwd_compute_ws(
19532269
Q,
19542270
K,
19552271
V,

0 commit comments

Comments
 (0)