Skip to content

Commit cc1bf0a

Browse files
giantcowLegNeato
authored andcommitted
build(cudnn-sys): Add CUDNN_INCLUDE_DIR (Rust-GPU#213)
* build(cudnn-sys): Add CUDNN_INCLUDE_DIR Enables users to specify a non-standard cuDNN install path. This seems to be needed for the newer editions of the CUDA toolkit, as cuDNN isn't included by default (at least in the Fedora repo's, you have to install from a tarball)
1 parent f95e9b3 commit cc1bf0a

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

crates/cudnn-sys/build/cudnn_sdk.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
use std::env;
12
use std::error;
23
use std::fs;
34
use std::path;
5+
use std::path::Path;
46

57
/// Represents the cuDNN SDK installation.
68
#[derive(Debug, Clone)]
@@ -57,17 +59,25 @@ impl CudnnSdk {
5759
}
5860

5961
fn find_cudnn_include_dir() -> Result<path::PathBuf, Box<dyn error::Error>> {
62+
let cudnn_include_dir = env::var_os("CUDNN_INCLUDE_DIR");
63+
6064
#[cfg(not(target_os = "windows"))]
6165
const CUDNN_DEFAULT_PATHS: &[&str] = &["/usr/include", "/usr/local/include"];
6266
#[cfg(target_os = "windows")]
6367
const CUDNN_DEFAULT_PATHS: &[&str] = &[
6468
"C:/Program Files/NVIDIA/CUDNN/v9.x/include",
6569
"C:/Program Files/NVIDIA/CUDNN/v8.x/include",
6670
];
67-
CUDNN_DEFAULT_PATHS
71+
72+
let mut cudnn_paths: Vec<&Path> = CUDNN_DEFAULT_PATHS.iter().map(Path::new).collect();
73+
if let Some(override_path) = &cudnn_include_dir {
74+
cudnn_paths.push(Path::new(override_path));
75+
}
76+
77+
cudnn_paths
6878
.iter()
69-
.find(|s| Self::is_cudnn_include_path(s))
70-
.map(path::PathBuf::from)
79+
.find(|p| Self::is_cudnn_include_path(p))
80+
.map(|p| p.to_path_buf())
7181
.ok_or("Cannot find cuDNN include directory.".into())
7282
}
7383

0 commit comments

Comments
 (0)