Skip to content

Commit ce83746

Browse files
r-barnesfacebook-github-bot
authored andcommitted
Error check some CUDA API calls (#1626)
Summary: Pull Request resolved: #1626 Reviewed By: sryap Differential Revision: D43787029 fbshipit-source-id: 87e07acf39010d489366d3e4ea10b9a33dec1fd5
1 parent 7c7aee0 commit ce83746

File tree

4 files changed

+19
-16
lines changed

4 files changed

+19
-16
lines changed

fbgemm_gpu/include/fbgemm_gpu/bench_utils.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ void flush_cache(int cache_size_mb = 40, bool do_write = false) {
3131
CUDA_CHECK(
3232
cudaMemcpy(d_flush, flush.data(), cache_size, cudaMemcpyHostToDevice));
3333
flush_gpu<<<cache_size / 512, 512>>>(d_flush, d_flush2, do_write);
34-
cudaFree(d_flush);
35-
cudaFree(d_flush2);
34+
CUDA_CHECK(cudaFree(d_flush));
35+
CUDA_CHECK(cudaFree(d_flush2));
3636
CUDA_CHECK(cudaDeviceSynchronize());
3737
CUDA_CHECK(cudaGetLastError());
3838
}

fbgemm_gpu/src/jagged_tensor_ops.cu

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -651,10 +651,10 @@ bool jagged_dense_dense_elementwise_jagged_output_matches_opt(
651651

652652
int max_shared_bytes;
653653
#ifndef __HIP_PLATFORM_HCC__
654-
cudaDeviceGetAttribute(
654+
C10_CUDA_CHECK(cudaDeviceGetAttribute(
655655
&max_shared_bytes,
656656
cudaDevAttrMaxSharedMemoryPerBlockOptin,
657-
y_0_reshaped.get_device());
657+
y_0_reshaped.get_device()));
658658
#else
659659
// MI100 has 64 KB local memory (shared memory) per workgroup
660660
max_shared_bytes = 64 << 10;
@@ -769,10 +769,10 @@ void jagged_dense_elementwise_jagged_output_opt_(
769769
if (dynamic_smem_size > cur_max_shared_bytes) {
770770
int max_shared_bytes;
771771
#ifndef __HIP_PLATFORM_HCC__
772-
cudaDeviceGetAttribute(
772+
C10_CUDA_CHECK(cudaDeviceGetAttribute(
773773
&max_shared_bytes,
774774
cudaDevAttrMaxSharedMemoryPerBlockOptin,
775-
y_reshaped.get_device());
775+
y_reshaped.get_device()));
776776
#else
777777
// MI100 has 64 KB local memory (shared memory) per workgroup
778778
max_shared_bytes = 64 << 10;
@@ -788,11 +788,11 @@ void jagged_dense_elementwise_jagged_output_opt_(
788788
#endif
789789
int used_shared_bytes = used_shared_kb << 10;
790790
#ifndef __HIP_PLATFORM_HCC__
791-
cudaFuncSetAttribute(
791+
C10_CUDA_CHECK(cudaFuncSetAttribute(
792792
jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
793793
index_t>,
794794
cudaFuncAttributeMaxDynamicSharedMemorySize,
795-
used_shared_bytes); // V100: 64 KB; A100: 96 KB.
795+
used_shared_bytes)); // V100: 64 KB; A100: 96 KB.
796796
#endif
797797
C10_CUDA_KERNEL_LAUNCH_CHECK();
798798
TORCH_CHECK(dynamic_smem_size <= used_shared_bytes);
@@ -973,10 +973,10 @@ void jagged_dense_dense_elementwise_jagged_output_opt_(
973973
if (dynamic_smem_size > cur_max_shared_bytes) {
974974
int max_shared_bytes;
975975
#ifndef __HIP_PLATFORM_HCC__
976-
cudaDeviceGetAttribute(
976+
C10_CUDA_CHECK(cudaDeviceGetAttribute(
977977
&max_shared_bytes,
978978
cudaDevAttrMaxSharedMemoryPerBlockOptin,
979-
y_0_reshaped.get_device());
979+
y_0_reshaped.get_device()));
980980
#else
981981
// MI100 has 64 KB local memory (shared memory) per workgroup
982982
max_shared_bytes = 64 << 10;
@@ -992,11 +992,11 @@ void jagged_dense_dense_elementwise_jagged_output_opt_(
992992
#endif
993993
int used_shared_bytes = used_shared_kb << 10;
994994
#ifndef __HIP_PLATFORM_HCC__
995-
cudaFuncSetAttribute(
995+
C10_CUDA_CHECK(cudaFuncSetAttribute(
996996
jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
997997
index_t>,
998998
cudaFuncAttributeMaxDynamicSharedMemorySize,
999-
used_shared_bytes); // V100: 64 KB; A100: 96 KB.
999+
used_shared_bytes)); // V100: 64 KB; A100: 96 KB.
10001000
#endif
10011001
C10_CUDA_KERNEL_LAUNCH_CHECK();
10021002
TORCH_CHECK(dynamic_smem_size <= used_shared_bytes);

fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,11 @@ void init_p2p_access() {
274274
for (const auto j : c10::irange(at::cuda::getNumGPUs())) {
275275
if (i != j) {
276276
at::cuda::CUDAGuard g(i);
277-
const auto err = cudaDeviceEnablePeerAccess(j, 0);
277+
const auto err =
278+
C10_CUDA_ERROR_HANDLED(cudaDeviceEnablePeerAccess(j, 0));
278279
if (err == cudaErrorPeerAccessAlreadyEnabled) {
279280
// ignore and clear the error if access was already enabled
280-
cudaGetLastError();
281+
C10_CUDA_CLEAR_ERROR();
281282
} else {
282283
AT_CUDA_CHECK(err);
283284
}

fbgemm_gpu/src/topology_utils.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include <ATen/cuda/CUDAContext.h>
99
#include <c10/core/Device.h>
10+
#include <c10/cuda/CUDAException.h>
1011
#include <algorithm>
1112

1213
#include "fbgemm_gpu/topology_utils.h"
@@ -131,14 +132,15 @@ AdjacencyMatrix<Links> get_nvlink_matrix() {
131132
&pci_info.busId[NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE],
132133
pci_bus_id.data());
133134
int32_t node = 0;
134-
auto err = cudaDeviceGetByPCIBusId(&node, pci_bus_id.data());
135+
auto err = C10_CUDA_ERROR_HANDLED(
136+
cudaDeviceGetByPCIBusId(&node, pci_bus_id.data()));
135137
if (err == cudaSuccess) {
136138
pci_bus_ids.insert({pci_bus_id, node});
137139
cuda_device_to_nvml_device.insert({node, i});
138140
} else {
139141
// flush the last error - this can occur when e.g. we set
140142
// CUDA_VISIBLE_DEVICES to a subset of the available GPUs in the system.
141-
cudaGetLastError();
143+
C10_CUDA_CLEAR_ERROR();
142144
}
143145
}
144146

0 commit comments

Comments
 (0)