Skip to content

Commit 5f5c3b7

Browse files
committed
context : allow cache-less context for embeddings
ggml-ci
1 parent dec80ac commit 5f5c3b7

File tree

3 files changed

+15
-4
lines changed

3 files changed

+15
-4
lines changed

examples/embedding/embedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
4949
}
5050
} else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
5151
// decoder-only model
52-
if (llama_decode(ctx, batch) < 0) {
52+
if (llama_encode(ctx, batch) < 0) {
5353
LOG_ERR("%s : failed to decode\n", __func__);
5454
}
5555
}

src/llama-context.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ llama_context::llama_context(
253253
}
254254

255255
// reserve worst-case graph
256-
if (!hparams.vocab_only) {
256+
if (!hparams.vocab_only && memory) {
257257
const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
258258
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
259259

@@ -760,12 +760,12 @@ int llama_context::encode(llama_batch & inp_batch) {
760760
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
761761
GGML_ASSERT(backend_embd != nullptr);
762762

763-
GGML_ASSERT(embd != nullptr);
764-
765763
switch (cparams.pooling_type) {
766764
case LLAMA_POOLING_TYPE_NONE:
767765
{
768766
// extract token embeddings
767+
GGML_ASSERT(embd != nullptr);
768+
769769
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
770770
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
771771
} break;
@@ -837,6 +837,11 @@ int llama_context::decode(llama_batch & inp_batch) {
837837
return -1;
838838
}
839839

840+
if (!memory) {
841+
LLAMA_LOG_WARN("%s: cannot decode batches with this context\n", __func__);
842+
return -1;
843+
}
844+
840845
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
841846

842847
// temporary allocate memory for the input batch if needed

src/llama-model.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12790,6 +12790,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1279012790
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
1279112791

1279212792
switch (arch) {
12793+
case LLM_ARCH_BERT:
12794+
case LLM_ARCH_JINA_BERT_V2:
12795+
case LLM_ARCH_NOMIC_BERT:
12796+
{
12797+
res = nullptr;
12798+
} break;
1279312799
case LLM_ARCH_MAMBA:
1279412800
case LLM_ARCH_RWKV6:
1279512801
case LLM_ARCH_RWKV6QWEN2:

0 commit comments

Comments
 (0)