@@ -2849,6 +2849,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
2849
2849
2850
2850
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
2851
2851
cuda_ctx->cuda_graph ->cpy_dest_ptrs .clear ();
2852
+ std::uint8_t batch_size_counter = 0 ;
2852
2853
2853
2854
for (int i = 0 ; i < cgraph->n_nodes ; i++) {
2854
2855
ggml_tensor * node = cgraph->nodes [i];
@@ -2872,12 +2873,18 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
2872
2873
}
2873
2874
2874
2875
if (node->op == GGML_OP_ADD && node->src [1 ] && node->src [1 ]->ne [1 ] > 1 ) {
2875
- // disable CUDA graphs for batch size > 1 for now.
2876
- // Changes in batch size or context size can cause changes to the grid size of some kernels.
2877
- use_cuda_graph = false ;
2878
- #ifndef NDEBUG
2879
- GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n " , __func__, node->name , node->ne [0 ], node->ne [1 ], node->ne [2 ], node->ne [3 ]);
2880
- #endif
2876
+ // disable CUDA graphs for batch size > 1 for now. The heuristic here allows to use CUDA graphs
2877
+ // for Gemma3n, which uses a single Matrix-Matrix Addition as part of `project_per_layer_input`, while detecting
2878
+ // batched execution for all graphs with >1 GGML_OP_ADD nodes. See also
2879
+ // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
2880
+ // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
2881
+ ++batch_size_counter;
2882
+ if (batch_size_counter > 1 ) {
2883
+ use_cuda_graph = false ;
2884
+ #ifndef NDEBUG
2885
+ GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to repeated batch size > 1 [%s] [%ld %ld %ld %ld]\n " , __func__, node->name , node->ne [0 ], node->ne [1 ], node->ne [2 ], node->ne [3 ]);
2886
+ #endif
2887
+ }
2881
2888
}
2882
2889
2883
2890
if (node->op == GGML_OP_MULTI_ADD && node->ne [1 ] > 1 ) {
0 commit comments