@@ -2483,9 +2483,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2483
2483
2484
2484
bool use_cuda_graph = true ;
2485
2485
bool cuda_graph_update_required = false ;
2486
- // vector of pointers to CUDA cpy kernels, which are required to identify
2487
- // kernel parameters which need updated in the graph for each token
2488
- std::vector<void *> ggml_cuda_cpy_fn_ptrs;
2489
2486
2490
2487
if (cuda_ctx->cuda_graph ->graph == nullptr ) {
2491
2488
if (ggml_cuda_info ().devices [cuda_ctx->device ].cc < CC_AMPERE) {
@@ -2531,7 +2528,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2531
2528
}
2532
2529
2533
2530
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
2534
- cuda_ctx->cuda_graph ->updated_kernel_arg .clear ();
2535
2531
for (int i = 0 ; i < cgraph->n_nodes ; i++) {
2536
2532
ggml_tensor * node = cgraph->nodes [i];
2537
2533
@@ -2558,16 +2554,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2558
2554
#endif
2559
2555
}
2560
2556
2561
- if (node->op == GGML_OP_CPY) {
2562
- // store the copy op parameter which changes with each token.
2563
- cuda_ctx->cuda_graph ->updated_kernel_arg .push_back ((char **) &(node->src [1 ]->data ));
2564
- // store a pointer to each copy op CUDA kernel to identify it later
2565
- void * ptr = ggml_cuda_cpy_fn (node->src [0 ], node->src [1 ]);
2566
- if (std::find (ggml_cuda_cpy_fn_ptrs.begin (), ggml_cuda_cpy_fn_ptrs.end (), ptr) == ggml_cuda_cpy_fn_ptrs.end ()) {
2567
- ggml_cuda_cpy_fn_ptrs.push_back (ptr);
2568
- }
2569
- }
2570
-
2571
2557
if (!use_cuda_graph) {
2572
2558
break ;
2573
2559
}
@@ -2657,64 +2643,23 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2657
2643
CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , NULL , NULL , 0 ));
2658
2644
}
2659
2645
2660
- // Perform update to graph (if required for this token), and change copy parameter (required for every token)
2661
-
2662
2646
if (cuda_graph_update_required) {
2663
- // Extract nodes from graph
2664
- // First call with null argument gets number of nodes in graph
2665
- CUDA_CHECK (cudaGraphGetNodes (cuda_ctx->cuda_graph ->graph , nullptr , &cuda_ctx->cuda_graph ->num_nodes ));
2666
- // Subsequent call with non-null argument gets nodes
2667
- cuda_ctx->cuda_graph ->nodes .resize (cuda_ctx->cuda_graph ->num_nodes );
2668
- cuda_ctx->cuda_graph ->params .resize (cuda_ctx->cuda_graph ->num_nodes );
2669
- if (cuda_ctx->cuda_graph ->num_nodes > 0 ) {
2670
- CUDA_CHECK (cudaGraphGetNodes (cuda_ctx->cuda_graph ->graph , cuda_ctx->cuda_graph ->nodes .data (), &cuda_ctx->cuda_graph ->num_nodes ));
2671
-
2672
- // Loop over nodes, and extract kernel parameters from each node
2673
- for (size_t i = 0 ; i < cuda_ctx->cuda_graph ->num_nodes ; i++) {
2674
- cudaGraphNodeType node_type;
2675
- CUDA_CHECK (cudaGraphNodeGetType (cuda_ctx->cuda_graph ->nodes [i], &node_type));
2676
- if (node_type == cudaGraphNodeTypeKernel) {
2677
- cudaError_t stat = cudaGraphKernelNodeGetParams (cuda_ctx->cuda_graph ->nodes [i], &cuda_ctx->cuda_graph ->params [i]); // Get params using runtime
2678
- if (stat == cudaErrorInvalidDeviceFunction) {
2679
- // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
2680
- // We don't need to update blas nodes, so clear error and move on.
2681
- cudaGetLastError ();
2682
- } else {
2683
- GGML_ASSERT (stat == cudaSuccess);
2684
- }
2685
- }
2686
- }
2687
- }
2688
- }
2689
-
2690
- // One of the arguments to the copy kernel is updated for each token, hence we need to
2691
- // replace that argument with the updated value in the CUDA graph
2692
- if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
2693
- int k = 0 ;
2694
- for (size_t i = 0 ; i < cuda_ctx->cuda_graph ->num_nodes ; i++) {
2695
- if (count (ggml_cuda_cpy_fn_ptrs.begin (), ggml_cuda_cpy_fn_ptrs.end (), cuda_ctx->cuda_graph ->params [i].func ) > 0 ) {
2696
- char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph ->updated_kernel_arg .at (k++);
2697
- cuda_ctx->cuda_graph ->params [i].kernelParams [1 ] = updated_kernel_arg_ptr;
2698
- CUDA_CHECK (cudaGraphKernelNodeSetParams (cuda_ctx->cuda_graph ->nodes [i], &cuda_ctx->cuda_graph ->params [i]));
2699
- }
2700
- }
2701
- }
2702
-
2703
- // Update graph executable
2704
- cudaGraphExecUpdateResultInfo result_info;
2705
- cudaError_t stat = cudaGraphExecUpdate (cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , &result_info);
2706
- if (stat == cudaErrorGraphExecUpdateFailure) {
2647
+ // Update graph executable
2648
+ cudaGraphExecUpdateResultInfo result_info;
2649
+ cudaError_t stat = cudaGraphExecUpdate (cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , &result_info);
2650
+ if (stat == cudaErrorGraphExecUpdateFailure) {
2707
2651
#ifndef NDEBUG
2708
- GGML_CUDA_LOG_ERROR (" %s: CUDA graph update failed\n " , __func__);
2652
+ GGML_CUDA_LOG_ERROR (" %s: CUDA graph update failed\n " , __func__);
2709
2653
#endif
2710
- // The pre-existing graph exec cannot be updated due to violated constraints
2711
- // so instead clear error and re-instantiate
2712
- cudaGetLastError ();
2713
- CUDA_CHECK (cudaGraphExecDestroy (cuda_ctx->cuda_graph ->instance ));
2714
- cuda_ctx->cuda_graph ->instance = nullptr ;
2715
- CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , NULL , NULL , 0 ));
2716
- } else {
2717
- GGML_ASSERT (stat == cudaSuccess);
2654
+ // The pre-existing graph exec cannot be updated due to violated constraints
2655
+ // so instead clear error and re-instantiate
2656
+ cudaGetLastError ();
2657
+ CUDA_CHECK (cudaGraphExecDestroy (cuda_ctx->cuda_graph ->instance ));
2658
+ cuda_ctx->cuda_graph ->instance = nullptr ;
2659
+ CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , NULL , NULL , 0 ));
2660
+ } else {
2661
+ GGML_ASSERT (stat == cudaSuccess);
2662
+ }
2718
2663
}
2719
2664
// Launch graph
2720
2665
CUDA_CHECK (cudaGraphLaunch (cuda_ctx->cuda_graph ->instance , cuda_ctx->stream ()));
0 commit comments