@@ -1497,6 +1497,131 @@ def _attn_bwd_dkdv(
1497
1497
start_m ,
1498
1498
num_steps , #
1499
1499
MASK : tl .constexpr ,
1500
+ ):
1501
+ offs_m = start_m + tl .arange (0 , BLOCK_M1 )
1502
+ offs_n = start_n + tl .arange (0 , BLOCK_N1 )
1503
+ offs_k = tl .arange (0 , HEAD_DIM )
1504
+ qT_ptrs = Q + offs_m [None , :] * stride_tok + offs_k [:, None ] * stride_d
1505
+ do_ptrs = DO + offs_m [:, None ] * stride_tok + offs_k [None , :] * stride_d
1506
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
1507
+ tl .static_assert (BLOCK_N1 % BLOCK_M1 == 0 )
1508
+ curr_m = start_m
1509
+ step_m = BLOCK_M1
1510
+ for blk_idx in range (num_steps ):
1511
+ qT = tl .load (qT_ptrs )
1512
+ # Load m before computing qk to reduce pipeline stall.
1513
+ offs_m = curr_m + tl .arange (0 , BLOCK_M1 )
1514
+ m = tl .load (M + offs_m )
1515
+ qkT = tl .dot (k , qT )
1516
+ #dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
1517
+ pT = tl .math .exp2 (qkT - m [None , :])
1518
+ # Autoregressive masking.
1519
+ if MASK :
1520
+ mask = offs_m [None , :] >= offs_n [:, None ]
1521
+ pT = tl .where (mask , pT , 0.0 )
1522
+ do = tl .load (do_ptrs )
1523
+ # Compute dV.
1524
+ ppT = pT
1525
+ ppT = ppT .to (tl .bfloat16 )
1526
+ dv += tl .dot (ppT , do )
1527
+ # D (= delta) is pre-divided by ds_scale.
1528
+ Di = tl .load (D + offs_m )
1529
+ # Compute dP and dS.
1530
+ dpT = tl .dot (v , tl .trans (do )).to (tl .float32 )
1531
+ dsT = pT * (dpT - Di [None , :])
1532
+ dsT = dsT .to (tl .bfloat16 )
1533
+ dk += tl .dot (dsT , tl .trans (qT ))
1534
+ # Increment pointers.
1535
+ curr_m += step_m
1536
+ qT_ptrs += step_m * stride_tok
1537
+ do_ptrs += step_m * stride_tok
1538
+ return dk , dv
1539
+
1540
+
1541
+ # the main inner-loop logic for computing dQ
1542
+ @triton .jit
1543
+ def _attn_bwd_dq (
1544
+ dq ,
1545
+ q ,
1546
+ K ,
1547
+ V , #
1548
+ do ,
1549
+ m ,
1550
+ D ,
1551
+ # shared by Q/K/V/DO.
1552
+ stride_tok ,
1553
+ stride_d , #
1554
+ H ,
1555
+ N_CTX , #
1556
+ BLOCK_M2 : tl .constexpr , #
1557
+ BLOCK_N2 : tl .constexpr , #
1558
+ HEAD_DIM : tl .constexpr ,
1559
+ # Filled in by the wrapper.
1560
+ start_m ,
1561
+ start_n ,
1562
+ num_steps , #
1563
+ MASK : tl .constexpr ,
1564
+ ):
1565
+ offs_m = start_m + tl .arange (0 , BLOCK_M2 )
1566
+ offs_n = start_n + tl .arange (0 , BLOCK_N2 )
1567
+ offs_k = tl .arange (0 , HEAD_DIM )
1568
+ kT_ptrs = K + offs_n [None , :] * stride_tok + offs_k [:, None ] * stride_d
1569
+ vT_ptrs = V + offs_n [None , :] * stride_tok + offs_k [:, None ] * stride_d
1570
+ # D (= delta) is pre-divided by ds_scale.
1571
+ Di = tl .load (D + offs_m )
1572
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
1573
+ tl .static_assert (BLOCK_M2 % BLOCK_N2 == 0 )
1574
+ curr_n = start_n
1575
+ step_n = BLOCK_N2
1576
+ for blk_idx in range (num_steps ):
1577
+ kT = tl .load (kT_ptrs )
1578
+ vT = tl .load (vT_ptrs )
1579
+ qk = tl .dot (q , kT )
1580
+ p = tl .math .exp2 (qk - m )
1581
+ # Autoregressive masking.
1582
+ if MASK :
1583
+ offs_n = curr_n + tl .arange (0 , BLOCK_N2 )
1584
+ mask = offs_m [:, None ] >= offs_n [None , :]
1585
+ p = tl .where (mask , p , 0.0 )
1586
+ # Compute dP and dS.
1587
+ dp = tl .dot (do , vT ).to (tl .float32 )
1588
+ ds = p * (dp - Di [:, None ])
1589
+ ds = ds .to (tl .bfloat16 )
1590
+ # Compute dQ.
1591
+ # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
1592
+ dq += tl .dot (ds , tl .trans (kT ))
1593
+ # Increment pointers.
1594
+ curr_n += step_n
1595
+ kT_ptrs += step_n * stride_tok
1596
+ vT_ptrs += step_n * stride_tok
1597
+ return dq
1598
+
1599
+
1600
+ # The main inner-loop logic for computing dK and dV.
1601
+ @triton .jit
1602
+ def _attn_bwd_dkdv_ws (
1603
+ dk ,
1604
+ dv , #
1605
+ Q ,
1606
+ k ,
1607
+ v ,
1608
+ sm_scale , #
1609
+ DO , #
1610
+ M ,
1611
+ D , #
1612
+ # shared by Q/K/V/DO.
1613
+ stride_tok ,
1614
+ stride_d , #
1615
+ H ,
1616
+ N_CTX ,
1617
+ BLOCK_M1 : tl .constexpr , #
1618
+ BLOCK_N1 : tl .constexpr , #
1619
+ HEAD_DIM : tl .constexpr , #
1620
+ # Filled in by the wrapper.
1621
+ start_n ,
1622
+ start_m ,
1623
+ num_steps , #
1624
+ MASK : tl .constexpr ,
1500
1625
):
1501
1626
offs_m = start_m + tl .arange (0 , BLOCK_M1 )
1502
1627
offs_n = start_n + tl .arange (0 , BLOCK_N1 )
@@ -1546,7 +1671,7 @@ def _attn_bwd_dkdv(
1546
1671
1547
1672
# the main inner-loop logic for computing dQ
1548
1673
@triton .jit
1549
- def _attn_bwd_dq (
1674
+ def _attn_bwd_dq_ws (
1550
1675
dq ,
1551
1676
q ,
1552
1677
K ,
@@ -1734,9 +1859,8 @@ def _attn_bwd_compute(
1734
1859
dk = tl .zeros ([BLOCK_N1 , HEAD_DIM ], dtype = tl .float32 )
1735
1860
1736
1861
# load K and V: they stay in SRAM throughout the inner loop.
1737
- with tl .async_task ([0 ]):
1738
- k = tl .load (K + offs_n [:, None ] * stride_tok + offs_k [None , :] * stride_d )
1739
- v = tl .load (V + offs_n [:, None ] * stride_tok + offs_k [None , :] * stride_d )
1862
+ k = tl .load (K + offs_n [:, None ] * stride_tok + offs_k [None , :] * stride_d )
1863
+ v = tl .load (V + offs_n [:, None ] * stride_tok + offs_k [None , :] * stride_d )
1740
1864
1741
1865
num_steps = BLOCK_N1 // MASK_BLOCK_M1
1742
1866
@@ -1805,9 +1929,8 @@ def _attn_bwd_compute(
1805
1929
MASK_BLOCK_N2 : tl .constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
1806
1930
offs_m = start_m + tl .arange (0 , BLOCK_M2 )
1807
1931
1808
- with tl .async_task ([0 ]):
1809
- q = tl .load (Q + offs_m [:, None ] * stride_tok + offs_k [None , :] * stride_d )
1810
- do = tl .load (DO + offs_m [:, None ] * stride_tok + offs_k [None , :] * stride_d )
1932
+ q = tl .load (Q + offs_m [:, None ] * stride_tok + offs_k [None , :] * stride_d )
1933
+ do = tl .load (DO + offs_m [:, None ] * stride_tok + offs_k [None , :] * stride_d )
1811
1934
dq = tl .zeros ([BLOCK_M2 , HEAD_DIM ], dtype = tl .float32 )
1812
1935
1813
1936
m = tl .load (M + offs_m )
@@ -1868,6 +1991,199 @@ def _attn_bwd_compute(
1868
1991
tl .store (dq_ptrs , dq )
1869
1992
1870
1993
1994
+ @triton .jit
1995
+ def _attn_bwd_compute_ws (
1996
+ Q ,
1997
+ K ,
1998
+ V ,
1999
+ sm_scale , #
2000
+ DO , #
2001
+ DQ ,
2002
+ DK ,
2003
+ DV , #
2004
+ M ,
2005
+ D ,
2006
+ # shared by Q/K/V/DO.
2007
+ stride_z ,
2008
+ stride_h ,
2009
+ stride_tok ,
2010
+ stride_d , #
2011
+ H ,
2012
+ N_CTX , #
2013
+ BLOCK_M1 : tl .constexpr , #
2014
+ BLOCK_N1 : tl .constexpr , #
2015
+ BLOCK_M2 : tl .constexpr , #
2016
+ BLOCK_N2 : tl .constexpr , #
2017
+ BLK_SLICE_FACTOR : tl .constexpr , #
2018
+ HEAD_DIM : tl .constexpr ,
2019
+ ):
2020
+ LN2 : tl .constexpr = 0.6931471824645996 # = ln(2)
2021
+
2022
+ bhid = tl .program_id (2 )
2023
+ off_chz = (bhid * N_CTX ).to (tl .int64 )
2024
+ adj = (stride_h * (bhid % H ) + stride_z * (bhid // H )).to (tl .int64 )
2025
+ pid = tl .program_id (0 )
2026
+
2027
+ # offset pointers for batch/head
2028
+ Q += adj
2029
+ K += adj
2030
+ V += adj
2031
+ DO += adj
2032
+ DQ += adj
2033
+ DK += adj
2034
+ DV += adj
2035
+ M += off_chz
2036
+ D += off_chz
2037
+
2038
+ # load scales
2039
+ offs_k = tl .arange (0 , HEAD_DIM )
2040
+
2041
+ start_n = pid * BLOCK_N1
2042
+ start_m = start_n
2043
+
2044
+ MASK_BLOCK_M1 : tl .constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
2045
+ offs_n = start_n + tl .arange (0 , BLOCK_N1 )
2046
+
2047
+ dv = tl .zeros ([BLOCK_N1 , HEAD_DIM ], dtype = tl .float32 )
2048
+ dk = tl .zeros ([BLOCK_N1 , HEAD_DIM ], dtype = tl .float32 )
2049
+
2050
+ # load K and V: they stay in SRAM throughout the inner loop.
2051
+ with tl .async_task ([0 ]):
2052
+ k = tl .load (K + offs_n [:, None ] * stride_tok + offs_k [None , :] * stride_d )
2053
+ v = tl .load (V + offs_n [:, None ] * stride_tok + offs_k [None , :] * stride_d )
2054
+
2055
+ num_steps = BLOCK_N1 // MASK_BLOCK_M1
2056
+
2057
+ dk , dv = _attn_bwd_dkdv_ws (
2058
+ dk ,
2059
+ dv , #
2060
+ Q ,
2061
+ k ,
2062
+ v ,
2063
+ sm_scale , #
2064
+ DO , #
2065
+ M ,
2066
+ D , #
2067
+ stride_tok ,
2068
+ stride_d , #
2069
+ H ,
2070
+ N_CTX , #
2071
+ MASK_BLOCK_M1 ,
2072
+ BLOCK_N1 ,
2073
+ HEAD_DIM , #
2074
+ start_n ,
2075
+ start_m ,
2076
+ num_steps , #
2077
+ MASK = True , #
2078
+ )
2079
+
2080
+ start_m += num_steps * MASK_BLOCK_M1
2081
+ num_steps = (N_CTX - start_m ) // BLOCK_M1
2082
+
2083
+ # Compute dK and dV for non-masked blocks.
2084
+ dk , dv = _attn_bwd_dkdv_ws ( #
2085
+ dk ,
2086
+ dv , #
2087
+ Q ,
2088
+ k ,
2089
+ v ,
2090
+ sm_scale , #
2091
+ DO , #
2092
+ M ,
2093
+ D , #
2094
+ stride_tok ,
2095
+ stride_d , #
2096
+ H ,
2097
+ N_CTX , #
2098
+ BLOCK_M1 ,
2099
+ BLOCK_N1 ,
2100
+ HEAD_DIM , #
2101
+ start_n ,
2102
+ start_m ,
2103
+ num_steps , #
2104
+ MASK = False , #
2105
+ )
2106
+
2107
+ with tl .async_task ([1 , 2 ]):
2108
+ dv_ptrs = DV + offs_n [:, None ] * stride_tok + offs_k [None , :] * stride_d
2109
+ tl .store (dv_ptrs , dv )
2110
+
2111
+ # Write back dK.
2112
+ dk *= sm_scale
2113
+ dk_ptrs = DK + offs_n [:, None ] * stride_tok + offs_k [None , :] * stride_d
2114
+ tl .store (dk_ptrs , dk )
2115
+
2116
+ # THIS BLOCK DOES DQ:
2117
+ start_m = pid * BLOCK_M2
2118
+ end_n = start_m + BLOCK_M2
2119
+
2120
+ MASK_BLOCK_N2 : tl .constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
2121
+ offs_m = start_m + tl .arange (0 , BLOCK_M2 )
2122
+
2123
+ with tl .async_task ([0 ]):
2124
+ q = tl .load (Q + offs_m [:, None ] * stride_tok + offs_k [None , :] * stride_d )
2125
+ do = tl .load (DO + offs_m [:, None ] * stride_tok + offs_k [None , :] * stride_d )
2126
+ dq = tl .zeros ([BLOCK_M2 , HEAD_DIM ], dtype = tl .float32 )
2127
+
2128
+ m = tl .load (M + offs_m )
2129
+ m = m [:, None ]
2130
+
2131
+ # Compute dQ for masked (diagonal) blocks.
2132
+ # NOTE: This code scans each row of QK^T backward (from right to left,
2133
+ # but inside each call to _attn_bwd_dq, from left to right), but that's
2134
+ # not due to anything important. I just wanted to reuse the loop
2135
+ # structure for dK & dV above as much as possible.
2136
+ num_steps = BLOCK_M2 // MASK_BLOCK_N2
2137
+ dq = _attn_bwd_dq_ws (
2138
+ dq ,
2139
+ q ,
2140
+ K ,
2141
+ V , #
2142
+ do ,
2143
+ m ,
2144
+ D , #
2145
+ stride_tok ,
2146
+ stride_d , #
2147
+ H ,
2148
+ N_CTX , #
2149
+ BLOCK_M2 ,
2150
+ MASK_BLOCK_N2 ,
2151
+ HEAD_DIM , #
2152
+ start_m ,
2153
+ end_n - num_steps * MASK_BLOCK_N2 ,
2154
+ num_steps , #
2155
+ MASK = True , #
2156
+ )
2157
+ end_n -= num_steps * MASK_BLOCK_N2
2158
+ # stage 2
2159
+ num_steps = end_n // BLOCK_N2
2160
+ dq = _attn_bwd_dq_ws (
2161
+ dq ,
2162
+ q ,
2163
+ K ,
2164
+ V , #
2165
+ do ,
2166
+ m ,
2167
+ D , #
2168
+ stride_tok ,
2169
+ stride_d , #
2170
+ H ,
2171
+ N_CTX , #
2172
+ BLOCK_M2 ,
2173
+ BLOCK_N2 ,
2174
+ HEAD_DIM , #
2175
+ start_m ,
2176
+ end_n - num_steps * BLOCK_N2 ,
2177
+ num_steps , #
2178
+ MASK = False , #
2179
+ )
2180
+ # Write back dQ.
2181
+ with tl .async_task ([1 , 2 ]):
2182
+ dq_ptrs = DQ + offs_m [:, None ] * stride_tok + offs_k [None , :] * stride_d
2183
+ dq *= LN2
2184
+ tl .store (dq_ptrs , dq )
2185
+
2186
+
1871
2187
@triton .autotune (list (filter (keep2 , configsBwd )), key = ["N_CTX" ])
1872
2188
@triton .jit
1873
2189
def _attn_bwd (
@@ -1949,7 +2265,7 @@ def _attn_bwd_ws(
1949
2265
BLK_SLICE_FACTOR : tl .constexpr , #
1950
2266
HEAD_DIM : tl .constexpr ,
1951
2267
):
1952
- _attn_bwd_compute (
2268
+ _attn_bwd_compute_ws (
1953
2269
Q ,
1954
2270
K ,
1955
2271
V ,
0 commit comments