Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 5c5288c

Browse files
authored
Merge pull request #576 from JuliaGPU/tb/cu
Nerf cu
2 parents d8b5ab3 + 9f7ab9e commit 5c5288c

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

src/array.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -233,13 +233,18 @@ Adapt.adapt_storage(::CUDAnative.Adaptor, xs::CuArray{T,N}) where {T,N} =
233233
# We don't convert isbits types in `adapt`, since they are already
234234
# considered GPU-compatible.
235235

236-
Adapt.adapt_storage(::Type{<:CuArray}, xs::AbstractArray) =
236+
Adapt.adapt_storage(::Type{CuArray}, xs::AbstractArray) =
237237
isbits(xs) ? xs : convert(CuArray, xs)
238238

239-
Adapt.adapt_storage(::Type{<:CuArray{T}}, xs::AbstractArray{<:Real}) where T <: AbstractFloat =
239+
# aggressively convert arrays of floats to float32
240+
Adapt.adapt_storage(::Type{CuArray}, xs::AbstractArray{<:AbstractFloat}) =
241+
isbits(xs) ? xs : convert(CuArray{Float32}, xs)
242+
243+
# if an element type is specified, convert to it
244+
Adapt.adapt_storage(::Type{<:CuArray{T}}, xs::AbstractArray) where {T} =
240245
isbits(xs) ? xs : convert(CuArray{T}, xs)
241246

242-
Adapt.adapt_storage(::Type{<:Array}, xs::CuArray) = convert(Array, xs)
247+
Adapt.adapt_storage(::Type{Array}, xs::CuArray) = convert(Array, xs)
243248

244249
Base.collect(x::CuArray{T,N}) where {T,N} = copyto!(Array{T,N}(undef, size(x)), x)
245250

@@ -311,9 +316,7 @@ end
311316

312317
## utilities
313318

314-
cu(xs) = adapt(CuArray{Float32}, xs)
315-
cu(::Type{Array{T,N}}) where {T,N} = CuArray{T,N,Nothing}
316-
cu(::Type{Array{T}}) where {T} = CuArray{T}
319+
cu(xs) = adapt(CuArray, xs)
317320
Base.getindex(::typeof(cu), xs...) = CuArray([xs...])
318321

319322
zeros(T::Type, dims...) = fill!(CuArray{T}(undef, dims...), 0)

src/dnn/rnn.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem]
99
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
1010

11-
using LinearAlgebra: copy_transpose!
11+
using LinearAlgebra
1212

1313
function params(w::CuVector, input, hidden, n = 1)
1414
slice(offset, shape) = reshape(view(w, offset.+(1:prod(shape))), shape)

test/base.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@ end
3939
@test cu(1:3) === 1:3
4040
@test Base.elsize(xs) == sizeof(Int)
4141
@test CuArray{Int, 2}(xs) === xs
42-
@test cu(Array{Float64,1}) == CuArray{Float64,1, Nothing}
43-
@test cu(Array{Float64,4}) == CuArray{Float64,4, Nothing}
44-
@test cu(Array{Float64}) == CuArray{Float64}
42+
43+
# test aggressive conversion to Float32, but only for floats
44+
@test cu([1]) isa AbstractArray{Int}
45+
@test cu(Float64[1]) isa AbstractArray{Float32}
4546

4647
@test_throws ArgumentError Base.unsafe_convert(Ptr{Int}, xs)
4748
@test_throws ArgumentError Base.unsafe_convert(Ptr{Float32}, xs)
@@ -241,6 +242,9 @@ end
241242
@test testf((x,y)->copyto!(y, selectdim(x, 2, 1)), ones(2,2,2), zeros(2,2))
242243
## inability to copyto! smaller destination
243244
@test testf((x,y)->copyto!(y, selectdim(x, 2, 1)), ones(2,2,2), zeros(3,3))
245+
246+
# but in conversion of indices (#506)
247+
show(devnull, cu(view(ones(1), [1])))
244248
end
245249

246250
@testset "reshape" begin

0 commit comments

Comments
 (0)