Skip to content

Commit b4ab8e4

Browse files
committed
WIP: Very temporary attempt at a version of the SSM scan kernel that parallelizes over d_state
This is extremely hacky! It also doesn't have any performance benefits yet Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent ad969d7 commit b4ab8e4

File tree

2 files changed

+305
-2
lines changed

2 files changed

+305
-2
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,7 +1211,8 @@ @implementation GGMLMetalClass
12111211
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
12121212
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
12131213
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
1214-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
1214+
// GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
1215+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group_GHART, true);
12151216
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
12161217
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
12171218
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
@@ -2986,6 +2987,7 @@ static bool ggml_metal_encode_node(
29862987
/*.n_group =*/ n_group,
29872988
/*.n_seq_tokens =*/ n_seq_tokens,
29882989
/*.n_seqs =*/ n_seqs,
2990+
/*.s_off =*/ ggml_nelements(src1) * sizeof(float),
29892991
/*.nb01 =*/ nb01,
29902992
/*.nb02 =*/ nb02,
29912993
/*.nb03 =*/ nb03,
@@ -3016,7 +3018,8 @@ static bool ggml_metal_encode_node(
30163018

30173019
if (ne30 == 1) {
30183020
// Mamba-2
3019-
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
3021+
[encoder setThreadgroupMemoryLength:d_state*sizeof(float) atIndex:0];
3022+
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
30203023
} else {
30213024
GGML_ASSERT(d_inner == 1);
30223025
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1751,6 +1751,306 @@ kernel void kernel_ssm_scan_f32(
17511751
}
17521752
}
17531753

1754+
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
1755+
// WIP--- ghart
1756+
kernel void kernel_ssm_scan_f32_group_GHART(
1757+
device const void * src0,
1758+
device const void * src1,
1759+
device const void * src2,
1760+
device const void * src3,
1761+
device const void * src4,
1762+
device const void * src5,
1763+
device const void * src6,
1764+
device float * dst,
1765+
threadgroup float * shared [[threadgroup(0)]],
1766+
constant ggml_metal_kargs_ssm_scan & args,
1767+
uint3 tgpig[[threadgroup_position_in_grid]],
1768+
uint3 tpitg[[thread_position_in_threadgroup]],
1769+
ushort sgitg[[simdgroup_index_in_threadgroup]],
1770+
ushort tiisg[[thread_index_in_simdgroup]],
1771+
uint3 ntg[[threads_per_threadgroup]]) {
1772+
1773+
const int64_t i1 = tgpig.x;
1774+
const int64_t ir = tgpig.y; // current head
1775+
const int64_t i3 = tgpig.z; // current seq
1776+
1777+
const uint64_t nb00 = sizeof(float);
1778+
const uint64_t nb10 = sizeof(float);
1779+
const uint64_t nb20 = sizeof(float);
1780+
1781+
const int64_t nc = args.d_state;
1782+
const int64_t nr = args.d_inner;
1783+
const int64_t nh = args.n_head;
1784+
const int64_t ng = args.n_group;
1785+
const int64_t n_t = args.n_seq_tokens;
1786+
1787+
const int64_t s_off = args.s_off;
1788+
1789+
device const int32_t * ids = (device const int32_t *) src6;
1790+
1791+
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
1792+
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1793+
1794+
for (int64_t i2 = 0; i2 < n_t; ++i2) {
1795+
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
1796+
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
1797+
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
1798+
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
1799+
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
1800+
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
1801+
1802+
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
1803+
const float x_dt = x[0] * dt_soft_plus;
1804+
const float dA = exp(dt_soft_plus * A[0]);
1805+
1806+
/*
1807+
1808+
if (sgitg == 0) {
1809+
shared[tiisg] = 0.0f;
1810+
}
1811+
1812+
float sumf = 0;
1813+
1814+
for (int64_t i0 = tpitg.x; i0 < nc; i0 += ntg.x) {
1815+
const int64_t i = i0 + i1*nc;
1816+
const float state = (s0[i] * dA) + (B[i0] * x_dt);
1817+
sumf += state * C[i0];
1818+
s[i] = state;
1819+
}
1820+
1821+
sumf = simd_sum(sumf);
1822+
1823+
threadgroup_barrier(mem_flags::mem_threadgroup);
1824+
1825+
if (sgitg == 0) {
1826+
shared[sgitg] = sumf;
1827+
}
1828+
1829+
threadgroup_barrier(mem_flags::mem_threadgroup);
1830+
1831+
sumf = shared[tiisg];
1832+
sumf = simd_sum(sumf);
1833+
1834+
if (tpitg.x == 0) {
1835+
y[0] = sumf;
1836+
}
1837+
1838+
/*/
1839+
threadgroup_barrier(mem_flags::mem_threadgroup);
1840+
1841+
// if (sgitg == 0) {
1842+
// shared[tiisg] = 0.0f;
1843+
// }
1844+
1845+
// float sumf = 0;
1846+
1847+
// Assuming num threads == d_state
1848+
// for (int64_t i0 = 0; i0 < nc; ++i0) {
1849+
const int64_t i = tpitg.x + i1*nc;
1850+
const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt);
1851+
shared[tpitg.x] = state * C[tpitg.x];
1852+
// sumf += state * C[tpitg.x];
1853+
s[i] = state;
1854+
// }
1855+
1856+
// sumf = simd_sum(sumf);
1857+
1858+
threadgroup_barrier(mem_flags::mem_threadgroup);
1859+
1860+
// sumf = shared[tiisg];
1861+
// sumf = simd_sum(sumf);
1862+
1863+
// GG: vvv this sum is a big bottleneck!
1864+
1865+
float sumf = 0.0f;
1866+
for (int64_t i0 = 0; i0 < nc; ++i0) {
1867+
sumf += shared[i0];
1868+
}
1869+
1870+
y[0] = sumf;
1871+
1872+
//*/
1873+
1874+
// recurse
1875+
s0 = s;
1876+
}
1877+
1878+
//----------------------------------
1879+
1880+
// //DEBUG
1881+
// const int64_t splitH = 16;
1882+
// const int64_t d_state = 128;
1883+
// // const int64_t d_state = args.d_state;
1884+
// const int64_t WARP_SIZE = ntg.x;
1885+
1886+
// const int64_t d_head = args.d_inner;
1887+
// const int64_t n_head = args.n_head;
1888+
// const int64_t n_group = args.n_group;
1889+
// const int64_t n_tok = args.n_seq_tokens;
1890+
1891+
// const int64_t head_idx = (tgpig.x * splitH) / d_head;
1892+
// const int64_t head_off = ((tgpig.x * splitH) % d_head) * sizeof(float);
1893+
// const int64_t seq_idx = tgpig.y;
1894+
1895+
// const int64_t group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float);
1896+
1897+
// device const int32_t * ids = (device const int32_t *) src6;
1898+
1899+
// device const float * s0_block = (device const float *) ((device const char *) src0 + ids[seq_idx] * args.nb03 + head_idx * args.nb02 + head_off * d_state);
1900+
// device const float * x_block = (device const float *) ((device const char *) src1 + (seq_idx * args.nb13) + tgpig.x * splitH * sizeof(float));
1901+
// device const float * dt_block = (device const float *) ((device const char *) src2 + (seq_idx * args.nb22) + head_idx * sizeof(float));
1902+
// device const float * A_block = (device const float *) ((device const char *) src3 + head_idx * args.nb31);
1903+
// device const float * B_block = (device const float *) ((device const char *) src4 + (seq_idx * args.nb43) + (group_off));
1904+
// device const float * C_block = (device const float *) ((device const char *) src5 + (seq_idx * args.nb53) + (group_off));
1905+
// device float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + tgpig.x * splitH;
1906+
// device float * s_block = (device float *) ((device char *) dst + args.s_off + seq_idx * args.nb03 + head_idx * args.nb02 + head_off * d_state);
1907+
1908+
// // strides across n_seq_tokens
1909+
// const int stride_x = args.nb12 / sizeof(float);
1910+
// const int stride_dt = args.nb21 / sizeof(float);
1911+
// const int stride_B = args.nb42 / sizeof(float);
1912+
// const int stride_C = args.nb52 / sizeof(float);
1913+
// const int stride_y = n_head * d_head;
1914+
1915+
// float state[splitH];
1916+
// // for the parallel accumulation
1917+
1918+
// //DEBUG -- TODO! No parallelism on accumulation
1919+
// float stateC[splitH * d_state];
1920+
1921+
1922+
1923+
1924+
//----------------------------------
1925+
1926+
1927+
1928+
// // #pragma unroll
1929+
// for (int j = 0; j < splitH; j++) {
1930+
// state[j] = s0_block[j * d_state + tpitg.x];
1931+
// }
1932+
1933+
1934+
// for (int64_t i = 0; i < n_tok; i++) {
1935+
// // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
1936+
// // TODO: only calculate B and C once per head group
1937+
// // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
1938+
// float dt_soft_plus = dt_block[i * stride_dt];
1939+
// if (dt_soft_plus <= 20.0f) {
1940+
// dt_soft_plus = log(exp(dt_soft_plus));
1941+
// }
1942+
// const float dA = exp(dt_soft_plus * A_block[0]);
1943+
// const float B = B_block[i * stride_B + tpitg.x];
1944+
// const float C = C_block[i * stride_C + tpitg.x];
1945+
1946+
// // across d_head
1947+
// // #pragma unroll
1948+
// for (int j = 0; j < splitH; j++) {
1949+
// const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
1950+
1951+
// state[j] = (state[j] * dA) + (B * x_dt);
1952+
1953+
// stateC[j * d_state + tpitg.x] = state[j] * C;
1954+
// }
1955+
1956+
// //DEBUG
1957+
// // __syncthreads();
1958+
1959+
// // parallel accumulation for stateC
1960+
// // TODO: simplify
1961+
// {
1962+
// static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
1963+
// static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
1964+
1965+
// // reduce until w matches the warp size
1966+
// // TODO: does this work even when the physical warp size is 64?
1967+
// // #pragma unroll
1968+
// for (int w = d_state; w > WARP_SIZE; w >>= 1) {
1969+
// // (assuming there are d_state threads)
1970+
// // #pragma unroll
1971+
// for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
1972+
// // TODO: check for bank conflicts
1973+
// const int k = (tpitg.x % (w >> 1)) + (d_state * (tpitg.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
1974+
// stateC[k] += stateC[k + (w >> 1)];
1975+
1976+
// }
1977+
// //DEBUG
1978+
// // __syncthreads();
1979+
// }
1980+
1981+
// // static_assert(splitH >= d_state / WARP_SIZE);
1982+
1983+
// // #pragma unroll
1984+
// for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
1985+
// float y = stateC[(tpitg.x % WARP_SIZE) + d_state * (tpitg.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
1986+
// //DEBUG
1987+
// // y = warp_reduce_sum(y);
1988+
1989+
// // store the above accumulations
1990+
// if (tpitg.x % WARP_SIZE == 0) {
1991+
// const int k = tpitg.x / WARP_SIZE + j * (d_state / WARP_SIZE);
1992+
// y_block[i * stride_y + k] = y;
1993+
// }
1994+
// }
1995+
// }
1996+
// }
1997+
1998+
// // write back the state
1999+
// // #pragma unroll
2000+
// for (int j = 0; j < splitH; j++) {
2001+
// s_block[j * d_state + tpitg.x] = state[j];
2002+
// }
2003+
2004+
2005+
2006+
// const int64_t i1 = tgpig.x;
2007+
// const int64_t ir = tgpig.y; // current head
2008+
// const int64_t i3 = tgpig.z; // current seq
2009+
2010+
// const uint64_t nb00 = sizeof(float);
2011+
// const uint64_t nb10 = sizeof(float);
2012+
// const uint64_t nb20 = sizeof(float);
2013+
2014+
// const int64_t nc = args.d_state;
2015+
// const int64_t nr = args.d_inner;
2016+
// const int64_t nh = args.n_head;
2017+
// const int64_t ng = args.n_group;
2018+
// const int64_t n_t = args.n_seq_tokens;
2019+
2020+
// const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
2021+
2022+
// device const int32_t * ids = (device const int32_t *) src6;
2023+
2024+
// device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
2025+
// device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
2026+
2027+
// for (int64_t i2 = 0; i2 < n_t; ++i2) {
2028+
// device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
2029+
// device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
2030+
// device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
2031+
// device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
2032+
// device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
2033+
// device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
2034+
2035+
// const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
2036+
// const float x_dt = x[0] * dt_soft_plus;
2037+
// const float dA = exp(dt_soft_plus * A[0]);
2038+
// float sumf = 0.0f;
2039+
2040+
// for (int64_t i0 = 0; i0 < nc; ++i0) {
2041+
// const int64_t i = i0 + i1*nc;
2042+
// const float state = (s0[i] * dA) + (B[i0] * x_dt);
2043+
// sumf += state * C[i0];
2044+
// s[i] = state;
2045+
// }
2046+
2047+
// y[0] = sumf;
2048+
2049+
// // recurse
2050+
// s0 = s;
2051+
// }
2052+
}
2053+
17542054
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
17552055
// TODO: optimize (e.g. by parallelizing over d_state)
17562056
kernel void kernel_ssm_scan_f32_group(

0 commit comments

Comments
 (0)