Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 87733ac

Browse files
committed
merge with master
1 parent 034a52c commit 87733ac

32 files changed

+569
-196
lines changed

.github/FUNDING.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
custom: https://numfocus.salsalabs.org/donate-to-julia/index.html

.gitlab-ci.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
variables:
22
CI_IMAGE_TAG: 'cuda'
3-
CI_DEV_PKGS: 'CUDAapi GPUArrays CUDAnative NNlib CUDAdrv'
43
JULIA_NUM_THREADS: '4'
54

65
include:

Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "CuArrays"
22
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
3-
version = "1.0.2"
3+
version = "2.0.0"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -31,7 +31,7 @@ test = ["Test", "FFTW", "ForwardDiff"]
3131
julia = "1.0"
3232
CUDAnative = "2.0"
3333
CUDAdrv = "3.0"
34-
CUDAapi = "0.5.3, 0.6"
35-
NNlib = "0.5, 0.6"
36-
GPUArrays = "0.7"
34+
CUDAapi = "0.5.3, 0.6, 1.0"
35+
NNlib = "0.6"
36+
GPUArrays = "0.7.1"
3737
Adapt = "0.4"

README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@ arrays:
2323

2424
## Installation
2525

26-
CuArrays should work **out-of-the-box** on Julia 1.0. You only need to have a
27-
proper set-up of CUDA, meaning the rest of the Julia CUDA stack should work
28-
(notably CUDAapi.jl, CUDAdrv.jl and CUDAnative.jl). If you encounter any issues
29-
with CuArrays.jl, please make sure those other packages are working as expected.
26+
CuArrays should work **out-of-the-box** on stable releases of Julia 1.x. You
27+
only need to have a proper set-up of CUDA, meaning the rest of the Julia CUDA
28+
stack should work (notably CUDAapi.jl, CUDAdrv.jl and CUDAnative.jl). If you
29+
encounter any issues with CuArrays.jl, please make sure those other packages are
30+
working as expected.
3031

