@@ -106,6 +106,22 @@ function toolkit()
106
106
end
107
107
108
108
# workaround @artifact_str eagerness on unsupported platforms by passing a variable
109
+ function generic_artifact (id)
110
+ dir = try
111
+ @artifact_str (id)
112
+ catch ex
113
+ @debug " Could not load artifact '$id '" exception= (ex,catch_backtrace ())
114
+ return nothing
115
+ end
116
+
117
+ # sometimes artifact downloads fail (e.g. JuliaGPU/CUDA.jl#1003)
118
+ if isempty (readdir (dir))
119
+ error (""" The artifact at $dir is empty.
120
+ This is probably caused by a failed download. Remove the directory and try again.""" )
121
+ end
122
+
123
+ return dir
124
+ end
109
125
function cuda_artifact (id, cuda:: VersionNumber )
110
126
platform = Base. BinaryPlatforms. HostPlatform ()
111
127
platform. tags[" cuda" ] = " $(cuda. major) .$(cuda. minor) "
@@ -604,7 +620,7 @@ function libcutensormg(; throw_error::Bool=true)
604
620
# CUTENSORMg additionally depends on CUDARt
605
621
libcudart ()
606
622
607
- if CUDA . CUTENSOR. version () < v " 1.4"
623
+ if CUTENSOR. version () < v " 1.4"
608
624
nothing
609
625
else
610
626
find_cutensor (toolkit (), " cutensorMg" , v " 1" )
@@ -682,6 +698,11 @@ function find_nccl(cuda::LocalToolkit, name, version)
682
698
return path
683
699
end
684
700
701
+
702
+ #
703
+ # CUQUANTUM
704
+ #
705
+
685
706
export libcutensornet, has_cutensornet, libcustatevec, has_custatevec
686
707
687
708
const __libcutensornet = Ref {Union{String,Nothing}} ()
712
733
has_custatevec () = libcustatevec (throw_error= false ) != = nothing
713
734
714
735
function find_cutensornet (cuda:: ArtifactToolkit , name, version)
715
- artifact_dir = cuda_artifact (" cuQuantum" , v " 0.1.3 " )
736
+ artifact_dir = generic_artifact (" cuQuantum" )
716
737
if artifact_dir === nothing
717
738
return nothing
718
739
end
@@ -757,3 +778,36 @@ function find_custatevec(cuda::LocalToolkit, name, version)
757
778
return path
758
779
end
759
780
781
+
782
+ #
783
+ # Utilities
784
+ #
785
+
786
+ export download_artifacts
787
+
788
+ """
789
+ download_artifacts()
790
+
791
+ Downloads the artifacts you will need to run CUDA.jl. This can be used to pre-populate the
792
+ artifacts directory from, e.g., a container build script.
793
+
794
+ If you want this function to not require a CUDA driver (which wouldn't be available from
795
+ said container build environment) be sure to set the `JULIA_CUDA_VERSION` environment
796
+ variable to an appropriate CUDA release number. This environment variable should then also
797
+ be set at run-time, and should be compatible with the NVIDIA driver that will be available
798
+ in that environment.
799
+
800
+ !!! warning
801
+
802
+ This function is a temporary hack, and will be removed once CUDA.jl uses JLLs for
803
+ downloading and installing artifacts.
804
+ """
805
+ function download_artifacts ()
806
+ toolkit = find_artifact_cuda ()
807
+ @assert nothing != = cuda_artifact (" CUDNN" , toolkit. release)
808
+ @assert nothing != = cuda_artifact (" CUTENSOR" , toolkit. release)
809
+ @assert nothing != = cuda_artifact (" NCCL" , toolkit. release)
810
+
811
+ @assert nothing != = generic_artifact (" CUDA_compat" )
812
+ @assert nothing != = generic_artifact (" cuQuantum" )
813
+ end
0 commit comments