Skip to content

Commit 9770efa

Browse files
committed
context : allow cache-less context for embeddings
ggml-ci
1 parent 1d36b36 commit 9770efa

File tree

3 files changed

+15
-4
lines changed

3 files changed

+15
-4
lines changed

examples/embedding/embedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
4949
}
5050
} else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
5151
// decoder-only model
52-
if (llama_decode(ctx, batch) < 0) {
52+
if (llama_encode(ctx, batch) < 0) {
5353
LOG_ERR("%s : failed to decode\n", __func__);
5454
}
5555
}

src/llama-context.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ llama_context::llama_context(
253253
}
254254

255255
// reserve worst-case graph
256-
if (!hparams.vocab_only) {
256+
if (!hparams.vocab_only && memory) {
257257
const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
258258
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
259259

@@ -763,12 +763,12 @@ int llama_context::encode(llama_batch & inp_batch) {
763763
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
764764
GGML_ASSERT(backend_embd != nullptr);
765765

766-
GGML_ASSERT(embd != nullptr);
767-
768766
switch (cparams.pooling_type) {
769767
case LLAMA_POOLING_TYPE_NONE:
770768
{
771769
// extract token embeddings
770+
GGML_ASSERT(embd != nullptr);
771+
772772
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
773773
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
774774
} break;
@@ -840,6 +840,11 @@ int llama_context::decode(llama_batch & inp_batch) {
840840
return -1;
841841
}
842842

843+
if (!memory) {
844+
LLAMA_LOG_WARN("%s: cannot decode batches with this context\n", __func__);
845+
return -1;
846+
}
847+
843848
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
844849

845850
// temporary allocate memory for the input batch if needed

src/llama-model.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12833,6 +12833,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1283312833
llama_memory_i * res;
1283412834

1283512835
switch (arch) {
12836+
case LLM_ARCH_BERT:
12837+
case LLM_ARCH_JINA_BERT_V2:
12838+
case LLM_ARCH_NOMIC_BERT:
12839+
{
12840+
res = nullptr;
12841+
} break;
1283612842
case LLM_ARCH_MAMBA:
1283712843
case LLM_ARCH_RWKV6:
1283812844
case LLM_ARCH_RWKV6QWEN2:

0 commit comments

Comments
 (0)