-
-
Notifications
You must be signed in to change notification settings - Fork 79
Create a CUDA context #406
base: master
Are you sure you want to change the base?
Changes from all commits
52e476d
036a170
a3246fd
7495b98
e0965c1
82f8696
44f802c
601a6f4
9324434
efb50d1
bc40e6d
7ecfbc9
b077a98
f3a484f
fb5b058
d44a86d
8997b6c
99425a1
01010c5
65d347c
815d49b
79424f3
7a4934e
50eaec9
8515b7f
be85709
289fb46
1d10f62
2a342ca
49f5937
c07fa95
e67dd22
0c681fb
173dea9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
using IRTools: isexpr, IR, @dynamo, postwalk | ||
using IRTools: meta, Pipe, finish, Variable, self | ||
using MacroTools: @forward | ||
|
||
import Base.Broadcast.broadcasted | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These imports are redundant now |
||
import Base.Broadcast.materialize | ||
import Base.Broadcast.Broadcasted | ||
|
||
# TODO use a WeakKeyDict | ||
struct CUDACtx | ||
array_bank::IdDict{Array,CuArray} | ||
end | ||
|
||
CUDACtx() = CUDACtx(IdDict{Array,CuArray}()) | ||
|
||
# Display fns for debugging, remove before committing | ||
function Base.summary(io::IO, cx::CUDACtx) | ||
print(io, "IR Context for CUDA ") | ||
summary(io, cx.array_bank) | ||
end | ||
|
||
function Base.show(io::IO, cx::CUDACtx) | ||
print(io, "IR Context for CUDA ") | ||
display(cx.array_bank) | ||
end | ||
|
||
@forward CUDACtx.array_bank Base.getindex, Base.iterate, | ||
Base.setindex!, Base.empty!, | ||
Base.length, Base.get! | ||
Base.first, Base.last, Base.haskey | ||
|
||
function _resize!(a::Array, sz::NTuple{<:Any,Integer}) | ||
ccall(:jl_array_grow_end, Cvoid, (Any, UInt), a, prod(sz)) | ||
ptr = convert(Ptr{Csize_t},pointer_from_objref(a)) | ||
for i = 1:length(sz) | ||
unsafe_store!(ptr+8*(i+2), sz[i]) | ||
end | ||
return a | ||
end | ||
|
||
function refill!(a::Array, b::CuArray) | ||
_resize!(a, size(b)) | ||
copy!(a, b) | ||
end | ||
|
||
function cache(cx, x::CuArray{T,N})::Array{T,N} where {T,N} | ||
cpu = Array{T,N}(undef, ntuple(_->0,N)) | ||
cx[cpu] = x | ||
return cpu | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In cases like This does seem to make things work (including backwards pass on Zygote with Flux models, but its hitting some bad code paths currently). |
||
end | ||
cache(cx, f) = f | ||
|
||
# TODO: BitArray and friends would like an AbstractArray construct | ||
function get_cached(cx::CUDACtx, arr::Array{T,N})::CuArray{T,N} where {T,N} | ||
get!(cx, arr, CuArray(arr)) | ||
end | ||
get_cached(cx::CUDACtx, x) = x | ||
|
||
function (cx::CUDACtx)(::typeof(Base._mapreducedim!), f, op, args...) | ||
gargs = map(x -> get_cached(cx, x), args) | ||
Base._mapreducedim!(f, op, gargs...) |> x-> cache(cx, x) | ||
end | ||
|
||
macro contextual(fs...) | ||
ex = Expr[] | ||
for f in fs | ||
q = quote | ||
function (cx::CUDACtx)(::typeof($f), args...) | ||
gargs = map(x -> get_cached(cx, x), args) | ||
cache(cx, $f(gargs...)) | ||
end | ||
end | ||
push!(ex, q) | ||
end | ||
|
||
quote | ||
$(ex...) | ||
end | ||
end | ||
|
||
@contextual :+ :- :* :/ sum similar materialize | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should set these things up to explicitly call whatever lower-level bindings we have; it should show what it would look like if we got rid of |
||
|
||
function (cx::CUDACtx)(::typeof(reshape), arr, args...) | ||
r = reshape(get_cached(cx, arr), args...) | ||
cache(cx, r) | ||
end | ||
|
||
@dynamo function (cx::CUDACtx)(meta...) | ||
ir = IR(meta...) | ||
ir == nothing && return | ||
|
||
pr = Pipe(ir) | ||
for (v,st) in pr | ||
isexpr(st.expr, :call) || continue | ||
ex = st.expr | ||
|
||
pr[v] = Expr(:call, self, ex.args...) | ||
|
||
end | ||
return finish(pr) | ||
end | ||
|
||
""" | ||
Disable `CUDACtx` for a function | ||
""" | ||
macro noop_pass(fs...) | ||
ex = [:( (cx::CUDACtx)(::typeof($f), args...) = $f(args...) ) for f in fs] | ||
|
||
quote | ||
$(ex...) | ||
end | ||
end | ||
|
||
@noop_pass get_cached NNlib.check_spdf | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need a noop for |
||
|
||
for f in names(NNlib) | ||
getfield(NNlib, f) isa Function || continue | ||
@eval function (cx::CUDACtx)(::typeof($f), args...) | ||
gargs = map(x -> get_cached(cx, x), args) | ||
cache(cx, $f(gargs...)) | ||
end | ||
end | ||
|
||
for f in names(LinearAlgebra) | ||
getfield(LinearAlgebra, f) isa Function || continue | ||
@eval function (cx::CUDACtx)(::typeof($f), args...) | ||
gargs = map(x -> get_cached(cx, x), args) | ||
cache(cx, $f(gargs...)) | ||
end | ||
end | ||
|
||
""" | ||
Creates a `cuda` context within which we travel | ||
through the entire callstack to find matrix/vector | ||
operations and try to offload them to a GPU. | ||
|
||
Example: | ||
``` | ||
cuda() do | ||
# do something | ||
end | ||
``` | ||
""" | ||
function cuda(f, ctx = CUDACtx()) | ||
out = ctx(f) | ||
for (x, cx) in ctx | ||
length(x) == length(cx) && continue | ||
refill!(x, cx) | ||
end | ||
empty!(ctx) | ||
return out | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
using CuArrays, Test | ||
using CuArrays.NNlib | ||
|
||
# Check simple ops work and broadcast | ||
@testset "simple ops" begin | ||
W = rand(5, 5) | ||
b = rand(5) | ||
op = cuda(() -> W*b) | ||
@test op ≈ W*b | ||
@test op isa Array | ||
|
||
a = rand(10) | ||
b = rand(10) | ||
|
||
r = cuda() do | ||
a + b | ||
end | ||
@test r isa Array | ||
|
||
r = cuda() do | ||
a .+ b | ||
end | ||
@test r isa Array | ||
end | ||
|
||
# Check that functions happen | ||
@testset "linear" begin | ||
linear(x, W, b) = (x * W) .+ b | ||
w = rand(10, 10) | ||
b = zeros(10) | ||
x = rand(10,10) | ||
r = cuda() do | ||
linear(x, w, b) | ||
end | ||
@test r isa Array{Float32} | ||
end | ||
|
||
# check that NNlib is wrapped correctly | ||
@testset "conv Context" begin | ||
w = rand(Float32, 3, 3, 3, 16) | ||
r = rand(Float32, 32, 32, 3, 1) | ||
g = cuda() do | ||
conv(r, w) | ||
end | ||
g = conv(r, w) | ||
@test c ≈ g | ||
@test g isa Array | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like a pretty generic function to export (both
cu
andcuda
are bound to confuse users). why not something that implies its action, e.g.,on_cuda
? or@cuda
re@async
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree about not exporting this for now. In the longer term, if this is successful it should replace
cu
entirely (alongside all the other APIs, for most users), so a generic name seems appropriate.I think
cuda() do ...
reads right, and provides an obvious space for options (cuda(device=2) do ...
), but@cuda
could work well too (especially in that it's a bit nicer for one liners).