Skip to content

Commit ccd8e56

Browse files
Nexesenexagray3
andcommitted
Simplify and improve CUDA graphs through use of indirect copy pointers ggml-org#9017
Co-Authored-By: agray3 <10851179+agray3@users.noreply.github.com>
1 parent f2204bf commit ccd8e56

File tree

5 files changed

+106
-135
lines changed

5 files changed

+106
-135
lines changed

ggml/include/ggml-backend.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,9 @@ extern "C" {
238238
// Set whether or not to use GGML graph caching
239239
GGML_API void ggml_set_cached_graph(ggml_backend_sched_t sched, bool set_value);
240240

241+
// Copy K and V cache pointers to backend
242+
GGML_API void ggml_backend_copy_kv_cache_ptrs(const int64_t n_layer, const int64_t kv_self_head, struct ggml_tensor ** kv_kl, struct ggml_tensor ** kv_vl, const int64_t n_embd_k_gqa, const int64_t n_embd_v_gqa, const bool flash_attn);
243+
241244
#ifdef __cplusplus
242245
}
243246
#endif

ggml/src/ggml-cuda.cu

Lines changed: 14 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2483,9 +2483,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
24832483

24842484
bool use_cuda_graph = true;
24852485
bool cuda_graph_update_required = false;
2486-
// vector of pointers to CUDA cpy kernels, which are required to identify
2487-
// kernel parameters which need updated in the graph for each token
2488-
std::vector<void *> ggml_cuda_cpy_fn_ptrs;
24892486

24902487
if (cuda_ctx->cuda_graph->graph == nullptr) {
24912488
if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
@@ -2531,7 +2528,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
25312528
}
25322529

25332530
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
2534-
cuda_ctx->cuda_graph->updated_kernel_arg.clear();
25352531
for (int i = 0; i < cgraph->n_nodes; i++) {
25362532
ggml_tensor * node = cgraph->nodes[i];
25372533

@@ -2558,16 +2554,6 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
25582554
#endif
25592555
}
25602556

2561-
if (node->op == GGML_OP_CPY) {
2562-
// store the copy op parameter which changes with each token.
2563-
cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
2564-
// store a pointer to each copy op CUDA kernel to identify it later
2565-
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
2566-
if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
2567-
ggml_cuda_cpy_fn_ptrs.push_back(ptr);
2568-
}
2569-
}
2570-
25712557
if (!use_cuda_graph) {
25722558
break;
25732559
}
@@ -2657,64 +2643,23 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
26572643
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
26582644
}
26592645

2660-
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
2661-
26622646
if (cuda_graph_update_required) {
2663-
// Extract nodes from graph
2664-
// First call with null argument gets number of nodes in graph
2665-
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
2666-
// Subsequent call with non-null argument gets nodes
2667-
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
2668-
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
2669-
if (cuda_ctx->cuda_graph->num_nodes > 0) {
2670-
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
2671-
2672-
// Loop over nodes, and extract kernel parameters from each node
2673-
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
2674-
cudaGraphNodeType node_type;
2675-
CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
2676-
if (node_type == cudaGraphNodeTypeKernel) {
2677-
cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
2678-
if (stat == cudaErrorInvalidDeviceFunction) {
2679-
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
2680-
// We don't need to update blas nodes, so clear error and move on.
2681-
cudaGetLastError();
2682-
} else {
2683-
GGML_ASSERT(stat == cudaSuccess);
2684-
}
2685-
}
2686-
}
2687-
}
2688-
}
2689-
2690-
// One of the arguments to the copy kernel is updated for each token, hence we need to
2691-
// replace that argument with the updated value in the CUDA graph
2692-
if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
2693-
int k = 0;
2694-
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
2695-
if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
2696-
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
2697-
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
2698-
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
2699-
}
2700-
}
2701-
}
2702-
2703-
// Update graph executable
2704-
cudaGraphExecUpdateResultInfo result_info;
2705-
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
2706-
if (stat == cudaErrorGraphExecUpdateFailure) {
2647+
// Update graph executable
2648+
cudaGraphExecUpdateResultInfo result_info;
2649+
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
2650+
if (stat == cudaErrorGraphExecUpdateFailure) {
27072651
#ifndef NDEBUG
2708-
GGML_CUDA_LOG_ERROR("%s: CUDA graph update failed\n", __func__);
2652+
GGML_CUDA_LOG_ERROR("%s: CUDA graph update failed\n", __func__);
27092653
#endif
2710-
// The pre-existing graph exec cannot be updated due to violated constraints
2711-
// so instead clear error and re-instantiate
2712-
cudaGetLastError();
2713-
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
2714-
cuda_ctx->cuda_graph->instance = nullptr;
2715-
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
2716-
} else {
2717-
GGML_ASSERT(stat == cudaSuccess);
2654+
// The pre-existing graph exec cannot be updated due to violated constraints
2655+
// so instead clear error and re-instantiate
2656+
cudaGetLastError();
2657+
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
2658+
cuda_ctx->cuda_graph->instance = nullptr;
2659+
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
2660+
} else {
2661+
GGML_ASSERT(stat == cudaSuccess);
2662+
}
27182663
}
27192664
// Launch graph
27202665
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));

ggml/src/ggml-cuda/common.cuh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -583,15 +583,11 @@ struct ggml_cuda_graph {
583583
}
584584
cudaGraph_t graph = nullptr;
585585
cudaGraphExec_t instance = nullptr;
586-
size_t num_nodes = 0;
587-
std::vector<cudaGraphNode_t> nodes;
588-
std::vector<cudaKernelNodeParams> params;
589586
bool disable_due_to_gpu_arch = false;
590587
bool disable_due_to_too_many_updates = false;
591588
bool disable_due_to_failed_graph_capture = false;
592589
int number_consecutive_updates = 0;
593590
std::vector<ggml_graph_node_properties> ggml_graph_properties;
594-
std::vector<char **> updated_kernel_arg;
595591
#endif
596592
};
597593

0 commit comments

Comments
 (0)