Skip to content

A question on your implementation of decoder phase of llama #79

@wangtianxia-sjtu

Description

@wangtianxia-sjtu

Recently I have been studying your code. However, It seems to me that your implemention will not expand the kv cache during the decoding phase. The follow code is excerpted from the function def _concatenate_to_cache in llama.py.

if query.shape[1] == 1:
    mesh = LLaMAConfig.get_jax_mesh(self.config.mesh_dim)
    def fn(cached_key, cached_value, key, value, cur_index):
        assert key.shape[1] == 1 and value.shape[1] == 1, (key.shape, value.shape)
        sp_size = max_length // mesh.shape['sp']
        axis_index = jax.lax.axis_index('sp')
        cur_index = cur_index - axis_index * sp_size
        key, value = jax.lax.cond(
            jnp.logical_and(cur_index >= 0, cur_index < sp_size),
            lambda: (
                cached_key.at[:, cur_index].set(key[:, -1]),
                cached_value.at[:, cur_index].set(value[:, -1]),
            ),
            lambda: (cached_key, cached_value),
        )
        return key, value

In this function, we will only update cached_key and cached_value with the newly-generated key/value in the decoding phase, instead of pushing back them into the cached_key and cached_value. However, it seems to me that a correct implementation of kvcache should make the size of kvcache grow and become longer.

Maybe I do not fully understand your code, but I am looking forward to your reply.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions