-
Notifications
You must be signed in to change notification settings - Fork 559
Open
Description
Recently I have been studying your code. However, It seems to me that your implemention will not expand the kv cache during the decoding phase. The follow code is excerpted from the function def _concatenate_to_cache
in llama.py.
if query.shape[1] == 1:
mesh = LLaMAConfig.get_jax_mesh(self.config.mesh_dim)
def fn(cached_key, cached_value, key, value, cur_index):
assert key.shape[1] == 1 and value.shape[1] == 1, (key.shape, value.shape)
sp_size = max_length // mesh.shape['sp']
axis_index = jax.lax.axis_index('sp')
cur_index = cur_index - axis_index * sp_size
key, value = jax.lax.cond(
jnp.logical_and(cur_index >= 0, cur_index < sp_size),
lambda: (
cached_key.at[:, cur_index].set(key[:, -1]),
cached_value.at[:, cur_index].set(value[:, -1]),
),
lambda: (cached_key, cached_value),
)
return key, value
In this function, we will only update cached_key
and cached_value
with the newly-generated key/value in the decoding phase, instead of pushing back them into the cached_key
and cached_value
. However, it seems to me that a correct implementation of kvcache should make the size of kvcache grow and become longer.
Maybe I do not fully understand your code, but I am looking forward to your reply.
Metadata
Metadata
Assignees
Labels
No labels