Skip to content

Commit 7b51bed

Browse files
committed
graph : move KV cache build functions to llama_context impl
ggml-ci
1 parent a2644b0 commit 7b51bed

File tree

6 files changed

+317
-350
lines changed

6 files changed

+317
-350
lines changed

src/llama-context.cpp

Lines changed: 271 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "llama-model.h"
77
#include "llama-kv-cache.h"
88

9+
#include <cassert>
910
#include <cstring>
1011
#include <stdexcept>
1112
#include <cinttypes>
@@ -1748,6 +1749,271 @@ const llama_kv_cache * llama_context_kv_self::get_kv_self() const {
17481749
return kv_self.get();
17491750
}
17501751

1752+
ggml_tensor * llama_context_kv_self::build_rope_shift(
1753+
ggml_context * ctx0,
1754+
ggml_tensor * cur,
1755+
ggml_tensor * shift,
1756+
ggml_tensor * factors,
1757+
ggml_backend_buffer * bbuf) const {
1758+
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
1759+
const auto & freq_base = cparams.rope_freq_base;
1760+
const auto & freq_scale = cparams.rope_freq_scale;
1761+
1762+
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
1763+
const auto & yarn_attn_factor = cparams.yarn_attn_factor;
1764+
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
1765+
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
1766+
1767+
const auto & hparams = model.hparams;
1768+
1769+
const auto & n_rot = hparams.n_rot;
1770+
const auto & rope_type = hparams.rope_type;
1771+
1772+
ggml_tensor * tmp;
1773+
1774+
if (ggml_is_quantized(cur->type)) {
1775+
// dequantize to f32 -> RoPE -> quantize back
1776+
tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32);
1777+
1778+
if (bbuf) {
1779+
for (const auto & backend : backends) {
1780+
// Figure out which backend KV cache belongs to
1781+
if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(bbuf))) {
1782+
ggml_backend_sched_set_tensor_backend(sched.get(), tmp, backend.get());
1783+
break;
1784+
}
1785+
}
1786+
}
1787+
1788+
tmp = ggml_rope_ext_inplace(ctx0, tmp,
1789+
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1790+
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
1791+
1792+
tmp = ggml_cpy(ctx0, tmp, cur);
1793+
} else {
1794+
// we rotate only the first n_rot dimensions
1795+
tmp = ggml_rope_ext_inplace(ctx0, cur,
1796+
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1797+
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
1798+
}
1799+
1800+
return tmp;
1801+
}
1802+
1803+
class llm_graph_input_k_shift : public llm_graph_input_i {
1804+
public:
1805+
llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
1806+
virtual ~llm_graph_input_k_shift() = default;
1807+
1808+
void set_input(const llama_ubatch * ubatch) override;
1809+
1810+
ggml_tensor * k_shift; // I32 [kv_size]
1811+
1812+
const llama_kv_cache_unified * kv_self;
1813+
};
1814+
1815+
void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
1816+
GGML_UNUSED(ubatch);
1817+
1818+
if (k_shift) {
1819+
assert(ggml_backend_buffer_is_host(k_shift->buffer));
1820+
1821+
int32_t * data = (int32_t *) k_shift->data;
1822+
1823+
for (uint32_t i = 0; i < kv_self->size; ++i) {
1824+
data[i] = kv_self->cells[i].delta;
1825+
}
1826+
}
1827+
}
1828+
1829+
llm_graph_result_ptr llama_context_kv_self::build_kv_self_shift(
1830+
ggml_context * ctx0,
1831+
ggml_cgraph * gf) const {
1832+
auto res = std::make_unique<llm_graph_result>();
1833+
1834+
const auto & hparams = model.hparams;
1835+
1836+
const auto & n_layer = hparams.n_layer;
1837+
1838+
const auto & n_embd_head_k = hparams.n_embd_head_k;
1839+
//const auto & n_embd_head_v = hparams.n_embd_head_v;
1840+
1841+
//GGML_ASSERT(kv_self->size == n_ctx);
1842+
1843+
auto inp = std::make_shared<llm_graph_input_k_shift>(kv_self.get());
1844+
1845+
inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx);
1846+
ggml_set_input(inp->k_shift);
1847+
1848+
res->add_input(inp);
1849+
1850+
for (uint32_t il = 0; il < n_layer; ++il) {
1851+
const int64_t n_head_kv = hparams.n_head_kv(il);
1852+
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1853+
1854+
ggml_tensor * rope_factors = model.build_rope_factors(n_ctx_per_seq(), il);
1855+
1856+
ggml_tensor * k =
1857+
ggml_view_3d(ctx0, kv_self->k_l[il],
1858+
n_embd_head_k, n_head_kv, kv_self->size,
1859+
ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
1860+
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
1861+
0);
1862+
1863+
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, kv_self->k_l[il]->buffer);
1864+
1865+
ggml_build_forward_expand(gf, cur);
1866+
}
1867+
1868+
return res;
1869+
}
1870+
1871+
llm_graph_result_ptr llama_context_kv_self::build_kv_self_defrag(
1872+
ggml_context * ctx0,
1873+
ggml_cgraph * gf) const {
1874+
auto res = std::make_unique<llm_graph_result>();
1875+
1876+
const auto & hparams = model.hparams;
1877+
1878+
const auto & ids = kv_self->defrag_info.ids;
1879+
1880+
#if 0
1881+
// CPU defrag
1882+
//
1883+
// TODO: optimizations are possible:
1884+
// - multiple threads
1885+
// - avoid copying to the host memory when already there
1886+
//
1887+
// likely not worth the effort, as we have ggml_graph based defrag
1888+
//
1889+
1890+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
1891+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
1892+
1893+
const uint32_t kv_size = size;
1894+
1895+
std::vector<uint8_t> buf_k;
1896+
std::vector<uint8_t> buf_v;
1897+
1898+
for (uint32_t il = 0; il < n_layer; ++il) {
1899+
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
1900+
const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
1901+
1902+
const size_t v_size_el = ggml_type_size(v_l[il]->type);
1903+
const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
1904+
1905+
buf_k.resize(k_size);
1906+
buf_v.resize(v_size);
1907+
1908+
ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
1909+
ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
1910+
1911+
// batch move [i, i+nm) to [id, id+nm)
1912+
// note: cells can move only to a lower index
1913+
for (uint32_t i = 0; i < n_kv; ++i) {
1914+
const uint32_t id = ids[i];
1915+
1916+
if (i == id || id == n_kv) {
1917+
continue;
1918+
}
1919+
1920+
uint32_t nm = 1;
1921+
1922+
while (i + nm < n_kv && ids[i + nm] == id + nm) {
1923+
nm++;
1924+
}
1925+
1926+
// move keys
1927+
{
1928+
const int64_t os = i*k_size_row;
1929+
const int64_t od = id*k_size_row;
1930+
1931+
memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
1932+
}
1933+
1934+
// move values (note: they are transposed)
1935+
{
1936+
const int64_t os = i;
1937+
const int64_t od = id;
1938+
1939+
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
1940+
memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
1941+
}
1942+
}
1943+
1944+
i += nm - 1;
1945+
}
1946+
1947+
ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
1948+
ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
1949+
}
1950+
#else
1951+
for (uint32_t i = 0; i < ids.size(); ++i) {
1952+
const uint32_t id = ids[i];
1953+
1954+
if (i == id || id == ids.size()) {
1955+
continue;
1956+
}
1957+
1958+
uint32_t nm = 1;
1959+
1960+
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
1961+
nm++;
1962+
}
1963+
1964+
for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
1965+
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1966+
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1967+
1968+
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il],
1969+
n_embd_k_gqa, nm,
1970+
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
1971+
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
1972+
1973+
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il],
1974+
n_embd_k_gqa, nm,
1975+
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
1976+
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
1977+
1978+
ggml_tensor * view_v_src;
1979+
ggml_tensor * view_v_dst;
1980+
1981+
if (cparams.flash_attn) {
1982+
// NOTE: the V cache is not transposed when using flash attention
1983+
view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
1984+
n_embd_v_gqa, nm,
1985+
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
1986+
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
1987+
1988+
view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
1989+
n_embd_v_gqa, nm,
1990+
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
1991+
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
1992+
} else {
1993+
view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
1994+
nm, n_embd_v_gqa,
1995+
ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
1996+
ggml_row_size(kv_self->v_l[il]->type, i));
1997+
1998+
view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
1999+
nm, n_embd_v_gqa,
2000+
ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
2001+
ggml_row_size(kv_self->v_l[il]->type, id));
2002+
}
2003+
2004+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
2005+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
2006+
}
2007+
2008+
i += nm - 1;
2009+
}
2010+
2011+
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
2012+
#endif
2013+
2014+
return res;
2015+
}
2016+
17512017
void llama_context_kv_self::kv_self_update() {
17522018
auto & kv = kv_self;
17532019

@@ -1766,21 +2032,7 @@ void llama_context_kv_self::kv_self_update() {
17662032

17672033
auto * gf = graph_init();
17682034

1769-
auto res = model.build_graph_k_shift(
1770-
{
1771-
/*.ctx =*/ ctx_compute.get(),
1772-
/*.model =*/ model,
1773-
/*.cparams =*/ cparams,
1774-
/*.ubatch =*/ {},
1775-
/*.sched =*/ sched.get(),
1776-
/*.backend_cpu =*/ backend_cpu,
1777-
/*.backends =*/ backends,
1778-
/*.cvec =*/ nullptr,
1779-
/*.loras =*/ nullptr,
1780-
/*.memory =*/ kv_self.get(),
1781-
/*.cross =*/ nullptr,
1782-
/*.n_outputs =*/ 0,
1783-
}, gf);
2035+
auto res = build_kv_self_shift(ctx_compute.get(), gf);
17842036

17852037
ggml_backend_sched_alloc_graph(sched.get(), gf);
17862038

@@ -1809,33 +2061,18 @@ void llama_context_kv_self::kv_self_update() {
18092061

18102062
auto * gf = graph_init();
18112063

1812-
model.build_graph_kv_self_defrag(
1813-
{
1814-
/*.ctx =*/ ctx_compute.get(),
1815-
/*.model =*/ model,
1816-
/*.cparams =*/ cparams,
1817-
/*.ubatch =*/ {},
1818-
/*.sched =*/ sched.get(),
1819-
/*.backend_cpu =*/ backend_cpu,
1820-
/*.backends =*/ backends,
1821-
/*.cvec =*/ nullptr,
1822-
/*.loras =*/ nullptr,
1823-
/*.memory =*/ nullptr,
1824-
/*.cross =*/ nullptr,
1825-
/*.n_outputs =*/ 0,
1826-
}, gf);
2064+
auto res = build_kv_self_defrag(ctx_compute.get(), gf);
18272065

18282066
ggml_backend_sched_alloc_graph(sched.get(), gf);
18292067

1830-
// no input
1831-
//input_set({});
2068+
res->set_inputs(nullptr);
18322069

18332070
graph_compute(gf, false);
2071+
2072+
need_reserve = true;
18342073
}
18352074

18362075
kv->do_defrag = false;
1837-
1838-
need_reserve = true;
18392076
}
18402077

18412078
// reserve a worst case graph if needed

src/llama-context.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,24 @@ class llama_context_kv_self : public llama_context_base {
442442
ggml_cgraph * gf,
443443
const llama_ubatch & ubatch) override;
444444

445+
// used by kv_self_update()
446+
private:
447+
ggml_tensor * build_rope_shift(
448+
ggml_context * ctx0,
449+
ggml_tensor * cur,
450+
ggml_tensor * shift,
451+
ggml_tensor * factors,
452+
ggml_backend_buffer * bbuf) const;
453+
454+
llm_graph_result_ptr build_kv_self_shift(
455+
ggml_context * ctx0,
456+
ggml_cgraph * gf) const;
457+
458+
llm_graph_result_ptr build_kv_self_defrag(
459+
ggml_context * ctx0,
460+
ggml_cgraph * gf) const;
461+
462+
protected:
445463
//
446464
// state save/load
447465
//

0 commit comments

Comments
 (0)