Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Create a CUDA context #406

Draft
wants to merge 34 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
CUDAnative = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Expand All @@ -18,19 +19,19 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

[compat]
Adapt = "1.0"
CUDAapi = "0.5.3, 0.6, 1.0"
CUDAdrv = "3.0"
CUDAnative = "2.0"
GPUArrays = "0.7.1, 1.0"
NNlib = "0.6"
julia = "1.0"

[extras]
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "FFTW", "ForwardDiff"]

[compat]
julia = "1.0"
CUDAnative = "2.0"
CUDAdrv = "3.0"
CUDAapi = "0.5.3, 0.6, 1.0"
NNlib = "0.6"
GPUArrays = "0.7.1, 1.0"
Adapt = "1.0"
4 changes: 3 additions & 1 deletion src/CuArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using CUDAapi, CUDAdrv, CUDAnative

using GPUArrays

export CuArray, CuVector, CuMatrix, CuVecOrMat, cu
export CuArray, CuVector, CuMatrix, CuVecOrMat, cu, cuda
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like a pretty generic function to export (both cu and cuda are bound to confuse users). why not something that implies its action, e.g., on_cuda? or @cuda re @async?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree about not exporting this for now. In the longer term, if this is successful it should replace cu entirely (alongside all the other APIs, for most users), so a generic name seems appropriate.