3132
Some parts of CuArrays.jl depend on **optional libraries**, such as
3233
[cuDNN](https://developer.nvidia.com/cudnn). The build process should notify

deps/build.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,21 @@ function main()
4141

4242
toolkit = find_toolkit()
4343

44-
for name in ("cublas", "cusparse", "cusolver", "cufft", "curand", "cudnn")
44+
# required libraries that are part of the CUDA toolkit
45+
for name in ("cublas", "cusparse", "cusolver", "cufft", "curand")
4546
lib = Symbol("lib$name")
4647
config[lib] = find_cuda_library(name, toolkit)
4748
if config[lib] == nothing
48-
build_warning("Could not find library '$name'")
49+
build_error("Could not find library '$name' (it should be part of the CUDA toolkit)")
50+
end
51+
end
52+
53+
# optional libraries
54+
for name in ("cudnn", )
55+
lib = Symbol("lib$name")
56+
config[lib] = find_cuda_library(name, toolkit)
57+
if config[lib] == nothing
58+
build_warning("Could not find optional library '$name'")
4959
end
5060
end
5161

docs/src/tutorials/intro.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ using BenchmarkTools
114114

115115
using CuArrays
116116

117-
x_d = cufill(1.0f0, N) # a vector stored on the GPU filled with 1.0 (Float32)
118-
y_d = cufill(2.0f0, N) # a vector stored on the GPU filled with 2.0
117+
x_d = CuArrays.fill(1.0f0, N) # a vector stored on the GPU filled with 1.0 (Float32)
118+
y_d = CuArrays.fill(2.0f0, N) # a vector stored on the GPU filled with 2.0
119119

120120
# Here the `d` means "device," in contrast with "host". Now let's do the increment:
121121

@@ -220,8 +220,8 @@ CUDAdrv.@profile bench_gpu1!(y_d, x_d)
220220

221221
# You can see that 100% of the time was spent in `ptxcall_gpu_add1__1`, the name of the
222222
# kernel that `CUDAnative` assigned when compiling `gpu_add1!` for these inputs. (Had you
223-
# created arrays of multiple data types, e.g., `xu_d = cufill(0x01, N)`, you might have
224-
# also seen `ptxcall_gpu_add1__2` and so on. Like the rest of Julia, you can define a
223+
# created arrays of multiple data types, e.g., `xu_d = CuArrays.fill(0x01, N)`, you might
224+
# have also seen `ptxcall_gpu_add1__2` and so on. Like the rest of Julia, you can define a
225225
# single method and it will be specialized at compile time for the particular data types
226226
# you're using.)
227227

src/CuArrays.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
__precompile__()
2-
31
module CuArrays
42

53
using CUDAdrv, CUDAnative
64

75
using GPUArrays
86

9-
export CuArray, CuVector, CuMatrix, CuVecOrMat, cu, cuzeros, cuones, cufill
7+
export CuArray, CuVector, CuMatrix, CuVecOrMat, cu
108

119
import LinearAlgebra, SpecialFunctions
1210

@@ -45,12 +43,12 @@ include("gpuarray_interface.jl")
4543
# of CuArrays and/or CUDAnative only use a single context), so keep track of the active one.
4644
const active_context = Ref{CuContext}()
4745

48-
libcublas !== nothing && include("blas/CUBLAS.jl")
49-
libcusparse !== nothing && include("sparse/CUSPARSE.jl")
50-
libcusolver !== nothing && include("solver/CUSOLVER.jl")
51-
libcufft !== nothing && include("fft/CUFFT.jl")
52-
libcurand !== nothing && include("rand/CURAND.jl")
53-
libcudnn !== nothing && include("dnn/CUDNN.jl")
46+
include("blas/CUBLAS.jl")
47+
include("sparse/CUSPARSE.jl")
48+
include("solver/CUSOLVER.jl")
49+
include("fft/CUFFT.jl")
50+
include("rand/CURAND.jl")
51+
libcudnn !== nothing && include("dnn/CUDNN.jl")
5452

5553
include("nnlib.jl")
5654

@@ -84,11 +82,13 @@ function __init__()
8482
active_context[] = ctx
8583

8684
# wipe the active handles
87-
isdefined(CuArrays, :CUBLAS) && (CUBLAS._handle[] = C_NULL; CUBLAS._xt_handle[] = C_NULL)
88-
isdefined(CuArrays, :CUSOLVER) && (CUSOLVER._dense_handle[] = C_NULL; CUSOLVER._sparse_handle[] = C_NULL)
89-
isdefined(CuArrays, :CUSPARSE) && (CUSPARSE._handle[] = C_NULL)
90-
isdefined(CuArrays, :CURAND) && (CURAND._generator[] = nothing)
91-
isdefined(CuArrays, :CUDNN) && (CUDNN._handle[] = C_NULL)
85+
CUBLAS._handle[] = C_NULL
86+
CUBLAS._xt_handle[] = C_NULL
87+
CUSOLVER._dense_handle[] = C_NULL
88+
CUSOLVER._sparse_handle[] = C_NULL
89+
CUSPARSE._handle[] = C_NULL
90+
CURAND._generator[] = nothing
91+
isdefined(CuArrays, :CUDNN) && (CUDNN._handle[] = C_NULL)
9292
end
9393
push!(CUDAnative.device!_listeners, callback)
9494

src/array.jl

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,12 @@ end
214214
cu(xs) = adapt(CuArray{Float32}, xs)
215215
Base.getindex(::typeof(cu), xs...) = CuArray([xs...])
216216

217-
cuzeros(T::Type, dims...) = fill!(CuArray{T}(undef, dims...), 0)
218-
cuones(T::Type, dims...) = fill!(CuArray{T}(undef, dims...), 1)
219-
cuzeros(dims...) = cuzeros(Float32, dims...)
220-
cuones(dims...) = cuones(Float32, dims...)
221-
cufill(v, dims...) = fill!(CuArray{typeof(v)}(undef, dims...), v)
222-
cufill(v, dims::Dims) = fill!(CuArray{typeof(v)}(undef, dims...), v)
217+
zeros(T::Type, dims...) = fill!(CuArray{T}(undef, dims...), 0)
218+
ones(T::Type, dims...) = fill!(CuArray{T}(undef, dims...), 1)
219+
zeros(dims...) = CuArrays.zeros(Float32, dims...)
220+
ones(dims...) = CuArrays.ones(Float32, dims...)
221+
fill(v, dims...) = fill!(CuArray{typeof(v)}(undef, dims...), v)
222+
fill(v, dims::Dims) = fill!(CuArray{typeof(v)}(undef, dims...), v)
223223

224224
# optimized implementation of `fill!` for types that are directly supported by memset
225225
const MemsetTypes = Dict(1=>UInt8, 2=>UInt16, 4=>UInt32)
@@ -270,3 +270,60 @@ function LinearAlgebra.triu!(A::CuMatrix{T}, d::Integer = 0) where T
270270
@cuda blocks=blk threads=thr kernel!(A, d)
271271
return A
272272
end
273+
274+
275+
## reversing
276+
277+
function _reverse(input::CuVector{T}, output::CuVector{T}) where {T}
278+
@assert length(input) == length(output)
279+
280+
nthreads = 256
281+
nblocks = ceil(Int, length(input) / nthreads)
282+
shmem = nthreads * sizeof(T)
283+
284+
function kernel(input::CuDeviceVector{T}, output::CuDeviceVector{T}) where {T}
285+
shared = @cuDynamicSharedMem(T, blockDim().x)
286+
287+
# load one element per thread from device memory and buffer it in reversed order
288+
289+
offset_in = blockDim().x * (blockIdx().x - 1)
290+
index_in = offset_in + threadIdx().x
291+
292+
if index_in <= length(input)
293+
index_shared = blockDim().x - threadIdx().x + 1
294+
@inbounds shared[index_shared] = input[index_in]
295+
end
296+
297+
sync_threads()
298+
299+
# write back in forward order, but to the reversed block offset as before
300+
301+
offset_out = length(output) - blockDim().x * blockIdx().x
302+
index_out = offset_out + threadIdx().x
303+
304+
if 1 <= index_out <= length(output)
305+
index_shared = threadIdx().x
306+
@inbounds output[index_out] = shared[index_shared]
307+
end
308+
309+
return
310+
end
311+
312+
@cuda threads=nthreads blocks=nblocks shmem=shmem kernel(input, output)
313+
314+
return
315+
end
316+
317+
function Base.reverse!(v::CuVector, start=1, stop=length(v))
318+
v′ = view(v, start:stop)
319+
_reverse(v′, v′)
320+
return v
321+
end
322+
323+
function Base.reverse(v::CuVector, start=1, stop=length(v))
324+
v′ = similar(v)
325+
start > 1 && copyto!(v′, 1, v, 1, start-1)
326+
_reverse(view(v, start:stop), view(v′, start:stop))
327+
stop < length(v) && copyto!(v′, stop+1, v, stop+1)
328+
return v′
329+
end

src/blas/highlevel.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function LinearAlgebra.BLAS.dotc(DX::CuArray{T}, DY::CuArray{T}) where T<:Union{
3030
end
3131

3232
function LinearAlgebra.BLAS.dot(DX::CuArray{T}, DY::CuArray{T}) where T<:Union{ComplexF32,ComplexF64}
33-
dotc(DX, DY)
33+
BLAS.dotc(DX, DY)
3434
end
3535

3636
function LinearAlgebra.BLAS.dotu(DX::CuArray{T}, DY::CuArray{T}) where T<:Union{ComplexF32,ComplexF64}
@@ -43,7 +43,7 @@ LinearAlgebra.norm(x::CublasArray) = nrm2(x)
4343
LinearAlgebra.BLAS.asum(x::CublasArray) = asum(length(x), x, 1)
4444

4545
function LinearAlgebra.axpy!(alpha::Number, x::CuArray{T}, y::CuArray{T}) where T<:CublasFloat
46-
length(x)==length(y) || throw(DimensionMismatch(""))
46+
length(x)==length(y) || throw(DimensionMismatch("axpy arguments have lengths $(length(x)) and $(length(y))"))
4747
axpy!(length(x), convert(T,alpha), x, 1, y, 1)
4848
end
4949

src/blas/wrappers.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,7 +1473,7 @@ for (fname, elty) in
14731473
unsafe_free!(Aptrs)
14741474

14751475
if !Pivot
1476-
pivotArray = CuArray(zeros(Cint, (n, length(A))))
1476+
pivotArray = CuArrays.zeros(Cint, (n, length(A)))
14771477
end
14781478
pivotArray, info, A
14791479
end
@@ -1513,7 +1513,7 @@ for (fname, elty) in
15131513
ldc = max(1,stride(C[1],2))
15141514
Aptrs = device_batch(A)
15151515
Cptrs = device_batch(C)
1516-
info = CuArray(zeros(Cint,length(A)))
1516+
info = CuArrays.zeros(Cint,length(A))
15171517
$fname(handle(), n, Aptrs, lda, pivotArray, Cptrs, ldc, info, length(A))
15181518
unsafe_free!(Cptrs)
15191519
unsafe_free!(Aptrs)
@@ -1552,7 +1552,7 @@ for (fname, elty) in
15521552
ldc = max(1,stride(C[1],2))
15531553
Aptrs = device_batch(A)
15541554
Cptrs = device_batch(C)
1555-
info = CuArray(zeros(Cint,length(A)))
1555+
info = CuArrays.zeros(Cint,length(A))
15561556
$fname(handle(), n, Aptrs, lda, Cptrs, ldc, info, length(A))
15571557
unsafe_free!(Cptrs)
15581558
unsafe_free!(Aptrs)
@@ -1638,7 +1638,7 @@ for (fname, elty) in
16381638
Aptrs = device_batch(A)
16391639
Cptrs = device_batch(C)
16401640
info = zero(Cint)
1641-
infoarray = CuArray(zeros(Cint, length(A)))
1641+
infoarray = CuArrays.zeros(Cint, length(A))
16421642
$fname(handle(), cutrans, m, n, nrhs, Aptrs, lda, Cptrs, ldc, [info], infoarray, length(A))
16431643
unsafe_free!(Cptrs)
16441644
unsafe_free!(Aptrs)

0 commit comments

Comments
 (0)