Skip to content

Commit 10ea682

Browse files
committed
graph : remove llama_model reference
ggml-ci
1 parent a9d5096 commit 10ea682

File tree

5 files changed

+506
-508
lines changed

5 files changed

+506
-508
lines changed

src/llama-context.cpp

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,17 +1131,18 @@ llm_graph_result_ptr llama_context_base::graph_build(
11311131
return model.build_graph(
11321132
{
11331133
/*.ctx =*/ ctx,
1134-
/*.model =*/ model,
1134+
/*.arch =*/ model.arch,
1135+
/*.hparams =*/ model.hparams,
11351136
/*.cparams =*/ cparams,
11361137
/*.ubatch =*/ ubatch,
11371138
/*.sched =*/ sched.get(),
11381139
/*.backend_cpu =*/ backend_cpu,
1139-
/*.backends =*/ backends,
11401140
/*.cvec =*/ &cvec,
11411141
/*.loras =*/ &loras,
11421142
/*.memory =*/ nullptr,
11431143
/*.cross =*/ nullptr,
11441144
/*.n_outputs =*/ n_outputs,
1145+
/*.cb =*/ graph_get_cb(),
11451146
}, gf, gtype);
11461147
}
11471148

@@ -1172,6 +1173,39 @@ enum ggml_status llama_context_base::graph_compute(
11721173
return status;
11731174
}
11741175

1176+
llm_graph_cb llama_context_base::graph_get_cb() const {
1177+
return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) {
1178+
if (il >= 0) {
1179+
ggml_format_name(cur, "%s-%d", name, il);
1180+
} else {
1181+
ggml_set_name(cur, name);
1182+
}
1183+
1184+
if (!cparams.offload_kqv) {
1185+
if (strcmp(name, "kqv_merged_cont") == 0) {
1186+
// all nodes between the KV store and the attention output are run on the CPU
1187+
ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu);
1188+
}
1189+
}
1190+
1191+
// norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
1192+
// FIXME: fix in ggml_backend_sched
1193+
const bool full_offload = model.params.n_gpu_layers > (int) model.hparams.n_layer;
1194+
if (ubatch.n_tokens < 32 || full_offload) {
1195+
if (il != -1 && strcmp(name, "norm") == 0) {
1196+
const auto & dev_layer = model.dev_layer(il);
1197+
for (const auto & backend : backends) {
1198+
if (ggml_backend_get_device(backend.get()) == dev_layer) {
1199+
if (ggml_backend_supports_op(backend.get(), cur)) {
1200+
ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get());
1201+
}
1202+
}
1203+
}
1204+
}
1205+
}
1206+
};
1207+
}
1208+
11751209
//
11761210
// perf
11771211
//
@@ -2567,17 +2601,18 @@ llm_graph_result_ptr llama_context_kv_self::graph_build(
25672601
return model.build_graph(
25682602
{
25692603
/*.ctx =*/ ctx,
2570-
/*.model =*/ model,
2604+
/*.arch =*/ model.arch,
2605+
/*.hparams =*/ model.hparams,
25712606
/*.cparams =*/ cparams,
25722607
/*.ubatch =*/ ubatch,
25732608
/*.sched =*/ sched.get(),
25742609
/*.backend_cpu =*/ backend_cpu,
2575-
/*.backends =*/ backends,
25762610
/*.cvec =*/ &cvec,
25772611
/*.loras =*/ &loras,
25782612
/*.memory =*/ kv_self.get(),
25792613
/*.cross =*/ nullptr,
25802614
/*.n_outputs =*/ n_outputs,
2615+
/*.cb =*/ graph_get_cb(),
25812616
}, gf, gtype);
25822617
}
25832618

@@ -3010,17 +3045,18 @@ llm_graph_result_ptr llama_context_recurrent::graph_build(
30103045
return model.build_graph(
30113046
{
30123047
/*.ctx =*/ ctx,
3013-
/*.model =*/ model,
3048+
/*.arch =*/ model.arch,
3049+
/*.hparams =*/ model.hparams,
30143050
/*.cparams =*/ cparams,
30153051
/*.ubatch =*/ ubatch,
30163052
/*.sched =*/ sched.get(),
30173053
/*.backend_cpu =*/ backend_cpu,
3018-
/*.backends =*/ backends,
30193054
/*.cvec =*/ &cvec,
30203055
/*.loras =*/ &loras,
30213056
/*.memory =*/ kv_self.get(),
30223057
/*.cross =*/ nullptr,
30233058
/*.n_outputs =*/ n_outputs,
3059+
/*.cb =*/ graph_get_cb(),
30243060
}, gf, gtype);
30253061
}
30263062

@@ -3227,17 +3263,18 @@ llm_graph_result_ptr llama_context_dec::graph_build(
32273263
return model.build_graph(
32283264
{
32293265
/*.ctx =*/ ctx,
3230-
/*.model =*/ model,
3266+
/*.arch =*/ model.arch,
3267+
/*.hparams =*/ model.hparams,
32313268
/*.cparams =*/ cparams,
32323269
/*.ubatch =*/ ubatch,
32333270
/*.sched =*/ sched.get(),
32343271
/*.backend_cpu =*/ backend_cpu,
3235-
/*.backends =*/ backends,
32363272
/*.cvec =*/ &cvec,
32373273
/*.loras =*/ &loras,
32383274
/*.memory =*/ kv_self.get(),
32393275
/*.cross =*/ cross,
32403276
/*.n_outputs =*/ n_outputs,
3277+
/*.cb =*/ graph_get_cb(),
32413278
}, gf, gtype);
32423279
}
32433280

src/llama-context.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,8 @@ class llama_context_base : public llama_context {
273273
ggml_cgraph * gf,
274274
bool batched);
275275

276+
llm_graph_cb graph_get_cb() const;
277+
276278
ggml_context_ptr ctx_compute;
277279

278280
public:

0 commit comments

Comments
 (0)