Skip to content

Commit c97d955

Browse files
committed
cuInit before querying compute capability
1 parent 498e81a commit c97d955

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

jaxlib/cuda/versions_helpers.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ size_t CudnnGetVersion() {
8686
}
8787
int CudaComputeCapability(int device) {
8888
int major, minor;
89+
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuInit(0)));
8990
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuDeviceGetAttribute(
9091
&major, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device)));
9192
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuDeviceGetAttribute(
@@ -102,4 +103,4 @@ int CudaDeviceCount() {
102103
}
103104

104105

105-
} // namespace jax::cuda
106+
} // namespace jax::cuda

0 commit comments

Comments
 (0)