6
6
#include " llama-model.h"
7
7
#include " llama-kv-cache.h"
8
8
9
+ #include < cassert>
9
10
#include < cstring>
10
11
#include < stdexcept>
11
12
#include < cinttypes>
@@ -1748,6 +1749,271 @@ const llama_kv_cache * llama_context_kv_self::get_kv_self() const {
1748
1749
return kv_self.get ();
1749
1750
}
1750
1751
1752
+ ggml_tensor * llama_context_kv_self::build_rope_shift (
1753
+ ggml_context * ctx0,
1754
+ ggml_tensor * cur,
1755
+ ggml_tensor * shift,
1756
+ ggml_tensor * factors,
1757
+ ggml_backend_buffer * bbuf) const {
1758
+ const auto & n_ctx_orig = cparams.n_ctx_orig_yarn ;
1759
+ const auto & freq_base = cparams.rope_freq_base ;
1760
+ const auto & freq_scale = cparams.rope_freq_scale ;
1761
+
1762
+ const auto & yarn_ext_factor = cparams.yarn_ext_factor ;
1763
+ const auto & yarn_attn_factor = cparams.yarn_attn_factor ;
1764
+ const auto & yarn_beta_fast = cparams.yarn_beta_fast ;
1765
+ const auto & yarn_beta_slow = cparams.yarn_beta_slow ;
1766
+
1767
+ const auto & hparams = model.hparams ;
1768
+
1769
+ const auto & n_rot = hparams.n_rot ;
1770
+ const auto & rope_type = hparams.rope_type ;
1771
+
1772
+ ggml_tensor * tmp;
1773
+
1774
+ if (ggml_is_quantized (cur->type )) {
1775
+ // dequantize to f32 -> RoPE -> quantize back
1776
+ tmp = ggml_cast (ctx0, cur, GGML_TYPE_F32);
1777
+
1778
+ if (bbuf) {
1779
+ for (const auto & backend : backends) {
1780
+ // Figure out which backend KV cache belongs to
1781
+ if (ggml_backend_supports_buft (backend.get (), ggml_backend_buffer_get_type (bbuf))) {
1782
+ ggml_backend_sched_set_tensor_backend (sched.get (), tmp, backend.get ());
1783
+ break ;
1784
+ }
1785
+ }
1786
+ }
1787
+
1788
+ tmp = ggml_rope_ext_inplace (ctx0, tmp,
1789
+ shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1790
+ yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
1791
+
1792
+ tmp = ggml_cpy (ctx0, tmp, cur);
1793
+ } else {
1794
+ // we rotate only the first n_rot dimensions
1795
+ tmp = ggml_rope_ext_inplace (ctx0, cur,
1796
+ shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1797
+ yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
1798
+ }
1799
+
1800
+ return tmp;
1801
+ }
1802
+
1803
+ class llm_graph_input_k_shift : public llm_graph_input_i {
1804
+ public:
1805
+ llm_graph_input_k_shift (const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
1806
+ virtual ~llm_graph_input_k_shift () = default ;
1807
+
1808
+ void set_input (const llama_ubatch * ubatch) override ;
1809
+
1810
+ ggml_tensor * k_shift; // I32 [kv_size]
1811
+
1812
+ const llama_kv_cache_unified * kv_self;
1813
+ };
1814
+
1815
+ void llm_graph_input_k_shift::set_input (const llama_ubatch * ubatch) {
1816
+ GGML_UNUSED (ubatch);
1817
+
1818
+ if (k_shift) {
1819
+ assert (ggml_backend_buffer_is_host (k_shift->buffer ));
1820
+
1821
+ int32_t * data = (int32_t *) k_shift->data ;
1822
+
1823
+ for (uint32_t i = 0 ; i < kv_self->size ; ++i) {
1824
+ data[i] = kv_self->cells [i].delta ;
1825
+ }
1826
+ }
1827
+ }
1828
+
1829
+ llm_graph_result_ptr llama_context_kv_self::build_kv_self_shift (
1830
+ ggml_context * ctx0,
1831
+ ggml_cgraph * gf) const {
1832
+ auto res = std::make_unique<llm_graph_result>();
1833
+
1834
+ const auto & hparams = model.hparams ;
1835
+
1836
+ const auto & n_layer = hparams.n_layer ;
1837
+
1838
+ const auto & n_embd_head_k = hparams.n_embd_head_k ;
1839
+ // const auto & n_embd_head_v = hparams.n_embd_head_v;
1840
+
1841
+ // GGML_ASSERT(kv_self->size == n_ctx);
1842
+
1843
+ auto inp = std::make_shared<llm_graph_input_k_shift>(kv_self.get ());
1844
+
1845
+ inp->k_shift = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, cparams.n_ctx );
1846
+ ggml_set_input (inp->k_shift );
1847
+
1848
+ res->add_input (inp);
1849
+
1850
+ for (uint32_t il = 0 ; il < n_layer; ++il) {
1851
+ const int64_t n_head_kv = hparams.n_head_kv (il);
1852
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
1853
+
1854
+ ggml_tensor * rope_factors = model.build_rope_factors (n_ctx_per_seq (), il);
1855
+
1856
+ ggml_tensor * k =
1857
+ ggml_view_3d (ctx0, kv_self->k_l [il],
1858
+ n_embd_head_k, n_head_kv, kv_self->size ,
1859
+ ggml_row_size (kv_self->k_l [il]->type , n_embd_head_k),
1860
+ ggml_row_size (kv_self->k_l [il]->type , n_embd_k_gqa),
1861
+ 0 );
1862
+
1863
+ ggml_tensor * cur = build_rope_shift (ctx0, k, inp->k_shift , rope_factors, kv_self->k_l [il]->buffer );
1864
+
1865
+ ggml_build_forward_expand (gf, cur);
1866
+ }
1867
+
1868
+ return res;
1869
+ }
1870
+
1871
+ llm_graph_result_ptr llama_context_kv_self::build_kv_self_defrag (
1872
+ ggml_context * ctx0,
1873
+ ggml_cgraph * gf) const {
1874
+ auto res = std::make_unique<llm_graph_result>();
1875
+
1876
+ const auto & hparams = model.hparams ;
1877
+
1878
+ const auto & ids = kv_self->defrag_info .ids ;
1879
+
1880
+ #if 0
1881
+ // CPU defrag
1882
+ //
1883
+ // TODO: optimizations are possible:
1884
+ // - multiple threads
1885
+ // - avoid copying to the host memory when already there
1886
+ //
1887
+ // likely not worth the effort, as we have ggml_graph based defrag
1888
+ //
1889
+
1890
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
1891
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
1892
+
1893
+ const uint32_t kv_size = size;
1894
+
1895
+ std::vector<uint8_t> buf_k;
1896
+ std::vector<uint8_t> buf_v;
1897
+
1898
+ for (uint32_t il = 0; il < n_layer; ++il) {
1899
+ const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
1900
+ const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
1901
+
1902
+ const size_t v_size_el = ggml_type_size(v_l[il]->type);
1903
+ const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
1904
+
1905
+ buf_k.resize(k_size);
1906
+ buf_v.resize(v_size);
1907
+
1908
+ ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
1909
+ ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
1910
+
1911
+ // batch move [i, i+nm) to [id, id+nm)
1912
+ // note: cells can move only to a lower index
1913
+ for (uint32_t i = 0; i < n_kv; ++i) {
1914
+ const uint32_t id = ids[i];
1915
+
1916
+ if (i == id || id == n_kv) {
1917
+ continue;
1918
+ }
1919
+
1920
+ uint32_t nm = 1;
1921
+
1922
+ while (i + nm < n_kv && ids[i + nm] == id + nm) {
1923
+ nm++;
1924
+ }
1925
+
1926
+ // move keys
1927
+ {
1928
+ const int64_t os = i*k_size_row;
1929
+ const int64_t od = id*k_size_row;
1930
+
1931
+ memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
1932
+ }
1933
+
1934
+ // move values (note: they are transposed)
1935
+ {
1936
+ const int64_t os = i;
1937
+ const int64_t od = id;
1938
+
1939
+ for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1940
+ memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
1941
+ }
1942
+ }
1943
+
1944
+ i += nm - 1;
1945
+ }
1946
+
1947
+ ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
1948
+ ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
1949
+ }
1950
+ #else
1951
+ for (uint32_t i = 0 ; i < ids.size (); ++i) {
1952
+ const uint32_t id = ids[i];
1953
+
1954
+ if (i == id || id == ids.size ()) {
1955
+ continue ;
1956
+ }
1957
+
1958
+ uint32_t nm = 1 ;
1959
+
1960
+ while (i + nm < ids.size () && ids[i + nm] == id + nm) {
1961
+ nm++;
1962
+ }
1963
+
1964
+ for (uint32_t il = 0 ; il < hparams.n_layer ; ++il) { // NOLINT
1965
+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
1966
+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa (il);
1967
+
1968
+ ggml_tensor * view_k_src = ggml_view_2d (ctx0, kv_self->k_l [il],
1969
+ n_embd_k_gqa, nm,
1970
+ ggml_row_size (kv_self->k_l [il]->type , n_embd_k_gqa),
1971
+ ggml_row_size (kv_self->k_l [il]->type , n_embd_k_gqa*i));
1972
+
1973
+ ggml_tensor * view_k_dst = ggml_view_2d (ctx0, kv_self->k_l [il],
1974
+ n_embd_k_gqa, nm,
1975
+ ggml_row_size (kv_self->k_l [il]->type , n_embd_k_gqa),
1976
+ ggml_row_size (kv_self->k_l [il]->type , n_embd_k_gqa*id));
1977
+
1978
+ ggml_tensor * view_v_src;
1979
+ ggml_tensor * view_v_dst;
1980
+
1981
+ if (cparams.flash_attn ) {
1982
+ // NOTE: the V cache is not transposed when using flash attention
1983
+ view_v_src = ggml_view_2d (ctx0, kv_self->v_l [il],
1984
+ n_embd_v_gqa, nm,
1985
+ ggml_row_size (kv_self->v_l [il]->type , n_embd_v_gqa),
1986
+ ggml_row_size (kv_self->v_l [il]->type , n_embd_v_gqa*i));
1987
+
1988
+ view_v_dst = ggml_view_2d (ctx0, kv_self->v_l [il],
1989
+ n_embd_v_gqa, nm,
1990
+ ggml_row_size (kv_self->v_l [il]->type , n_embd_v_gqa),
1991
+ ggml_row_size (kv_self->v_l [il]->type , n_embd_v_gqa*id));
1992
+ } else {
1993
+ view_v_src = ggml_view_2d (ctx0, kv_self->v_l [il],
1994
+ nm, n_embd_v_gqa,
1995
+ ggml_row_size (kv_self->v_l [il]->type , kv_self->size ),
1996
+ ggml_row_size (kv_self->v_l [il]->type , i));
1997
+
1998
+ view_v_dst = ggml_view_2d (ctx0, kv_self->v_l [il],
1999
+ nm, n_embd_v_gqa,
2000
+ ggml_row_size (kv_self->v_l [il]->type , kv_self->size ),
2001
+ ggml_row_size (kv_self->v_l [il]->type , id));
2002
+ }
2003
+
2004
+ ggml_build_forward_expand (gf, ggml_cpy (ctx0, view_k_src, view_k_dst));
2005
+ ggml_build_forward_expand (gf, ggml_cpy (ctx0, view_v_src, view_v_dst));
2006
+ }
2007
+
2008
+ i += nm - 1 ;
2009
+ }
2010
+
2011
+ // LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
2012
+ #endif
2013
+
2014
+ return res;
2015
+ }
2016
+
1751
2017
void llama_context_kv_self::kv_self_update () {
1752
2018
auto & kv = kv_self;
1753
2019
@@ -1766,21 +2032,7 @@ void llama_context_kv_self::kv_self_update() {
1766
2032
1767
2033
auto * gf = graph_init ();
1768
2034
1769
- auto res = model.build_graph_k_shift (
1770
- {
1771
- /* .ctx =*/ ctx_compute.get (),
1772
- /* .model =*/ model,
1773
- /* .cparams =*/ cparams,
1774
- /* .ubatch =*/ {},
1775
- /* .sched =*/ sched.get (),
1776
- /* .backend_cpu =*/ backend_cpu,
1777
- /* .backends =*/ backends,
1778
- /* .cvec =*/ nullptr ,
1779
- /* .loras =*/ nullptr ,
1780
- /* .memory =*/ kv_self.get (),
1781
- /* .cross =*/ nullptr ,
1782
- /* .n_outputs =*/ 0 ,
1783
- }, gf);
2035
+ auto res = build_kv_self_shift (ctx_compute.get (), gf);
1784
2036
1785
2037
ggml_backend_sched_alloc_graph (sched.get (), gf);
1786
2038
@@ -1809,33 +2061,18 @@ void llama_context_kv_self::kv_self_update() {
1809
2061
1810
2062
auto * gf = graph_init ();
1811
2063
1812
- model.build_graph_kv_self_defrag (
1813
- {
1814
- /* .ctx =*/ ctx_compute.get (),
1815
- /* .model =*/ model,
1816
- /* .cparams =*/ cparams,
1817
- /* .ubatch =*/ {},
1818
- /* .sched =*/ sched.get (),
1819
- /* .backend_cpu =*/ backend_cpu,
1820
- /* .backends =*/ backends,
1821
- /* .cvec =*/ nullptr ,
1822
- /* .loras =*/ nullptr ,
1823
- /* .memory =*/ nullptr ,
1824
- /* .cross =*/ nullptr ,
1825
- /* .n_outputs =*/ 0 ,
1826
- }, gf);
2064
+ auto res = build_kv_self_defrag (ctx_compute.get (), gf);
1827
2065
1828
2066
ggml_backend_sched_alloc_graph (sched.get (), gf);
1829
2067
1830
- // no input
1831
- // input_set({});
2068
+ res->set_inputs (nullptr );
1832
2069
1833
2070
graph_compute (gf, false );
2071
+
2072
+ need_reserve = true ;
1834
2073
}
1835
2074
1836
2075
kv->do_defrag = false ;
1837
-
1838
- need_reserve = true ;
1839
2076
}
1840
2077
1841
2078
// reserve a worst case graph if needed
0 commit comments