Skip to content

Commit d781bed

Browse files
committed
Fix Gemma3n not executed as CUDA_GRAPH on NVGPUs ggml-org#14741
Author : Olivier Simons
1 parent 5c75115 commit d781bed

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2849,6 +2849,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
28492849

28502850
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
28512851
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
2852+
std::uint8_t batch_size_counter = 0;
28522853

28532854
for (int i = 0; i < cgraph->n_nodes; i++) {
28542855
ggml_tensor * node = cgraph->nodes[i];
@@ -2872,12 +2873,18 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
28722873
}
28732874

28742875
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
2875-
// disable CUDA graphs for batch size > 1 for now.
2876-
// Changes in batch size or context size can cause changes to the grid size of some kernels.
2877-
use_cuda_graph = false;
2878-
#ifndef NDEBUG
2879-
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
2880-
#endif
2876+
// disable CUDA graphs for batch size > 1 for now. The heuristic here allows to use CUDA graphs
2877+
// for Gemma3n, which uses a single Matrix-Matrix Addition as part of `project_per_layer_input`, while detecting
2878+
// batched execution for all graphs with >1 GGML_OP_ADD nodes. See also
2879+
// https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
2880+
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
2881+
++batch_size_counter;
2882+
if (batch_size_counter > 1) {
2883+
use_cuda_graph = false;
2884+
#ifndef NDEBUG
2885+
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to repeated batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
2886+
#endif
2887+
}
28812888
}
28822889

28832890
if (node->op == GGML_OP_MULTI_ADD && node->ne[1] > 1) {

0 commit comments

Comments
 (0)