Dealing with context shifting #1394
Replies: 1 comment
-
I guess i succeeded to implement things inspired by codes of a guy who implemented StreamingLLM with llama-cpp-python. def kv_cache_seq_ltrim(
model: llama_cpp.Llama,
n_keep: int,
n_discard: int,
):
"""
Implementation comes from this GitHub repository:
https://github.com/Limour-dev/llama-python-streamingllm/blob/main/llama_cpp_python_streamingllm.py
Args:
n_keep(int): number of first tokens to keep.
n_discard(int) number of tokens to discard.
Returns:
None
Schema:
n_keep(3) n_keep(3)+n_discard(3)
| |
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] # Initial state.
[0, 1, 2, -, -, -, 6, 7, 8, 9] # kv_cache_seq_rm
[0, 1, 2, 6, 7, 8, 9] # kv_cache_seq_shift
"""
n_tokens = model.n_tokens
model._ctx.kv_cache_seq_rm(-1, n_keep, n_keep + n_discard)
model._ctx.kv_cache_seq_shift(0, n_keep + n_discard, n_tokens, -n_discard)
model.input_ids[n_keep:n_tokens - n_discard] = model.input_ids[n_keep + n_discard:n_tokens]
model.n_tokens = n_tokens - n_discard And This is the usage: def push_context(item, auto_shift_kv: bool):
"""
Args:
item: new item to push.
auto_shift_kv(bool): if true, shift KV-Cache and input_id if needed.
Returns:
None
"""
if not auto_shift_kv or len(context) < MAXLEN:
context.push(item)
return
# Context shifting.
oldest_item = context[0]
n_sys_ppt_tokens = len(tokenize(system_prompt)) + 1 # +1 for bos token.
n_oldest_ctx_tokens = len(tokenize(item2str(oldest_item)))
kv_cache_seq_ltrim(
model=llama_model,
n_keep=n_sys_ppt_tokens,
n_discard=n_oldest_ctx_tokens,
)
context.push(item) where it tries to shift context by length of oldest message except system prompt which is static in every inference occasion. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Dear community.
I wanna ask how you guys are dealing with context shifting problem using llama-cpp-python.
First, let's clarify what context shifting is.
As you know, original llama.cpp and llama-cpp-python holds KV cache of the latest evaluated inputs to prevent it from redoing attention calculation for same substring.
It works as I expected until chat history hits the maximum context size.
If that happens, the oldest chunk in the history is discarded.
This is context shifting, which results in substring mismatch, and hence KV cache no longer works.
It seems Kobold.cpp, for example, somehow succeeded to deal with this problem as mentioned in this thread.
Reddit thread
(The place I picked up the word "Context Shifting")
Currently I'm working with chat application using llama-cpp-python, and prompt eval time can be critical in large size model.
So, I want to utilize KV cache to shorten the evaluation.
I read documents and found some KV Cache manipulating APIs are provided by llama-cpp-python, but the explanation is barely detailed.
These are llama-cpp-python's KV Cache function family I'm still not sure how to use.
I want ask you guys how I can tackle this problem.
Thank you.
Beta Was this translation helpful? Give feedback.
All reactions