@@ -1131,17 +1131,18 @@ llm_graph_result_ptr llama_context_base::graph_build(
1131
1131
return model.build_graph (
1132
1132
{
1133
1133
/* .ctx =*/ ctx,
1134
- /* .model =*/ model,
1134
+ /* .arch =*/ model.arch ,
1135
+ /* .hparams =*/ model.hparams ,
1135
1136
/* .cparams =*/ cparams,
1136
1137
/* .ubatch =*/ ubatch,
1137
1138
/* .sched =*/ sched.get (),
1138
1139
/* .backend_cpu =*/ backend_cpu,
1139
- /* .backends =*/ backends,
1140
1140
/* .cvec =*/ &cvec,
1141
1141
/* .loras =*/ &loras,
1142
1142
/* .memory =*/ nullptr ,
1143
1143
/* .cross =*/ nullptr ,
1144
1144
/* .n_outputs =*/ n_outputs,
1145
+ /* .cb =*/ graph_get_cb (),
1145
1146
}, gf, gtype);
1146
1147
}
1147
1148
@@ -1172,6 +1173,39 @@ enum ggml_status llama_context_base::graph_compute(
1172
1173
return status;
1173
1174
}
1174
1175
1176
+ llm_graph_cb llama_context_base::graph_get_cb () const {
1177
+ return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) {
1178
+ if (il >= 0 ) {
1179
+ ggml_format_name (cur, " %s-%d" , name, il);
1180
+ } else {
1181
+ ggml_set_name (cur, name);
1182
+ }
1183
+
1184
+ if (!cparams.offload_kqv ) {
1185
+ if (strcmp (name, " kqv_merged_cont" ) == 0 ) {
1186
+ // all nodes between the KV store and the attention output are run on the CPU
1187
+ ggml_backend_sched_set_tensor_backend (sched.get (), cur, backend_cpu);
1188
+ }
1189
+ }
1190
+
1191
+ // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
1192
+ // FIXME: fix in ggml_backend_sched
1193
+ const bool full_offload = model.params .n_gpu_layers > (int ) model.hparams .n_layer ;
1194
+ if (ubatch.n_tokens < 32 || full_offload) {
1195
+ if (il != -1 && strcmp (name, " norm" ) == 0 ) {
1196
+ const auto & dev_layer = model.dev_layer (il);
1197
+ for (const auto & backend : backends) {
1198
+ if (ggml_backend_get_device (backend.get ()) == dev_layer) {
1199
+ if (ggml_backend_supports_op (backend.get (), cur)) {
1200
+ ggml_backend_sched_set_tensor_backend (sched.get (), cur, backend.get ());
1201
+ }
1202
+ }
1203
+ }
1204
+ }
1205
+ }
1206
+ };
1207
+ }
1208
+
1175
1209
//
1176
1210
// perf
1177
1211
//
@@ -2567,17 +2601,18 @@ llm_graph_result_ptr llama_context_kv_self::graph_build(
2567
2601
return model.build_graph (
2568
2602
{
2569
2603
/* .ctx =*/ ctx,
2570
- /* .model =*/ model,
2604
+ /* .arch =*/ model.arch ,
2605
+ /* .hparams =*/ model.hparams ,
2571
2606
/* .cparams =*/ cparams,
2572
2607
/* .ubatch =*/ ubatch,
2573
2608
/* .sched =*/ sched.get (),
2574
2609
/* .backend_cpu =*/ backend_cpu,
2575
- /* .backends =*/ backends,
2576
2610
/* .cvec =*/ &cvec,
2577
2611
/* .loras =*/ &loras,
2578
2612
/* .memory =*/ kv_self.get (),
2579
2613
/* .cross =*/ nullptr ,
2580
2614
/* .n_outputs =*/ n_outputs,
2615
+ /* .cb =*/ graph_get_cb (),
2581
2616
}, gf, gtype);
2582
2617
}
2583
2618
@@ -3010,17 +3045,18 @@ llm_graph_result_ptr llama_context_recurrent::graph_build(
3010
3045
return model.build_graph (
3011
3046
{
3012
3047
/* .ctx =*/ ctx,
3013
- /* .model =*/ model,
3048
+ /* .arch =*/ model.arch ,
3049
+ /* .hparams =*/ model.hparams ,
3014
3050
/* .cparams =*/ cparams,
3015
3051
/* .ubatch =*/ ubatch,
3016
3052
/* .sched =*/ sched.get (),
3017
3053
/* .backend_cpu =*/ backend_cpu,
3018
- /* .backends =*/ backends,
3019
3054
/* .cvec =*/ &cvec,
3020
3055
/* .loras =*/ &loras,
3021
3056
/* .memory =*/ kv_self.get (),
3022
3057
/* .cross =*/ nullptr ,
3023
3058
/* .n_outputs =*/ n_outputs,
3059
+ /* .cb =*/ graph_get_cb (),
3024
3060
}, gf, gtype);
3025
3061
}
3026
3062
@@ -3227,17 +3263,18 @@ llm_graph_result_ptr llama_context_dec::graph_build(
3227
3263
return model.build_graph (
3228
3264
{
3229
3265
/* .ctx =*/ ctx,
3230
- /* .model =*/ model,
3266
+ /* .arch =*/ model.arch ,
3267
+ /* .hparams =*/ model.hparams ,
3231
3268
/* .cparams =*/ cparams,
3232
3269
/* .ubatch =*/ ubatch,
3233
3270
/* .sched =*/ sched.get (),
3234
3271
/* .backend_cpu =*/ backend_cpu,
3235
- /* .backends =*/ backends,
3236
3272
/* .cvec =*/ &cvec,
3237
3273
/* .loras =*/ &loras,
3238
3274
/* .memory =*/ kv_self.get (),
3239
3275
/* .cross =*/ cross,
3240
3276
/* .n_outputs =*/ n_outputs,
3277
+ /* .cb =*/ graph_get_cb (),
3241
3278
}, gf, gtype);
3242
3279
}
3243
3280
0 commit comments