@@ -14857,55 +14857,55 @@ static int llama_decode_internal(
14857
14857
// Re-build graph only if graph caching is not possible
14858
14858
if(!ggml_use_cached_graph(lctx.sched) || n_has_changed_since_last_token) {
14859
14859
14860
- gf = llama_build_graph(lctx, u_batch, false);
14861
-
14862
- // Set whether GGML graph caching is in use within GGML module, based on
14863
- // whether caching was activated here during the previous token
14864
- ggml_set_cached_graph(lctx.sched,lctx.cached_graph.is_active);
14865
-
14866
- // Disable future graph caching in presence of env var,
14867
- // if there are multiple devices, if batch size is greater than 1,
14868
- // or if nsplits is not 2.
14869
- // TO DO enable graph caching for these cases
14870
- bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr)
14871
- || (llama_get_device_count(model) > 1)
14872
- || (ggml_backend_sched_get_n_splits(lctx.sched) != 2);
14873
- for (int i = 0 ; i < gf->n_nodes; i++) {
14874
- if (gf->nodes[i]->op == GGML_OP_ADD && gf->nodes[i]->src[1] && gf->nodes[i]->src[1]->ne[1] > 1) {
14875
- disable_cached_ggml_graph = true;
14876
- break;
14860
+ gf = llama_build_graph(lctx, u_batch, false);
14861
+
14862
+ // Set whether GGML graph caching is in use within GGML module, based on
14863
+ // whether caching was activated here during the previous token
14864
+ ggml_set_cached_graph(lctx.sched,lctx.cached_graph.is_active);
14865
+
14866
+ // Disable future graph caching in presence of env var,
14867
+ // if there are multiple devices, if batch size is greater than 1,
14868
+ // or if nsplits is not 2.
14869
+ // TO DO enable graph caching for these cases
14870
+ bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr)
14871
+ || (llama_get_device_count(model) > 1)
14872
+ || (ggml_backend_sched_get_n_splits(lctx.sched) != 2);
14873
+ for (int i = 0 ; i < gf->n_nodes; i++) {
14874
+ if (gf->nodes[i]->op == GGML_OP_ADD && gf->nodes[i]->src[1] && gf->nodes[i]->src[1]->ne[1] > 1) {
14875
+ disable_cached_ggml_graph = true;
14876
+ break;
14877
+ }
14877
14878
}
14878
- }
14879
14879
14880
- // Set whether graph caching should be used for future tokens
14881
- lctx.cached_graph.is_active=!disable_cached_ggml_graph;
14882
-
14883
- // the output is always the last tensor in the graph
14884
- res = gf->nodes[gf->n_nodes - 1];
14885
- embd = gf->nodes[gf->n_nodes - 2];
14886
- if (lctx.n_outputs == 0) {
14887
- // no output
14888
- res = nullptr;
14889
- embd = nullptr;
14890
- } else if (cparams.embeddings) {
14891
- res = nullptr; // do not extract logits for embedding case
14892
- embd = nullptr;
14893
- for (int i = gf->n_nodes - 1; i >= 0; --i) {
14894
- if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
14895
- embd = gf->nodes[i];
14896
- break;
14880
+ // Set whether graph caching should be used for future tokens
14881
+ lctx.cached_graph.is_active=!disable_cached_ggml_graph;
14882
+
14883
+ // the output is always the last tensor in the graph
14884
+ res = gf->nodes[gf->n_nodes - 1];
14885
+ embd = gf->nodes[gf->n_nodes - 2];
14886
+ if (lctx.n_outputs == 0) {
14887
+ // no output
14888
+ res = nullptr;
14889
+ embd = nullptr;
14890
+ } else if (cparams.embeddings) {
14891
+ res = nullptr; // do not extract logits for embedding case
14892
+ embd = nullptr;
14893
+ for (int i = gf->n_nodes - 1; i >= 0; --i) {
14894
+ if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
14895
+ embd = gf->nodes[i];
14896
+ break;
14897
+ }
14897
14898
}
14899
+ GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
14900
+ } else {
14901
+ embd = nullptr; // do not extract embeddings when not needed
14902
+ GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
14898
14903
}
14899
- GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
14900
- } else {
14901
- embd = nullptr; // do not extract embeddings when not needed
14902
- GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
14903
- }
14904
- lctx.cached_graph.res = res;
14905
- lctx.cached_graph.embd = embd;
14906
- // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
14904
+ lctx.cached_graph.res = res;
14905
+ lctx.cached_graph.embd = embd;
14906
+ // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
14907
14907
14908
- ggml_backend_sched_alloc_graph(lctx.sched, gf);
14908
+ ggml_backend_sched_alloc_graph(lctx.sched, gf);
14909
14909
14910
14910
}
14911
14911
else {
0 commit comments