-
-
Notifications
You must be signed in to change notification settings - Fork 79
Create a CUDA context #406
base: master
Are you sure you want to change the base?
Changes from 27 commits
52e476d
036a170
a3246fd
7495b98
e0965c1
82f8696
44f802c
601a6f4
9324434
efb50d1
bc40e6d
7ecfbc9
b077a98
f3a484f
fb5b058
d44a86d
8997b6c
99425a1
01010c5
65d347c
815d49b
79424f3
7a4934e
50eaec9
8515b7f
be85709
289fb46
1d10f62
2a342ca
49f5937
c07fa95
e67dd22
0c681fb
173dea9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
using IRTools: isexpr, IR, @dynamo, postwalk | ||
DhairyaLGandhi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
using IRTools: meta, Pipe, finish, Variable, self | ||
using MacroTools: @forward | ||
|
||
import Base.Broadcast.broadcasted | ||
import Base.Broadcast.materialize | ||
import Base.Broadcast.Broadcasted | ||
|
||
# TODO use a WeakKeyDict | ||
struct CUDACtx | ||
array_bank::IdDict{Array,CuArray} | ||
end | ||
|
||
CUDACtx() = CUDACtx(IdDict{Array,CuArray}()) | ||
|
||
# Display fns for debugging, remove before committing | ||
function Base.summary(io::IO, cx::CUDACtx) | ||
print(io, "IR Context for CUDA ") | ||
summary(io, cx.array_bank) | ||
end | ||
|
||
function Base.show(io::IO, cx::CUDACtx) | ||
print(io, "IR Context for CUDA ") | ||
display(cx.array_bank) | ||
end | ||
|
||
@forward CUDACtx.array_bank Base.getindex, Base.iterate, | ||
Base.setindex!, Base.empty!, | ||
Base.length, Base.get! | ||
Base.first, Base.last, Base.haskey | ||
|
||
function _resize!(a::Array, sz::NTuple{<:Any,Integer}) | ||
ccall(:jl_array_grow_end, Cvoid, (Any, UInt), a, prod(sz)) | ||
ptr = convert(Ptr{Csize_t},pointer_from_objref(a)) | ||
for i = 1:length(sz) | ||
unsafe_store!(ptr+8*(i+2), sz[i]) | ||
end | ||
return a | ||
end | ||
|
||
function refill!(a::Array, b::CuArray) | ||
_resize!(a, size(b)) | ||
copy!(a, b) | ||
end | ||
|
||
function cache(cx, x::CuArray{T,N})::Array{T,N} where {T,N} | ||
cpu = Array{T,N}(undef, ntuple(_->0,N)) | ||
cx[cpu] = x | ||
return cpu | ||
end | ||
cache(cx, f) = f | ||
|
||
for f in (:+, :-, :*, :/) | ||
@eval function (cx::CUDACtx)(::typeof($f), a::AbstractArray, b::AbstractArray) | ||
ga = get_cached(cx, a) | ||
gb = get_cached(cx, b) | ||
cache(cx, $f(ga, gb)) | ||
end | ||
end | ||
|
||
function get_cached(cx::CUDACtx, arr::Array{T,N})::CuArray{T,N} where {T,N} | ||
get!(cx, arr, CuArray(arr)) | ||
end | ||
get_cached(cx::CUDACtx, x) = x | ||
|
||
function (cx::CUDACtx)(::typeof(broadcasted), f, args...) | ||
gargs = map(x -> get_cached(cx, x), args) | ||
broadcasted(f, gargs...) |> x -> cache(cx, x) | ||
end | ||
|
||
function (cx::CUDACtx)(::typeof(broadcast), f, args...) | ||
gargs = map(x -> get_cached(cx, x), args) | ||
broadcast(f, gargs...) |> x -> cache(cx, x) | ||
end | ||
|
||
function wrap_cuize(f) | ||
@eval function (cx::CUDACtx)(::typeof($f), args...) | ||
gargs = map(x -> get_cached(cx, x), args) | ||
cache(cx, $f(gargs...)) | ||
end | ||
end | ||
|
||
wrap_cuize.((sum, similar, materialize)) | ||
|
||
function (cx::CUDACtx)(::typeof(reshape), arr, args...) | ||
r = reshape(get_cached(cx, arr), args...) | ||
cache(cx, r) | ||
end | ||
|
||
@dynamo function (cx::CUDACtx)(meta...) | ||
ir = IR(meta...) | ||
ir == nothing && return | ||
|
||
pr = Pipe(ir) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could replace this code with |
||
for (v,st) in pr | ||
isexpr(st.expr, :call) || continue | ||
ex = st.expr | ||
|
||
pr[v] = Expr(:call, self, ex.args...) | ||
|
||
end | ||
return finish(pr) | ||
end | ||
|
||
""" | ||
Disable `CUDACtx` for a function | ||
""" | ||
function noop_pass(f) | ||
@eval (c::CUDACtx)(::typeof($f), args...) = $f(args...) | ||
end | ||
|
||
noop_pass.((get_cached, NNlib.check_spdf, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably best if these are macros. It'd be nice to add an |
||
)) | ||
|
||
for f in names(NNlib) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to do this explicitly per function. |
||
getfield(NNlib, f) isa Function || continue | ||
@eval function (cx::CUDACtx)(::typeof($f), args...) | ||
gargs = map(x -> get_cached(cx, x), args) | ||
cache(cx, $f(gargs...)) | ||
end | ||
end | ||
|
||
# Hold all the arrays related to the op | ||
# BitArray and friends would like an AbstractArray construct | ||
|
||
""" | ||
Creates a `cuda` context within which we travel | ||
through the entire callstack to find matrix/vector | ||
operations and try to offload them to a GPU. | ||
|
||
Example: | ||
``` | ||
cuda() do | ||
# do something | ||
end | ||
``` | ||
""" | ||
function cuda(f, ctx = CUDACtx()) | ||
out = ctx(f) | ||
for (x, cx) in ctx | ||
length(x) == length(cx) && continue | ||
refill!(x, cx) | ||
end | ||
empty!(ctx) | ||
return out | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
using CuArrays, Test | ||
using CuArrays.NNlib | ||
|
||
# Check simple ops work and broadcast | ||
@testset "simple ops" begin | ||
W = rand(5, 5) | ||
b = rand(5) | ||
@test cuda(() -> W*b) ≈ W*b | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good to check types here as well, e.g. that the output is still an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be worthwhile to have a way to switch off emptying the context? I'd like to be able to say if Arrays were in fact also allocated on the GPU; and a crude way might be to check the types in the context dict after the fact |
||
|
||
a = rand(10) | ||
b = rand(10) | ||
|
||
r = cuda() do | ||
a + b | ||
end | ||
@test r isa Array | ||
|
||
r = cuda() do | ||
a .+ b | ||
end | ||
@test r isa Array | ||
end | ||
|
||
# Check that functions happen | ||
@testset "linear" begin | ||
linear(x, W, b) = (x * W) .+ b | ||
w = rand(10, 10) | ||
b = zeros(10) | ||
x = rand(10,10) | ||
r = cuda() do | ||
linear(x, w, b) | ||
end | ||
@test r isa Array{Float32} | ||
end | ||
|
||
# check that NNlib is wrapped correctly | ||
@testset "conv Context" begin | ||
w = rand(Float32, 3, 3, 3, 16) | ||
r = rand(Float32, 32, 32, 3, 1) | ||
g = cuda() do | ||
conv(r, w) | ||
end | ||
g = conv(r, w) | ||
@test c ≈ g | ||
@test g isa Array | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like a pretty generic function to export (both
cu
andcuda
are bound to confuse users). why not something that implies its action, e.g.,on_cuda
? or@cuda
re@async
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree about not exporting this for now. In the longer term, if this is successful it should replace
cu
entirely (alongside all the other APIs, for most users), so a generic name seems appropriate.I think
cuda() do ...
reads right, and provides an obvious space for options (cuda(device=2) do ...
), but@cuda
could work well too (especially in that it's a bit nicer for one liners).