@@ -2483,6 +2483,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2483
2483
2484
2484
bool use_cuda_graph = true ;
2485
2485
bool cuda_graph_update_required = false ;
2486
+ // vector of pointers to CUDA cpy kernels, which are required to identify
2487
+ // kernel parameters which need updated in the graph for each token
2488
+ std::vector<void *> ggml_cuda_cpy_fn_ptrs;
2486
2489
2487
2490
if (cuda_ctx->cuda_graph ->graph == nullptr ) {
2488
2491
if (ggml_cuda_info ().devices [cuda_ctx->device ].cc < CC_AMPERE) {
@@ -2528,6 +2531,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2528
2531
}
2529
2532
2530
2533
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
2534
+ cuda_ctx->cuda_graph ->updated_kernel_arg .clear ();
2531
2535
for (int i = 0 ; i < cgraph->n_nodes ; i++) {
2532
2536
ggml_tensor * node = cgraph->nodes [i];
2533
2537
@@ -2554,6 +2558,16 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2554
2558
#endif
2555
2559
}
2556
2560
2561
+ if (node->op == GGML_OP_CPY) {
2562
+ // store the copy op parameter which changes with each token.
2563
+ cuda_ctx->cuda_graph ->updated_kernel_arg .push_back ((char **) &(node->src [1 ]->data ));
2564
+ // store a pointer to each copy op CUDA kernel to identify it later
2565
+ void * ptr = ggml_cuda_cpy_fn (node->src [0 ], node->src [1 ]);
2566
+ if (std::find (ggml_cuda_cpy_fn_ptrs.begin (), ggml_cuda_cpy_fn_ptrs.end (), ptr) == ggml_cuda_cpy_fn_ptrs.end ()) {
2567
+ ggml_cuda_cpy_fn_ptrs.push_back (ptr);
2568
+ }
2569
+ }
2570
+
2557
2571
if (!use_cuda_graph) {
2558
2572
break ;
2559
2573
}
@@ -2643,23 +2657,64 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2643
2657
CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , NULL , NULL , 0 ));
2644
2658
}
2645
2659
2660
+ // Perform update to graph (if required for this token), and change copy parameter (required for every token)
2661
+
2646
2662
if (cuda_graph_update_required) {
2647
- // Update graph executable
2648
- cudaGraphExecUpdateResultInfo result_info;
2649
- cudaError_t stat = cudaGraphExecUpdate (cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , &result_info);
2650
- if (stat == cudaErrorGraphExecUpdateFailure) {
2663
+ // Extract nodes from graph
2664
+ // First call with null argument gets number of nodes in graph
2665
+ CUDA_CHECK (cudaGraphGetNodes (cuda_ctx->cuda_graph ->graph , nullptr , &cuda_ctx->cuda_graph ->num_nodes ));
2666
+ // Subsequent call with non-null argument gets nodes
2667
+ cuda_ctx->cuda_graph ->nodes .resize (cuda_ctx->cuda_graph ->num_nodes );
2668
+ cuda_ctx->cuda_graph ->params .resize (cuda_ctx->cuda_graph ->num_nodes );
2669
+ if (cuda_ctx->cuda_graph ->num_nodes > 0 ) {
2670
+ CUDA_CHECK (cudaGraphGetNodes (cuda_ctx->cuda_graph ->graph , cuda_ctx->cuda_graph ->nodes .data (), &cuda_ctx->cuda_graph ->num_nodes ));
2671
+
2672
+ // Loop over nodes, and extract kernel parameters from each node
2673
+ for (size_t i = 0 ; i < cuda_ctx->cuda_graph ->num_nodes ; i++) {
2674
+ cudaGraphNodeType node_type;
2675
+ CUDA_CHECK (cudaGraphNodeGetType (cuda_ctx->cuda_graph ->nodes [i], &node_type));
2676
+ if (node_type == cudaGraphNodeTypeKernel) {
2677
+ cudaError_t stat = cudaGraphKernelNodeGetParams (cuda_ctx->cuda_graph ->nodes [i], &cuda_ctx->cuda_graph ->params [i]); // Get params using runtime
2678
+ if (stat == cudaErrorInvalidDeviceFunction) {
2679
+ // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
2680
+ // We don't need to update blas nodes, so clear error and move on.
2681
+ cudaGetLastError ();
2682
+ } else {
2683
+ GGML_ASSERT (stat == cudaSuccess);
2684
+ }
2685
+ }
2686
+ }
2687
+ }
2688
+ }
2689
+
2690
+ // One of the arguments to the copy kernel is updated for each token, hence we need to
2691
+ // replace that argument with the updated value in the CUDA graph
2692
+ if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
2693
+ int k = 0 ;
2694
+ for (size_t i = 0 ; i < cuda_ctx->cuda_graph ->num_nodes ; i++) {
2695
+ if (count (ggml_cuda_cpy_fn_ptrs.begin (), ggml_cuda_cpy_fn_ptrs.end (), cuda_ctx->cuda_graph ->params [i].func ) > 0 ) {
2696
+ char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph ->updated_kernel_arg .at (k++);
2697
+ cuda_ctx->cuda_graph ->params [i].kernelParams [1 ] = updated_kernel_arg_ptr;
2698
+ CUDA_CHECK (cudaGraphKernelNodeSetParams (cuda_ctx->cuda_graph ->nodes [i], &cuda_ctx->cuda_graph ->params [i]));
2699
+ }
2700
+ }
2701
+ }
2702
+
2703
+ // Update graph executable
2704
+ cudaGraphExecUpdateResultInfo result_info;
2705
+ cudaError_t stat = cudaGraphExecUpdate (cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , &result_info);
2706
+ if (stat == cudaErrorGraphExecUpdateFailure) {
2651
2707
#ifndef NDEBUG
2652
- GGML_CUDA_LOG_ERROR (" %s: CUDA graph update failed\n " , __func__);
2708
+ GGML_CUDA_LOG_ERROR (" %s: CUDA graph update failed\n " , __func__);
2653
2709
#endif
2654
- // The pre-existing graph exec cannot be updated due to violated constraints
2655
- // so instead clear error and re-instantiate
2656
- cudaGetLastError ();
2657
- CUDA_CHECK (cudaGraphExecDestroy (cuda_ctx->cuda_graph ->instance ));
2658
- cuda_ctx->cuda_graph ->instance = nullptr ;
2659
- CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , NULL , NULL , 0 ));
2660
- } else {
2661
- GGML_ASSERT (stat == cudaSuccess);
2662
- }
2710
+ // The pre-existing graph exec cannot be updated due to violated constraints
2711
+ // so instead clear error and re-instantiate
2712
+ cudaGetLastError ();
2713
+ CUDA_CHECK (cudaGraphExecDestroy (cuda_ctx->cuda_graph ->instance ));
2714
+ cuda_ctx->cuda_graph ->instance = nullptr ;
2715
+ CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , NULL , NULL , 0 ));
2716
+ } else {
2717
+ GGML_ASSERT (stat == cudaSuccess);
2663
2718
}
2664
2719
// Launch graph
2665
2720
CUDA_CHECK (cudaGraphLaunch (cuda_ctx->cuda_graph ->instance , cuda_ctx->stream ()));
0 commit comments