Skip to content

Commit 56fc498

Browse files
committed
Revert " Simplify and improve CUDA graphs through use of indirect copy pointers ggml-org#9017"
This reverts commit 1dea402e4cb8f64737aa49ba98bc9647656e4d26.
1 parent 6a0119a commit 56fc498

File tree

5 files changed

+135
-106
lines changed

5 files changed

+135
-106
lines changed

ggml/include/ggml-backend.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,6 @@ extern "C" {
238238
// Set whether or not to use GGML graph caching
239239
GGML_API void ggml_set_cached_graph(ggml_backend_sched_t sched, bool set_value);
240240

241-
// Copy K and V cache pointers to backend
242-
GGML_API void ggml_backend_copy_kv_cache_ptrs(const int64_t n_layer, const int64_t kv_self_head, struct ggml_tensor ** kv_kl, struct ggml_tensor ** kv_vl, const int64_t n_embd_k_gqa, const int64_t n_embd_v_gqa, const bool flash_attn);
243-
244241
#ifdef __cplusplus
245242
}
246243
#endif

ggml/src/ggml-cuda.cu

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2483,6 +2483,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
24832483

24842484
bool use_cuda_graph = true;
24852485
bool cuda_graph_update_required = false;
2486+
// vector of pointers to CUDA cpy kernels, which are required to identify
2487+
// kernel parameters which need updated in the graph for each token
2488+
std::vector<void *> ggml_cuda_cpy_fn_ptrs;
24862489

24872490
if (cuda_ctx->cuda_graph->graph == nullptr) {
24882491
if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
@@ -2528,6 +2531,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
25282531
}
25292532

25302533
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
2534+
cuda_ctx->cuda_graph->updated_kernel_arg.clear();
25312535
for (int i = 0; i < cgraph->n_nodes; i++) {
25322536
ggml_tensor * node = cgraph->nodes[i];
25332537

@@ -2554,6 +2558,16 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
25542558
#endif
25552559
}
25562560

2561+
if (node->op == GGML_OP_CPY) {
2562+
// store the copy op parameter which changes with each token.
2563+
cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
2564+
// store a pointer to each copy op CUDA kernel to identify it later
2565+
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
2566+
if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
2567+
ggml_cuda_cpy_fn_ptrs.push_back(ptr);
2568+
}
2569+
}
2570+
25572571
if (!use_cuda_graph) {
25582572
break;
25592573
}
@@ -2643,23 +2657,64 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
26432657
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
26442658
}
26452659

2660+
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
2661+
26462662
if (cuda_graph_update_required) {
2647-
// Update graph executable
2648-
cudaGraphExecUpdateResultInfo result_info;
2649-
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
2650-
if (stat == cudaErrorGraphExecUpdateFailure) {
2663+
// Extract nodes from graph
2664+
// First call with null argument gets number of nodes in graph
2665+
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
2666+
// Subsequent call with non-null argument gets nodes
2667+
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
2668+
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
2669+
if (cuda_ctx->cuda_graph->num_nodes > 0) {
2670+
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
2671+
2672+
// Loop over nodes, and extract kernel parameters from each node
2673+
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
2674+
cudaGraphNodeType node_type;
2675+
CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
2676+
if (node_type == cudaGraphNodeTypeKernel) {
2677+
cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
2678+
if (stat == cudaErrorInvalidDeviceFunction) {
2679+
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
2680+
// We don't need to update blas nodes, so clear error and move on.
2681+
cudaGetLastError();
2682+
} else {
2683+
GGML_ASSERT(stat == cudaSuccess);
2684+
}
2685+
}
2686+
}
2687+
}
2688+
}
2689+
2690+
// One of the arguments to the copy kernel is updated for each token, hence we need to
2691+
// replace that argument with the updated value in the CUDA graph
2692+
if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
2693+
int k = 0;
2694+
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
2695+
if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
2696+
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
2697+
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
2698+
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
2699+
}
2700+
}
2701+
}
2702+
2703+
// Update graph executable
2704+
cudaGraphExecUpdateResultInfo result_info;
2705+
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
2706+
if (stat == cudaErrorGraphExecUpdateFailure) {
26512707
#ifndef NDEBUG
2652-
GGML_CUDA_LOG_ERROR("%s: CUDA graph update failed\n", __func__);
2708+
GGML_CUDA_LOG_ERROR("%s: CUDA graph update failed\n", __func__);
26532709
#endif
2654-
// The pre-existing graph exec cannot be updated due to violated constraints
2655-
// so instead clear error and re-instantiate
2656-
cudaGetLastError();
2657-
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
2658-
cuda_ctx->cuda_graph->instance = nullptr;
2659-
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
2660-
} else {
2661-
GGML_ASSERT(stat == cudaSuccess);
2662-
}
2710+
// The pre-existing graph exec cannot be updated due to violated constraints
2711+
// so instead clear error and re-instantiate
2712+
cudaGetLastError();
2713+
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
2714+
cuda_ctx->cuda_graph->instance = nullptr;
2715+
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
2716+
} else {
2717+
GGML_ASSERT(stat == cudaSuccess);
26632718
}
26642719
// Launch graph
26652720
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));

ggml/src/ggml-cuda/common.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,11 +583,15 @@ struct ggml_cuda_graph {
583583
}
584584
cudaGraph_t graph = nullptr;
585585
cudaGraphExec_t instance = nullptr;
586+
size_t num_nodes = 0;
587+
std::vector<cudaGraphNode_t> nodes;
588+
std::vector<cudaKernelNodeParams> params;
586589
bool disable_due_to_gpu_arch = false;
587590
bool disable_due_to_too_many_updates = false;
588591
bool disable_due_to_failed_graph_capture = false;
589592
int number_consecutive_updates = 0;
590593
std::vector<ggml_graph_node_properties> ggml_graph_properties;
594+
std::vector<char **> updated_kernel_arg;
591595
#endif
592596
};
593597

0 commit comments

Comments
 (0)