@@ -1751,6 +1751,306 @@ kernel void kernel_ssm_scan_f32(
1751
1751
}
1752
1752
}
1753
1753
1754
+ // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
1755
+ // WIP--- ghart
1756
+ kernel void kernel_ssm_scan_f32_group_GHART (
1757
+ device const void * src0,
1758
+ device const void * src1,
1759
+ device const void * src2,
1760
+ device const void * src3,
1761
+ device const void * src4,
1762
+ device const void * src5,
1763
+ device const void * src6,
1764
+ device float * dst,
1765
+ threadgroup float * shared [[threadgroup(0 )]],
1766
+ constant ggml_metal_kargs_ssm_scan & args,
1767
+ uint3 tgpig[[threadgroup_position_in_grid]],
1768
+ uint3 tpitg[[thread_position_in_threadgroup]],
1769
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1770
+ ushort tiisg[[thread_index_in_simdgroup]],
1771
+ uint3 ntg[[threads_per_threadgroup]]) {
1772
+
1773
+ const int64_t i1 = tgpig.x ;
1774
+ const int64_t ir = tgpig.y ; // current head
1775
+ const int64_t i3 = tgpig.z ; // current seq
1776
+
1777
+ const uint64_t nb00 = sizeof (float );
1778
+ const uint64_t nb10 = sizeof (float );
1779
+ const uint64_t nb20 = sizeof (float );
1780
+
1781
+ const int64_t nc = args.d_state ;
1782
+ const int64_t nr = args.d_inner ;
1783
+ const int64_t nh = args.n_head ;
1784
+ const int64_t ng = args.n_group ;
1785
+ const int64_t n_t = args.n_seq_tokens ;
1786
+
1787
+ const int64_t s_off = args.s_off ;
1788
+
1789
+ device const int32_t * ids = (device const int32_t *) src6;
1790
+
1791
+ device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03 );
1792
+ device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
1793
+
1794
+ for (int64_t i2 = 0 ; i2 < n_t ; ++i2) {
1795
+ device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13 ); // {dim, nh, nt, ns}
1796
+ device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22 ); // {nh, nt, ns}
1797
+ device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31 ); // {1, nh}
1798
+ device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1 ))*args.nb41 + i2*args.nb42 + i3*args.nb43 ); // {d_state, ng, nt, ns}
1799
+ device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1 ))*args.nb51 + i2*args.nb52 + i3*args.nb53 ); // {d_state, ng, nt, ns}
1800
+ device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t *nh*nr))*nb00); // {dim, nh, nt, ns}
1801
+
1802
+ const float dt_soft_plus = dt[0 ] <= 20 .0f ? log (1 .0f + exp (dt[0 ])) : dt[0 ];
1803
+ const float x_dt = x[0 ] * dt_soft_plus;
1804
+ const float dA = exp (dt_soft_plus * A[0 ]);
1805
+
1806
+ /*
1807
+
1808
+ if (sgitg == 0) {
1809
+ shared[tiisg] = 0.0f;
1810
+ }
1811
+
1812
+ float sumf = 0;
1813
+
1814
+ for (int64_t i0 = tpitg.x; i0 < nc; i0 += ntg.x) {
1815
+ const int64_t i = i0 + i1*nc;
1816
+ const float state = (s0[i] * dA) + (B[i0] * x_dt);
1817
+ sumf += state * C[i0];
1818
+ s[i] = state;
1819
+ }
1820
+
1821
+ sumf = simd_sum(sumf);
1822
+
1823
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1824
+
1825
+ if (sgitg == 0) {
1826
+ shared[sgitg] = sumf;
1827
+ }
1828
+
1829
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1830
+
1831
+ sumf = shared[tiisg];
1832
+ sumf = simd_sum(sumf);
1833
+
1834
+ if (tpitg.x == 0) {
1835
+ y[0] = sumf;
1836
+ }
1837
+
1838
+ /*/
1839
+ threadgroup_barrier (mem_flags::mem_threadgroup);
1840
+
1841
+ // if (sgitg == 0) {
1842
+ // shared[tiisg] = 0.0f;
1843
+ // }
1844
+
1845
+ // float sumf = 0;
1846
+
1847
+ // Assuming num threads == d_state
1848
+ // for (int64_t i0 = 0; i0 < nc; ++i0) {
1849
+ const int64_t i = tpitg.x + i1*nc;
1850
+ const float state = (s0[i] * dA) + (B[tpitg.x ] * x_dt);
1851
+ shared[tpitg.x ] = state * C[tpitg.x ];
1852
+ // sumf += state * C[tpitg.x];
1853
+ s[i] = state;
1854
+ // }
1855
+
1856
+ // sumf = simd_sum(sumf);
1857
+
1858
+ threadgroup_barrier (mem_flags::mem_threadgroup);
1859
+
1860
+ // sumf = shared[tiisg];
1861
+ // sumf = simd_sum(sumf);
1862
+
1863
+ // GG: vvv this sum is a big bottleneck!
1864
+
1865
+ float sumf = 0 .0f ;
1866
+ for (int64_t i0 = 0 ; i0 < nc; ++i0) {
1867
+ sumf += shared[i0];
1868
+ }
1869
+
1870
+ y[0 ] = sumf;
1871
+
1872
+ // */
1873
+
1874
+ // recurse
1875
+ s0 = s;
1876
+ }
1877
+
1878
+ // ----------------------------------
1879
+
1880
+ // //DEBUG
1881
+ // const int64_t splitH = 16;
1882
+ // const int64_t d_state = 128;
1883
+ // // const int64_t d_state = args.d_state;
1884
+ // const int64_t WARP_SIZE = ntg.x;
1885
+
1886
+ // const int64_t d_head = args.d_inner;
1887
+ // const int64_t n_head = args.n_head;
1888
+ // const int64_t n_group = args.n_group;
1889
+ // const int64_t n_tok = args.n_seq_tokens;
1890
+
1891
+ // const int64_t head_idx = (tgpig.x * splitH) / d_head;
1892
+ // const int64_t head_off = ((tgpig.x * splitH) % d_head) * sizeof(float);
1893
+ // const int64_t seq_idx = tgpig.y;
1894
+
1895
+ // const int64_t group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float);
1896
+
1897
+ // device const int32_t * ids = (device const int32_t *) src6;
1898
+
1899
+ // device const float * s0_block = (device const float *) ((device const char *) src0 + ids[seq_idx] * args.nb03 + head_idx * args.nb02 + head_off * d_state);
1900
+ // device const float * x_block = (device const float *) ((device const char *) src1 + (seq_idx * args.nb13) + tgpig.x * splitH * sizeof(float));
1901
+ // device const float * dt_block = (device const float *) ((device const char *) src2 + (seq_idx * args.nb22) + head_idx * sizeof(float));
1902
+ // device const float * A_block = (device const float *) ((device const char *) src3 + head_idx * args.nb31);
1903
+ // device const float * B_block = (device const float *) ((device const char *) src4 + (seq_idx * args.nb43) + (group_off));
1904
+ // device const float * C_block = (device const float *) ((device const char *) src5 + (seq_idx * args.nb53) + (group_off));
1905
+ // device float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + tgpig.x * splitH;
1906
+ // device float * s_block = (device float *) ((device char *) dst + args.s_off + seq_idx * args.nb03 + head_idx * args.nb02 + head_off * d_state);
1907
+
1908
+ // // strides across n_seq_tokens
1909
+ // const int stride_x = args.nb12 / sizeof(float);
1910
+ // const int stride_dt = args.nb21 / sizeof(float);
1911
+ // const int stride_B = args.nb42 / sizeof(float);
1912
+ // const int stride_C = args.nb52 / sizeof(float);
1913
+ // const int stride_y = n_head * d_head;
1914
+
1915
+ // float state[splitH];
1916
+ // // for the parallel accumulation
1917
+
1918
+ // //DEBUG -- TODO! No parallelism on accumulation
1919
+ // float stateC[splitH * d_state];
1920
+
1921
+
1922
+
1923
+
1924
+ // ----------------------------------
1925
+
1926
+
1927
+
1928
+ // // #pragma unroll
1929
+ // for (int j = 0; j < splitH; j++) {
1930
+ // state[j] = s0_block[j * d_state + tpitg.x];
1931
+ // }
1932
+
1933
+
1934
+ // for (int64_t i = 0; i < n_tok; i++) {
1935
+ // // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
1936
+ // // TODO: only calculate B and C once per head group
1937
+ // // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
1938
+ // float dt_soft_plus = dt_block[i * stride_dt];
1939
+ // if (dt_soft_plus <= 20.0f) {
1940
+ // dt_soft_plus = log(exp(dt_soft_plus));
1941
+ // }
1942
+ // const float dA = exp(dt_soft_plus * A_block[0]);
1943
+ // const float B = B_block[i * stride_B + tpitg.x];
1944
+ // const float C = C_block[i * stride_C + tpitg.x];
1945
+
1946
+ // // across d_head
1947
+ // // #pragma unroll
1948
+ // for (int j = 0; j < splitH; j++) {
1949
+ // const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
1950
+
1951
+ // state[j] = (state[j] * dA) + (B * x_dt);
1952
+
1953
+ // stateC[j * d_state + tpitg.x] = state[j] * C;
1954
+ // }
1955
+
1956
+ // //DEBUG
1957
+ // // __syncthreads();
1958
+
1959
+ // // parallel accumulation for stateC
1960
+ // // TODO: simplify
1961
+ // {
1962
+ // static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
1963
+ // static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
1964
+
1965
+ // // reduce until w matches the warp size
1966
+ // // TODO: does this work even when the physical warp size is 64?
1967
+ // // #pragma unroll
1968
+ // for (int w = d_state; w > WARP_SIZE; w >>= 1) {
1969
+ // // (assuming there are d_state threads)
1970
+ // // #pragma unroll
1971
+ // for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
1972
+ // // TODO: check for bank conflicts
1973
+ // const int k = (tpitg.x % (w >> 1)) + (d_state * (tpitg.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
1974
+ // stateC[k] += stateC[k + (w >> 1)];
1975
+
1976
+ // }
1977
+ // //DEBUG
1978
+ // // __syncthreads();
1979
+ // }
1980
+
1981
+ // // static_assert(splitH >= d_state / WARP_SIZE);
1982
+
1983
+ // // #pragma unroll
1984
+ // for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
1985
+ // float y = stateC[(tpitg.x % WARP_SIZE) + d_state * (tpitg.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
1986
+ // //DEBUG
1987
+ // // y = warp_reduce_sum(y);
1988
+
1989
+ // // store the above accumulations
1990
+ // if (tpitg.x % WARP_SIZE == 0) {
1991
+ // const int k = tpitg.x / WARP_SIZE + j * (d_state / WARP_SIZE);
1992
+ // y_block[i * stride_y + k] = y;
1993
+ // }
1994
+ // }
1995
+ // }
1996
+ // }
1997
+
1998
+ // // write back the state
1999
+ // // #pragma unroll
2000
+ // for (int j = 0; j < splitH; j++) {
2001
+ // s_block[j * d_state + tpitg.x] = state[j];
2002
+ // }
2003
+
2004
+
2005
+
2006
+ // const int64_t i1 = tgpig.x;
2007
+ // const int64_t ir = tgpig.y; // current head
2008
+ // const int64_t i3 = tgpig.z; // current seq
2009
+
2010
+ // const uint64_t nb00 = sizeof(float);
2011
+ // const uint64_t nb10 = sizeof(float);
2012
+ // const uint64_t nb20 = sizeof(float);
2013
+
2014
+ // const int64_t nc = args.d_state;
2015
+ // const int64_t nr = args.d_inner;
2016
+ // const int64_t nh = args.n_head;
2017
+ // const int64_t ng = args.n_group;
2018
+ // const int64_t n_t = args.n_seq_tokens;
2019
+
2020
+ // const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
2021
+
2022
+ // device const int32_t * ids = (device const int32_t *) src6;
2023
+
2024
+ // device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
2025
+ // device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
2026
+
2027
+ // for (int64_t i2 = 0; i2 < n_t; ++i2) {
2028
+ // device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
2029
+ // device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
2030
+ // device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
2031
+ // device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
2032
+ // device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
2033
+ // device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
2034
+
2035
+ // const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
2036
+ // const float x_dt = x[0] * dt_soft_plus;
2037
+ // const float dA = exp(dt_soft_plus * A[0]);
2038
+ // float sumf = 0.0f;
2039
+
2040
+ // for (int64_t i0 = 0; i0 < nc; ++i0) {
2041
+ // const int64_t i = i0 + i1*nc;
2042
+ // const float state = (s0[i] * dA) + (B[i0] * x_dt);
2043
+ // sumf += state * C[i0];
2044
+ // s[i] = state;
2045
+ // }
2046
+
2047
+ // y[0] = sumf;
2048
+
2049
+ // // recurse
2050
+ // s0 = s;
2051
+ // }
2052
+ }
2053
+
1754
2054
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
1755
2055
// TODO: optimize (e.g. by parallelizing over d_state)
1756
2056
kernel void kernel_ssm_scan_f32_group (
0 commit comments