diff --git a/Project.toml b/Project.toml index 59fc9c8c..975ee092 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde" CUDAnative = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +IRTools = "7869d1d1-7146-5819-86e3-90919afe41df" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" diff --git a/src/CuArrays.jl b/src/CuArrays.jl index 8c285ce2..781f63d4 100644 --- a/src/CuArrays.jl +++ b/src/CuArrays.jl @@ -4,7 +4,7 @@ using CUDAapi, CUDAdrv, CUDAnative using GPUArrays -export CuArray, CuVector, CuMatrix, CuVecOrMat, cu +export CuArray, CuVector, CuMatrix, CuVecOrMat, cu, cuda import LinearAlgebra @@ -81,6 +81,8 @@ include("tensor/CUTENSOR.jl") include("nnlib.jl") +include("contextual.jl") + include("deprecated.jl") diff --git a/src/contextual.jl b/src/contextual.jl new file mode 100644 index 00000000..f2bc5f62 --- /dev/null +++ b/src/contextual.jl @@ -0,0 +1,152 @@ +using IRTools: isexpr, IR, @dynamo, postwalk +using IRTools: meta, Pipe, finish, Variable, self +using MacroTools: @forward + +import Base.Broadcast.broadcasted +import Base.Broadcast.materialize +import Base.Broadcast.Broadcasted + +# TODO use a WeakKeyDict +struct CUDACtx + array_bank::IdDict{Array,CuArray} +end + +CUDACtx() = CUDACtx(IdDict{Array,CuArray}()) + +# Display fns for debugging, remove before committing +function Base.summary(io::IO, cx::CUDACtx) + print(io, "IR Context for CUDA ") + summary(io, cx.array_bank) +end + +function Base.show(io::IO, cx::CUDACtx) + print(io, "IR Context for CUDA ") + display(cx.array_bank) +end + +@forward CUDACtx.array_bank Base.getindex, Base.iterate, + Base.setindex!, Base.empty!, + Base.length, Base.get! + Base.first, Base.last, Base.haskey + +function _resize!(a::Array, sz::NTuple{<:Any,Integer}) + ccall(:jl_array_grow_end, Cvoid, (Any, UInt), a, prod(sz)) + ptr = convert(Ptr{Csize_t},pointer_from_objref(a)) + for i = 1:length(sz) + unsafe_store!(ptr+8*(i+2), sz[i]) + end + return a +end + +function refill!(a::Array, b::CuArray) + _resize!(a, size(b)) + copy!(a, b) +end + +function cache(cx, x::CuArray{T,N})::Array{T,N} where {T,N} + cpu = Array{T,N}(undef, ntuple(_->0,N)) + cx[cpu] = x + return cpu +end +cache(cx, f) = f + +# TODO: BitArray and friends would like an AbstractArray construct +function get_cached(cx::CUDACtx, arr::Array{T,N})::CuArray{T,N} where {T,N} + get!(cx, arr, CuArray(arr)) +end +get_cached(cx::CUDACtx, x) = x + +function (cx::CUDACtx)(::typeof(Base._mapreducedim!), f, op, args...) + gargs = map(x -> get_cached(cx, x), args) + Base._mapreducedim!(f, op, gargs...) |> x-> cache(cx, x) +end + +macro contextual(fs...) + ex = Expr[] + for f in fs + q = quote + function (cx::CUDACtx)(::typeof($f), args...) + gargs = map(x -> get_cached(cx, x), args) + cache(cx, $f(gargs...)) + end + end + push!(ex, q) + end + + quote + $(ex...) + end +end + +@contextual :+ :- :* :/ sum similar materialize + +function (cx::CUDACtx)(::typeof(reshape), arr, args...) + r = reshape(get_cached(cx, arr), args...) + cache(cx, r) +end + +@dynamo function (cx::CUDACtx)(meta...) + ir = IR(meta...) + ir == nothing && return + + pr = Pipe(ir) + for (v,st) in pr + isexpr(st.expr, :call) || continue + ex = st.expr + + pr[v] = Expr(:call, self, ex.args...) + + end + return finish(pr) +end + +""" + Disable `CUDACtx` for a function +""" +macro noop_pass(fs...) + ex = [:( (cx::CUDACtx)(::typeof($f), args...) = $f(args...) ) for f in fs] + + quote + $(ex...) + end +end + +@noop_pass get_cached NNlib.check_spdf + +for f in names(NNlib) + getfield(NNlib, f) isa Function || continue + @eval function (cx::CUDACtx)(::typeof($f), args...) + gargs = map(x -> get_cached(cx, x), args) + cache(cx, $f(gargs...)) + end +end + +for f in names(LinearAlgebra) + getfield(LinearAlgebra, f) isa Function || continue + @eval function (cx::CUDACtx)(::typeof($f), args...) + gargs = map(x -> get_cached(cx, x), args) + cache(cx, $f(gargs...)) + end +end + +""" + Creates a `cuda` context within which we travel + through the entire callstack to find matrix/vector + operations and try to offload them to a GPU. + + Example: + ``` + cuda() do + # do something + end + ``` +""" +function cuda(f, ctx = CUDACtx()) + out = ctx(f) + for (x, cx) in ctx + length(x) == length(cx) && continue + refill!(x, cx) + end + empty!(ctx) + return out +end diff --git a/test/contextual.jl b/test/contextual.jl new file mode 100644 index 00000000..7c9c2620 --- /dev/null +++ b/test/contextual.jl @@ -0,0 +1,48 @@ +using CuArrays, Test +using CuArrays.NNlib + +# Check simple ops work and broadcast +@testset "simple ops" begin + W = rand(5, 5) + b = rand(5) + op = cuda(() -> W*b) + @test op ≈ W*b + @test op isa Array + + a = rand(10) + b = rand(10) + + r = cuda() do + a + b + end + @test r isa Array + + r = cuda() do + a .+ b + end + @test r isa Array +end + +# Check that functions happen +@testset "linear" begin + linear(x, W, b) = (x * W) .+ b + w = rand(10, 10) + b = zeros(10) + x = rand(10,10) + r = cuda() do + linear(x, w, b) + end + @test r isa Array{Float32} +end + +# check that NNlib is wrapped correctly +@testset "conv Context" begin + w = rand(Float32, 3, 3, 3, 16) + r = rand(Float32, 32, 32, 3, 1) + g = cuda() do + conv(r, w) + end + g = conv(r, w) + @test c ≈ g + @test g isa Array +end diff --git a/test/runtests.jl b/test/runtests.jl index 87355ee7..d889261a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,6 +27,7 @@ include("sparse_solver.jl") include("dnn.jl") include("tensor.jl") include("forwarddiff.jl") +include("contextual.jl") CuArrays.memory_status() CuArrays.pool_timings()