I think cuda() do ... reads right, and provides an obvious space for options (cuda(device=2) do ...), but @cuda could work well too (especially in that it's a bit nicer for one liners).


import LinearAlgebra

Expand Down Expand Up @@ -81,6 +81,8 @@ include("dnn/CUDNN.jl")

include("nnlib.jl")

include("context.jl")

include("deprecated.jl")


Expand Down
169 changes: 169 additions & 0 deletions src/context.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
using IRTools: isexpr, IR, @dynamo, postwalk
using IRTools: meta, Pipe, finish, Variable, self
using MacroTools: @forward

import Base.Broadcast.broadcasted
import Base.Broadcast.materialize
import Base.Broadcast.Broadcasted

# TODO use a WeakKeyDict
struct CUDACtx
array_bank::IdDict{Array,CuArray}
end

CUDACtx() = CUDACtx(IdDict{Array,CuArray}())

# Display fns for debugging, remove before committing
function Base.summary(io::IO, c::CUDACtx)
print(io, "IR Context for CUDA ")
summary(io, c.array_bank)
end

function Base.show(io::IO, c::CUDACtx)
print(io, "IR Context for CUDA ")
display(c.array_bank)
end

@forward CUDACtx.array_bank Base.getindex, Base.iterate,
Base.setindex!, Base.empty!,
Base.length,
Base.first, Base.last, Base.haskey

function _resize!(a::Array, sz::NTuple{<:Any,Integer})
ccall(:jl_array_grow_end, Cvoid, (Any, UInt), a, prod(sz))
ptr = convert(Ptr{Csize_t},pointer_from_objref(a))
for i = 1:length(sz)
unsafe_store!(ptr+8*(i+2), sz[i])
end
return a
end

function refill!(a::Array, b::CuArray)
_resize!(a, size(b))
copy!(a, b)
end

function cache(cx, x::CuArray{T,N})::Array{T,N} where {T,N}
cpu = Array{T,N}(undef, ntuple(_->0,N))
cx[cpu] = x
return cpu
end
cache(cx, f) = f

for f in (:+, :-, :*, :/)
@eval function (c::CUDACtx)(::typeof($f), a::AbstractArray, b::AbstractArray)
ga = get_cached(array_bank, a)
gb = get_cached(array_bank, b)
cache(array_bank, $f(ga, gb))
end
end

function get_cached(array_bank, arr::Array{T,N})::CuArray{T,N} where {T,N}
haskey(array_bank, arr) ?
array_bank[arr] :
(array_bank[arr] = CuArray(arr))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could probably write get_cached(cx, x) as cx[x].

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on using get!? I suppose the extra network transfer would be a problem in this case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could look at Base.@get!, which avoids evaluating the new value if it's not needed.

end

function (c::CUDACtx)(::typeof(broadcasted), f, args...)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For broadcast I think we should leave the broadcast struct alone (so that CuArrays can't leak into the program), and instead do all conversion and computation in materialize.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I will dig into it

gargs = map(x -> get_cached(array_bank, x), args)
broadcasted(f, gargs...) |> x -> cache(array_bank, x)
end

function (c::CUDACtx)(::typeof(getproperty), o, s::Symbol)
getproperty(o, s) |> get_cached
end

function (c::CUDACtx)(::typeof(broadcast), f, args...)
gargs = map(x -> get_cached(array_bank, x), args)
broadcast(f, gargs...) |> x -> cache(array_bank, x)
end

function (c::CUDACtx)(::typeof(getfield), o, s::Symbol)
getfield(o, s) |> get_cached
end

function wrap_cuize(f)
@eval function (c::CUDACtx)(::typeof($f), args...)
gargs = map(get_cached, args)
cache(array_bank, $f(gargs...))
end
end

wrap_cuize.((sum, similar, materialize))

function (c::CUDACtx)(::typeof(reshape), arr, args...)
r = reshape(get_cached(arr), args...)
cache(array_bank, r)
end

@dynamo function (c::CUDACtx)(meta...)
meta == nothing && return
ir = IR(meta...)
ir == nothing && return

pr = Pipe(ir)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could replace this code with IRTools.recurse!(ir) (there's some info in the docs if needed).

for (v,st) in pr
isexpr(st.expr, :call) || continue
ex = st.expr

pr[v] = Expr(:call, self, ex.args...)

end
return finish(pr)
end

# TODO: remove arbitrary things in one place
get_cached(array_bank, t::Union{Type, UnitRange, Function, Broadcasted, Symbol, Module, Nothing, Missing, Ptr, CuPtr, T}) where {T <: Real} = t
get_cached(array_bank, t::Union{Tuple, NamedTuple}) = map(get_cached, t)

get_cached(array_bank, x::CuArray) = x

function get_cached(x::T) where T
T <: Array && return get_cached(array_bank, x)
isstructtype(T) && return x
get_cached(array_bank, x)
end

"""
Disable `CUDACtx` for a function
"""
function noop_pass(f)
@eval (c::CUDACtx)(::typeof($f), args...) = $f(args...)
end

noop_pass.((get_cached, NNlib.check_spdf,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably best if these are macros. It'd be nice to add an @cuda macro or similar for the purpose of overloading CUDACtx.

))

for f in names(NNlib)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to do this explicitly per function.

getfield(NNlib, f) isa Function || continue
@eval function (c::CUDACtx)(::typeof($f), args...)
gargs = map(get_cached, args)
cache(array_bank, $f(gargs...))
end
end

# Hold all the arrays related to the op
# BitArray and friends would like an AbstractArray construct
const array_bank = CUDACtx()

"""
Creates a `cuda` context within which we travel
through the entire callstack to find matrix/vector
operations and try to offload them to a GPU.

Example:
```
cuda() do
# do something
end
```
"""
function cuda(f)
out = array_bank(f)
for (x, cx) in array_bank
length(x) == length(cx) && continue
refill!(x, cx)
end
empty!(array_bank)
return out
end
46 changes: 46 additions & 0 deletions test/context.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
using CuArrays, Test
using CuArrays.NNlib

# Check simple ops work and broadcast
@testset "simple ops" begin
W = rand(5, 5)
b = rand(5)
@test cuda(() -> W*b) ≈ W*b
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to check types here as well, e.g. that the output is still an Array.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be worthwhile to have a way to switch off emptying the context? I'd like to be able to say if Arrays were in fact also allocated on the GPU; and a crude way might be to check the types in the context dict after the fact


a = rand(10)
b = rand(10)

r = cuda() do
a + b
end
@test r isa Array

r = cuda() do
a .+ b
end
@test r isa Array
end

# Check that functions happen
@testset "linear" begin
linear(x, W, b) = (x * W) .+ b
w = rand(10, 10)
b = zeros(10)
x = rand(10,10)
r = cuda() do
linear(x, w, b)
end
@test r isa Array{Float32}
end

# check that NNlib is wrapped correctly
@testset "conv Context" begin
w = rand(Float32, 3, 3, 3, 16)
r = rand(Float32, 32, 32, 3, 1)
g = cuda() do
conv(r, w)
end
g = conv(r, w)
@test c ≈ g
@test g isa Array
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ include("solver.jl")
include("sparse_solver.jl")
include("dnn.jl")
include("forwarddiff.jl")
include("context.jl")

CuArrays.pool_status()
CuArrays.pool_timings()
Expand Down